001/*******************************************************************************
002 * Copyright 2017 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018package org.mitre.openid.connect.request;
019
020
021import static org.mitre.openid.connect.request.ConnectRequestParameters.AUD;
022import static org.mitre.openid.connect.request.ConnectRequestParameters.CLAIMS;
023import static org.mitre.openid.connect.request.ConnectRequestParameters.CLIENT_ID;
024import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE;
025import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD;
026import static org.mitre.openid.connect.request.ConnectRequestParameters.DISPLAY;
027import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_HINT;
028import static org.mitre.openid.connect.request.ConnectRequestParameters.MAX_AGE;
029import static org.mitre.openid.connect.request.ConnectRequestParameters.NONCE;
030import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT;
031import static org.mitre.openid.connect.request.ConnectRequestParameters.REDIRECT_URI;
032import static org.mitre.openid.connect.request.ConnectRequestParameters.REQUEST;
033import static org.mitre.openid.connect.request.ConnectRequestParameters.RESPONSE_TYPE;
034import static org.mitre.openid.connect.request.ConnectRequestParameters.SCOPE;
035import static org.mitre.openid.connect.request.ConnectRequestParameters.STATE;
036
037import java.io.Serializable;
038import java.text.ParseException;
039import java.util.Collections;
040import java.util.Map;
041import java.util.Set;
042
043import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
044import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
045import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
046import org.mitre.oauth2.model.ClientDetailsEntity;
047import org.mitre.oauth2.model.PKCEAlgorithm;
048import org.mitre.oauth2.service.ClientDetailsEntityService;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051import org.springframework.beans.factory.annotation.Autowired;
052import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
053import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
054import org.springframework.security.oauth2.common.util.OAuth2Utils;
055import org.springframework.security.oauth2.provider.AuthorizationRequest;
056import org.springframework.security.oauth2.provider.request.DefaultOAuth2RequestFactory;
057import org.springframework.stereotype.Component;
058
059import com.google.common.base.Strings;
060import com.google.gson.JsonElement;
061import com.google.gson.JsonObject;
062import com.google.gson.JsonParser;
063import com.nimbusds.jose.Algorithm;
064import com.nimbusds.jose.JWEObject.State;
065import com.nimbusds.jose.JWSAlgorithm;
066import com.nimbusds.jwt.EncryptedJWT;
067import com.nimbusds.jwt.JWT;
068import com.nimbusds.jwt.JWTClaimsSet;
069import com.nimbusds.jwt.JWTParser;
070import com.nimbusds.jwt.PlainJWT;
071import com.nimbusds.jwt.SignedJWT;
072
073@Component("connectOAuth2RequestFactory")
074public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
075
076        /**
077         * Logger for this class
078         */
079        private static final Logger logger = LoggerFactory.getLogger(ConnectOAuth2RequestFactory.class);
080
081        private ClientDetailsEntityService clientDetailsService;
082
083        @Autowired
084        private ClientKeyCacheService validators;
085
086        @Autowired
087        private JWTEncryptionAndDecryptionService encryptionService;
088
089        private JsonParser parser = new JsonParser();
090
091        /**
092         * Constructor with arguments
093         *
094         * @param clientDetailsService
095         */
096        @Autowired
097        public ConnectOAuth2RequestFactory(ClientDetailsEntityService clientDetailsService) {
098                super(clientDetailsService);
099                this.clientDetailsService = clientDetailsService;
100        }
101
102        @Override
103        public AuthorizationRequest createAuthorizationRequest(Map<String, String> inputParams) {
104
105
106                AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections.<String, String> emptyMap(),
107                                inputParams.get(OAuth2Utils.CLIENT_ID),
108                                OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null,
109                                null, false, inputParams.get(OAuth2Utils.STATE),
110                                inputParams.get(OAuth2Utils.REDIRECT_URI),
111                                OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE)));
112
113                //Add extension parameters to the 'extensions' map
114
115                if (inputParams.containsKey(PROMPT)) {
116                        request.getExtensions().put(PROMPT, inputParams.get(PROMPT));
117                }
118                if (inputParams.containsKey(NONCE)) {
119                        request.getExtensions().put(NONCE, inputParams.get(NONCE));
120                }
121
122                if (inputParams.containsKey(CLAIMS)) {
123                        JsonObject claimsRequest = parseClaimRequest(inputParams.get(CLAIMS));
124                        if (claimsRequest != null) {
125                                request.getExtensions().put(CLAIMS, claimsRequest.toString());
126                        }
127                }
128
129                if (inputParams.containsKey(MAX_AGE)) {
130                        request.getExtensions().put(MAX_AGE, inputParams.get(MAX_AGE));
131                }
132
133                if (inputParams.containsKey(LOGIN_HINT)) {
134                        request.getExtensions().put(LOGIN_HINT, inputParams.get(LOGIN_HINT));
135                }
136
137                if (inputParams.containsKey(AUD)) {
138                        request.getExtensions().put(AUD, inputParams.get(AUD));
139                }
140
141                if (inputParams.containsKey(CODE_CHALLENGE)) {
142                        request.getExtensions().put(CODE_CHALLENGE, inputParams.get(CODE_CHALLENGE));
143                        if (inputParams.containsKey(CODE_CHALLENGE_METHOD)) {
144                                request.getExtensions().put(CODE_CHALLENGE_METHOD, inputParams.get(CODE_CHALLENGE_METHOD));
145                        } else {
146                                // if the client doesn't specify a code challenge transformation method, it's "plain"
147                                request.getExtensions().put(CODE_CHALLENGE_METHOD, PKCEAlgorithm.plain.getName());
148                        }
149
150                }
151
152                if (inputParams.containsKey(REQUEST)) {
153                        request.getExtensions().put(REQUEST, inputParams.get(REQUEST));
154                        processRequestObject(inputParams.get(REQUEST), request);
155                }
156
157                if (request.getClientId() != null) {
158                        try {
159                                ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
160
161                                if ((request.getScope() == null || request.getScope().isEmpty())) {
162                                        Set<String> clientScopes = client.getScope();
163                                        request.setScope(clientScopes);
164                                }
165
166                                if (request.getExtensions().get(MAX_AGE) == null && client.getDefaultMaxAge() != null) {
167                                        request.getExtensions().put(MAX_AGE, client.getDefaultMaxAge().toString());
168                                }
169                        } catch (OAuth2Exception e) {
170                                logger.error("Caught OAuth2 exception trying to test client scopes and max age:", e);
171                        }
172                }
173
174                return request;
175        }
176
177        /**
178         *
179         * @param jwtString
180         * @param request
181         */
182        private void processRequestObject(String jwtString, AuthorizationRequest request) {
183
184                // parse the request object
185                try {
186                        JWT jwt = JWTParser.parse(jwtString);
187
188                        if (jwt instanceof SignedJWT) {
189                                // it's a signed JWT, check the signature
190
191                                SignedJWT signedJwt = (SignedJWT)jwt;
192
193                                // need to check clientId first so that we can load the client to check other fields
194                                if (request.getClientId() == null) {
195                                        request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID));
196                                }
197
198                                ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
199
200                                if (client == null) {
201                                        throw new InvalidClientException("Client not found: " + request.getClientId());
202                                }
203
204
205                                JWSAlgorithm alg = signedJwt.getHeader().getAlgorithm();
206
207                                if (client.getRequestObjectSigningAlg() == null ||
208                                                !client.getRequestObjectSigningAlg().equals(alg)) {
209                                        throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")");
210                                }
211
212                                JWTSigningAndValidationService validator = validators.getValidator(client, alg);
213
214                                if (validator == null) {
215                                        throw new InvalidClientException("Unable to create signature validator for client " + client + " and algorithm " + alg);
216                                }
217
218                                if (!validator.validateSignature(signedJwt)) {
219                                        throw new InvalidClientException("Signature did not validate for presented JWT request object.");
220                                }
221
222                        } else if (jwt instanceof PlainJWT) {
223                                PlainJWT plainJwt = (PlainJWT)jwt;
224
225                                // need to check clientId first so that we can load the client to check other fields
226                                if (request.getClientId() == null) {
227                                        request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim(CLIENT_ID));
228                                }
229
230                                ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
231
232                                if (client == null) {
233                                        throw new InvalidClientException("Client not found: " + request.getClientId());
234                                }
235
236                                if (client.getRequestObjectSigningAlg() == null) {
237                                        throw new InvalidClientException("Client is not registered for unsigned request objects (no request_object_signing_alg registered)");
238                                } else if (!client.getRequestObjectSigningAlg().equals(Algorithm.NONE)) {
239                                        throw new InvalidClientException("Client is not registered for unsigned request objects (request_object_signing_alg is " + client.getRequestObjectSigningAlg() +")");
240                                }
241
242                                // if we got here, we're OK, keep processing
243
244                        } else if (jwt instanceof EncryptedJWT) {
245
246                                EncryptedJWT encryptedJWT = (EncryptedJWT)jwt;
247
248                                // decrypt the jwt if we can
249
250                                encryptionService.decryptJwt(encryptedJWT);
251
252                                // TODO: what if the content is a signed JWT? (#525)
253
254                                if (!encryptedJWT.getState().equals(State.DECRYPTED)) {
255                                        throw new InvalidClientException("Unable to decrypt the request object");
256                                }
257
258                                // need to check clientId first so that we can load the client to check other fields
259                                if (request.getClientId() == null) {
260                                        request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim(CLIENT_ID));
261                                }
262
263                                ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
264
265                                if (client == null) {
266                                        throw new InvalidClientException("Client not found: " + request.getClientId());
267                                }
268
269
270                        }
271
272
273                        /*
274                         * NOTE: Claims inside the request object always take precedence over those in the parameter map.
275                         */
276
277                        // now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims
278
279                        JWTClaimsSet claims = jwt.getJWTClaimsSet();
280
281                        Set<String> responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim(RESPONSE_TYPE));
282                        if (!responseTypes.isEmpty()) {
283                                if (!responseTypes.equals(request.getResponseTypes())) {
284                                        logger.info("Mismatch between request object and regular parameter for response_type, using request object");
285                                }
286                                request.setResponseTypes(responseTypes);
287                        }
288
289                        String redirectUri = claims.getStringClaim(REDIRECT_URI);
290                        if (redirectUri != null) {
291                                if (!redirectUri.equals(request.getRedirectUri())) {
292                                        logger.info("Mismatch between request object and regular parameter for redirect_uri, using request object");
293                                }
294                                request.setRedirectUri(redirectUri);
295                        }
296
297                        String state = claims.getStringClaim(STATE);
298                        if(state != null) {
299                                if (!state.equals(request.getState())) {
300                                        logger.info("Mismatch between request object and regular parameter for state, using request object");
301                                }
302                                request.setState(state);
303                        }
304
305                        String nonce = claims.getStringClaim(NONCE);
306                        if(nonce != null) {
307                                if (!nonce.equals(request.getExtensions().get(NONCE))) {
308                                        logger.info("Mismatch between request object and regular parameter for nonce, using request object");
309                                }
310                                request.getExtensions().put(NONCE, nonce);
311                        }
312
313                        String display = claims.getStringClaim(DISPLAY);
314                        if (display != null) {
315                                if (!display.equals(request.getExtensions().get(DISPLAY))) {
316                                        logger.info("Mismatch between request object and regular parameter for display, using request object");
317                                }
318                                request.getExtensions().put(DISPLAY, display);
319                        }
320
321                        String prompt = claims.getStringClaim(PROMPT);
322                        if (prompt != null) {
323                                if (!prompt.equals(request.getExtensions().get(PROMPT))) {
324                                        logger.info("Mismatch between request object and regular parameter for prompt, using request object");
325                                }
326                                request.getExtensions().put(PROMPT, prompt);
327                        }
328
329                        Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim(SCOPE));
330                        if (!scope.isEmpty()) {
331                                if (!scope.equals(request.getScope())) {
332                                        logger.info("Mismatch between request object and regular parameter for scope, using request object");
333                                }
334                                request.setScope(scope);
335                        }
336
337                        JsonObject claimRequest = parseClaimRequest(claims.getStringClaim(CLAIMS));
338                        if (claimRequest != null) {
339                                Serializable claimExtension = request.getExtensions().get(CLAIMS);
340                                if (claimExtension == null || !claimRequest.equals(parseClaimRequest(claimExtension.toString()))) {
341                                        logger.info("Mismatch between request object and regular parameter for claims, using request object");
342                                }
343                                // we save the string because the object might not be a Java Serializable, and we can parse it easily enough anyway
344                                request.getExtensions().put(CLAIMS, claimRequest.toString());
345                        }
346
347                        String loginHint = claims.getStringClaim(LOGIN_HINT);
348                        if (loginHint != null) {
349                                if (!loginHint.equals(request.getExtensions().get(LOGIN_HINT))) {
350                                        logger.info("Mistmatch between request object and regular parameter for login_hint, using requst object");
351                                }
352                                request.getExtensions().put(LOGIN_HINT, loginHint);
353                        }
354
355                } catch (ParseException e) {
356                        logger.error("ParseException while parsing RequestObject:", e);
357                }
358        }
359
360        /**
361         * @param claimRequestString
362         * @return
363         */
364        private JsonObject parseClaimRequest(String claimRequestString) {
365                if (Strings.isNullOrEmpty(claimRequestString)) {
366                        return null;
367                } else {
368                        JsonElement el = parser.parse(claimRequestString);
369                        if (el != null && el.isJsonObject()) {
370                                return el.getAsJsonObject();
371                        } else {
372                                return null;
373                        }
374                }
375        }
376
377}