AuthorizationRequestFilter.java
/*******************************************************************************
* Copyright 2017 The MIT Internet Trust Consortium
*
* Portions copyright 2011-2013 The MITRE Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/**
*
*/
package org.mitre.openid.connect.filter;
import static org.mitre.openid.connect.request.ConnectRequestParameters.ERROR;
import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_HINT;
import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_REQUIRED;
import static org.mitre.openid.connect.request.ConnectRequestParameters.MAX_AGE;
import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT;
import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT_LOGIN;
import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT_NONE;
import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT_SEPARATOR;
import static org.mitre.openid.connect.request.ConnectRequestParameters.STATE;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.apache.http.client.utils.URIBuilder;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.openid.connect.service.LoginHintExtracter;
import org.mitre.openid.connect.service.impl.RemoveLoginHintsWithHTTP;
import org.mitre.openid.connect.web.AuthenticationTimeStamper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.provider.AuthorizationRequest;
import org.springframework.security.oauth2.provider.OAuth2RequestFactory;
import org.springframework.security.oauth2.provider.endpoint.RedirectResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.GenericFilterBean;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
/**
* @author jricher
*
*/
@Component("authRequestFilter")
public class AuthorizationRequestFilter extends GenericFilterBean {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(AuthorizationRequestFilter.class);
public final static String PROMPTED = "PROMPT_FILTER_PROMPTED";
public final static String PROMPT_REQUESTED = "PROMPT_FILTER_REQUESTED";
@Autowired
private OAuth2RequestFactory authRequestFactory;
@Autowired
private ClientDetailsEntityService clientService;
@Autowired
private RedirectResolver redirectResolver;
@Autowired(required = false)
private LoginHintExtracter loginHintExtracter = new RemoveLoginHintsWithHTTP();
private RequestMatcher requestMatcher = new AntPathRequestMatcher("/authorize");
/**
*
*/
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
HttpSession session = request.getSession();
// skip everything that's not an authorize URL
if (!requestMatcher.matches(request)) {
chain.doFilter(req, res);
return;
}
try {
// we have to create our own auth request in order to get at all the parmeters appropriately
AuthorizationRequest authRequest = null;
ClientDetailsEntity client = null;
authRequest = authRequestFactory.createAuthorizationRequest(createRequestMap(request.getParameterMap()));
if (!Strings.isNullOrEmpty(authRequest.getClientId())) {
client = clientService.loadClientByClientId(authRequest.getClientId());
}
// save the login hint to the session
// but first check to see if the login hint makes any sense
String loginHint = loginHintExtracter.extractHint((String) authRequest.getExtensions().get(LOGIN_HINT));
if (!Strings.isNullOrEmpty(loginHint)) {
session.setAttribute(LOGIN_HINT, loginHint);
} else {
session.removeAttribute(LOGIN_HINT);
}
if (authRequest.getExtensions().get(PROMPT) != null) {
// we have a "prompt" parameter
String prompt = (String)authRequest.getExtensions().get(PROMPT);
List<String> prompts = Splitter.on(PROMPT_SEPARATOR).splitToList(Strings.nullToEmpty(prompt));
if (prompts.contains(PROMPT_NONE)) {
// see if the user's logged in
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (auth != null) {
// user's been logged in already (by session management)
// we're OK, continue without prompting
chain.doFilter(req, res);
} else {
logger.info("Client requested no prompt");
// user hasn't been logged in, we need to "return an error"
if (client != null && authRequest.getRedirectUri() != null) {
// if we've got a redirect URI then we'll send it
String url = redirectResolver.resolveRedirect(authRequest.getRedirectUri(), client);
try {
URIBuilder uriBuilder = new URIBuilder(url);
uriBuilder.addParameter(ERROR, LOGIN_REQUIRED);
if (!Strings.isNullOrEmpty(authRequest.getState())) {
uriBuilder.addParameter(STATE, authRequest.getState()); // copy the state parameter if one was given
}
response.sendRedirect(uriBuilder.toString());
return;
} catch (URISyntaxException e) {
logger.error("Can't build redirect URI for prompt=none, sending error instead", e);
response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied");
return;
}
}
response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied");
return;
}
} else if (prompts.contains(PROMPT_LOGIN)) {
// first see if the user's already been prompted in this session
if (session.getAttribute(PROMPTED) == null) {
// user hasn't been PROMPTED yet, we need to check
session.setAttribute(PROMPT_REQUESTED, Boolean.TRUE);
// see if the user's logged in
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (auth != null) {
// user's been logged in already (by session management)
// log them out and continue
SecurityContextHolder.getContext().setAuthentication(null);
chain.doFilter(req, res);
} else {
// user hasn't been logged in yet, we can keep going since we'll get there
chain.doFilter(req, res);
}
} else {
// user has been PROMPTED, we're fine
// but first, undo the prompt tag
session.removeAttribute(PROMPTED);
chain.doFilter(req, res);
}
} else {
// prompt parameter is a value we don't care about, not our business
chain.doFilter(req, res);
}
} else if (authRequest.getExtensions().get(MAX_AGE) != null ||
(client != null && client.getDefaultMaxAge() != null)) {
// default to the client's stored value, check the string parameter
Integer max = (client != null ? client.getDefaultMaxAge() : null);
String maxAge = (String) authRequest.getExtensions().get(MAX_AGE);
if (maxAge != null) {
max = Integer.parseInt(maxAge);
}
if (max != null) {
Date authTime = (Date) session.getAttribute(AuthenticationTimeStamper.AUTH_TIMESTAMP);
Date now = new Date();
if (authTime != null) {
long seconds = (now.getTime() - authTime.getTime()) / 1000;
if (seconds > max) {
// session is too old, log the user out and continue
SecurityContextHolder.getContext().setAuthentication(null);
}
}
}
chain.doFilter(req, res);
} else {
// no prompt parameter, not our business
chain.doFilter(req, res);
}
} catch (InvalidClientException e) {
// we couldn't find the client, move on and let the rest of the system catch the error
chain.doFilter(req, res);
}
}
/**
* @param parameterMap
* @return
*/
private Map<String, String> createRequestMap(Map<String, String[]> parameterMap) {
Map<String, String> requestMap = new HashMap<>();
for (String key : parameterMap.keySet()) {
String[] val = parameterMap.get(key);
if (val != null && val.length > 0) {
requestMap.put(key, val[0]); // add the first value only (which is what Spring seems to do)
}
}
return requestMap;
}
/**
* @return the requestMatcher
*/
public RequestMatcher getRequestMatcher() {
return requestMatcher;
}
/**
* @param requestMatcher the requestMatcher to set
*/
public void setRequestMatcher(RequestMatcher requestMatcher) {
this.requestMatcher = requestMatcher;
}
}