001/*******************************************************************************
002 * Copyright 2017 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018package org.mitre.openid.connect.client;
019
020import java.io.IOException;
021import java.net.URI;
022import java.net.URISyntaxException;
023import java.util.concurrent.ExecutionException;
024import java.util.concurrent.TimeUnit;
025
026import org.apache.http.client.HttpClient;
027import org.apache.http.client.utils.URIBuilder;
028import org.apache.http.impl.client.HttpClientBuilder;
029import org.mitre.openid.connect.config.ServerConfiguration;
030import org.mitre.openid.connect.config.ServerConfiguration.UserInfoTokenMethod;
031import org.mitre.openid.connect.model.DefaultUserInfo;
032import org.mitre.openid.connect.model.PendingOIDCAuthenticationToken;
033import org.mitre.openid.connect.model.UserInfo;
034import org.slf4j.Logger;
035import org.slf4j.LoggerFactory;
036import org.springframework.http.HttpMethod;
037import org.springframework.http.client.ClientHttpRequest;
038import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
039import org.springframework.util.LinkedMultiValueMap;
040import org.springframework.util.MultiValueMap;
041import org.springframework.web.client.RestTemplate;
042
043import com.google.common.base.Strings;
044import com.google.common.cache.CacheBuilder;
045import com.google.common.cache.CacheLoader;
046import com.google.common.cache.LoadingCache;
047import com.google.common.util.concurrent.UncheckedExecutionException;
048import com.google.gson.JsonObject;
049import com.google.gson.JsonParser;
050
051/**
052 * Utility class to fetch userinfo from the userinfo endpoint, if available. Caches the results.
053 * @author jricher
054 *
055 */
056public class UserInfoFetcher {
057
058        /**
059         * Logger for this class
060         */
061        private static final Logger logger = LoggerFactory.getLogger(UserInfoFetcher.class);
062
063        private LoadingCache<PendingOIDCAuthenticationToken, UserInfo> cache;
064
065        public UserInfoFetcher() {
066                this(HttpClientBuilder.create().useSystemProperties().build());
067        }
068
069        public UserInfoFetcher(HttpClient httpClient) {
070                cache = CacheBuilder.newBuilder()
071                                .expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch
072                                .maximumSize(100)
073                                .build(new UserInfoLoader(httpClient));
074        }
075
076        public UserInfo loadUserInfo(final PendingOIDCAuthenticationToken token) {
077                try {
078                        return cache.get(token);
079                } catch (UncheckedExecutionException | ExecutionException e) {
080                        logger.warn("Couldn't load User Info from token: " + e.getMessage());
081                        return null;
082                }
083
084        }
085
086
087        private class UserInfoLoader extends CacheLoader<PendingOIDCAuthenticationToken, UserInfo> {
088                private HttpComponentsClientHttpRequestFactory factory;
089
090                UserInfoLoader(HttpClient httpClient) {
091                        this.factory = new HttpComponentsClientHttpRequestFactory(httpClient);
092                }
093
094                @Override
095                public UserInfo load(final PendingOIDCAuthenticationToken token) throws URISyntaxException {
096
097                        ServerConfiguration serverConfiguration = token.getServerConfiguration();
098
099                        if (serverConfiguration == null) {
100                                logger.warn("No server configuration found.");
101                                return null;
102                        }
103
104                        if (Strings.isNullOrEmpty(serverConfiguration.getUserInfoUri())) {
105                                logger.warn("No userinfo endpoint, not fetching.");
106                                return null;
107                        }
108
109                        String userInfoString = null;
110
111                        if (serverConfiguration.getUserInfoTokenMethod() == null || serverConfiguration.getUserInfoTokenMethod().equals(UserInfoTokenMethod.HEADER)) {
112                                RestTemplate restTemplate = new RestTemplate(factory) {
113
114                                        @Override
115                                        protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException {
116                                                ClientHttpRequest httpRequest = super.createRequest(url, method);
117                                                httpRequest.getHeaders().add("Authorization", String.format("Bearer %s", token.getAccessTokenValue()));
118                                                return httpRequest;
119                                        }
120                                };
121
122                                userInfoString = restTemplate.getForObject(serverConfiguration.getUserInfoUri(), String.class);
123
124                        } else if (serverConfiguration.getUserInfoTokenMethod().equals(UserInfoTokenMethod.FORM)) {
125                                MultiValueMap<String, String> form = new LinkedMultiValueMap<>();
126                                form.add("access_token", token.getAccessTokenValue());
127
128                                RestTemplate restTemplate = new RestTemplate(factory);
129                                userInfoString = restTemplate.postForObject(serverConfiguration.getUserInfoUri(), form, String.class);
130                        } else if (serverConfiguration.getUserInfoTokenMethod().equals(UserInfoTokenMethod.QUERY)) {
131                                URIBuilder builder = new URIBuilder(serverConfiguration.getUserInfoUri());
132                                builder.setParameter("access_token",  token.getAccessTokenValue());
133
134                                RestTemplate restTemplate = new RestTemplate(factory);
135                                userInfoString = restTemplate.getForObject(builder.toString(), String.class);
136                        }
137
138
139                        if (!Strings.isNullOrEmpty(userInfoString)) {
140
141                                JsonObject userInfoJson = new JsonParser().parse(userInfoString).getAsJsonObject();
142
143                                UserInfo userInfo = fromJson(userInfoJson);
144
145                                return userInfo;
146                        } else {
147                                // didn't get anything throw exception
148                                throw new IllegalArgumentException("Unable to load user info");
149                        }
150
151                }
152        }
153
154        protected UserInfo fromJson(JsonObject userInfoJson) {
155                return DefaultUserInfo.fromJson(userInfoJson);
156        }
157}