DefaultOAuth2ProviderTokenService.java

/*******************************************************************************
 * Copyright 2017 The MIT Internet Trust Consortium
 *
 * Portions copyright 2011-2013 The MITRE Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
/**
 *
 */
package org.mitre.oauth2.service.impl;

import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE;
import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD;
import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_VERIFIER;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collection;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;

import org.mitre.data.AbstractPageOperationTemplate;
import org.mitre.data.DefaultPageCriteria;
import org.mitre.oauth2.model.AuthenticationHolderEntity;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
import org.mitre.oauth2.model.OAuth2RefreshTokenEntity;
import org.mitre.oauth2.model.PKCEAlgorithm;
import org.mitre.oauth2.model.SystemScope;
import org.mitre.oauth2.repository.AuthenticationHolderRepository;
import org.mitre.oauth2.repository.OAuth2TokenRepository;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.oauth2.service.OAuth2TokenEntityService;
import org.mitre.oauth2.service.SystemScopeService;
import org.mitre.openid.connect.model.ApprovedSite;
import org.mitre.openid.connect.service.ApprovedSiteService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
import org.springframework.security.oauth2.common.exceptions.InvalidScopeException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.TokenRequest;
import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import com.google.common.base.Strings;
import com.google.common.collect.Sets;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;


/**
 * @author jricher
 *
 */
