DeviceEndpoint.java

/*******************************************************************************
 * Copyright 2017 The MIT Internet Trust Consortium
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/

package org.mitre.oauth2.web;

import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

import javax.servlet.http.HttpSession;

import org.mitre.oauth2.exception.DeviceCodeCreationException;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.DeviceCode;
import org.mitre.oauth2.model.SystemScope;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.oauth2.service.DeviceCodeService;
import org.mitre.oauth2.service.SystemScopeService;
import org.mitre.oauth2.token.DeviceTokenGranter;
import org.mitre.openid.connect.config.ConfigurationPropertiesBean;
import org.mitre.openid.connect.view.HttpCodeView;
import org.mitre.openid.connect.view.JsonEntityView;
import org.mitre.openid.connect.view.JsonErrorView;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.util.OAuth2Utils;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.security.oauth2.provider.AuthorizationRequest;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.OAuth2RequestFactory;
import org.springframework.stereotype.Controller;
import org.springframework.ui.ModelMap;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;

import com.google.common.collect.Sets;

/**
 * Implements https://tools.ietf.org/html/draft-ietf-oauth-device-flow
 *
 * @see DeviceTokenGranter
 *
 * @author jricher
 *
 */
@Controller
public class DeviceEndpoint {

	public static final String URL = "devicecode";
	public static final String USER_URL = "device";

	public static final Logger logger = LoggerFactory.getLogger(DeviceEndpoint.class);

	@Autowired
	private ClientDetailsEntityService clientService;

	@Autowired
	private SystemScopeService scopeService;

	@Autowired
	private ConfigurationPropertiesBean config;

	@Autowired
	private DeviceCodeService deviceCodeService;

	@Autowired
	private OAuth2RequestFactory oAuth2RequestFactory;

