001/******************************************************************************* 002 * Copyright 2017 The MIT Internet Trust Consortium 003 * 004 * Portions copyright 2011-2013 The MITRE Corporation 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); 007 * you may not use this file except in compliance with the License. 008 * You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 *******************************************************************************/ 018/** 019 * 020 */ 021package org.mitre.openid.connect.client.service.impl; 022 023import java.util.HashSet; 024import java.util.Set; 025import java.util.concurrent.ExecutionException; 026 027import org.apache.http.client.HttpClient; 028import org.apache.http.impl.client.HttpClientBuilder; 029import org.mitre.oauth2.model.RegisteredClient; 030import org.mitre.openid.connect.ClientDetailsEntityJsonProcessor; 031import org.mitre.openid.connect.client.service.ClientConfigurationService; 032import org.mitre.openid.connect.client.service.RegisteredClientService; 033import org.mitre.openid.connect.config.ServerConfiguration; 034import org.slf4j.Logger; 035import org.slf4j.LoggerFactory; 036import org.springframework.http.HttpEntity; 037import org.springframework.http.HttpHeaders; 038import org.springframework.http.HttpMethod; 039import org.springframework.http.MediaType; 040import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; 041import org.springframework.security.authentication.AuthenticationServiceException; 042import org.springframework.security.oauth2.common.OAuth2AccessToken; 043import org.springframework.security.oauth2.common.exceptions.InvalidClientException; 044import org.springframework.web.client.RestClientException; 045import org.springframework.web.client.RestTemplate; 046 047import com.google.common.cache.CacheBuilder; 048import com.google.common.cache.CacheLoader; 049import com.google.common.cache.LoadingCache; 050import com.google.common.collect.Lists; 051import com.google.common.util.concurrent.UncheckedExecutionException; 052import com.google.gson.Gson; 053import com.google.gson.JsonObject; 054 055/** 056 * @author jricher 057 * 058 */ 059public class DynamicRegistrationClientConfigurationService implements ClientConfigurationService { 060 061 /** 062 * Logger for this class 063 */ 064 private static final Logger logger = LoggerFactory.getLogger(DynamicRegistrationClientConfigurationService.class); 065 066 private LoadingCache<ServerConfiguration, RegisteredClient> clients; 067 068 private RegisteredClientService registeredClientService = new InMemoryRegisteredClientService(); 069 070 private RegisteredClient template; 071 072 private Set<String> whitelist = new HashSet<>(); 073 private Set<String> blacklist = new HashSet<>(); 074 075 public DynamicRegistrationClientConfigurationService() { 076 this(HttpClientBuilder.create().useSystemProperties().build()); 077 } 078 079 public DynamicRegistrationClientConfigurationService(HttpClient httpClient) { 080 clients = CacheBuilder.newBuilder().build(new DynamicClientRegistrationLoader(httpClient)); 081 } 082 083 @Override 084 public RegisteredClient getClientConfiguration(ServerConfiguration issuer) { 085 try { 086 if (!whitelist.isEmpty() && !whitelist.contains(issuer.getIssuer())) { 087 throw new AuthenticationServiceException("Whitelist was nonempty, issuer was not in whitelist: " + issuer); 088 } 089 090 if (blacklist.contains(issuer.getIssuer())) { 091 throw new AuthenticationServiceException("Issuer was in blacklist: " + issuer); 092 } 093 094 return clients.get(issuer); 095 } catch (UncheckedExecutionException | ExecutionException e) { 096 logger.warn("Unable to get client configuration", e); 097 return null; 098 } 099 } 100 101 /** 102 * @return the template 103 */ 104 public RegisteredClient getTemplate() { 105 return template; 106 } 107 108 /** 109 * @param template the template to set 110 */ 111 public void setTemplate(RegisteredClient template) { 112 // make sure the template doesn't have unwanted fields set on it 113 if (template != null) { 114 template.setClientId(null); 115 template.setClientSecret(null); 116 template.setRegistrationClientUri(null); 117 template.setRegistrationAccessToken(null); 118 } 119 this.template = template; 120 } 121 122 /** 123 * @return the registeredClientService 124 */ 125 public RegisteredClientService getRegisteredClientService() { 126 return registeredClientService; 127 } 128 129 /** 130 * @param registeredClientService the registeredClientService to set 131 */ 132 public void setRegisteredClientService(RegisteredClientService registeredClientService) { 133 this.registeredClientService = registeredClientService; 134 } 135 136 137 /** 138 * @return the whitelist 139 */ 140 public Set<String> getWhitelist() { 141 return whitelist; 142 } 143 144 /** 145 * @param whitelist the whitelist to set 146 */ 147 public void setWhitelist(Set<String> whitelist) { 148 this.whitelist = whitelist; 149 } 150 151 /** 152 * @return the blacklist 153 */ 154 public Set<String> getBlacklist() { 155 return blacklist; 156 } 157 158 /** 159 * @param blacklist the blacklist to set 160 */ 161 public void setBlacklist(Set<String> blacklist) { 162 this.blacklist = blacklist; 163 } 164 165 166 /** 167 * Loader class that fetches the client information. 168 * 169 * If a client has been registered (ie, it's known to the RegisteredClientService), then this 170 * will fetch the client's configuration from the server. 171 * 172 * @author jricher 173 * 174 */ 175 public class DynamicClientRegistrationLoader extends CacheLoader<ServerConfiguration, RegisteredClient> { 176 private HttpComponentsClientHttpRequestFactory httpFactory; 177 private Gson gson = new Gson(); // note that this doesn't serialize nulls by default 178 179 public DynamicClientRegistrationLoader() { 180 this(HttpClientBuilder.create().useSystemProperties().build()); 181 } 182 183 public DynamicClientRegistrationLoader(HttpClient httpClient) { 184 this.httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient); 185 } 186 187 @Override 188 public RegisteredClient load(ServerConfiguration serverConfig) throws Exception { 189 RestTemplate restTemplate = new RestTemplate(httpFactory); 190 191 192 RegisteredClient knownClient = registeredClientService.getByIssuer(serverConfig.getIssuer()); 193 if (knownClient == null) { 194 195 // dynamically register this client 196 JsonObject jsonRequest = ClientDetailsEntityJsonProcessor.serialize(template); 197 String serializedClient = gson.toJson(jsonRequest); 198 199 HttpHeaders headers = new HttpHeaders(); 200 headers.setContentType(MediaType.APPLICATION_JSON); 201 headers.setAccept(Lists.newArrayList(MediaType.APPLICATION_JSON)); 202 203 HttpEntity<String> entity = new HttpEntity<>(serializedClient, headers); 204 205 try { 206 String registered = restTemplate.postForObject(serverConfig.getRegistrationEndpointUri(), entity, String.class); 207 208 RegisteredClient client = ClientDetailsEntityJsonProcessor.parseRegistered(registered); 209 210 // save this client for later 211 registeredClientService.save(serverConfig.getIssuer(), client); 212 213 return client; 214 } catch (RestClientException rce) { 215 throw new InvalidClientException("Error registering client with server"); 216 } 217 } else { 218 219 if (knownClient.getClientId() == null) { 220 221 // load this client's information from the server 222 HttpHeaders headers = new HttpHeaders(); 223 headers.set("Authorization", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, knownClient.getRegistrationAccessToken())); 224 headers.setAccept(Lists.newArrayList(MediaType.APPLICATION_JSON)); 225 226 HttpEntity<String> entity = new HttpEntity<>(headers); 227 228 try { 229 String registered = restTemplate.exchange(knownClient.getRegistrationClientUri(), HttpMethod.GET, entity, String.class).getBody(); 230 // TODO: handle HTTP errors 231 232 RegisteredClient client = ClientDetailsEntityJsonProcessor.parseRegistered(registered); 233 234 return client; 235 } catch (RestClientException rce) { 236 throw new InvalidClientException("Error loading previously registered client information from server"); 237 } 238 } else { 239 // it's got a client ID from the store, don't bother trying to load it 240 return knownClient; 241 } 242 } 243 } 244 245 } 246 247}