001/*******************************************************************************
002 * Copyright 2017 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018package org.mitre.oauth2.repository.impl;
019
020import java.text.ParseException;
021import java.util.ArrayList;
022import java.util.Date;
023import java.util.LinkedHashSet;
024import java.util.List;
025import java.util.Set;
026
027import javax.persistence.EntityManager;
028import javax.persistence.PersistenceContext;
029import javax.persistence.Query;
030import javax.persistence.TypedQuery;
031import javax.persistence.criteria.CriteriaBuilder;
032import javax.persistence.criteria.CriteriaDelete;
033import javax.persistence.criteria.Root;
034
035import org.mitre.data.DefaultPageCriteria;
036import org.mitre.data.PageCriteria;
037import org.mitre.oauth2.model.ClientDetailsEntity;
038import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
039import org.mitre.oauth2.model.OAuth2RefreshTokenEntity;
040import org.mitre.oauth2.repository.OAuth2TokenRepository;
041import org.mitre.openid.connect.model.ApprovedSite;
042import org.mitre.uma.model.ResourceSet;
043import org.mitre.util.jpa.JpaUtil;
044import org.slf4j.Logger;
045import org.slf4j.LoggerFactory;
046import org.springframework.stereotype.Repository;
047import org.springframework.transaction.annotation.Transactional;
048
049import com.nimbusds.jwt.JWT;
050import com.nimbusds.jwt.JWTParser;
051
052@Repository
053public class JpaOAuth2TokenRepository implements OAuth2TokenRepository {
054
055        private static final int MAXEXPIREDRESULTS = 1000;
056
057        private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class);
058
059        @PersistenceContext(unitName="defaultPersistenceUnit")
060        private EntityManager manager;
061
062        @Override
063        public Set<OAuth2AccessTokenEntity> getAllAccessTokens() {
064                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, OAuth2AccessTokenEntity.class);
065                return new LinkedHashSet<>(query.getResultList());
066        }
067
068        @Override
069        public Set<OAuth2RefreshTokenEntity> getAllRefreshTokens() {
070                TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, OAuth2RefreshTokenEntity.class);
071                return new LinkedHashSet<>(query.getResultList());
072        }
073
074
075        @Override
076        public OAuth2AccessTokenEntity getAccessTokenByValue(String accessTokenValue) {
077                try {
078                        JWT jwt = JWTParser.parse(accessTokenValue);
079                        TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2AccessTokenEntity.class);
080                        query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE, jwt);
081                        return JpaUtil.getSingleResult(query.getResultList());
082                } catch (ParseException e) {
083                        return null;
084                }
085        }
086
087        @Override
088        public OAuth2AccessTokenEntity getAccessTokenById(Long id) {
089                return manager.find(OAuth2AccessTokenEntity.class, id);
090        }
091
092        @Override
093        @Transactional(value="defaultTransactionManager")
094        public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity token) {
095                return JpaUtil.saveOrUpdate(token.getId(), manager, token);
096        }
097
098        @Override
099        @Transactional(value="defaultTransactionManager")
100        public void removeAccessToken(OAuth2AccessTokenEntity accessToken) {
101                OAuth2AccessTokenEntity found = getAccessTokenById(accessToken.getId());
102                if (found != null) {
103                        manager.remove(found);
104                } else {
105                        throw new IllegalArgumentException("Access token not found: " + accessToken);
106                }
107        }
108
109        @Override
110        @Transactional(value="defaultTransactionManager")
111        public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
112                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, OAuth2AccessTokenEntity.class);
113                query.setParameter(OAuth2AccessTokenEntity.PARAM_REFERSH_TOKEN, refreshToken);
114                List<OAuth2AccessTokenEntity> accessTokens = query.getResultList();
115                for (OAuth2AccessTokenEntity accessToken : accessTokens) {
116                        removeAccessToken(accessToken);
117                }
118        }
119
120        @Override
121        public OAuth2RefreshTokenEntity getRefreshTokenByValue(String refreshTokenValue) {
122                try {
123                        JWT jwt = JWTParser.parse(refreshTokenValue);
124                        TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2RefreshTokenEntity.class);
125                        query.setParameter(OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE, jwt);
126                        return JpaUtil.getSingleResult(query.getResultList());
127                } catch (ParseException e) {
128                        return null;
129                }
130        }
131
132        @Override
133        public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) {
134                return manager.find(OAuth2RefreshTokenEntity.class, id);
135        }
136
137        @Override
138        @Transactional(value="defaultTransactionManager")
139        public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
140                return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, refreshToken);
141        }
142
143        @Override
144        @Transactional(value="defaultTransactionManager")
145        public void removeRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
146                OAuth2RefreshTokenEntity found = getRefreshTokenById(refreshToken.getId());
147                if (found != null) {
148                        manager.remove(found);
149                } else {
150                        throw new IllegalArgumentException("Refresh token not found: " + refreshToken);
151                }
152        }
153
154        @Override
155        @Transactional(value="defaultTransactionManager")
156        public void clearTokensForClient(ClientDetailsEntity client) {
157                TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class);
158                queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client);
159                List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
160                for (OAuth2AccessTokenEntity accessToken : accessTokens) {
161                        removeAccessToken(accessToken);
162                }
163                TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class);
164                queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client);
165                List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList();
166                for (OAuth2RefreshTokenEntity refreshToken : refreshTokens) {
167                        removeRefreshToken(refreshToken);
168                }
169        }
170
171        /* (non-Javadoc)
172         * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getAccessTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
173         */
174        @Override
175        public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) {
176                TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class);
177                queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client);
178                List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
179                return accessTokens;
180        }
181
182        /* (non-Javadoc)
183         * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getRefreshTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
184         */
185        @Override
186        public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) {
187                TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class);
188                queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client);
189                List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList();
190                return refreshTokens;
191        }
192
193        @Override
194        public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() {
195                DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS);
196                return getAllExpiredAccessTokens(pageCriteria);
197        }
198
199        @Override
200        public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens(PageCriteria pageCriteria) {
201                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2AccessTokenEntity.class);
202                query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date());
203                return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria));
204        }
205
206        @Override
207        public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens() {
208                DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS);
209                return getAllExpiredRefreshTokens(pageCriteria);
210        }
211
212        @Override
213        public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens(PageCriteria pageCriteria) {
214                TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2RefreshTokenEntity.class);
215                query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date());
216                return new LinkedHashSet<>(JpaUtil.getResultPage(query,pageCriteria));
217        }
218
219
220
221        /* (non-Javadoc)
222         * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getAccessTokensForResourceSet(org.mitre.uma.model.ResourceSet)
223         */
224        @Override
225        public Set<OAuth2AccessTokenEntity> getAccessTokensForResourceSet(ResourceSet rs) {
226                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class);
227                query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId());
228                return new LinkedHashSet<>(query.getResultList());
229        }
230
231        /* (non-Javadoc)
232         * @see org.mitre.oauth2.repository.OAuth2TokenRepository#clearDuplicateAccessTokens()
233         */
234        @Override
235        @Transactional(value="defaultTransactionManager")
236        public void clearDuplicateAccessTokens() {
237
238                Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1");
239                @SuppressWarnings("unchecked")
240                List<Object[]> resultList = query.getResultList();
241                List<JWT> values = new ArrayList<>();
242                for (Object[] r : resultList) {
243                        logger.warn("Found duplicate access tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]);
244                        values.add((JWT) r[0]);
245                }
246                if (values.size() > 0) {
247                        CriteriaBuilder cb = manager.getCriteriaBuilder();
248                        CriteriaDelete<OAuth2AccessTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2AccessTokenEntity.class);
249                        Root<OAuth2AccessTokenEntity> root = criteriaDelete.from(OAuth2AccessTokenEntity.class);
250                        criteriaDelete.where(root.get("jwt").in(values));
251                        int result = manager.createQuery(criteriaDelete).executeUpdate();
252                        logger.warn("Deleted {} duplicate access tokens", result);
253                }
254        }
255
256        /* (non-Javadoc)
257         * @see org.mitre.oauth2.repository.OAuth2TokenRepository#clearDuplicateRefreshTokens()
258         */
259        @Override
260        @Transactional(value="defaultTransactionManager")
261        public void clearDuplicateRefreshTokens() {
262                Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1");
263                @SuppressWarnings("unchecked")
264                List<Object[]> resultList = query.getResultList();
265                List<JWT> values = new ArrayList<>();
266                for (Object[] r : resultList) {
267                        logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]);
268                        values.add((JWT) r[0]);
269                }
270                if (values.size() > 0) {
271                        CriteriaBuilder cb = manager.getCriteriaBuilder();
272                        CriteriaDelete<OAuth2RefreshTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class);
273                        Root<OAuth2RefreshTokenEntity> root = criteriaDelete.from(OAuth2RefreshTokenEntity.class);
274                        criteriaDelete.where(root.get("jwt").in(values));
275                        int result = manager.createQuery(criteriaDelete).executeUpdate();
276                        logger.warn("Deleted {} duplicate refresh tokens", result);
277                }
278
279        }
280
281        @Override
282        public List<OAuth2AccessTokenEntity> getAccessTokensForApprovedSite(ApprovedSite approvedSite) {
283                TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, OAuth2AccessTokenEntity.class);
284                queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, approvedSite);
285                List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
286                return accessTokens;
287        }
288
289}