	@RequestMapping(value = "/" + URL, method = RequestMethod.POST, consumes = MediaType.APPLICATION_FORM_URLENCODED_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
	public String requestDeviceCode(@RequestParam("client_id") String clientId, @RequestParam(name="scope", required=false) String scope, Map<String, String> parameters, ModelMap model) {

		ClientDetailsEntity client;
		try {
			client = clientService.loadClientByClientId(clientId);

			// make sure this client can do the device flow

			Collection<String> authorizedGrantTypes = client.getAuthorizedGrantTypes();
			if (authorizedGrantTypes != null && !authorizedGrantTypes.isEmpty()
					&& !authorizedGrantTypes.contains(DeviceTokenGranter.GRANT_TYPE)) {
				throw new InvalidClientException("Unauthorized grant type: " + DeviceTokenGranter.GRANT_TYPE);
			}

		} catch (IllegalArgumentException e) {
			logger.error("IllegalArgumentException was thrown when attempting to load client", e);
			model.put(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
			return HttpCodeView.VIEWNAME;
		}

		if (client == null) {
			logger.error("could not find client " + clientId);
			model.put(HttpCodeView.CODE, HttpStatus.NOT_FOUND);
			return HttpCodeView.VIEWNAME;
		}

		// make sure the client is allowed to ask for those scopes
		Set<String> requestedScopes = OAuth2Utils.parseParameterList(scope);
		Set<String> allowedScopes = client.getScope();

		if (!scopeService.scopesMatch(allowedScopes, requestedScopes)) {
			// client asked for scopes it can't have
			logger.error("Client asked for " + requestedScopes + " but is allowed " + allowedScopes);
			model.put(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
			model.put(JsonErrorView.ERROR, "invalid_scope");
			return JsonErrorView.VIEWNAME;
		}

		// if we got here the request is legit

		try {
			DeviceCode dc = deviceCodeService.createNewDeviceCode(requestedScopes, client, parameters);
	
			Map<String, Object> response = new HashMap<>();
			response.put("device_code", dc.getDeviceCode());
			response.put("user_code", dc.getUserCode());
			response.put("verification_uri", config.getIssuer() + USER_URL);
			if (client.getDeviceCodeValiditySeconds() != null) {
				response.put("expires_in", client.getDeviceCodeValiditySeconds());
			}
	
			model.put(JsonEntityView.ENTITY, response);
	
	
			return JsonEntityView.VIEWNAME;
		} catch (DeviceCodeCreationException dcce) {
			
			model.put(HttpCodeView.CODE, HttpStatus.BAD_REQUEST);
			model.put(JsonErrorView.ERROR, dcce.getError());
			model.put(JsonErrorView.ERROR_MESSAGE, dcce.getMessage());
			
			return JsonErrorView.VIEWNAME;
		}

	}

	@PreAuthorize("hasRole('ROLE_USER')")
	@RequestMapping(value = "/" + USER_URL, method = RequestMethod.GET)
	public String requestUserCode(ModelMap model) {

		// print out a page that asks the user to enter their user code
		// user must be logged in

		return "requestUserCode";
	}

	@PreAuthorize("hasRole('ROLE_USER')")
	@RequestMapping(value = "/" + USER_URL + "/verify", method = RequestMethod.POST)
	public String readUserCode(@RequestParam("user_code") String userCode, ModelMap model, HttpSession session) {

		// look up the request based on the user code
		DeviceCode dc = deviceCodeService.lookUpByUserCode(userCode);

		// we couldn't find the device code
		if (dc == null) {
			model.addAttribute("error", "noUserCode");
			return "requestUserCode";
		}

		// make sure the code hasn't expired yet
		if (dc.getExpiration() != null && dc.getExpiration().before(new Date())) {
			model.addAttribute("error", "expiredUserCode");
			return "requestUserCode";
		}

		// make sure the device code hasn't already been approved
		if (dc.isApproved()) {
			model.addAttribute("error", "userCodeAlreadyApproved");
			return "requestUserCode";
		}

		ClientDetailsEntity client = clientService.loadClientByClientId(dc.getClientId());

		model.put("client", client);
		model.put("dc", dc);

		// pre-process the scopes
		Set<SystemScope> scopes = scopeService.fromStrings(dc.getScope());

		Set<SystemScope> sortedScopes = new LinkedHashSet<>(scopes.size());
		Set<SystemScope> systemScopes = scopeService.getAll();

		// sort scopes for display based on the inherent order of system scopes
		for (SystemScope s : systemScopes) {
			if (scopes.contains(s)) {
				sortedScopes.add(s);
			}
		}

		// add in any scopes that aren't system scopes to the end of the list
		sortedScopes.addAll(Sets.difference(scopes, systemScopes));

		model.put("scopes", sortedScopes);

		AuthorizationRequest authorizationRequest = oAuth2RequestFactory.createAuthorizationRequest(dc.getRequestParameters());

		session.setAttribute("authorizationRequest", authorizationRequest);
		session.setAttribute("deviceCode", dc);

		return "approveDevice";
	}

	@PreAuthorize("hasRole('ROLE_USER')")
	@RequestMapping(value = "/" + USER_URL + "/approve", method = RequestMethod.POST)
	public String approveDevice(@RequestParam("user_code") String userCode, @RequestParam(value = "user_oauth_approval") Boolean approve, ModelMap model, Authentication auth, HttpSession session) {

		AuthorizationRequest authorizationRequest = (AuthorizationRequest) session.getAttribute("authorizationRequest");
		DeviceCode dc = (DeviceCode) session.getAttribute("deviceCode");

		// make sure the form that was submitted is the one that we were expecting
		if (!dc.getUserCode().equals(userCode)) {
			model.addAttribute("error", "userCodeMismatch");
			return "requestUserCode";
		}

		// make sure the code hasn't expired yet
		if (dc.getExpiration() != null && dc.getExpiration().before(new Date())) {
			model.addAttribute("error", "expiredUserCode");
			return "requestUserCode";
		}

		ClientDetailsEntity client = clientService.loadClientByClientId(dc.getClientId());

		model.put("client", client);

		// user did not approve
		if (!approve) {
			model.addAttribute("approved", false);
			return "deviceApproved";
		}

		// create an OAuth request for storage
		OAuth2Request o2req = oAuth2RequestFactory.createOAuth2Request(authorizationRequest);
		OAuth2Authentication o2Auth = new OAuth2Authentication(o2req, auth);
		
		DeviceCode approvedCode = deviceCodeService.approveDeviceCode(dc, o2Auth);
		
		// pre-process the scopes
		Set<SystemScope> scopes = scopeService.fromStrings(dc.getScope());

		Set<SystemScope> sortedScopes = new LinkedHashSet<>(scopes.size());
		Set<SystemScope> systemScopes = scopeService.getAll();

		// sort scopes for display based on the inherent order of system scopes
		for (SystemScope s : systemScopes) {
			if (scopes.contains(s)) {
				sortedScopes.add(s);
			}
		}

		// add in any scopes that aren't system scopes to the end of the list
		sortedScopes.addAll(Sets.difference(scopes, systemScopes));

		model.put("scopes", sortedScopes);
		model.put("approved", true);

		return "deviceApproved";
	}
}