DefaultJWTEncryptionAndDecryptionService.java
/*******************************************************************************
* Copyright 2017 The MIT Internet Trust Consortium
*
* Portions copyright 2011-2013 The MITRE Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package org.mitre.jwt.encryption.service.impl;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.annotation.PostConstruct;
import org.mitre.jose.keystore.JWKSetKeyStore;
import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Strings;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEDecrypter;
import com.nimbusds.jose.JWEEncrypter;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.crypto.DirectDecrypter;
import com.nimbusds.jose.crypto.DirectEncrypter;
import com.nimbusds.jose.crypto.ECDHDecrypter;
import com.nimbusds.jose.crypto.ECDHEncrypter;
import com.nimbusds.jose.crypto.RSADecrypter;
import com.nimbusds.jose.crypto.RSAEncrypter;
import com.nimbusds.jose.crypto.bc.BouncyCastleProviderSingleton;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
/**
* @author wkim
*
*/
public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAndDecryptionService {
/**
* Logger for this class
*/
private static final Logger logger = LoggerFactory.getLogger(DefaultJWTEncryptionAndDecryptionService.class);
// map of identifier to encrypter
private Map<String, JWEEncrypter> encrypters = new HashMap<>();
// map of identifier to decrypter
private Map<String, JWEDecrypter> decrypters = new HashMap<>();
private String defaultEncryptionKeyId;
private String defaultDecryptionKeyId;
private JWEAlgorithm defaultAlgorithm;
// map of identifier to key
private Map<String, JWK> keys = new HashMap<>();
/**
* Build this service based on the keys given. All public keys will be used to make encrypters,
* all private keys will be used to make decrypters.
*
* @param keys
* @throws NoSuchAlgorithmException
* @throws InvalidKeySpecException
* @throws JOSEException
*/
public DefaultJWTEncryptionAndDecryptionService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
this.keys = keys;
buildEncryptersAndDecrypters();
}
/**
* Build this service based on the given keystore. All keys must have a key
* id ({@code kid}) field in order to be used.
*
* @param keyStore
* @throws NoSuchAlgorithmException
* @throws InvalidKeySpecException
* @throws JOSEException
*/
public DefaultJWTEncryptionAndDecryptionService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
// convert all keys in the keystore to a map based on key id
for (JWK key : keyStore.getKeys()) {
if (!Strings.isNullOrEmpty(key.getKeyID())) {
this.keys.put(key.getKeyID(), key);
} else {
throw new IllegalArgumentException("Tried to load a key from a keystore without a 'kid' field: " + key);
}
}
buildEncryptersAndDecrypters();
}
@PostConstruct
public void afterPropertiesSet() {
if (keys == null) {
throw new IllegalArgumentException("Encryption and decryption service must have at least one key configured.");
}
try {
buildEncryptersAndDecrypters();
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException("Encryption and decryption service could not find given algorithm.");
} catch (InvalidKeySpecException e) {
throw new IllegalArgumentException("Encryption and decryption service saw an invalid key specification.");
} catch (JOSEException e) {
throw new IllegalArgumentException("Encryption and decryption service was unable to process JOSE object.");
}
}
public String getDefaultEncryptionKeyId() {
if (defaultEncryptionKeyId != null) {
return defaultEncryptionKeyId;
} else if (keys.size() == 1) {
// if there's only one key in the map, it's the default
return keys.keySet().iterator().next();
} else {
return null;
}
}
public void setDefaultEncryptionKeyId(String defaultEncryptionKeyId) {
this.defaultEncryptionKeyId = defaultEncryptionKeyId;
}
public String getDefaultDecryptionKeyId() {
if (defaultDecryptionKeyId != null) {
return defaultDecryptionKeyId;
} else if (keys.size() == 1) {
// if there's only one key in the map, it's the default
return keys.keySet().iterator().next();
} else {
return null;
}
}
public void setDefaultDecryptionKeyId(String defaultDecryptionKeyId) {
this.defaultDecryptionKeyId = defaultDecryptionKeyId;
}
public JWEAlgorithm getDefaultAlgorithm() {
return defaultAlgorithm;
}
public void setDefaultAlgorithm(JWEAlgorithm defaultAlgorithm) {
this.defaultAlgorithm = defaultAlgorithm;
}
/* (non-Javadoc)
* @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#encryptJwt(com.nimbusds.jwt.EncryptedJWT)
*/
@Override
public void encryptJwt(JWEObject jwt) {
if (getDefaultEncryptionKeyId() == null) {
throw new IllegalStateException("Tried to call default encryption with no default encrypter ID set");
}
JWEEncrypter encrypter = encrypters.get(getDefaultEncryptionKeyId());
try {
jwt.encrypt(encrypter);
} catch (JOSEException e) {
logger.error("Failed to encrypt JWT, error was: ", e);
}
}
/* (non-Javadoc)
* @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#decryptJwt(com.nimbusds.jwt.EncryptedJWT)
*/
@Override
public void decryptJwt(JWEObject jwt) {
if (getDefaultDecryptionKeyId() == null) {
throw new IllegalStateException("Tried to call default decryption with no default decrypter ID set");
}
JWEDecrypter decrypter = decrypters.get(getDefaultDecryptionKeyId());
try {
jwt.decrypt(decrypter);
} catch (JOSEException e) {
logger.error("Failed to decrypt JWT, error was: ", e);
}
}
/**
* Builds all the encrypters and decrypters for this service based on the key map.
* @throws
* @throws InvalidKeySpecException
* @throws NoSuchAlgorithmException
* @throws JOSEException
*/
private void buildEncryptersAndDecrypters() throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {
String id = jwkEntry.getKey();
JWK jwk = jwkEntry.getValue();
if (jwk instanceof RSAKey) {
// build RSA encrypters and decrypters
RSAEncrypter encrypter = new RSAEncrypter((RSAKey) jwk); // there should always at least be the public key
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt!
RSADecrypter decrypter = new RSADecrypter((RSAKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key #" + jwk.getKeyID());
}
} else if (jwk instanceof ECKey) {
// build EC Encrypters and decrypters
ECDHEncrypter encrypter = new ECDHEncrypter((ECKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
if (jwk.isPrivate()) { // we can decrypt too
ECDHDecrypter decrypter = new ECDHDecrypter((ECKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
decrypters.put(id, decrypter);
} else {
logger.warn("No private key for key # " + jwk.getKeyID());
}
} else if (jwk instanceof OctetSequenceKey) {
// build symmetric encrypters and decrypters
DirectEncrypter encrypter = new DirectEncrypter((OctetSequenceKey) jwk);
encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
DirectDecrypter decrypter = new DirectDecrypter((OctetSequenceKey) jwk);
decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
encrypters.put(id, encrypter);
decrypters.put(id, decrypter);
} else {
logger.warn("Unknown key type: " + jwk);
}
}
}
@Override
public Map<String, JWK> getAllPublicKeys() {
Map<String, JWK> pubKeys = new HashMap<>();
// pull out all public keys
for (String keyId : keys.keySet()) {
JWK key = keys.get(keyId);
JWK pub = key.toPublicJWK();
if (pub != null) {
pubKeys.put(keyId, pub);
}
}
return pubKeys;
}
@Override
public Collection<JWEAlgorithm> getAllEncryptionAlgsSupported() {
Set<JWEAlgorithm> algs = new HashSet<>();
for (JWEEncrypter encrypter : encrypters.values()) {
algs.addAll(encrypter.supportedJWEAlgorithms());
}
for (JWEDecrypter decrypter : decrypters.values()) {
algs.addAll(decrypter.supportedJWEAlgorithms());
}
return algs;
}
/* (non-Javadoc)
* @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#getAllEncryptionEncsSupported()
*/
@Override
public Collection<EncryptionMethod> getAllEncryptionEncsSupported() {
Set<EncryptionMethod> encs = new HashSet<>();
for (JWEEncrypter encrypter : encrypters.values()) {
encs.addAll(encrypter.supportedEncryptionMethods());
}
for (JWEDecrypter decrypter : decrypters.values()) {
encs.addAll(decrypter.supportedEncryptionMethods());
}
return encs;
}
}