001/*******************************************************************************
002 * Copyright 2017 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018/**
019 *
020 */
021package org.mitre.openid.connect.client.service.impl;
022
023import java.util.HashSet;
024import java.util.Set;
025import java.util.concurrent.ExecutionException;
026
027import org.apache.http.client.HttpClient;
028import org.apache.http.impl.client.HttpClientBuilder;
029import org.mitre.oauth2.model.RegisteredClient;
030import org.mitre.openid.connect.ClientDetailsEntityJsonProcessor;
031import org.mitre.openid.connect.client.service.ClientConfigurationService;
032import org.mitre.openid.connect.client.service.RegisteredClientService;
033import org.mitre.openid.connect.config.ServerConfiguration;
034import org.slf4j.Logger;
035import org.slf4j.LoggerFactory;
036import org.springframework.http.HttpEntity;
037import org.springframework.http.HttpHeaders;
038import org.springframework.http.HttpMethod;
039import org.springframework.http.MediaType;
040import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
041import org.springframework.security.authentication.AuthenticationServiceException;
042import org.springframework.security.oauth2.common.OAuth2AccessToken;
043import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
044import org.springframework.web.client.RestClientException;
045import org.springframework.web.client.RestTemplate;
046
047import com.google.common.cache.CacheBuilder;
048import com.google.common.cache.CacheLoader;
049import com.google.common.cache.LoadingCache;
050import com.google.common.collect.Lists;
051import com.google.common.util.concurrent.UncheckedExecutionException;
052import com.google.gson.Gson;
053import com.google.gson.JsonObject;
054
055/**
056 * @author jricher
057 *
058 */
059public class DynamicRegistrationClientConfigurationService implements ClientConfigurationService {
060
061        /**
062         * Logger for this class
063         */
064        private static final Logger logger = LoggerFactory.getLogger(DynamicRegistrationClientConfigurationService.class);
065
066        private LoadingCache<ServerConfiguration, RegisteredClient> clients;
067
068        private RegisteredClientService registeredClientService = new InMemoryRegisteredClientService();
069
070        private RegisteredClient template;
071
072        private Set<String> whitelist = new HashSet<>();
073        private Set<String> blacklist = new HashSet<>();
074
075        public DynamicRegistrationClientConfigurationService() {
076                this(HttpClientBuilder.create().useSystemProperties().build());
077        }
078
079        public DynamicRegistrationClientConfigurationService(HttpClient httpClient) {
080                clients = CacheBuilder.newBuilder().build(new DynamicClientRegistrationLoader(httpClient));
081        }
082
083        @Override
084        public RegisteredClient getClientConfiguration(ServerConfiguration issuer) {
085                try {
086                        if (!whitelist.isEmpty() && !whitelist.contains(issuer.getIssuer())) {
087                                throw new AuthenticationServiceException("Whitelist was nonempty, issuer was not in whitelist: " + issuer);
088                        }
089
090                        if (blacklist.contains(issuer.getIssuer())) {
091                                throw new AuthenticationServiceException("Issuer was in blacklist: " + issuer);
092                        }
093
094                        return clients.get(issuer);
095                } catch (UncheckedExecutionException | ExecutionException e) {
096                        logger.warn("Unable to get client configuration", e);
097                        return null;
098                }
099        }
100
101        /**
102         * @return the template
103         */
104        public RegisteredClient getTemplate() {
105                return template;
106        }
107
108        /**
109         * @param template the template to set
110         */
111        public void setTemplate(RegisteredClient template) {
112                // make sure the template doesn't have unwanted fields set on it
113                if (template != null) {
114                        template.setClientId(null);
115                        template.setClientSecret(null);
116                        template.setRegistrationClientUri(null);
117                        template.setRegistrationAccessToken(null);
118                }
119                this.template = template;
120        }
121
122        /**
123         * @return the registeredClientService
124         */
125        public RegisteredClientService getRegisteredClientService() {
126                return registeredClientService;
127        }
128
129        /**
130         * @param registeredClientService the registeredClientService to set
131         */
132        public void setRegisteredClientService(RegisteredClientService registeredClientService) {
133                this.registeredClientService = registeredClientService;
134        }
135
136
137        /**
138         * @return the whitelist
139         */
140        public Set<String> getWhitelist() {
141                return whitelist;
142        }
143
144        /**
145         * @param whitelist the whitelist to set
146         */
147        public void setWhitelist(Set<String> whitelist) {
148                this.whitelist = whitelist;
149        }
150
151        /**
152         * @return the blacklist
153         */
154        public Set<String> getBlacklist() {
155                return blacklist;
156        }
157
158        /**
159         * @param blacklist the blacklist to set
160         */
161        public void setBlacklist(Set<String> blacklist) {
162                this.blacklist = blacklist;
163        }
164
165
166        /**
167         * Loader class that fetches the client information.
168         *
169         * If a client has been registered (ie, it's known to the RegisteredClientService), then this
170         * will fetch the client's configuration from the server.
171         *
172         * @author jricher
173         *
174         */
175        public class DynamicClientRegistrationLoader extends CacheLoader<ServerConfiguration, RegisteredClient> {
176                private HttpComponentsClientHttpRequestFactory httpFactory;
177                private Gson gson = new Gson(); // note that this doesn't serialize nulls by default
178
179                public DynamicClientRegistrationLoader() {
180                        this(HttpClientBuilder.create().useSystemProperties().build());
181                }
182
183                public DynamicClientRegistrationLoader(HttpClient httpClient) {
184                        this.httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
185                }
186
187                @Override
188                public RegisteredClient load(ServerConfiguration serverConfig) throws Exception {
189                        RestTemplate restTemplate = new RestTemplate(httpFactory);
190
191
192                        RegisteredClient knownClient = registeredClientService.getByIssuer(serverConfig.getIssuer());
193                        if (knownClient == null) {
194
195                                // dynamically register this client
196                                JsonObject jsonRequest = ClientDetailsEntityJsonProcessor.serialize(template);
197                                String serializedClient = gson.toJson(jsonRequest);
198
199                                HttpHeaders headers = new HttpHeaders();
200                                headers.setContentType(MediaType.APPLICATION_JSON);
201                                headers.setAccept(Lists.newArrayList(MediaType.APPLICATION_JSON));
202
203                                HttpEntity<String> entity = new HttpEntity<>(serializedClient, headers);
204
205                                try {
206                                        String registered = restTemplate.postForObject(serverConfig.getRegistrationEndpointUri(), entity, String.class);
207
208                                        RegisteredClient client = ClientDetailsEntityJsonProcessor.parseRegistered(registered);
209
210                                        // save this client for later
211                                        registeredClientService.save(serverConfig.getIssuer(), client);
212
213                                        return client;
214                                } catch (RestClientException rce) {
215                                        throw new InvalidClientException("Error registering client with server");
216                                }
217                        } else {
218
219                                if (knownClient.getClientId() == null) {
220
221                                        // load this client's information from the server
222                                        HttpHeaders headers = new HttpHeaders();
223                                        headers.set("Authorization", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, knownClient.getRegistrationAccessToken()));
224                                        headers.setAccept(Lists.newArrayList(MediaType.APPLICATION_JSON));
225
226                                        HttpEntity<String> entity = new HttpEntity<>(headers);
227
228                                        try {
229                                                String registered = restTemplate.exchange(knownClient.getRegistrationClientUri(), HttpMethod.GET, entity, String.class).getBody();
230                                                // TODO: handle HTTP errors
231
232                                                RegisteredClient client = ClientDetailsEntityJsonProcessor.parseRegistered(registered);
233
234                                                return client;
235                                        } catch (RestClientException rce) {
236                                                throw new InvalidClientException("Error loading previously registered client information from server");
237                                        }
238                                } else {
239                                        // it's got a client ID from the store, don't bother trying to load it
240                                        return knownClient;
241                                }
242                        }
243                }
244
245        }
246
247}