001/*******************************************************************************
002 * Copyright 2017 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018package org.mitre.oauth2.introspectingfilter;
019
020import static org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod.SECRET_BASIC;
021
022import java.io.IOException;
023import java.net.URI;
024import java.util.Calendar;
025import java.util.Date;
026import java.util.HashMap;
027import java.util.HashSet;
028import java.util.Map;
029import java.util.Set;
030
031import org.apache.http.client.HttpClient;
032import org.apache.http.impl.client.HttpClientBuilder;
033import org.mitre.oauth2.introspectingfilter.service.IntrospectionAuthorityGranter;
034import org.mitre.oauth2.introspectingfilter.service.IntrospectionConfigurationService;
035import org.mitre.oauth2.introspectingfilter.service.impl.SimpleIntrospectionAuthorityGranter;
036import org.mitre.oauth2.model.RegisteredClient;
037import org.slf4j.Logger;
038import org.slf4j.LoggerFactory;
039import org.springframework.http.HttpMethod;
040import org.springframework.http.client.ClientHttpRequest;
041import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
042import org.springframework.security.core.Authentication;
043import org.springframework.security.core.AuthenticationException;
044import org.springframework.security.oauth2.common.OAuth2AccessToken;
045import org.springframework.security.oauth2.common.util.OAuth2Utils;
046import org.springframework.security.oauth2.provider.OAuth2Authentication;
047import org.springframework.security.oauth2.provider.OAuth2Request;
048import org.springframework.security.oauth2.provider.token.ResourceServerTokenServices;
049import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
050import org.springframework.util.LinkedMultiValueMap;
051import org.springframework.util.MultiValueMap;
052import org.springframework.web.client.RestClientException;
053import org.springframework.web.client.RestTemplate;
054
055import com.google.gson.JsonElement;
056import com.google.gson.JsonObject;
057import com.google.gson.JsonParser;
058import com.nimbusds.jose.util.Base64;
059
060/**
061 * This ResourceServerTokenServices implementation introspects incoming tokens at a
062 * server's introspection endpoint URL and passes an Authentication object along
063 * based on the response from the introspection endpoint.
064 * @author jricher
065 *
066 */
067public class IntrospectingTokenService implements ResourceServerTokenServices {
068
069        private IntrospectionConfigurationService introspectionConfigurationService;
070        private IntrospectionAuthorityGranter introspectionAuthorityGranter = new SimpleIntrospectionAuthorityGranter();
071
072        private int defaultExpireTime = 300000; // 5 minutes in milliseconds
073        private boolean forceCacheExpireTime = false; // force removal of cached tokens based on default expire time
074        private boolean cacheNonExpiringTokens = false;
075        private boolean cacheTokens = true;
076
077        private HttpComponentsClientHttpRequestFactory factory;
078
079        public IntrospectingTokenService() {
080                this(HttpClientBuilder.create().useSystemProperties().build());
081        }
082
083        public IntrospectingTokenService(HttpClient httpClient) {
084                this.factory = new HttpComponentsClientHttpRequestFactory(httpClient);
085        }
086
087        // Inner class to store in the hash map
088        private class TokenCacheObject {
089                OAuth2AccessToken token;
090                OAuth2Authentication auth;
091                Date cacheExpire;
092
093                private TokenCacheObject(OAuth2AccessToken token, OAuth2Authentication auth) {
094                        this.token = token;
095                        this.auth = auth;
096
097                        // we don't need to check the cacheTokens values, because this won't actually be added to the cache if cacheTokens is false
098                        // if the token isn't null we use the token expire time
099                        // if forceCacheExpireTime is also true, we also make sure that the token expire time is shorter than the default expire time
100                        if ((this.token.getExpiration() != null) && (!forceCacheExpireTime || (forceCacheExpireTime && (this.token.getExpiration().getTime() - System.currentTimeMillis() <= defaultExpireTime)))) {
101                                this.cacheExpire = this.token.getExpiration();
102                        } else { // if the token doesn't have an expire time, or if the using forceCacheExpireTime the token expire time is longer than the default, then use the default expire time
103                                Calendar cal = Calendar.getInstance();
104                                cal.add(Calendar.MILLISECOND, defaultExpireTime);
105                                this.cacheExpire = cal.getTime();
106                        }
107                }
108        }
109
110        private Map<String, TokenCacheObject> authCache = new HashMap<>();
111        /**
112         * Logger for this class
113         */
114        private static final Logger logger = LoggerFactory.getLogger(IntrospectingTokenService.class);
115
116        /**
117         * @return the introspectionConfigurationService
118         */
119        public IntrospectionConfigurationService getIntrospectionConfigurationService() {
120                return introspectionConfigurationService;
121        }
122
123        /**
124         * @param introspectionConfigurationService the introspectionConfigurationService to set
125         */
126        public void setIntrospectionConfigurationService(IntrospectionConfigurationService introspectionUrlProvider) {
127                this.introspectionConfigurationService = introspectionUrlProvider;
128        }
129
130        /**
131         * @param introspectionAuthorityGranter the introspectionAuthorityGranter to set
132         */
133        public void setIntrospectionAuthorityGranter(IntrospectionAuthorityGranter introspectionAuthorityGranter) {
134                this.introspectionAuthorityGranter = introspectionAuthorityGranter;
135        }
136
137        /**
138         * @return the introspectionAuthorityGranter
139         */
140        public IntrospectionAuthorityGranter getIntrospectionAuthorityGranter() {
141                return introspectionAuthorityGranter;
142        }
143
144        /**
145         * get the default cache expire time in milliseconds
146         * @return
147         */
148        public int getDefaultExpireTime() {
149                return defaultExpireTime;
150        }
151
152        /**
153         * set the default cache expire time in milliseconds
154         * @param defaultExpireTime
155         */
156        public void setDefaultExpireTime(int defaultExpireTime) {
157                this.defaultExpireTime = defaultExpireTime;
158        }
159
160        /**
161         * check if forcing a cache expire time maximum value
162         * @return the forceCacheExpireTime setting
163         */
164        public boolean isForceCacheExpireTime() {
165                return forceCacheExpireTime;
166        }
167
168        /**
169         * set forcing a cache expire time maximum value
170         * @param forceCacheExpireTime
171         */
172        public void setForceCacheExpireTime(boolean forceCacheExpireTime) {
173                this.forceCacheExpireTime = forceCacheExpireTime;
174        }
175
176        /**
177         * Are non-expiring tokens cached using the default cache time
178         * @return state of cacheNonExpiringTokens
179         */
180        public boolean isCacheNonExpiringTokens() {
181                return cacheNonExpiringTokens;
182        }
183
184        /**
185         * should non-expiring tokens be cached using the default cache timeout
186         * @param cacheNonExpiringTokens
187         */
188        public void setCacheNonExpiringTokens(boolean cacheNonExpiringTokens) {
189                this.cacheNonExpiringTokens = cacheNonExpiringTokens;
190        }
191
192        /**
193         * Is the service caching tokens, or is it hitting the introspection end point every time
194         * @return true is caching tokens locally, false hits the introspection end point every time
195         */
196        public boolean isCacheTokens() {
197                return cacheTokens;
198        }
199
200        /**
201         * Configure if the client should cache tokens locally or not
202         * @param cacheTokens
203         */
204        public void setCacheTokens(boolean cacheTokens) {
205                this.cacheTokens = cacheTokens;
206        }
207
208        /**
209         * Check to see if the introspection end point response for a token has been cached locally
210         * This call will return the token if it has been cached and is still valid according to
211         * the cache expire time on the TokenCacheObject. If a cached value has been found but is
212         * expired, either by default expire times or the token's own expire time, then the token is
213         * removed from the cache and null is returned.
214         * @param key is the token to check
215         * @return the cached TokenCacheObject or null
216         */
217        private TokenCacheObject checkCache(String key) {
218                if (cacheTokens && authCache.containsKey(key)) {
219                        TokenCacheObject tco = authCache.get(key);
220
221                        if (tco != null && tco.cacheExpire != null && tco.cacheExpire.after(new Date())) {
222                                return tco;
223                        } else {
224                                // if the token is expired, don't keep things around.
225                                authCache.remove(key);
226                        }
227                }
228                return null;
229        }
230
231        private OAuth2Request createStoredRequest(final JsonObject token) {
232                String clientId = token.get("client_id").getAsString();
233                Set<String> scopes = new HashSet<>();
234                if (token.has("scope")) {
235                        scopes.addAll(OAuth2Utils.parseParameterList(token.get("scope").getAsString()));
236                }
237                Map<String, String> parameters = new HashMap<>();
238                parameters.put("client_id", clientId);
239                parameters.put("scope", OAuth2Utils.formatParameterList(scopes));
240                OAuth2Request storedRequest = new OAuth2Request(parameters, clientId, null, true, scopes, null, null, null, null);
241                return storedRequest;
242        }
243
244        private Authentication createUserAuthentication(JsonObject token) {
245                JsonElement userId = token.get("user_id");
246                if(userId == null) {
247                        return null;
248                }
249
250                return new PreAuthenticatedAuthenticationToken(userId.getAsString(), token, introspectionAuthorityGranter.getAuthorities(token));
251        }
252
253        private OAuth2AccessToken createAccessToken(final JsonObject token, final String tokenString) {
254                OAuth2AccessToken accessToken = new OAuth2AccessTokenImpl(token, tokenString);
255                return accessToken;
256        }
257
258        /**
259         * Validate a token string against the introspection endpoint,
260         * then parse it and store it in the local cache if caching is enabled.
261         *
262         * @param accessToken Token to pass to the introspection endpoint
263         * @return TokenCacheObject containing authentication and token if the token was valid, otherwise null
264         */
265        private TokenCacheObject parseToken(String accessToken) {
266
267                // find out which URL to ask
268                String introspectionUrl;
269                RegisteredClient client;
270                try {
271                        introspectionUrl = introspectionConfigurationService.getIntrospectionUrl(accessToken);
272                        client = introspectionConfigurationService.getClientConfiguration(accessToken);
273                } catch (IllegalArgumentException e) {
274                        logger.error("Unable to load introspection URL or client configuration", e);
275                        return null;
276                }
277                // Use the SpringFramework RestTemplate to send the request to the
278                // endpoint
279                String validatedToken = null;
280
281                RestTemplate restTemplate;
282                MultiValueMap<String, String> form = new LinkedMultiValueMap<>();
283
284                final String clientId = client.getClientId();
285                final String clientSecret = client.getClientSecret();
286
287                if (SECRET_BASIC.equals(client.getTokenEndpointAuthMethod())){
288                        // use BASIC auth if configured to do so
289                        restTemplate = new RestTemplate(factory) {
290
291                                @Override
292                                protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException {
293                                        ClientHttpRequest httpRequest = super.createRequest(url, method);
294                                        httpRequest.getHeaders().add("Authorization",
295                                                        String.format("Basic %s", Base64.encode(String.format("%s:%s", clientId, clientSecret)) ));
296                                        return httpRequest;
297                                }
298                        };
299                } else {  //Alternatively use form based auth
300                        restTemplate = new RestTemplate(factory);
301
302                        form.add("client_id", clientId);
303                        form.add("client_secret", clientSecret);
304                }
305
306                form.add("token", accessToken);
307
308                try {
309                        validatedToken = restTemplate.postForObject(introspectionUrl, form, String.class);
310                } catch (RestClientException rce) {
311                        logger.error("validateToken", rce);
312                        return null;
313                }
314                if (validatedToken != null) {
315                        // parse the json
316                        JsonElement jsonRoot = new JsonParser().parse(validatedToken);
317                        if (!jsonRoot.isJsonObject()) {
318                                return null; // didn't get a proper JSON object
319                        }
320
321                        JsonObject tokenResponse = jsonRoot.getAsJsonObject();
322
323                        if (tokenResponse.get("error") != null) {
324                                // report an error?
325                                logger.error("Got an error back: " + tokenResponse.get("error") + ", " + tokenResponse.get("error_description"));
326                                return null;
327                        }
328
329                        if (!tokenResponse.get("active").getAsBoolean()) {
330                                // non-valid token
331                                logger.info("Server returned non-active token");
332                                return null;
333                        }
334                        // create an OAuth2Authentication
335                        OAuth2Authentication auth = new OAuth2Authentication(createStoredRequest(tokenResponse), createUserAuthentication(tokenResponse));
336                        // create an OAuth2AccessToken
337                        OAuth2AccessToken token = createAccessToken(tokenResponse, accessToken);
338
339                        if (token.getExpiration() == null || token.getExpiration().after(new Date())) {
340                                // Store them in the cache
341                                TokenCacheObject tco = new TokenCacheObject(token, auth);
342                                if (cacheTokens && (cacheNonExpiringTokens || token.getExpiration() != null)) {
343                                        authCache.put(accessToken, tco);
344                                }
345                                return tco;
346                        }
347                }
348
349                // when the token is invalid for whatever reason
350                return null;
351        }
352
353        @Override
354        public OAuth2Authentication loadAuthentication(String accessToken) throws AuthenticationException {
355                // First check if the in memory cache has an Authentication object, and
356                // that it is still valid
357                // If Valid, return it
358                TokenCacheObject cacheAuth = checkCache(accessToken);
359                if (cacheAuth != null) {
360                        return cacheAuth.auth;
361                } else {
362                        cacheAuth = parseToken(accessToken);
363                        if (cacheAuth != null) {
364                                return cacheAuth.auth;
365                        } else {
366                                return null;
367                        }
368                }
369        }
370
371        @Override
372        public OAuth2AccessToken readAccessToken(String accessToken) {
373                // First check if the in memory cache has a Token object, and that it is
374                // still valid
375                // If Valid, return it
376                TokenCacheObject cacheAuth = checkCache(accessToken);
377                if (cacheAuth != null) {
378                        return cacheAuth.token;
379                } else {
380                        cacheAuth = parseToken(accessToken);
381                        if (cacheAuth != null) {
382                                return cacheAuth.token;
383                        } else {
384                                return null;
385                        }
386                }
387        }
388
389}