ConnectOAuth2RequestFactory.java

  1. /*******************************************************************************
  2.  * Copyright 2017 The MIT Internet Trust Consortium
  3.  *
  4.  * Portions copyright 2011-2013 The MITRE Corporation
  5.  *
  6.  * Licensed under the Apache License, Version 2.0 (the "License");
  7.  * you may not use this file except in compliance with the License.
  8.  * You may obtain a copy of the License at
  9.  *
  10.  *   http://www.apache.org/licenses/LICENSE-2.0
  11.  *
  12.  * Unless required by applicable law or agreed to in writing, software
  13.  * distributed under the License is distributed on an "AS IS" BASIS,
  14.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15.  * See the License for the specific language governing permissions and
  16.  * limitations under the License.
  17.  *******************************************************************************/
  18. package org.mitre.openid.connect.request;


  19. import static org.mitre.openid.connect.request.ConnectRequestParameters.AUD;
  20. import static org.mitre.openid.connect.request.ConnectRequestParameters.CLAIMS;
  21. import static org.mitre.openid.connect.request.ConnectRequestParameters.CLIENT_ID;
  22. import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE;
  23. import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD;
  24. import static org.mitre.openid.connect.request.ConnectRequestParameters.DISPLAY;
  25. import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_HINT;
  26. import static org.mitre.openid.connect.request.ConnectRequestParameters.MAX_AGE;
  27. import static org.mitre.openid.connect.request.ConnectRequestParameters.NONCE;
  28. import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT;
  29. import static org.mitre.openid.connect.request.ConnectRequestParameters.REDIRECT_URI;
  30. import static org.mitre.openid.connect.request.ConnectRequestParameters.REQUEST;
  31. import static org.mitre.openid.connect.request.ConnectRequestParameters.RESPONSE_TYPE;
  32. import static org.mitre.openid.connect.request.ConnectRequestParameters.SCOPE;
  33. import static org.mitre.openid.connect.request.ConnectRequestParameters.STATE;

  34. import java.io.Serializable;
  35. import java.text.ParseException;
  36. import java.util.Collections;
  37. import java.util.Map;
  38. import java.util.Set;

  39. import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
  40. import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
  41. import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
  42. import org.mitre.oauth2.model.ClientDetailsEntity;
  43. import org.mitre.oauth2.model.PKCEAlgorithm;
  44. import org.mitre.oauth2.service.ClientDetailsEntityService;
  45. import org.slf4j.Logger;
  46. import org.slf4j.LoggerFactory;
  47. import org.springframework.beans.factory.annotation.Autowired;
  48. import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
  49. import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
  50. import org.springframework.security.oauth2.common.util.OAuth2Utils;
  51. import org.springframework.security.oauth2.provider.AuthorizationRequest;
  52. import org.springframework.security.oauth2.provider.request.DefaultOAuth2RequestFactory;
  53. import org.springframework.stereotype.Component;

  54. import com.google.common.base.Strings;
  55. import com.google.gson.JsonElement;
  56. import com.google.gson.JsonObject;
  57. import com.google.gson.JsonParser;
  58. import com.nimbusds.jose.Algorithm;
  59. import com.nimbusds.jose.JWEObject.State;
  60. import com.nimbusds.jose.JWSAlgorithm;
  61. import com.nimbusds.jwt.EncryptedJWT;
  62. import com.nimbusds.jwt.JWT;
  63. import com.nimbusds.jwt.JWTClaimsSet;
  64. import com.nimbusds.jwt.JWTParser;
  65. import com.nimbusds.jwt.PlainJWT;
  66. import com.nimbusds.jwt.SignedJWT;

  67. @Component("connectOAuth2RequestFactory")
  68. public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {

  69.     /**
  70.      * Logger for this class
  71.      */
  72.     private static final Logger logger = LoggerFactory.getLogger(ConnectOAuth2RequestFactory.class);

  73.     private ClientDetailsEntityService clientDetailsService;

  74.     @Autowired
  75.     private ClientKeyCacheService validators;

  76.     @Autowired
  77.     private JWTEncryptionAndDecryptionService encryptionService;

  78.     private JsonParser parser = new JsonParser();

  79.     /**
  80.      * Constructor with arguments
  81.      *
  82.      * @param clientDetailsService
  83.      */
  84.     @Autowired
  85.     public ConnectOAuth2RequestFactory(ClientDetailsEntityService clientDetailsService) {
  86.         super(clientDetailsService);
  87.         this.clientDetailsService = clientDetailsService;
  88.     }

  89.     @Override
  90.     public AuthorizationRequest createAuthorizationRequest(Map<String, String> inputParams) {


  91.         AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections.<String, String> emptyMap(),
  92.                 inputParams.get(OAuth2Utils.CLIENT_ID),
  93.                 OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null,
  94.                 null, false, inputParams.get(OAuth2Utils.STATE),
  95.                 inputParams.get(OAuth2Utils.REDIRECT_URI),
  96.                 OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE)));

  97.         //Add extension parameters to the 'extensions' map

  98.         if (inputParams.containsKey(PROMPT)) {
  99.             request.getExtensions().put(PROMPT, inputParams.get(PROMPT));
  100.         }
  101.         if (inputParams.containsKey(NONCE)) {
  102.             request.getExtensions().put(NONCE, inputParams.get(NONCE));
  103.         }

  104.         if (inputParams.containsKey(CLAIMS)) {
  105.             JsonObject claimsRequest = parseClaimRequest(inputParams.get(CLAIMS));
  106.             if (claimsRequest != null) {
  107.                 request.getExtensions().put(CLAIMS, claimsRequest.toString());
  108.             }
  109.         }

  110.         if (inputParams.containsKey(MAX_AGE)) {
  111.             request.getExtensions().put(MAX_AGE, inputParams.get(MAX_AGE));
  112.         }

  113.         if (inputParams.containsKey(LOGIN_HINT)) {
  114.             request.getExtensions().put(LOGIN_HINT, inputParams.get(LOGIN_HINT));
  115.         }

  116.         if (inputParams.containsKey(AUD)) {
  117.             request.getExtensions().put(AUD, inputParams.get(AUD));
  118.         }

  119.         if (inputParams.containsKey(CODE_CHALLENGE)) {
  120.             request.getExtensions().put(CODE_CHALLENGE, inputParams.get(CODE_CHALLENGE));
  121.             if (inputParams.containsKey(CODE_CHALLENGE_METHOD)) {
  122.                 request.getExtensions().put(CODE_CHALLENGE_METHOD, inputParams.get(CODE_CHALLENGE_METHOD));
  123.             } else {
  124.                 // if the client doesn't specify a code challenge transformation method, it's "plain"
  125.                 request.getExtensions().put(CODE_CHALLENGE_METHOD, PKCEAlgorithm.plain.getName());
  126.             }

  127.         }

  128.         if (inputParams.containsKey(REQUEST)) {
  129.             request.getExtensions().put(REQUEST, inputParams.get(REQUEST));
  130.             processRequestObject(inputParams.get(REQUEST), request);
  131.         }

  132.         if (request.getClientId() != null) {
  133.             try {
  134.                 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());

  135.                 if ((request.getScope() == null || request.getScope().isEmpty())) {
  136.                     Set<String> clientScopes = client.getScope();
  137.                     request.setScope(clientScopes);
  138.                 }

  139.                 if (request.getExtensions().get(MAX_AGE) == null && client.getDefaultMaxAge() != null) {
  140.                     request.getExtensions().put(MAX_AGE, client.getDefaultMaxAge().toString());
  141.                 }
  142.             } catch (OAuth2Exception e) {
  143.                 logger.error("Caught OAuth2 exception trying to test client scopes and max age:", e);
  144.             }
  145.         }

  146.         return request;
  147.     }

  148.     /**
  149.      *
  150.      * @param jwtString
  151.      * @param request
  152.      */
  153.     private void processRequestObject(String jwtString, AuthorizationRequest request) {

  154.         // parse the request object
  155.         try {
  156.             JWT jwt = JWTParser.parse(jwtString);

  157.             if (jwt instanceof SignedJWT) {
  158.                 // it's a signed JWT, check the signature

  159.                 SignedJWT signedJwt = (SignedJWT)jwt;

  160.                 // need to check clientId first so that we can load the client to check other fields
  161.                 if (request.getClientId() == null) {
  162.                     request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID));
  163.                 }

  164.                 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());

  165.                 if (client == null) {
  166.                     throw new InvalidClientException("Client not found: " + request.getClientId());
  167.                 }


  168.                 JWSAlgorithm alg = signedJwt.getHeader().getAlgorithm();

  169.                 if (client.getRequestObjectSigningAlg() == null ||
  170.                         !client.getRequestObjectSigningAlg().equals(alg)) {
  171.                     throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")");
  172.                 }

  173.                 JWTSigningAndValidationService validator = validators.getValidator(client, alg);

  174.                 if (validator == null) {
  175.                     throw new InvalidClientException("Unable to create signature validator for client " + client + " and algorithm " + alg);
  176.                 }

  177.                 if (!validator.validateSignature(signedJwt)) {
  178.                     throw new InvalidClientException("Signature did not validate for presented JWT request object.");
  179.                 }

  180.             } else if (jwt instanceof PlainJWT) {
  181.                 PlainJWT plainJwt = (PlainJWT)jwt;

  182.                 // need to check clientId first so that we can load the client to check other fields
  183.                 if (request.getClientId() == null) {
  184.                     request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID));
  185.                 }

  186.                 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());

  187.                 if (client == null) {
  188.                     throw new InvalidClientException("Client not found: " + request.getClientId());
  189.                 }

  190.                 if (client.getRequestObjectSigningAlg() == null) {
  191.                     throw new InvalidClientException("Client is not registered for unsigned request objects (no request_object_signing_alg registered)");
  192.                 } else if (!client.getRequestObjectSigningAlg().equals(Algorithm.NONE)) {
  193.                     throw new InvalidClientException("Client is not registered for unsigned request objects (request_object_signing_alg is " + client.getRequestObjectSigningAlg() +")");
  194.                 }

  195.                 // if we got here, we're OK, keep processing

  196.             } else if (jwt instanceof EncryptedJWT) {

  197.                 EncryptedJWT encryptedJWT = (EncryptedJWT)jwt;

  198.                 // decrypt the jwt if we can

  199.                 encryptionService.decryptJwt(encryptedJWT);

  200.                 // TODO: what if the content is a signed JWT? (#525)

  201.                 if (!encryptedJWT.getState().equals(State.DECRYPTED)) {
  202.                     throw new InvalidClientException("Unable to decrypt the request object");
  203.                 }

  204.                 // need to check clientId first so that we can load the client to check other fields
  205.                 if (request.getClientId() == null) {
  206.                     request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim(CLIENT_ID));
  207.                 }

  208.                 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());

  209.                 if (client == null) {
  210.                     throw new InvalidClientException("Client not found: " + request.getClientId());
  211.                 }


  212.             }


  213.             /*
  214.              * NOTE: Claims inside the request object always take precedence over those in the parameter map.
  215.              */

  216.             // now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims

  217.             JWTClaimsSet claims = jwt.getJWTClaimsSet();

  218.             Set<String> responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim(RESPONSE_TYPE));
  219.             if (!responseTypes.isEmpty()) {
  220.                 if (!responseTypes.equals(request.getResponseTypes())) {
  221.                     logger.info("Mismatch between request object and regular parameter for response_type, using request object");
  222.                 }
  223.                 request.setResponseTypes(responseTypes);
  224.             }

  225.             String redirectUri = claims.getStringClaim(REDIRECT_URI);
  226.             if (redirectUri != null) {
  227.                 if (!redirectUri.equals(request.getRedirectUri())) {
  228.                     logger.info("Mismatch between request object and regular parameter for redirect_uri, using request object");
  229.                 }
  230.                 request.setRedirectUri(redirectUri);
  231.             }

  232.             String state = claims.getStringClaim(STATE);
  233.             if(state != null) {
  234.                 if (!state.equals(request.getState())) {
  235.                     logger.info("Mismatch between request object and regular parameter for state, using request object");
  236.                 }
  237.                 request.setState(state);
  238.             }

  239.             String nonce = claims.getStringClaim(NONCE);
  240.             if(nonce != null) {
  241.                 if (!nonce.equals(request.getExtensions().get(NONCE))) {
  242.                     logger.info("Mismatch between request object and regular parameter for nonce, using request object");
  243.                 }
  244.                 request.getExtensions().put(NONCE, nonce);
  245.             }

  246.             String display = claims.getStringClaim(DISPLAY);
  247.             if (display != null) {
  248.                 if (!display.equals(request.getExtensions().get(DISPLAY))) {
  249.                     logger.info("Mismatch between request object and regular parameter for display, using request object");
  250.                 }
  251.                 request.getExtensions().put(DISPLAY, display);
  252.             }

  253.             String prompt = claims.getStringClaim(PROMPT);
  254.             if (prompt != null) {
  255.                 if (!prompt.equals(request.getExtensions().get(PROMPT))) {
  256.                     logger.info("Mismatch between request object and regular parameter for prompt, using request object");
  257.                 }
  258.                 request.getExtensions().put(PROMPT, prompt);
  259.             }

  260.             Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim(SCOPE));
  261.             if (!scope.isEmpty()) {
  262.                 if (!scope.equals(request.getScope())) {
  263.                     logger.info("Mismatch between request object and regular parameter for scope, using request object");
  264.                 }
  265.                 request.setScope(scope);
  266.             }

  267.             JsonObject claimRequest = parseClaimRequest(claims.getStringClaim(CLAIMS));
  268.             if (claimRequest != null) {
  269.                 Serializable claimExtension = request.getExtensions().get(CLAIMS);
  270.                 if (claimExtension == null || !claimRequest.equals(parseClaimRequest(claimExtension.toString()))) {
  271.                     logger.info("Mismatch between request object and regular parameter for claims, using request object");
  272.                 }
  273.                 // we save the string because the object might not be a Java Serializable, and we can parse it easily enough anyway
  274.                 request.getExtensions().put(CLAIMS, claimRequest.toString());
  275.             }

  276.             String loginHint = claims.getStringClaim(LOGIN_HINT);
  277.             if (loginHint != null) {
  278.                 if (!loginHint.equals(request.getExtensions().get(LOGIN_HINT))) {
  279.                     logger.info("Mistmatch between request object and regular parameter for login_hint, using requst object");
  280.                 }
  281.                 request.getExtensions().put(LOGIN_HINT, loginHint);
  282.             }

  283.         } catch (ParseException e) {
  284.             logger.error("ParseException while parsing RequestObject:", e);
  285.         }
  286.     }

  287.     /**
  288.      * @param claimRequestString
  289.      * @return
  290.      */
  291.     private JsonObject parseClaimRequest(String claimRequestString) {
  292.         if (Strings.isNullOrEmpty(claimRequestString)) {
  293.             return null;
  294.         } else {
  295.             JsonElement el = parser.parse(claimRequestString);
  296.             if (el != null && el.isJsonObject()) {
  297.                 return el.getAsJsonObject();
  298.             } else {
  299.                 return null;
  300.             }
  301.         }
  302.     }

  303. }