WebfingerIssuerService.java
/*******************************************************************************
* Copyright 2017 The MIT Internet Trust Consortium
*
* Portions copyright 2011-2013 The MITRE Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/**
*
*/
package org.mitre.openid.connect.client.service.impl;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import javax.servlet.http.HttpServletRequest;
import org.apache.http.client.HttpClient;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.impl.client.HttpClientBuilder;
import org.mitre.discovery.util.WebfingerURLNormalizer;
import org.mitre.openid.connect.client.model.IssuerServiceResponse;
import org.mitre.openid.connect.client.service.IssuerService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponents;
import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.util.concurrent.UncheckedExecutionException;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonParser;
/**
* Use Webfinger to discover the appropriate issuer for a user-given input string.
* @author jricher
*
*/
public class WebfingerIssuerService implements IssuerService {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(WebfingerIssuerService.class);
// map of user input -> issuer, loaded dynamically from webfinger discover
private LoadingCache<String, LoadingResult> issuers;
// private data shuttle class to get back two bits of info from the cache loader
private class LoadingResult {
public String loginHint;
public String issuer;
public LoadingResult(String loginHint, String issuer) {
this.loginHint = loginHint;
this.issuer = issuer;
}
}
private Set<String> whitelist = new HashSet<>();
private Set<String> blacklist = new HashSet<>();
/**
* Name of the incoming parameter to check for discovery purposes.
*/
private String parameterName = "identifier";
/**
* URL of the page to forward to if no identifier is given.
*/
private String loginPageUrl;
/**
* Strict enfocement of "https"
*/
private boolean forceHttps = true;
public WebfingerIssuerService() {
this(HttpClientBuilder.create().useSystemProperties().build());
}
public WebfingerIssuerService(HttpClient httpClient) {
issuers = CacheBuilder.newBuilder().build(new WebfingerIssuerFetcher(httpClient));
}
/* (non-Javadoc)
* @see org.mitre.openid.connect.client.service.IssuerService#getIssuer(javax.servlet.http.HttpServletRequest)
*/
@Override
public IssuerServiceResponse getIssuer(HttpServletRequest request) {
String identifier = request.getParameter(parameterName);
if (!Strings.isNullOrEmpty(identifier)) {
try {
LoadingResult lr = issuers.get(identifier);
if (!whitelist.isEmpty() && !whitelist.contains(lr.issuer)) {
throw new AuthenticationServiceException("Whitelist was nonempty, issuer was not in whitelist: " + lr.issuer);
}
if (blacklist.contains(lr.issuer)) {
throw new AuthenticationServiceException("Issuer was in blacklist: " + lr.issuer);
}
return new IssuerServiceResponse(lr.issuer, lr.loginHint, request.getParameter("target_link_uri"));
} catch (UncheckedExecutionException | ExecutionException e) {
logger.warn("Issue fetching issuer for user input: " + identifier + ": " + e.getMessage());
return null;
}
} else {
logger.warn("No user input given, directing to login page: " + loginPageUrl);
return new IssuerServiceResponse(loginPageUrl);
}
}
/**
* @return the parameterName
*/
public String getParameterName() {
return parameterName;
}
/**
* @param parameterName the parameterName to set
*/
public void setParameterName(String parameterName) {
this.parameterName = parameterName;
}
/**
* @return the loginPageUrl
*/
public String getLoginPageUrl() {
return loginPageUrl;
}
/**
* @param loginPageUrl the loginPageUrl to set
*/
public void setLoginPageUrl(String loginPageUrl) {
this.loginPageUrl = loginPageUrl;
}
/**
* @return the whitelist
*/
public Set<String> getWhitelist() {
return whitelist;
}
/**
* @param whitelist the whitelist to set
*/
public void setWhitelist(Set<String> whitelist) {
this.whitelist = whitelist;
}
/**
* @return the blacklist
*/
public Set<String> getBlacklist() {
return blacklist;
}
/**
* @param blacklist the blacklist to set
*/
public void setBlacklist(Set<String> blacklist) {
this.blacklist = blacklist;
}
/**
* @return the forceHttps
*/
public boolean isForceHttps() {
return forceHttps;
}
/**
* @param forceHttps the forceHttps to set
*/
public void setForceHttps(boolean forceHttps) {
this.forceHttps = forceHttps;
}
/**
* @author jricher
*
*/
private class WebfingerIssuerFetcher extends CacheLoader<String, LoadingResult> {
private HttpComponentsClientHttpRequestFactory httpFactory;
private JsonParser parser = new JsonParser();
WebfingerIssuerFetcher(HttpClient httpClient) {
this.httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
}
@Override
public LoadingResult load(String identifier) throws Exception {
UriComponents key = WebfingerURLNormalizer.normalizeResource(identifier);
RestTemplate restTemplate = new RestTemplate(httpFactory);
// construct the URL to go to
String scheme = key.getScheme();
// preserving http scheme is strictly for demo system use only.
if (!Strings.isNullOrEmpty(scheme) &&scheme.equals("http")) {
if (forceHttps) {
throw new IllegalArgumentException("Scheme must not be 'http'");
} else {
logger.warn("Webfinger endpoint MUST use the https URI scheme, overriding by configuration");
scheme = "http://"; // add on colon and slashes.
}
} else {
// otherwise we don't know the scheme, assume HTTPS
scheme = "https://";
}
// do a webfinger lookup
URIBuilder builder = new URIBuilder(scheme
+ key.getHost()
+ (key.getPort() >= 0 ? ":" + key.getPort() : "")
+ Strings.nullToEmpty(key.getPath())
+ "/.well-known/webfinger"
+ (Strings.isNullOrEmpty(key.getQuery()) ? "" : "?" + key.getQuery())
);
builder.addParameter("resource", identifier);
builder.addParameter("rel", "http://openid.net/specs/connect/1.0/issuer");
try {
// do the fetch
logger.info("Loading: " + builder.toString());
String webfingerResponse = restTemplate.getForObject(builder.build(), String.class);
JsonElement json = parser.parse(webfingerResponse);
if (json != null && json.isJsonObject()) {
// find the issuer
JsonArray links = json.getAsJsonObject().get("links").getAsJsonArray();
for (JsonElement link : links) {
if (link.isJsonObject()) {
JsonObject linkObj = link.getAsJsonObject();
if (linkObj.has("href")
&& linkObj.has("rel")
&& linkObj.get("rel").getAsString().equals("http://openid.net/specs/connect/1.0/issuer")) {
// we found the issuer, return it
String href = linkObj.get("href").getAsString();
if (identifier.equals(href)
|| identifier.startsWith("http")) {
// try to avoid sending a URL as the login hint
return new LoadingResult(null, href);
} else {
// otherwise pass back whatever the user typed as a login hint
return new LoadingResult(identifier, href);
}
}
}
}
}
} catch (JsonParseException | RestClientException e) {
logger.warn("Failure in fetching webfinger input", e.getMessage());
}
// we couldn't find it!
if (key.getScheme().equals("http") || key.getScheme().equals("https")) {
// if it looks like HTTP then punt: return the input, hope for the best
logger.warn("Returning normalized input string as issuer, hoping for the best: " + identifier);
return new LoadingResult(null, identifier);
} else {
// if it's not HTTP, give up
logger.warn("Couldn't find issuer: " + identifier);
throw new IllegalArgumentException();
}
}
}
}