ConnectOAuth2RequestFactory.java
- /*******************************************************************************
- * Copyright 2017 The MIT Internet Trust Consortium
- *
- * Portions copyright 2011-2013 The MITRE Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
- package org.mitre.openid.connect.request;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.AUD;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.CLAIMS;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.CLIENT_ID;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.DISPLAY;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_HINT;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.MAX_AGE;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.NONCE;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.REDIRECT_URI;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.REQUEST;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.RESPONSE_TYPE;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.SCOPE;
- import static org.mitre.openid.connect.request.ConnectRequestParameters.STATE;
- import java.io.Serializable;
- import java.text.ParseException;
- import java.util.Collections;
- import java.util.Map;
- import java.util.Set;
- import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
- import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
- import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
- import org.mitre.oauth2.model.ClientDetailsEntity;
- import org.mitre.oauth2.model.PKCEAlgorithm;
- import org.mitre.oauth2.service.ClientDetailsEntityService;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
- import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
- import org.springframework.security.oauth2.common.util.OAuth2Utils;
- import org.springframework.security.oauth2.provider.AuthorizationRequest;
- import org.springframework.security.oauth2.provider.request.DefaultOAuth2RequestFactory;
- import org.springframework.stereotype.Component;
- import com.google.common.base.Strings;
- import com.google.gson.JsonElement;
- import com.google.gson.JsonObject;
- import com.google.gson.JsonParser;
- import com.nimbusds.jose.Algorithm;
- import com.nimbusds.jose.JWEObject.State;
- import com.nimbusds.jose.JWSAlgorithm;
- import com.nimbusds.jwt.EncryptedJWT;
- import com.nimbusds.jwt.JWT;
- import com.nimbusds.jwt.JWTClaimsSet;
- import com.nimbusds.jwt.JWTParser;
- import com.nimbusds.jwt.PlainJWT;
- import com.nimbusds.jwt.SignedJWT;
- @Component("connectOAuth2RequestFactory")
- public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
- /**
- * Logger for this class
- */
- private static final Logger logger = LoggerFactory.getLogger(ConnectOAuth2RequestFactory.class);
- private ClientDetailsEntityService clientDetailsService;
- @Autowired
- private ClientKeyCacheService validators;
- @Autowired
- private JWTEncryptionAndDecryptionService encryptionService;
- private JsonParser parser = new JsonParser();
- /**
- * Constructor with arguments
- *
- * @param clientDetailsService
- */
- @Autowired
- public ConnectOAuth2RequestFactory(ClientDetailsEntityService clientDetailsService) {
- super(clientDetailsService);
- this.clientDetailsService = clientDetailsService;
- }
- @Override
- public AuthorizationRequest createAuthorizationRequest(Map<String, String> inputParams) {
- AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections.<String, String> emptyMap(),
- inputParams.get(OAuth2Utils.CLIENT_ID),
- OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null,
- null, false, inputParams.get(OAuth2Utils.STATE),
- inputParams.get(OAuth2Utils.REDIRECT_URI),
- OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE)));
- //Add extension parameters to the 'extensions' map
- if (inputParams.containsKey(PROMPT)) {
- request.getExtensions().put(PROMPT, inputParams.get(PROMPT));
- }
- if (inputParams.containsKey(NONCE)) {
- request.getExtensions().put(NONCE, inputParams.get(NONCE));
- }
- if (inputParams.containsKey(CLAIMS)) {
- JsonObject claimsRequest = parseClaimRequest(inputParams.get(CLAIMS));
- if (claimsRequest != null) {
- request.getExtensions().put(CLAIMS, claimsRequest.toString());
- }
- }
- if (inputParams.containsKey(MAX_AGE)) {
- request.getExtensions().put(MAX_AGE, inputParams.get(MAX_AGE));
- }
- if (inputParams.containsKey(LOGIN_HINT)) {
- request.getExtensions().put(LOGIN_HINT, inputParams.get(LOGIN_HINT));
- }
- if (inputParams.containsKey(AUD)) {
- request.getExtensions().put(AUD, inputParams.get(AUD));
- }
- if (inputParams.containsKey(CODE_CHALLENGE)) {
- request.getExtensions().put(CODE_CHALLENGE, inputParams.get(CODE_CHALLENGE));
- if (inputParams.containsKey(CODE_CHALLENGE_METHOD)) {
- request.getExtensions().put(CODE_CHALLENGE_METHOD, inputParams.get(CODE_CHALLENGE_METHOD));
- } else {
- // if the client doesn't specify a code challenge transformation method, it's "plain"
- request.getExtensions().put(CODE_CHALLENGE_METHOD, PKCEAlgorithm.plain.getName());
- }
- }
- if (inputParams.containsKey(REQUEST)) {
- request.getExtensions().put(REQUEST, inputParams.get(REQUEST));
- processRequestObject(inputParams.get(REQUEST), request);
- }
- if (request.getClientId() != null) {
- try {
- ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
- if ((request.getScope() == null || request.getScope().isEmpty())) {
- Set<String> clientScopes = client.getScope();
- request.setScope(clientScopes);
- }
- if (request.getExtensions().get(MAX_AGE) == null && client.getDefaultMaxAge() != null) {
- request.getExtensions().put(MAX_AGE, client.getDefaultMaxAge().toString());
- }
- } catch (OAuth2Exception e) {
- logger.error("Caught OAuth2 exception trying to test client scopes and max age:", e);
- }
- }
- return request;
- }
- /**
- *
- * @param jwtString
- * @param request
- */
- private void processRequestObject(String jwtString, AuthorizationRequest request) {
- // parse the request object
- try {
- JWT jwt = JWTParser.parse(jwtString);
- if (jwt instanceof SignedJWT) {
- // it's a signed JWT, check the signature
- SignedJWT signedJwt = (SignedJWT)jwt;
- // need to check clientId first so that we can load the client to check other fields
- if (request.getClientId() == null) {
- request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID));
- }
- ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
- if (client == null) {
- throw new InvalidClientException("Client not found: " + request.getClientId());
- }
- JWSAlgorithm alg = signedJwt.getHeader().getAlgorithm();
- if (client.getRequestObjectSigningAlg() == null ||
- !client.getRequestObjectSigningAlg().equals(alg)) {
- throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")");
- }
- JWTSigningAndValidationService validator = validators.getValidator(client, alg);
- if (validator == null) {
- throw new InvalidClientException("Unable to create signature validator for client " + client + " and algorithm " + alg);
- }
- if (!validator.validateSignature(signedJwt)) {
- throw new InvalidClientException("Signature did not validate for presented JWT request object.");
- }
- } else if (jwt instanceof PlainJWT) {
- PlainJWT plainJwt = (PlainJWT)jwt;
- // need to check clientId first so that we can load the client to check other fields
- if (request.getClientId() == null) {
- request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID));
- }
- ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
- if (client == null) {
- throw new InvalidClientException("Client not found: " + request.getClientId());
- }
- if (client.getRequestObjectSigningAlg() == null) {
- throw new InvalidClientException("Client is not registered for unsigned request objects (no request_object_signing_alg registered)");
- } else if (!client.getRequestObjectSigningAlg().equals(Algorithm.NONE)) {
- throw new InvalidClientException("Client is not registered for unsigned request objects (request_object_signing_alg is " + client.getRequestObjectSigningAlg() +")");
- }
- // if we got here, we're OK, keep processing
- } else if (jwt instanceof EncryptedJWT) {
- EncryptedJWT encryptedJWT = (EncryptedJWT)jwt;
- // decrypt the jwt if we can
- encryptionService.decryptJwt(encryptedJWT);
- // TODO: what if the content is a signed JWT? (#525)
- if (!encryptedJWT.getState().equals(State.DECRYPTED)) {
- throw new InvalidClientException("Unable to decrypt the request object");
- }
- // need to check clientId first so that we can load the client to check other fields
- if (request.getClientId() == null) {
- request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim(CLIENT_ID));
- }
- ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
- if (client == null) {
- throw new InvalidClientException("Client not found: " + request.getClientId());
- }
- }
- /*
- * NOTE: Claims inside the request object always take precedence over those in the parameter map.
- */
- // now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims
- JWTClaimsSet claims = jwt.getJWTClaimsSet();
- Set<String> responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim(RESPONSE_TYPE));
- if (!responseTypes.isEmpty()) {
- if (!responseTypes.equals(request.getResponseTypes())) {
- logger.info("Mismatch between request object and regular parameter for response_type, using request object");
- }
- request.setResponseTypes(responseTypes);
- }
- String redirectUri = claims.getStringClaim(REDIRECT_URI);
- if (redirectUri != null) {
- if (!redirectUri.equals(request.getRedirectUri())) {
- logger.info("Mismatch between request object and regular parameter for redirect_uri, using request object");
- }
- request.setRedirectUri(redirectUri);
- }
- String state = claims.getStringClaim(STATE);
- if(state != null) {
- if (!state.equals(request.getState())) {
- logger.info("Mismatch between request object and regular parameter for state, using request object");
- }
- request.setState(state);
- }
- String nonce = claims.getStringClaim(NONCE);
- if(nonce != null) {
- if (!nonce.equals(request.getExtensions().get(NONCE))) {
- logger.info("Mismatch between request object and regular parameter for nonce, using request object");
- }
- request.getExtensions().put(NONCE, nonce);
- }
- String display = claims.getStringClaim(DISPLAY);
- if (display != null) {
- if (!display.equals(request.getExtensions().get(DISPLAY))) {
- logger.info("Mismatch between request object and regular parameter for display, using request object");
- }
- request.getExtensions().put(DISPLAY, display);
- }
- String prompt = claims.getStringClaim(PROMPT);
- if (prompt != null) {
- if (!prompt.equals(request.getExtensions().get(PROMPT))) {
- logger.info("Mismatch between request object and regular parameter for prompt, using request object");
- }
- request.getExtensions().put(PROMPT, prompt);
- }
- Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim(SCOPE));
- if (!scope.isEmpty()) {
- if (!scope.equals(request.getScope())) {
- logger.info("Mismatch between request object and regular parameter for scope, using request object");
- }
- request.setScope(scope);
- }
- JsonObject claimRequest = parseClaimRequest(claims.getStringClaim(CLAIMS));
- if (claimRequest != null) {
- Serializable claimExtension = request.getExtensions().get(CLAIMS);
- if (claimExtension == null || !claimRequest.equals(parseClaimRequest(claimExtension.toString()))) {
- logger.info("Mismatch between request object and regular parameter for claims, using request object");
- }
- // we save the string because the object might not be a Java Serializable, and we can parse it easily enough anyway
- request.getExtensions().put(CLAIMS, claimRequest.toString());
- }
- String loginHint = claims.getStringClaim(LOGIN_HINT);
- if (loginHint != null) {
- if (!loginHint.equals(request.getExtensions().get(LOGIN_HINT))) {
- logger.info("Mistmatch between request object and regular parameter for login_hint, using requst object");
- }
- request.getExtensions().put(LOGIN_HINT, loginHint);
- }
- } catch (ParseException e) {
- logger.error("ParseException while parsing RequestObject:", e);
- }
- }
- /**
- * @param claimRequestString
- * @return
- */
- private JsonObject parseClaimRequest(String claimRequestString) {
- if (Strings.isNullOrEmpty(claimRequestString)) {
- return null;
- } else {
- JsonElement el = parser.parse(claimRequestString);
- if (el != null && el.isJsonObject()) {
- return el.getAsJsonObject();
- } else {
- return null;
- }
- }
- }
- }