UserInfoJWTView.java

  1. /*******************************************************************************
  2.  * Copyright 2017 The MIT Internet Trust Consortium
  3.  *
  4.  * Licensed under the Apache License, Version 2.0 (the "License");
  5.  * you may not use this file except in compliance with the License.
  6.  * You may obtain a copy of the License at
  7.  *
  8.  *   http://www.apache.org/licenses/LICENSE-2.0
  9.  *
  10.  * Unless required by applicable law or agreed to in writing, software
  11.  * distributed under the License is distributed on an "AS IS" BASIS,
  12.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13.  * See the License for the specific language governing permissions and
  14.  * limitations under the License.
  15.  *******************************************************************************/
  16. /**
  17.  *
  18.  */
  19. package org.mitre.openid.connect.view;

  20. import java.io.IOException;
  21. import java.io.StringWriter;
  22. import java.io.Writer;
  23. import java.text.ParseException;
  24. import java.util.Date;
  25. import java.util.Map;
  26. import java.util.UUID;

  27. import javax.servlet.http.HttpServletRequest;
  28. import javax.servlet.http.HttpServletResponse;

  29. import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
  30. import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
  31. import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
  32. import org.mitre.jwt.signer.service.impl.SymmetricKeyJWTValidatorCacheService;
  33. import org.mitre.oauth2.model.ClientDetailsEntity;
  34. import org.mitre.openid.connect.config.ConfigurationPropertiesBean;
  35. import org.slf4j.Logger;
  36. import org.slf4j.LoggerFactory;
  37. import org.springframework.beans.factory.annotation.Autowired;
  38. import org.springframework.http.MediaType;
  39. import org.springframework.stereotype.Component;

  40. import com.google.common.base.Strings;
  41. import com.google.common.collect.Lists;
  42. import com.google.gson.JsonObject;
  43. import com.nimbusds.jose.Algorithm;
  44. import com.nimbusds.jose.JWEHeader;
  45. import com.nimbusds.jose.JWSAlgorithm;
  46. import com.nimbusds.jose.JWSHeader;
  47. import com.nimbusds.jwt.EncryptedJWT;
  48. import com.nimbusds.jwt.JWTClaimsSet;
  49. import com.nimbusds.jwt.SignedJWT;

  50. /**
  51.  * @author jricher
  52.  *
  53.  */
  54. @Component(UserInfoJWTView.VIEWNAME)
  55. public class UserInfoJWTView extends UserInfoView {

  56.     public static final String CLIENT = "client";

  57.     /**
  58.      * Logger for this class
  59.      */
  60.     private static final Logger logger = LoggerFactory.getLogger(UserInfoJWTView.class);

  61.     public static final String VIEWNAME = "userInfoJwtView";

  62.     public static final String JOSE_MEDIA_TYPE_VALUE = "application/jwt";
  63.     public static final MediaType JOSE_MEDIA_TYPE = new MediaType("application", "jwt");


  64.     @Autowired
  65.     private JWTSigningAndValidationService jwtService;

  66.     @Autowired
  67.     private ConfigurationPropertiesBean config;

  68.     @Autowired
  69.     private ClientKeyCacheService encrypters;

  70.     @Autowired
  71.     private SymmetricKeyJWTValidatorCacheService symmetricCacheService;

  72.     @Override
  73.     protected void writeOut(JsonObject json, Map<String, Object> model,
  74.             HttpServletRequest request, HttpServletResponse response) {

  75.         try {
  76.             ClientDetailsEntity client = (ClientDetailsEntity)model.get(CLIENT);

  77.             // use the parser to import the user claims into the object
  78.             StringWriter writer = new StringWriter();
  79.             gson.toJson(json, writer);

  80.             response.setContentType(JOSE_MEDIA_TYPE_VALUE);

  81.             JWTClaimsSet claims = new JWTClaimsSet.Builder(JWTClaimsSet.parse(writer.toString()))
  82.                     .audience(Lists.newArrayList(client.getClientId()))
  83.                     .issuer(config.getIssuer())
  84.                     .issueTime(new Date())
  85.                     .jwtID(UUID.randomUUID().toString()) // set a random NONCE in the middle of it
  86.                     .build();


  87.             if (client.getUserInfoEncryptedResponseAlg() != null && !client.getUserInfoEncryptedResponseAlg().equals(Algorithm.NONE)
  88.                     && client.getUserInfoEncryptedResponseEnc() != null && !client.getUserInfoEncryptedResponseEnc().equals(Algorithm.NONE)
  89.                     && (!Strings.isNullOrEmpty(client.getJwksUri()) || client.getJwks() != null)) {

  90.                 // encrypt it to the client's key

  91.                 JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client);

  92.                 if (encrypter != null) {

  93.                     EncryptedJWT encrypted = new EncryptedJWT(new JWEHeader(client.getUserInfoEncryptedResponseAlg(), client.getUserInfoEncryptedResponseEnc()), claims);

  94.                     encrypter.encryptJwt(encrypted);


  95.                     Writer out = response.getWriter();
  96.                     out.write(encrypted.serialize());

  97.                 } else {
  98.                     logger.error("Couldn't find encrypter for client: " + client.getClientId());
  99.                 }
  100.             } else {

  101.                 JWSAlgorithm signingAlg = jwtService.getDefaultSigningAlgorithm(); // default to the server's preference
  102.                 if (client.getUserInfoSignedResponseAlg() != null) {
  103.                     signingAlg = client.getUserInfoSignedResponseAlg(); // override with the client's preference if available
  104.                 }
  105.                 JWSHeader header = new JWSHeader(signingAlg, null, null, null, null, null, null, null, null, null,
  106.                         jwtService.getDefaultSignerKeyId(),
  107.                         null, null);
  108.                 SignedJWT signed = new SignedJWT(header, claims);

  109.                 if (signingAlg.equals(JWSAlgorithm.HS256)
  110.                         || signingAlg.equals(JWSAlgorithm.HS384)
  111.                         || signingAlg.equals(JWSAlgorithm.HS512)) {

  112.                     // sign it with the client's secret
  113.                     JWTSigningAndValidationService signer = symmetricCacheService.getSymmetricValidtor(client);
  114.                     signer.signJwt(signed);

  115.                 } else {
  116.                     // sign it with the server's key
  117.                     jwtService.signJwt(signed);
  118.                 }

  119.                 Writer out = response.getWriter();
  120.                 out.write(signed.serialize());
  121.             }
  122.         } catch (IOException e) {
  123.             logger.error("IO Exception in UserInfoJwtView", e);
  124.         } catch (ParseException e) {
  125.             // TODO Auto-generated catch block
  126.             e.printStackTrace();
  127.         }

  128.     }
  129. }