IntrospectingTokenService.java

/*******************************************************************************
 * Copyright 2017 The MIT Internet Trust Consortium
 *
 * Portions copyright 2011-2013 The MITRE Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package org.mitre.oauth2.introspectingfilter;

import static org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod.SECRET_BASIC;

import java.io.IOException;
import java.net.URI;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.mitre.oauth2.introspectingfilter.service.IntrospectionAuthorityGranter;
import org.mitre.oauth2.introspectingfilter.service.IntrospectionConfigurationService;
import org.mitre.oauth2.introspectingfilter.service.impl.SimpleIntrospectionAuthorityGranter;
import org.mitre.oauth2.model.RegisteredClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.util.OAuth2Utils;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.token.ResourceServerTokenServices;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.nimbusds.jose.util.Base64;

/**
 * This ResourceServerTokenServices implementation introspects incoming tokens at a
 * server's introspection endpoint URL and passes an Authentication object along
 * based on the response from the introspection endpoint.
 * @author jricher
 *
 */
public class IntrospectingTokenService implements ResourceServerTokenServices {

	private IntrospectionConfigurationService introspectionConfigurationService;
	private IntrospectionAuthorityGranter introspectionAuthorityGranter = new SimpleIntrospectionAuthorityGranter();

	private int defaultExpireTime = 300000; // 5 minutes in milliseconds
	private boolean forceCacheExpireTime = false; // force removal of cached tokens based on default expire time
	private boolean cacheNonExpiringTokens = false;
	private boolean cacheTokens = true;

	private HttpComponentsClientHttpRequestFactory factory;

	public IntrospectingTokenService() {
		this(HttpClientBuilder.create().useSystemProperties().build());
	}

	public IntrospectingTokenService(HttpClient httpClient) {
		this.factory = new HttpComponentsClientHttpRequestFactory(httpClient);
	}

	// Inner class to store in the hash map
	private class TokenCacheObject {
		OAuth2AccessToken token;
		OAuth2Authentication auth;
		Date cacheExpire;

		private TokenCacheObject(OAuth2AccessToken token, OAuth2Authentication auth) {
			this.token = token;
			this.auth = auth;

			// we don't need to check the cacheTokens values, because this won't actually be added to the cache if cacheTokens is false
			// if the token isn't null we use the token expire time
			// if forceCacheExpireTime is also true, we also make sure that the token expire time is shorter than the default expire time
			if ((this.token.getExpiration() != null) && (!forceCacheExpireTime || (forceCacheExpireTime && (this.token.getExpiration().getTime() - System.currentTimeMillis() <= defaultExpireTime)))) {
				this.cacheExpire = this.token.getExpiration();
			} else { // if the token doesn't have an expire time, or if the using forceCacheExpireTime the token expire time is longer than the default, then use the default expire time
				Calendar cal = Calendar.getInstance();
				cal.add(Calendar.MILLISECOND, defaultExpireTime);
				this.cacheExpire = cal.getTime();
			}
		}
	}

	private Map<String, TokenCacheObject> authCache = new HashMap<>();
	/**
	 * Logger for this class
	 */
	private static final Logger logger = LoggerFactory.getLogger(IntrospectingTokenService.class);

	/**
	 * @return the introspectionConfigurationService
	 */
	public IntrospectionConfigurationService getIntrospectionConfigurationService() {
		return introspectionConfigurationService;
	}

	/**
	 * @param introspectionConfigurationService the introspectionConfigurationService to set
	 */
	public void setIntrospectionConfigurationService(IntrospectionConfigurationService introspectionUrlProvider) {
		this.introspectionConfigurationService = introspectionUrlProvider;
	}

	/**
	 * @param introspectionAuthorityGranter the introspectionAuthorityGranter to set
	 */
	public void setIntrospectionAuthorityGranter(IntrospectionAuthorityGranter introspectionAuthorityGranter) {
		this.introspectionAuthorityGranter = introspectionAuthorityGranter;
	}

	/**
	 * @return the introspectionAuthorityGranter
	 */
	public IntrospectionAuthorityGranter getIntrospectionAuthorityGranter() {
		return introspectionAuthorityGranter;
	}

