001/******************************************************************************* 002 * Copyright 2017 The MIT Internet Trust Consortium 003 * 004 * Portions copyright 2011-2013 The MITRE Corporation 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); 007 * you may not use this file except in compliance with the License. 008 * You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 *******************************************************************************/ 018package org.mitre.jwt.encryption.service.impl; 019 020import java.security.NoSuchAlgorithmException; 021import java.security.spec.InvalidKeySpecException; 022import java.util.Collection; 023import java.util.HashMap; 024import java.util.HashSet; 025import java.util.Map; 026import java.util.Set; 027 028import javax.annotation.PostConstruct; 029 030import org.mitre.jose.keystore.JWKSetKeyStore; 031import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; 032import org.slf4j.Logger; 033import org.slf4j.LoggerFactory; 034 035import com.google.common.base.Strings; 036import com.nimbusds.jose.EncryptionMethod; 037import com.nimbusds.jose.JOSEException; 038import com.nimbusds.jose.JWEAlgorithm; 039import com.nimbusds.jose.JWEDecrypter; 040import com.nimbusds.jose.JWEEncrypter; 041import com.nimbusds.jose.JWEObject; 042import com.nimbusds.jose.crypto.DirectDecrypter; 043import com.nimbusds.jose.crypto.DirectEncrypter; 044import com.nimbusds.jose.crypto.ECDHDecrypter; 045import com.nimbusds.jose.crypto.ECDHEncrypter; 046import com.nimbusds.jose.crypto.RSADecrypter; 047import com.nimbusds.jose.crypto.RSAEncrypter; 048import com.nimbusds.jose.crypto.bc.BouncyCastleProviderSingleton; 049import com.nimbusds.jose.jwk.ECKey; 050import com.nimbusds.jose.jwk.JWK; 051import com.nimbusds.jose.jwk.OctetSequenceKey; 052import com.nimbusds.jose.jwk.RSAKey; 053 054/** 055 * @author wkim 056 * 057 */ 058public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAndDecryptionService { 059 060 /** 061 * Logger for this class 062 */ 063 private static final Logger logger = LoggerFactory.getLogger(DefaultJWTEncryptionAndDecryptionService.class); 064 065 // map of identifier to encrypter 066 private Map<String, JWEEncrypter> encrypters = new HashMap<>(); 067 068 // map of identifier to decrypter 069 private Map<String, JWEDecrypter> decrypters = new HashMap<>(); 070 071 private String defaultEncryptionKeyId; 072 073 private String defaultDecryptionKeyId; 074 075 private JWEAlgorithm defaultAlgorithm; 076 077 // map of identifier to key 078 private Map<String, JWK> keys = new HashMap<>(); 079 080 /** 081 * Build this service based on the keys given. All public keys will be used to make encrypters, 082 * all private keys will be used to make decrypters. 083 * 084 * @param keys 085 * @throws NoSuchAlgorithmException 086 * @throws InvalidKeySpecException 087 * @throws JOSEException 088 */ 089 public DefaultJWTEncryptionAndDecryptionService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException { 090 this.keys = keys; 091 buildEncryptersAndDecrypters(); 092 } 093 094 /** 095 * Build this service based on the given keystore. All keys must have a key 096 * id ({@code kid}) field in order to be used. 097 * 098 * @param keyStore 099 * @throws NoSuchAlgorithmException 100 * @throws InvalidKeySpecException 101 * @throws JOSEException 102 */ 103 public DefaultJWTEncryptionAndDecryptionService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException { 104 105 // convert all keys in the keystore to a map based on key id 106 for (JWK key : keyStore.getKeys()) { 107 if (!Strings.isNullOrEmpty(key.getKeyID())) { 108 this.keys.put(key.getKeyID(), key); 109 } else { 110 throw new IllegalArgumentException("Tried to load a key from a keystore without a 'kid' field: " + key); 111 } 112 } 113 114 buildEncryptersAndDecrypters(); 115 116 } 117 118 119 @PostConstruct 120 public void afterPropertiesSet() { 121 122 if (keys == null) { 123 throw new IllegalArgumentException("Encryption and decryption service must have at least one key configured."); 124 } 125 try { 126 buildEncryptersAndDecrypters(); 127 } catch (NoSuchAlgorithmException e) { 128 throw new IllegalArgumentException("Encryption and decryption service could not find given algorithm."); 129 } catch (InvalidKeySpecException e) { 130 throw new IllegalArgumentException("Encryption and decryption service saw an invalid key specification."); 131 } catch (JOSEException e) { 132 throw new IllegalArgumentException("Encryption and decryption service was unable to process JOSE object."); 133 } 134 } 135 136 public String getDefaultEncryptionKeyId() { 137 if (defaultEncryptionKeyId != null) { 138 return defaultEncryptionKeyId; 139 } else if (keys.size() == 1) { 140 // if there's only one key in the map, it's the default 141 return keys.keySet().iterator().next(); 142 } else { 143 return null; 144 } 145 } 146 147 public void setDefaultEncryptionKeyId(String defaultEncryptionKeyId) { 148 this.defaultEncryptionKeyId = defaultEncryptionKeyId; 149 } 150 151 public String getDefaultDecryptionKeyId() { 152 if (defaultDecryptionKeyId != null) { 153 return defaultDecryptionKeyId; 154 } else if (keys.size() == 1) { 155 // if there's only one key in the map, it's the default 156 return keys.keySet().iterator().next(); 157 } else { 158 return null; 159 } 160 } 161 162 public void setDefaultDecryptionKeyId(String defaultDecryptionKeyId) { 163 this.defaultDecryptionKeyId = defaultDecryptionKeyId; 164 } 165 166 public JWEAlgorithm getDefaultAlgorithm() { 167 return defaultAlgorithm; 168 } 169 170 public void setDefaultAlgorithm(JWEAlgorithm defaultAlgorithm) { 171 this.defaultAlgorithm = defaultAlgorithm; 172 } 173 174 /* (non-Javadoc) 175 * @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#encryptJwt(com.nimbusds.jwt.EncryptedJWT) 176 */ 177 @Override 178 public void encryptJwt(JWEObject jwt) { 179 if (getDefaultEncryptionKeyId() == null) { 180 throw new IllegalStateException("Tried to call default encryption with no default encrypter ID set"); 181 } 182 183 JWEEncrypter encrypter = encrypters.get(getDefaultEncryptionKeyId()); 184 185 try { 186 jwt.encrypt(encrypter); 187 } catch (JOSEException e) { 188 189 logger.error("Failed to encrypt JWT, error was: ", e); 190 } 191 192 } 193 194 /* (non-Javadoc) 195 * @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#decryptJwt(com.nimbusds.jwt.EncryptedJWT) 196 */ 197 @Override 198 public void decryptJwt(JWEObject jwt) { 199 if (getDefaultDecryptionKeyId() == null) { 200 throw new IllegalStateException("Tried to call default decryption with no default decrypter ID set"); 201 } 202 203 JWEDecrypter decrypter = decrypters.get(getDefaultDecryptionKeyId()); 204 205 try { 206 jwt.decrypt(decrypter); 207 } catch (JOSEException e) { 208 209 logger.error("Failed to decrypt JWT, error was: ", e); 210 } 211 212 } 213 214 /** 215 * Builds all the encrypters and decrypters for this service based on the key map. 216 * @throws 217 * @throws InvalidKeySpecException 218 * @throws NoSuchAlgorithmException 219 * @throws JOSEException 220 */ 221 private void buildEncryptersAndDecrypters() throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException { 222 223 for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) { 224 225 String id = jwkEntry.getKey(); 226 JWK jwk = jwkEntry.getValue(); 227 228 if (jwk instanceof RSAKey) { 229 // build RSA encrypters and decrypters 230 231 RSAEncrypter encrypter = new RSAEncrypter((RSAKey) jwk); // there should always at least be the public key 232 encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance()); 233 encrypters.put(id, encrypter); 234 235 if (jwk.isPrivate()) { // we can decrypt! 236 RSADecrypter decrypter = new RSADecrypter((RSAKey) jwk); 237 decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance()); 238 decrypters.put(id, decrypter); 239 } else { 240 logger.warn("No private key for key #" + jwk.getKeyID()); 241 } 242 } else if (jwk instanceof ECKey) { 243 244 // build EC Encrypters and decrypters 245 246 ECDHEncrypter encrypter = new ECDHEncrypter((ECKey) jwk); 247 encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance()); 248 encrypters.put(id, encrypter); 249 250 if (jwk.isPrivate()) { // we can decrypt too 251 ECDHDecrypter decrypter = new ECDHDecrypter((ECKey) jwk); 252 decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance()); 253 decrypters.put(id, decrypter); 254 } else { 255 logger.warn("No private key for key # " + jwk.getKeyID()); 256 } 257 258 } else if (jwk instanceof OctetSequenceKey) { 259 // build symmetric encrypters and decrypters 260 261 DirectEncrypter encrypter = new DirectEncrypter((OctetSequenceKey) jwk); 262 encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance()); 263 DirectDecrypter decrypter = new DirectDecrypter((OctetSequenceKey) jwk); 264 decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance()); 265 266 encrypters.put(id, encrypter); 267 decrypters.put(id, decrypter); 268 269 } else { 270 logger.warn("Unknown key type: " + jwk); 271 } 272 273 } 274 } 275 276 @Override 277 public Map<String, JWK> getAllPublicKeys() { 278 Map<String, JWK> pubKeys = new HashMap<>(); 279 280 // pull out all public keys 281 for (String keyId : keys.keySet()) { 282 JWK key = keys.get(keyId); 283 JWK pub = key.toPublicJWK(); 284 if (pub != null) { 285 pubKeys.put(keyId, pub); 286 } 287 } 288 289 return pubKeys; 290 } 291 292 @Override 293 public Collection<JWEAlgorithm> getAllEncryptionAlgsSupported() { 294 Set<JWEAlgorithm> algs = new HashSet<>(); 295 296 for (JWEEncrypter encrypter : encrypters.values()) { 297 algs.addAll(encrypter.supportedJWEAlgorithms()); 298 } 299 300 for (JWEDecrypter decrypter : decrypters.values()) { 301 algs.addAll(decrypter.supportedJWEAlgorithms()); 302 } 303 304 return algs; 305 } 306 307 /* (non-Javadoc) 308 * @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#getAllEncryptionEncsSupported() 309 */ 310 @Override 311 public Collection<EncryptionMethod> getAllEncryptionEncsSupported() { 312 Set<EncryptionMethod> encs = new HashSet<>(); 313 314 for (JWEEncrypter encrypter : encrypters.values()) { 315 encs.addAll(encrypter.supportedEncryptionMethods()); 316 } 317 318 for (JWEDecrypter decrypter : decrypters.values()) { 319 encs.addAll(decrypter.supportedEncryptionMethods()); 320 } 321 322 return encs; 323 } 324 325 326}