001/******************************************************************************* 002 * Copyright 2017 The MIT Internet Trust Consortium 003 * 004 * Portions copyright 2011-2013 The MITRE Corporation 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); 007 * you may not use this file except in compliance with the License. 008 * You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 *******************************************************************************/ 018package org.mitre.openid.connect.request; 019 020 021import static org.mitre.openid.connect.request.ConnectRequestParameters.AUD; 022import static org.mitre.openid.connect.request.ConnectRequestParameters.CLAIMS; 023import static org.mitre.openid.connect.request.ConnectRequestParameters.CLIENT_ID; 024import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE; 025import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD; 026import static org.mitre.openid.connect.request.ConnectRequestParameters.DISPLAY; 027import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_HINT; 028import static org.mitre.openid.connect.request.ConnectRequestParameters.MAX_AGE; 029import static org.mitre.openid.connect.request.ConnectRequestParameters.NONCE; 030import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT; 031import static org.mitre.openid.connect.request.ConnectRequestParameters.REDIRECT_URI; 032import static org.mitre.openid.connect.request.ConnectRequestParameters.REQUEST; 033import static org.mitre.openid.connect.request.ConnectRequestParameters.RESPONSE_TYPE; 034import static org.mitre.openid.connect.request.ConnectRequestParameters.SCOPE; 035import static org.mitre.openid.connect.request.ConnectRequestParameters.STATE; 036 037import java.io.Serializable; 038import java.text.ParseException; 039import java.util.Collections; 040import java.util.Map; 041import java.util.Set; 042 043import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; 044import org.mitre.jwt.signer.service.JWTSigningAndValidationService; 045import org.mitre.jwt.signer.service.impl.ClientKeyCacheService; 046import org.mitre.oauth2.model.ClientDetailsEntity; 047import org.mitre.oauth2.model.PKCEAlgorithm; 048import org.mitre.oauth2.service.ClientDetailsEntityService; 049import org.slf4j.Logger; 050import org.slf4j.LoggerFactory; 051import org.springframework.beans.factory.annotation.Autowired; 052import org.springframework.security.oauth2.common.exceptions.InvalidClientException; 053import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; 054import org.springframework.security.oauth2.common.util.OAuth2Utils; 055import org.springframework.security.oauth2.provider.AuthorizationRequest; 056import org.springframework.security.oauth2.provider.request.DefaultOAuth2RequestFactory; 057import org.springframework.stereotype.Component; 058 059import com.google.common.base.Strings; 060import com.google.gson.JsonElement; 061import com.google.gson.JsonObject; 062import com.google.gson.JsonParser; 063import com.nimbusds.jose.Algorithm; 064import com.nimbusds.jose.JWEObject.State; 065import com.nimbusds.jose.JWSAlgorithm; 066import com.nimbusds.jwt.EncryptedJWT; 067import com.nimbusds.jwt.JWT; 068import com.nimbusds.jwt.JWTClaimsSet; 069import com.nimbusds.jwt.JWTParser; 070import com.nimbusds.jwt.PlainJWT; 071import com.nimbusds.jwt.SignedJWT; 072 073@Component("connectOAuth2RequestFactory") 074public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { 075 076 /** 077 * Logger for this class 078 */ 079 private static final Logger logger = LoggerFactory.getLogger(ConnectOAuth2RequestFactory.class); 080 081 private ClientDetailsEntityService clientDetailsService; 082 083 @Autowired 084 private ClientKeyCacheService validators; 085 086 @Autowired 087 private JWTEncryptionAndDecryptionService encryptionService; 088 089 private JsonParser parser = new JsonParser(); 090 091 /** 092 * Constructor with arguments 093 * 094 * @param clientDetailsService 095 */ 096 @Autowired 097 public ConnectOAuth2RequestFactory(ClientDetailsEntityService clientDetailsService) { 098 super(clientDetailsService); 099 this.clientDetailsService = clientDetailsService; 100 } 101 102 @Override 103 public AuthorizationRequest createAuthorizationRequest(Map<String, String> inputParams) { 104 105 106 AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections.<String, String> emptyMap(), 107 inputParams.get(OAuth2Utils.CLIENT_ID), 108 OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null, 109 null, false, inputParams.get(OAuth2Utils.STATE), 110 inputParams.get(OAuth2Utils.REDIRECT_URI), 111 OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE))); 112 113 //Add extension parameters to the 'extensions' map 114 115 if (inputParams.containsKey(PROMPT)) { 116 request.getExtensions().put(PROMPT, inputParams.get(PROMPT)); 117 } 118 if (inputParams.containsKey(NONCE)) { 119 request.getExtensions().put(NONCE, inputParams.get(NONCE)); 120 } 121 122 if (inputParams.containsKey(CLAIMS)) { 123 JsonObject claimsRequest = parseClaimRequest(inputParams.get(CLAIMS)); 124 if (claimsRequest != null) { 125 request.getExtensions().put(CLAIMS, claimsRequest.toString()); 126 } 127 } 128 129 if (inputParams.containsKey(MAX_AGE)) { 130 request.getExtensions().put(MAX_AGE, inputParams.get(MAX_AGE)); 131 } 132 133 if (inputParams.containsKey(LOGIN_HINT)) { 134 request.getExtensions().put(LOGIN_HINT, inputParams.get(LOGIN_HINT)); 135 } 136 137 if (inputParams.containsKey(AUD)) { 138 request.getExtensions().put(AUD, inputParams.get(AUD)); 139 } 140 141 if (inputParams.containsKey(CODE_CHALLENGE)) { 142 request.getExtensions().put(CODE_CHALLENGE, inputParams.get(CODE_CHALLENGE)); 143 if (inputParams.containsKey(CODE_CHALLENGE_METHOD)) { 144 request.getExtensions().put(CODE_CHALLENGE_METHOD, inputParams.get(CODE_CHALLENGE_METHOD)); 145 } else { 146 // if the client doesn't specify a code challenge transformation method, it's "plain" 147 request.getExtensions().put(CODE_CHALLENGE_METHOD, PKCEAlgorithm.plain.getName()); 148 } 149 150 } 151 152 if (inputParams.containsKey(REQUEST)) { 153 request.getExtensions().put(REQUEST, inputParams.get(REQUEST)); 154 processRequestObject(inputParams.get(REQUEST), request); 155 } 156 157 if (request.getClientId() != null) { 158 try { 159 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); 160 161 if ((request.getScope() == null || request.getScope().isEmpty())) { 162 Set<String> clientScopes = client.getScope(); 163 request.setScope(clientScopes); 164 } 165 166 if (request.getExtensions().get(MAX_AGE) == null && client.getDefaultMaxAge() != null) { 167 request.getExtensions().put(MAX_AGE, client.getDefaultMaxAge().toString()); 168 } 169 } catch (OAuth2Exception e) { 170 logger.error("Caught OAuth2 exception trying to test client scopes and max age:", e); 171 } 172 } 173 174 return request; 175 } 176 177 /** 178 * 179 * @param jwtString 180 * @param request 181 */ 182 private void processRequestObject(String jwtString, AuthorizationRequest request) { 183 184 // parse the request object 185 try { 186 JWT jwt = JWTParser.parse(jwtString); 187 188 if (jwt instanceof SignedJWT) { 189 // it's a signed JWT, check the signature 190 191 SignedJWT signedJwt = (SignedJWT)jwt; 192 193 // need to check clientId first so that we can load the client to check other fields 194 if (request.getClientId() == null) { 195 request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID)); 196 } 197 198 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); 199 200 if (client == null) { 201 throw new InvalidClientException("Client not found: " + request.getClientId()); 202 } 203 204 205 JWSAlgorithm alg = signedJwt.getHeader().getAlgorithm(); 206 207 if (client.getRequestObjectSigningAlg() == null || 208 !client.getRequestObjectSigningAlg().equals(alg)) { 209 throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")"); 210 } 211 212 JWTSigningAndValidationService validator = validators.getValidator(client, alg); 213 214 if (validator == null) { 215 throw new InvalidClientException("Unable to create signature validator for client " + client + " and algorithm " + alg); 216 } 217 218 if (!validator.validateSignature(signedJwt)) { 219 throw new InvalidClientException("Signature did not validate for presented JWT request object."); 220 } 221 222 } else if (jwt instanceof PlainJWT) { 223 PlainJWT plainJwt = (PlainJWT)jwt; 224 225 // need to check clientId first so that we can load the client to check other fields 226 if (request.getClientId() == null) { 227 request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID)); 228 } 229 230 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); 231 232 if (client == null) { 233 throw new InvalidClientException("Client not found: " + request.getClientId()); 234 } 235 236 if (client.getRequestObjectSigningAlg() == null) { 237 throw new InvalidClientException("Client is not registered for unsigned request objects (no request_object_signing_alg registered)"); 238 } else if (!client.getRequestObjectSigningAlg().equals(Algorithm.NONE)) { 239 throw new InvalidClientException("Client is not registered for unsigned request objects (request_object_signing_alg is " + client.getRequestObjectSigningAlg() +")"); 240 } 241 242 // if we got here, we're OK, keep processing 243 244 } else if (jwt instanceof EncryptedJWT) { 245 246 EncryptedJWT encryptedJWT = (EncryptedJWT)jwt; 247 248 // decrypt the jwt if we can 249 250 encryptionService.decryptJwt(encryptedJWT); 251 252 // TODO: what if the content is a signed JWT? (#525) 253 254 if (!encryptedJWT.getState().equals(State.DECRYPTED)) { 255 throw new InvalidClientException("Unable to decrypt the request object"); 256 } 257 258 // need to check clientId first so that we can load the client to check other fields 259 if (request.getClientId() == null) { 260 request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim(CLIENT_ID)); 261 } 262 263 ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); 264 265 if (client == null) { 266 throw new InvalidClientException("Client not found: " + request.getClientId()); 267 } 268 269 270 } 271 272 273 /* 274 * NOTE: Claims inside the request object always take precedence over those in the parameter map. 275 */ 276 277 // now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims 278 279 JWTClaimsSet claims = jwt.getJWTClaimsSet(); 280 281 Set<String> responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim(RESPONSE_TYPE)); 282 if (!responseTypes.isEmpty()) { 283 if (!responseTypes.equals(request.getResponseTypes())) { 284 logger.info("Mismatch between request object and regular parameter for response_type, using request object"); 285 } 286 request.setResponseTypes(responseTypes); 287 } 288 289 String redirectUri = claims.getStringClaim(REDIRECT_URI); 290 if (redirectUri != null) { 291 if (!redirectUri.equals(request.getRedirectUri())) { 292 logger.info("Mismatch between request object and regular parameter for redirect_uri, using request object"); 293 } 294 request.setRedirectUri(redirectUri); 295 } 296 297 String state = claims.getStringClaim(STATE); 298 if(state != null) { 299 if (!state.equals(request.getState())) { 300 logger.info("Mismatch between request object and regular parameter for state, using request object"); 301 } 302 request.setState(state); 303 } 304 305 String nonce = claims.getStringClaim(NONCE); 306 if(nonce != null) { 307 if (!nonce.equals(request.getExtensions().get(NONCE))) { 308 logger.info("Mismatch between request object and regular parameter for nonce, using request object"); 309 } 310 request.getExtensions().put(NONCE, nonce); 311 } 312 313 String display = claims.getStringClaim(DISPLAY); 314 if (display != null) { 315 if (!display.equals(request.getExtensions().get(DISPLAY))) { 316 logger.info("Mismatch between request object and regular parameter for display, using request object"); 317 } 318 request.getExtensions().put(DISPLAY, display); 319 } 320 321 String prompt = claims.getStringClaim(PROMPT); 322 if (prompt != null) { 323 if (!prompt.equals(request.getExtensions().get(PROMPT))) { 324 logger.info("Mismatch between request object and regular parameter for prompt, using request object"); 325 } 326 request.getExtensions().put(PROMPT, prompt); 327 } 328 329 Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim(SCOPE)); 330 if (!scope.isEmpty()) { 331 if (!scope.equals(request.getScope())) { 332 logger.info("Mismatch between request object and regular parameter for scope, using request object"); 333 } 334 request.setScope(scope); 335 } 336 337 JsonObject claimRequest = parseClaimRequest(claims.getStringClaim(CLAIMS)); 338 if (claimRequest != null) { 339 Serializable claimExtension = request.getExtensions().get(CLAIMS); 340 if (claimExtension == null || !claimRequest.equals(parseClaimRequest(claimExtension.toString()))) { 341 logger.info("Mismatch between request object and regular parameter for claims, using request object"); 342 } 343 // we save the string because the object might not be a Java Serializable, and we can parse it easily enough anyway 344 request.getExtensions().put(CLAIMS, claimRequest.toString()); 345 } 346 347 String loginHint = claims.getStringClaim(LOGIN_HINT); 348 if (loginHint != null) { 349 if (!loginHint.equals(request.getExtensions().get(LOGIN_HINT))) { 350 logger.info("Mistmatch between request object and regular parameter for login_hint, using requst object"); 351 } 352 request.getExtensions().put(LOGIN_HINT, loginHint); 353 } 354 355 } catch (ParseException e) { 356 logger.error("ParseException while parsing RequestObject:", e); 357 } 358 } 359 360 /** 361 * @param claimRequestString 362 * @return 363 */ 364 private JsonObject parseClaimRequest(String claimRequestString) { 365 if (Strings.isNullOrEmpty(claimRequestString)) { 366 return null; 367 } else { 368 JsonElement el = parser.parse(claimRequestString); 369 if (el != null && el.isJsonObject()) { 370 return el.getAsJsonObject(); 371 } else { 372 return null; 373 } 374 } 375 } 376 377}