	/**
	 * get the default cache expire time in milliseconds
	 * @return
	 */
	public int getDefaultExpireTime() {
		return defaultExpireTime;
	}

	/**
	 * set the default cache expire time in milliseconds
	 * @param defaultExpireTime
	 */
	public void setDefaultExpireTime(int defaultExpireTime) {
		this.defaultExpireTime = defaultExpireTime;
	}

	/**
	 * check if forcing a cache expire time maximum value
	 * @return the forceCacheExpireTime setting
	 */
	public boolean isForceCacheExpireTime() {
		return forceCacheExpireTime;
	}

	/**
	 * set forcing a cache expire time maximum value
	 * @param forceCacheExpireTime
	 */
	public void setForceCacheExpireTime(boolean forceCacheExpireTime) {
		this.forceCacheExpireTime = forceCacheExpireTime;
	}

	/**
	 * Are non-expiring tokens cached using the default cache time
	 * @return state of cacheNonExpiringTokens
	 */
	public boolean isCacheNonExpiringTokens() {
		return cacheNonExpiringTokens;
	}

	/**
	 * should non-expiring tokens be cached using the default cache timeout
	 * @param cacheNonExpiringTokens
	 */
	public void setCacheNonExpiringTokens(boolean cacheNonExpiringTokens) {
		this.cacheNonExpiringTokens = cacheNonExpiringTokens;
	}

	/**
	 * Is the service caching tokens, or is it hitting the introspection end point every time
	 * @return true is caching tokens locally, false hits the introspection end point every time
	 */
	public boolean isCacheTokens() {
		return cacheTokens;
	}

	/**
	 * Configure if the client should cache tokens locally or not
	 * @param cacheTokens
	 */
	public void setCacheTokens(boolean cacheTokens) {
		this.cacheTokens = cacheTokens;
	}

	/**
	 * Check to see if the introspection end point response for a token has been cached locally
	 * This call will return the token if it has been cached and is still valid according to
	 * the cache expire time on the TokenCacheObject. If a cached value has been found but is
	 * expired, either by default expire times or the token's own expire time, then the token is
	 * removed from the cache and null is returned.
	 * @param key is the token to check
	 * @return the cached TokenCacheObject or null
	 */
	private TokenCacheObject checkCache(String key) {
		if (cacheTokens && authCache.containsKey(key)) {
			TokenCacheObject tco = authCache.get(key);

			if (tco != null && tco.cacheExpire != null && tco.cacheExpire.after(new Date())) {
				return tco;
			} else {
				// if the token is expired, don't keep things around.
				authCache.remove(key);
			}
		}
		return null;
	}

	private OAuth2Request createStoredRequest(final JsonObject token) {
		String clientId = token.get("client_id").getAsString();
		Set<String> scopes = new HashSet<>();
		if (token.has("scope")) {
			scopes.addAll(OAuth2Utils.parseParameterList(token.get("scope").getAsString()));
		}
		Map<String, String> parameters = new HashMap<>();
		parameters.put("client_id", clientId);
		parameters.put("scope", OAuth2Utils.formatParameterList(scopes));
		OAuth2Request storedRequest = new OAuth2Request(parameters, clientId, null, true, scopes, null, null, null, null);
		return storedRequest;
	}

	private Authentication createUserAuthentication(JsonObject token) {
		JsonElement userId = token.get("user_id");
		if(userId == null) {
			return null;
		}

		return new PreAuthenticatedAuthenticationToken(userId.getAsString(), token, introspectionAuthorityGranter.getAuthorities(token));
	}

	private OAuth2AccessToken createAccessToken(final JsonObject token, final String tokenString) {
		OAuth2AccessToken accessToken = new OAuth2AccessTokenImpl(token, tokenString);
		return accessToken;
	}

