001/******************************************************************************* 002 * Copyright 2017 The MIT Internet Trust Consortium 003 * 004 * Portions copyright 2011-2013 The MITRE Corporation 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); 007 * you may not use this file except in compliance with the License. 008 * You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 *******************************************************************************/ 018package org.mitre.openid.connect.client; 019 020import java.io.IOException; 021import java.net.URI; 022import java.net.URISyntaxException; 023import java.util.concurrent.ExecutionException; 024import java.util.concurrent.TimeUnit; 025 026import org.apache.http.client.HttpClient; 027import org.apache.http.client.utils.URIBuilder; 028import org.apache.http.impl.client.HttpClientBuilder; 029import org.mitre.openid.connect.config.ServerConfiguration; 030import org.mitre.openid.connect.config.ServerConfiguration.UserInfoTokenMethod; 031import org.mitre.openid.connect.model.DefaultUserInfo; 032import org.mitre.openid.connect.model.PendingOIDCAuthenticationToken; 033import org.mitre.openid.connect.model.UserInfo; 034import org.slf4j.Logger; 035import org.slf4j.LoggerFactory; 036import org.springframework.http.HttpMethod; 037import org.springframework.http.client.ClientHttpRequest; 038import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; 039import org.springframework.util.LinkedMultiValueMap; 040import org.springframework.util.MultiValueMap; 041import org.springframework.web.client.RestTemplate; 042 043import com.google.common.base.Strings; 044import com.google.common.cache.CacheBuilder; 045import com.google.common.cache.CacheLoader; 046import com.google.common.cache.LoadingCache; 047import com.google.common.util.concurrent.UncheckedExecutionException; 048import com.google.gson.JsonObject; 049import com.google.gson.JsonParser; 050 051/** 052 * Utility class to fetch userinfo from the userinfo endpoint, if available. Caches the results. 053 * @author jricher 054 * 055 */ 056public class UserInfoFetcher { 057 058 /** 059 * Logger for this class 060 */ 061 private static final Logger logger = LoggerFactory.getLogger(UserInfoFetcher.class); 062 063 private LoadingCache<PendingOIDCAuthenticationToken, UserInfo> cache; 064 065 public UserInfoFetcher() { 066 this(HttpClientBuilder.create().useSystemProperties().build()); 067 } 068 069 public UserInfoFetcher(HttpClient httpClient) { 070 cache = CacheBuilder.newBuilder() 071 .expireAfterWrite(1, TimeUnit.HOURS) // expires 1 hour after fetch 072 .maximumSize(100) 073 .build(new UserInfoLoader(httpClient)); 074 } 075 076 public UserInfo loadUserInfo(final PendingOIDCAuthenticationToken token) { 077 try { 078 return cache.get(token); 079 } catch (UncheckedExecutionException | ExecutionException e) { 080 logger.warn("Couldn't load User Info from token: " + e.getMessage()); 081 return null; 082 } 083 084 } 085 086 087 private class UserInfoLoader extends CacheLoader<PendingOIDCAuthenticationToken, UserInfo> { 088 private HttpComponentsClientHttpRequestFactory factory; 089 090 UserInfoLoader(HttpClient httpClient) { 091 this.factory = new HttpComponentsClientHttpRequestFactory(httpClient); 092 } 093 094 @Override 095 public UserInfo load(final PendingOIDCAuthenticationToken token) throws URISyntaxException { 096 097 ServerConfiguration serverConfiguration = token.getServerConfiguration(); 098 099 if (serverConfiguration == null) { 100 logger.warn("No server configuration found."); 101 return null; 102 } 103 104 if (Strings.isNullOrEmpty(serverConfiguration.getUserInfoUri())) { 105 logger.warn("No userinfo endpoint, not fetching."); 106 return null; 107 } 108 109 String userInfoString = null; 110 111 if (serverConfiguration.getUserInfoTokenMethod() == null || serverConfiguration.getUserInfoTokenMethod().equals(UserInfoTokenMethod.HEADER)) { 112 RestTemplate restTemplate = new RestTemplate(factory) { 113 114 @Override 115 protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException { 116 ClientHttpRequest httpRequest = super.createRequest(url, method); 117 httpRequest.getHeaders().add("Authorization", String.format("Bearer %s", token.getAccessTokenValue())); 118 return httpRequest; 119 } 120 }; 121 122 userInfoString = restTemplate.getForObject(serverConfiguration.getUserInfoUri(), String.class); 123 124 } else if (serverConfiguration.getUserInfoTokenMethod().equals(UserInfoTokenMethod.FORM)) { 125 MultiValueMap<String, String> form = new LinkedMultiValueMap<>(); 126 form.add("access_token", token.getAccessTokenValue()); 127 128 RestTemplate restTemplate = new RestTemplate(factory); 129 userInfoString = restTemplate.postForObject(serverConfiguration.getUserInfoUri(), form, String.class); 130 } else if (serverConfiguration.getUserInfoTokenMethod().equals(UserInfoTokenMethod.QUERY)) { 131 URIBuilder builder = new URIBuilder(serverConfiguration.getUserInfoUri()); 132 builder.setParameter("access_token", token.getAccessTokenValue()); 133 134 RestTemplate restTemplate = new RestTemplate(factory); 135 userInfoString = restTemplate.getForObject(builder.toString(), String.class); 136 } 137 138 139 if (!Strings.isNullOrEmpty(userInfoString)) { 140 141 JsonObject userInfoJson = new JsonParser().parse(userInfoString).getAsJsonObject(); 142 143 UserInfo userInfo = fromJson(userInfoJson); 144 145 return userInfo; 146 } else { 147 // didn't get anything throw exception 148 throw new IllegalArgumentException("Unable to load user info"); 149 } 150 151 } 152 } 153 154 protected UserInfo fromJson(JsonObject userInfoJson) { 155 return DefaultUserInfo.fromJson(userInfoJson); 156 } 157}