@Service("defaultOAuth2ProviderTokenService")
public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityService {

	/**
	 * Logger for this class
	 */
	private static final Logger logger = LoggerFactory.getLogger(DefaultOAuth2ProviderTokenService.class);

	@Autowired
	private OAuth2TokenRepository tokenRepository;

	@Autowired
	private AuthenticationHolderRepository authenticationHolderRepository;

	@Autowired
	private ClientDetailsEntityService clientDetailsService;

	@Autowired
	private TokenEnhancer tokenEnhancer;

	@Autowired
	private SystemScopeService scopeService;

	@Autowired
	private ApprovedSiteService approvedSiteService;


	@Override
	public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String id) {

		Set<OAuth2AccessTokenEntity> all = tokenRepository.getAllAccessTokens();
		Set<OAuth2AccessTokenEntity> results = Sets.newLinkedHashSet();

		for (OAuth2AccessTokenEntity token : all) {
			if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
				results.add(token);
			}
		}

		return results;
	}


	@Override
	public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String id) {
		Set<OAuth2RefreshTokenEntity> all = tokenRepository.getAllRefreshTokens();
		Set<OAuth2RefreshTokenEntity> results = Sets.newLinkedHashSet();

		for (OAuth2RefreshTokenEntity token : all) {
			if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
				results.add(token);
			}
		}

		return results;
	}

	@Override
	public OAuth2AccessTokenEntity getAccessTokenById(Long id) {
		return clearExpiredAccessToken(tokenRepository.getAccessTokenById(id));
	}

	@Override
	public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) {
		return clearExpiredRefreshToken(tokenRepository.getRefreshTokenById(id));
	}

	/**
	 * Utility function to delete an access token that's expired before returning it.
	 * @param token the token to check
	 * @return null if the token is null or expired, the input token (unchanged) if it hasn't
	 */
	private OAuth2AccessTokenEntity clearExpiredAccessToken(OAuth2AccessTokenEntity token) {
		if (token == null) {
			return null;
		} else if (token.isExpired()) {
			// immediately revoke expired token
			logger.debug("Clearing expired access token: " + token.getValue());
			revokeAccessToken(token);
			return null;
		} else {
			return token;
		}
	}

	/**
	 * Utility function to delete a refresh token that's expired before returning it.
	 * @param token the token to check
	 * @return null if the token is null or expired, the input token (unchanged) if it hasn't
	 */
	private OAuth2RefreshTokenEntity clearExpiredRefreshToken(OAuth2RefreshTokenEntity token) {
		if (token == null) {
			return null;
		} else if (token.isExpired()) {
			// immediately revoke expired token
			logger.debug("Clearing expired refresh token: " + token.getValue());
			revokeRefreshToken(token);
			return null;
		} else {
			return token;
		}
	}

	@Override
	@Transactional(value="defaultTransactionManager")
	public OAuth2AccessTokenEntity createAccessToken(OAuth2Authentication authentication) throws AuthenticationException, InvalidClientException {
		if (authentication != null && authentication.getOAuth2Request() != null) {
			// look up our client
			OAuth2Request request = authentication.getOAuth2Request();

			ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());

			if (client == null) {
				throw new InvalidClientException("Client not found: " + request.getClientId());
			}


			// handle the PKCE code challenge if present
			if (request.getExtensions().containsKey(CODE_CHALLENGE)) {
				String challenge = (String) request.getExtensions().get(CODE_CHALLENGE);
				PKCEAlgorithm alg = PKCEAlgorithm.parse((String) request.getExtensions().get(CODE_CHALLENGE_METHOD));

				String verifier = request.getRequestParameters().get(CODE_VERIFIER);

				if (alg.equals(PKCEAlgorithm.plain)) {
					// do a direct string comparison
					if (!challenge.equals(verifier)) {
						throw new InvalidRequestException("Code challenge and verifier do not match");
					}
				} else if (alg.equals(PKCEAlgorithm.S256)) {
					// hash the verifier
					try {
						MessageDigest digest = MessageDigest.getInstance("SHA-256");
						String hash = Base64URL.encode(digest.digest(verifier.getBytes(StandardCharsets.US_ASCII))).toString();
						if (!challenge.equals(hash)) {
							throw new InvalidRequestException("Code challenge and verifier do not match");
						}
					} catch (NoSuchAlgorithmException e) {
						logger.error("Unknown algorithm for PKCE digest", e);
					}
				}

			}


			OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();//accessTokenFactory.createNewAccessToken();

			// attach the client
			token.setClient(client);

			// inherit the scope from the auth, but make a new set so it is
			//not unmodifiable. Unmodifiables don't play nicely with Eclipselink, which
			//wants to use the clone operation.
			Set<SystemScope> scopes = scopeService.fromStrings(request.getScope());

			// remove any of the special system scopes
			scopes = scopeService.removeReservedScopes(scopes);

			token.setScope(scopeService.toStrings(scopes));

			// make it expire if necessary
			if (client.getAccessTokenValiditySeconds() != null && client.getAccessTokenValiditySeconds() > 0) {
				Date expiration = new Date(System.currentTimeMillis() + (client.getAccessTokenValiditySeconds() * 1000L));
				token.setExpiration(expiration);
			}

			// attach the authorization so that we can look it up later
			AuthenticationHolderEntity authHolder = new AuthenticationHolderEntity();
			authHolder.setAuthentication(authentication);
			authHolder = authenticationHolderRepository.save(authHolder);

			token.setAuthenticationHolder(authHolder);

			// attach a refresh token, if this client is allowed to request them and the user gets the offline scope
			if (client.isAllowRefresh() && token.getScope().contains(SystemScopeService.OFFLINE_ACCESS)) {
				OAuth2RefreshTokenEntity savedRefreshToken = createRefreshToken(client, authHolder);

				token.setRefreshToken(savedRefreshToken);
			}

			//Add approved site reference, if any
			OAuth2Request originalAuthRequest = authHolder.getAuthentication().getOAuth2Request();

			if (originalAuthRequest.getExtensions() != null && originalAuthRequest.getExtensions().containsKey("approved_site")) {

				Long apId = Long.parseLong((String) originalAuthRequest.getExtensions().get("approved_site"));
				ApprovedSite ap = approvedSiteService.getById(apId);

				token.setApprovedSite(ap);
			}

			OAuth2AccessTokenEntity enhancedToken = (OAuth2AccessTokenEntity) tokenEnhancer.enhance(token, authentication);

			OAuth2AccessTokenEntity savedToken = saveAccessToken(enhancedToken);

			if (savedToken.getRefreshToken() != null) {
				tokenRepository.saveRefreshToken(savedToken.getRefreshToken()); // make sure we save any changes that might have been enhanced
			}

			return savedToken;
		}

		throw new AuthenticationCredentialsNotFoundException("No authentication credentials found");
	}


	private OAuth2RefreshTokenEntity createRefreshToken(ClientDetailsEntity client, AuthenticationHolderEntity authHolder) {
		OAuth2RefreshTokenEntity refreshToken = new OAuth2RefreshTokenEntity(); //refreshTokenFactory.createNewRefreshToken();
		JWTClaimsSet.Builder refreshClaims = new JWTClaimsSet.Builder();


		// make it expire if necessary
		if (client.getRefreshTokenValiditySeconds() != null) {
			Date expiration = new Date(System.currentTimeMillis() + (client.getRefreshTokenValiditySeconds() * 1000L));
			refreshToken.setExpiration(expiration);
			refreshClaims.expirationTime(expiration);
		}

		// set a random identifier
		refreshClaims.jwtID(UUID.randomUUID().toString());

		// TODO: add issuer fields, signature to JWT

		PlainJWT refreshJwt = new PlainJWT(refreshClaims.build());
		refreshToken.setJwt(refreshJwt);

		//Add the authentication
		refreshToken.setAuthenticationHolder(authHolder);
		refreshToken.setClient(client);



		// save the token first so that we can set it to a member of the access token (NOTE: is this step necessary?)
		OAuth2RefreshTokenEntity savedRefreshToken = tokenRepository.saveRefreshToken(refreshToken);
		return savedRefreshToken;
	}

	@Override
	@Transactional(value="defaultTransactionManager")
	public OAuth2AccessTokenEntity refreshAccessToken(String refreshTokenValue, TokenRequest authRequest) throws AuthenticationException {
		
		if (Strings.isNullOrEmpty(refreshTokenValue)) {
			// throw an invalid token exception if there's no refresh token value at all
			throw new InvalidTokenException("Invalid refresh token: " + refreshTokenValue);
		}

		OAuth2RefreshTokenEntity refreshToken = clearExpiredRefreshToken(tokenRepository.getRefreshTokenByValue(refreshTokenValue));

		if (refreshToken == null) {
			// throw an invalid token exception if we couldn't find the token
			throw new InvalidTokenException("Invalid refresh token: " + refreshTokenValue);
		}

		ClientDetailsEntity client = refreshToken.getClient();

		AuthenticationHolderEntity authHolder = refreshToken.getAuthenticationHolder();

		// make sure that the client requesting the token is the one who owns the refresh token
		ClientDetailsEntity requestingClient = clientDetailsService.loadClientByClientId(authRequest.getClientId());
		if (!client.getClientId().equals(requestingClient.getClientId())) {
			tokenRepository.removeRefreshToken(refreshToken);
			throw new InvalidClientException("Client does not own the presented refresh token");
		}

		//Make sure this client allows access token refreshing
		if (!client.isAllowRefresh()) {
			throw new InvalidClientException("Client does not allow refreshing access token!");
		}

		// clear out any access tokens
		if (client.isClearAccessTokensOnRefresh()) {
			tokenRepository.clearAccessTokensForRefreshToken(refreshToken);
		}

		if (refreshToken.isExpired()) {
			tokenRepository.removeRefreshToken(refreshToken);
			throw new InvalidTokenException("Expired refresh token: " + refreshTokenValue);
		}

		OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();

		// get the stored scopes from the authentication holder's authorization request; these are the scopes associated with the refresh token
		Set<String> refreshScopesRequested = new HashSet<>(refreshToken.getAuthenticationHolder().getAuthentication().getOAuth2Request().getScope());
		Set<SystemScope> refreshScopes = scopeService.fromStrings(refreshScopesRequested);
		// remove any of the special system scopes
		refreshScopes = scopeService.removeReservedScopes(refreshScopes);

		Set<String> scopeRequested = authRequest.getScope() == null ? new HashSet<String>() : new HashSet<>(authRequest.getScope());
		Set<SystemScope> scope = scopeService.fromStrings(scopeRequested);

		// remove any of the special system scopes
		scope = scopeService.removeReservedScopes(scope);

		if (scope != null && !scope.isEmpty()) {
			// ensure a proper subset of scopes
			if (refreshScopes != null && refreshScopes.containsAll(scope)) {
				// set the scope of the new access token if requested
				token.setScope(scopeService.toStrings(scope));
			} else {
				String errorMsg = "Up-scoping is not allowed.";
				logger.error(errorMsg);
				throw new InvalidScopeException(errorMsg);
			}
		} else {
			// otherwise inherit the scope of the refresh token (if it's there -- this can return a null scope set)
			token.setScope(scopeService.toStrings(refreshScopes));
		}

		token.setClient(client);

		if (client.getAccessTokenValiditySeconds() != null) {
			Date expiration = new Date(System.currentTimeMillis() + (client.getAccessTokenValiditySeconds() * 1000L));
			token.setExpiration(expiration);
		}

		if (client.isReuseRefreshToken()) {
			// if the client re-uses refresh tokens, do that
			token.setRefreshToken(refreshToken);
		} else {
			// otherwise, make a new refresh token
			OAuth2RefreshTokenEntity newRefresh = createRefreshToken(client, authHolder);
			token.setRefreshToken(newRefresh);

			// clean up the old refresh token
			tokenRepository.removeRefreshToken(refreshToken);
		}

		token.setAuthenticationHolder(authHolder);

		tokenEnhancer.enhance(token, authHolder.getAuthentication());

		tokenRepository.saveAccessToken(token);

		return token;

	}

	@Override
	public OAuth2Authentication loadAuthentication(String accessTokenValue) throws AuthenticationException {

		OAuth2AccessTokenEntity accessToken = clearExpiredAccessToken(tokenRepository.getAccessTokenByValue(accessTokenValue));

		if (accessToken == null) {
			throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
		} else {
			return accessToken.getAuthenticationHolder().getAuthentication();
		}
	}


	/**
	 * Get an access token from its token value.
	 */
	@Override
	public OAuth2AccessTokenEntity readAccessToken(String accessTokenValue) throws AuthenticationException {
		OAuth2AccessTokenEntity accessToken = clearExpiredAccessToken(tokenRepository.getAccessTokenByValue(accessTokenValue));
		if (accessToken == null) {
			throw new InvalidTokenException("Access token for value " + accessTokenValue + " was not found");
		} else {
			return accessToken;
		}
	}

	/**
	 * Get an access token by its authentication object.
	 */
	@Override
	public OAuth2AccessTokenEntity getAccessToken(OAuth2Authentication authentication) {
		// TODO: implement this against the new service (#825)
		throw new UnsupportedOperationException("Unable to look up access token from authentication object.");
	}

	/**
	 * Get a refresh token by its token value.
	 */
	@Override
	public OAuth2RefreshTokenEntity getRefreshToken(String refreshTokenValue) throws AuthenticationException {
		OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenByValue(refreshTokenValue);
		if (refreshToken == null) {
			throw new InvalidTokenException("Refresh token for value " + refreshTokenValue + " was not found");
		}
		else {
			return refreshToken;
		}
	}

	/**
	 * Revoke a refresh token and all access tokens issued to it.
	 */
	@Override
	@Transactional(value="defaultTransactionManager")
	public void revokeRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
		tokenRepository.clearAccessTokensForRefreshToken(refreshToken);
		tokenRepository.removeRefreshToken(refreshToken);
	}

	/**
	 * Revoke an access token.
	 */
	@Override
	@Transactional(value="defaultTransactionManager")
	public void revokeAccessToken(OAuth2AccessTokenEntity accessToken) {
		tokenRepository.removeAccessToken(accessToken);
	}


	/* (non-Javadoc)
	 * @see org.mitre.oauth2.service.OAuth2TokenEntityService#getAccessTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
	 */
	@Override
	public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) {
		return tokenRepository.getAccessTokensForClient(client);
	}

	/* (non-Javadoc)
	 * @see org.mitre.oauth2.service.OAuth2TokenEntityService#getRefreshTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
	 */
	@Override
	public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) {
		return tokenRepository.getRefreshTokensForClient(client);
	}

	/**
	 * Clears out expired tokens and any abandoned authentication objects
	 */
	@Override
	public void clearExpiredTokens() {
		logger.debug("Cleaning out all expired tokens");

		new AbstractPageOperationTemplate<OAuth2AccessTokenEntity>("clearExpiredAccessTokens") {
			@Override
			public Collection<OAuth2AccessTokenEntity> fetchPage() {
				return tokenRepository.getAllExpiredAccessTokens(new DefaultPageCriteria());
			}

			@Override
			public void doOperation(OAuth2AccessTokenEntity item) {
				revokeAccessToken(item);
			}
		}.execute();

		new AbstractPageOperationTemplate<OAuth2RefreshTokenEntity>("clearExpiredRefreshTokens") {
			@Override
			public Collection<OAuth2RefreshTokenEntity> fetchPage() {
				return tokenRepository.getAllExpiredRefreshTokens(new DefaultPageCriteria());
			}

			@Override
			public void doOperation(OAuth2RefreshTokenEntity item) {
				revokeRefreshToken(item);
			}
		}.execute();

		new AbstractPageOperationTemplate<AuthenticationHolderEntity>("clearExpiredAuthenticationHolders") {
			@Override
			public Collection<AuthenticationHolderEntity> fetchPage() {
				return authenticationHolderRepository.getOrphanedAuthenticationHolders(new DefaultPageCriteria());
			}

			@Override
			public void doOperation(AuthenticationHolderEntity item) {
				authenticationHolderRepository.remove(item);
			}
		}.execute();
	}

	/* (non-Javadoc)
	 * @see org.mitre.oauth2.service.OAuth2TokenEntityService#saveAccessToken(org.mitre.oauth2.model.OAuth2AccessTokenEntity)
	 */
	@Override
	@Transactional(value="defaultTransactionManager")
	public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity accessToken) {
		OAuth2AccessTokenEntity newToken = tokenRepository.saveAccessToken(accessToken);

		// if the old token has any additional information for the return from the token endpoint, carry it through here after save
		if (accessToken.getAdditionalInformation() != null && !accessToken.getAdditionalInformation().isEmpty()) {
			newToken.getAdditionalInformation().putAll(accessToken.getAdditionalInformation());
		}

		return newToken;
	}

	/* (non-Javadoc)
	 * @see org.mitre.oauth2.service.OAuth2TokenEntityService#saveRefreshToken(org.mitre.oauth2.model.OAuth2RefreshTokenEntity)
	 */
	@Override
	@Transactional(value="defaultTransactionManager")
	public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
		return tokenRepository.saveRefreshToken(refreshToken);
	}

	/**
	 * @return the tokenEnhancer
	 */
	public TokenEnhancer getTokenEnhancer() {
		return tokenEnhancer;
	}

	/**
	 * @param tokenEnhancer the tokenEnhancer to set
	 */
	public void setTokenEnhancer(TokenEnhancer tokenEnhancer) {
		this.tokenEnhancer = tokenEnhancer;
	}

	@Override
	public OAuth2AccessTokenEntity getRegistrationAccessTokenForClient(ClientDetailsEntity client) {
		List<OAuth2AccessTokenEntity> allTokens = getAccessTokensForClient(client);

		for (OAuth2AccessTokenEntity token : allTokens) {
			if ((token.getScope().contains(SystemScopeService.REGISTRATION_TOKEN_SCOPE) || token.getScope().contains(SystemScopeService.RESOURCE_TOKEN_SCOPE))
					&& token.getScope().size() == 1) {
				// if it only has the registration scope, then it's a registration token
				return token;
			}
		}

		return null;
	}



}