	/**
	 * Validate a token string against the introspection endpoint,
	 * then parse it and store it in the local cache if caching is enabled.
	 *
	 * @param accessToken Token to pass to the introspection endpoint
	 * @return TokenCacheObject containing authentication and token if the token was valid, otherwise null
	 */
	private TokenCacheObject parseToken(String accessToken) {

		// find out which URL to ask
		String introspectionUrl;
		RegisteredClient client;
		try {
			introspectionUrl = introspectionConfigurationService.getIntrospectionUrl(accessToken);
			client = introspectionConfigurationService.getClientConfiguration(accessToken);
		} catch (IllegalArgumentException e) {
			logger.error("Unable to load introspection URL or client configuration", e);
			return null;
		}
		// Use the SpringFramework RestTemplate to send the request to the
		// endpoint
		String validatedToken = null;

		RestTemplate restTemplate;
		MultiValueMap<String, String> form = new LinkedMultiValueMap<>();

		final String clientId = client.getClientId();
		final String clientSecret = client.getClientSecret();

		if (SECRET_BASIC.equals(client.getTokenEndpointAuthMethod())){
			// use BASIC auth if configured to do so
			restTemplate = new RestTemplate(factory) {

				@Override
				protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException {
					ClientHttpRequest httpRequest = super.createRequest(url, method);
					httpRequest.getHeaders().add("Authorization",
							String.format("Basic %s", Base64.encode(String.format("%s:%s", clientId, clientSecret)) ));
					return httpRequest;
				}
			};
		} else {  //Alternatively use form based auth
			restTemplate = new RestTemplate(factory);

			form.add("client_id", clientId);
			form.add("client_secret", clientSecret);
		}

		form.add("token", accessToken);

		try {
			validatedToken = restTemplate.postForObject(introspectionUrl, form, String.class);
		} catch (RestClientException rce) {
			logger.error("validateToken", rce);
			return null;
		}
		if (validatedToken != null) {
			// parse the json
			JsonElement jsonRoot = new JsonParser().parse(validatedToken);
			if (!jsonRoot.isJsonObject()) {
				return null; // didn't get a proper JSON object
			}

			JsonObject tokenResponse = jsonRoot.getAsJsonObject();

			if (tokenResponse.get("error") != null) {
				// report an error?
				logger.error("Got an error back: " + tokenResponse.get("error") + ", " + tokenResponse.get("error_description"));
				return null;
			}

			if (!tokenResponse.get("active").getAsBoolean()) {
				// non-valid token
				logger.info("Server returned non-active token");
				return null;
			}
			// create an OAuth2Authentication
			OAuth2Authentication auth = new OAuth2Authentication(createStoredRequest(tokenResponse), createUserAuthentication(tokenResponse));
			// create an OAuth2AccessToken
			OAuth2AccessToken token = createAccessToken(tokenResponse, accessToken);

			if (token.getExpiration() == null || token.getExpiration().after(new Date())) {
				// Store them in the cache
				TokenCacheObject tco = new TokenCacheObject(token, auth);
				if (cacheTokens && (cacheNonExpiringTokens || token.getExpiration() != null)) {
					authCache.put(accessToken, tco);
				}
				return tco;
			}
		}

		// when the token is invalid for whatever reason
		return null;
	}

	@Override
	public OAuth2Authentication loadAuthentication(String accessToken) throws AuthenticationException {
		// First check if the in memory cache has an Authentication object, and
		// that it is still valid
		// If Valid, return it
		TokenCacheObject cacheAuth = checkCache(accessToken);
		if (cacheAuth != null) {
			return cacheAuth.auth;
		} else {
			cacheAuth = parseToken(accessToken);
			if (cacheAuth != null) {
				return cacheAuth.auth;
			} else {
				return null;
			}
		}
	}

	@Override
	public OAuth2AccessToken readAccessToken(String accessToken) {
		// First check if the in memory cache has a Token object, and that it is
		// still valid
		// If Valid, return it
		TokenCacheObject cacheAuth = checkCache(accessToken);
		if (cacheAuth != null) {
			return cacheAuth.token;
		} else {
			cacheAuth = parseToken(accessToken);
			if (cacheAuth != null) {
				return cacheAuth.token;
			} else {
				return null;
			}
		}
	}

}