001/******************************************************************************* 002 * Copyright 2017 The MIT Internet Trust Consortium 003 * 004 * Portions copyright 2011-2013 The MITRE Corporation 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); 007 * you may not use this file except in compliance with the License. 008 * You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 *******************************************************************************/ 018package org.mitre.oauth2.repository.impl; 019 020import java.text.ParseException; 021import java.util.ArrayList; 022import java.util.Date; 023import java.util.LinkedHashSet; 024import java.util.List; 025import java.util.Set; 026 027import javax.persistence.EntityManager; 028import javax.persistence.PersistenceContext; 029import javax.persistence.Query; 030import javax.persistence.TypedQuery; 031import javax.persistence.criteria.CriteriaBuilder; 032import javax.persistence.criteria.CriteriaDelete; 033import javax.persistence.criteria.Root; 034 035import org.mitre.data.DefaultPageCriteria; 036import org.mitre.data.PageCriteria; 037import org.mitre.oauth2.model.ClientDetailsEntity; 038import org.mitre.oauth2.model.OAuth2AccessTokenEntity; 039import org.mitre.oauth2.model.OAuth2RefreshTokenEntity; 040import org.mitre.oauth2.repository.OAuth2TokenRepository; 041import org.mitre.openid.connect.model.ApprovedSite; 042import org.mitre.uma.model.ResourceSet; 043import org.mitre.util.jpa.JpaUtil; 044import org.slf4j.Logger; 045import org.slf4j.LoggerFactory; 046import org.springframework.stereotype.Repository; 047import org.springframework.transaction.annotation.Transactional; 048 049import com.nimbusds.jwt.JWT; 050import com.nimbusds.jwt.JWTParser; 051 052@Repository 053public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { 054 055 private static final int MAXEXPIREDRESULTS = 1000; 056 057 private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class); 058 059 @PersistenceContext(unitName="defaultPersistenceUnit") 060 private EntityManager manager; 061 062 @Override 063 public Set<OAuth2AccessTokenEntity> getAllAccessTokens() { 064 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, OAuth2AccessTokenEntity.class); 065 return new LinkedHashSet<>(query.getResultList()); 066 } 067 068 @Override 069 public Set<OAuth2RefreshTokenEntity> getAllRefreshTokens() { 070 TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, OAuth2RefreshTokenEntity.class); 071 return new LinkedHashSet<>(query.getResultList()); 072 } 073 074 075 @Override 076 public OAuth2AccessTokenEntity getAccessTokenByValue(String accessTokenValue) { 077 try { 078 JWT jwt = JWTParser.parse(accessTokenValue); 079 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2AccessTokenEntity.class); 080 query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE, jwt); 081 return JpaUtil.getSingleResult(query.getResultList()); 082 } catch (ParseException e) { 083 return null; 084 } 085 } 086 087 @Override 088 public OAuth2AccessTokenEntity getAccessTokenById(Long id) { 089 return manager.find(OAuth2AccessTokenEntity.class, id); 090 } 091 092 @Override 093 @Transactional(value="defaultTransactionManager") 094 public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity token) { 095 return JpaUtil.saveOrUpdate(token.getId(), manager, token); 096 } 097 098 @Override 099 @Transactional(value="defaultTransactionManager") 100 public void removeAccessToken(OAuth2AccessTokenEntity accessToken) { 101 OAuth2AccessTokenEntity found = getAccessTokenById(accessToken.getId()); 102 if (found != null) { 103 manager.remove(found); 104 } else { 105 throw new IllegalArgumentException("Access token not found: " + accessToken); 106 } 107 } 108 109 @Override 110 @Transactional(value="defaultTransactionManager") 111 public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity refreshToken) { 112 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, OAuth2AccessTokenEntity.class); 113 query.setParameter(OAuth2AccessTokenEntity.PARAM_REFERSH_TOKEN, refreshToken); 114 List<OAuth2AccessTokenEntity> accessTokens = query.getResultList(); 115 for (OAuth2AccessTokenEntity accessToken : accessTokens) { 116 removeAccessToken(accessToken); 117 } 118 } 119 120 @Override 121 public OAuth2RefreshTokenEntity getRefreshTokenByValue(String refreshTokenValue) { 122 try { 123 JWT jwt = JWTParser.parse(refreshTokenValue); 124 TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2RefreshTokenEntity.class); 125 query.setParameter(OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE, jwt); 126 return JpaUtil.getSingleResult(query.getResultList()); 127 } catch (ParseException e) { 128 return null; 129 } 130 } 131 132 @Override 133 public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) { 134 return manager.find(OAuth2RefreshTokenEntity.class, id); 135 } 136 137 @Override 138 @Transactional(value="defaultTransactionManager") 139 public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) { 140 return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, refreshToken); 141 } 142 143 @Override 144 @Transactional(value="defaultTransactionManager") 145 public void removeRefreshToken(OAuth2RefreshTokenEntity refreshToken) { 146 OAuth2RefreshTokenEntity found = getRefreshTokenById(refreshToken.getId()); 147 if (found != null) { 148 manager.remove(found); 149 } else { 150 throw new IllegalArgumentException("Refresh token not found: " + refreshToken); 151 } 152 } 153 154 @Override 155 @Transactional(value="defaultTransactionManager") 156 public void clearTokensForClient(ClientDetailsEntity client) { 157 TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); 158 queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); 159 List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList(); 160 for (OAuth2AccessTokenEntity accessToken : accessTokens) { 161 removeAccessToken(accessToken); 162 } 163 TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); 164 queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); 165 List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList(); 166 for (OAuth2RefreshTokenEntity refreshToken : refreshTokens) { 167 removeRefreshToken(refreshToken); 168 } 169 } 170 171 /* (non-Javadoc) 172 * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getAccessTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity) 173 */ 174 @Override 175 public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) { 176 TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); 177 queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); 178 List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList(); 179 return accessTokens; 180 } 181 182 /* (non-Javadoc) 183 * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getRefreshTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity) 184 */ 185 @Override 186 public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) { 187 TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); 188 queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); 189 List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList(); 190 return refreshTokens; 191 } 192 193 @Override 194 public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() { 195 DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); 196 return getAllExpiredAccessTokens(pageCriteria); 197 } 198 199 @Override 200 public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens(PageCriteria pageCriteria) { 201 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2AccessTokenEntity.class); 202 query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); 203 return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria)); 204 } 205 206 @Override 207 public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens() { 208 DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); 209 return getAllExpiredRefreshTokens(pageCriteria); 210 } 211 212 @Override 213 public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens(PageCriteria pageCriteria) { 214 TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2RefreshTokenEntity.class); 215 query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); 216 return new LinkedHashSet<>(JpaUtil.getResultPage(query,pageCriteria)); 217 } 218 219 220 221 /* (non-Javadoc) 222 * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getAccessTokensForResourceSet(org.mitre.uma.model.ResourceSet) 223 */ 224 @Override 225 public Set<OAuth2AccessTokenEntity> getAccessTokensForResourceSet(ResourceSet rs) { 226 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class); 227 query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId()); 228 return new LinkedHashSet<>(query.getResultList()); 229 } 230 231 /* (non-Javadoc) 232 * @see org.mitre.oauth2.repository.OAuth2TokenRepository#clearDuplicateAccessTokens() 233 */ 234 @Override 235 @Transactional(value="defaultTransactionManager") 236 public void clearDuplicateAccessTokens() { 237 238 Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); 239 @SuppressWarnings("unchecked") 240 List<Object[]> resultList = query.getResultList(); 241 List<JWT> values = new ArrayList<>(); 242 for (Object[] r : resultList) { 243 logger.warn("Found duplicate access tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); 244 values.add((JWT) r[0]); 245 } 246 if (values.size() > 0) { 247 CriteriaBuilder cb = manager.getCriteriaBuilder(); 248 CriteriaDelete<OAuth2AccessTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2AccessTokenEntity.class); 249 Root<OAuth2AccessTokenEntity> root = criteriaDelete.from(OAuth2AccessTokenEntity.class); 250 criteriaDelete.where(root.get("jwt").in(values)); 251 int result = manager.createQuery(criteriaDelete).executeUpdate(); 252 logger.warn("Deleted {} duplicate access tokens", result); 253 } 254 } 255 256 /* (non-Javadoc) 257 * @see org.mitre.oauth2.repository.OAuth2TokenRepository#clearDuplicateRefreshTokens() 258 */ 259 @Override 260 @Transactional(value="defaultTransactionManager") 261 public void clearDuplicateRefreshTokens() { 262 Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); 263 @SuppressWarnings("unchecked") 264 List<Object[]> resultList = query.getResultList(); 265 List<JWT> values = new ArrayList<>(); 266 for (Object[] r : resultList) { 267 logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); 268 values.add((JWT) r[0]); 269 } 270 if (values.size() > 0) { 271 CriteriaBuilder cb = manager.getCriteriaBuilder(); 272 CriteriaDelete<OAuth2RefreshTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class); 273 Root<OAuth2RefreshTokenEntity> root = criteriaDelete.from(OAuth2RefreshTokenEntity.class); 274 criteriaDelete.where(root.get("jwt").in(values)); 275 int result = manager.createQuery(criteriaDelete).executeUpdate(); 276 logger.warn("Deleted {} duplicate refresh tokens", result); 277 } 278 279 } 280 281 @Override 282 public List<OAuth2AccessTokenEntity> getAccessTokensForApprovedSite(ApprovedSite approvedSite) { 283 TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, OAuth2AccessTokenEntity.class); 284 queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, approvedSite); 285 List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList(); 286 return accessTokens; 287 } 288 289}