001/*******************************************************************************
002 * Copyright 2017 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018package org.mitre.jwt.encryption.service.impl;
019
020import java.security.NoSuchAlgorithmException;
021import java.security.spec.InvalidKeySpecException;
022import java.util.Collection;
023import java.util.HashMap;
024import java.util.HashSet;
025import java.util.Map;
026import java.util.Set;
027
028import javax.annotation.PostConstruct;
029
030import org.mitre.jose.keystore.JWKSetKeyStore;
031import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
032import org.slf4j.Logger;
033import org.slf4j.LoggerFactory;
034
035import com.google.common.base.Strings;
036import com.nimbusds.jose.EncryptionMethod;
037import com.nimbusds.jose.JOSEException;
038import com.nimbusds.jose.JWEAlgorithm;
039import com.nimbusds.jose.JWEDecrypter;
040import com.nimbusds.jose.JWEEncrypter;
041import com.nimbusds.jose.JWEObject;
042import com.nimbusds.jose.crypto.DirectDecrypter;
043import com.nimbusds.jose.crypto.DirectEncrypter;
044import com.nimbusds.jose.crypto.ECDHDecrypter;
045import com.nimbusds.jose.crypto.ECDHEncrypter;
046import com.nimbusds.jose.crypto.RSADecrypter;
047import com.nimbusds.jose.crypto.RSAEncrypter;
048import com.nimbusds.jose.crypto.bc.BouncyCastleProviderSingleton;
049import com.nimbusds.jose.jwk.ECKey;
050import com.nimbusds.jose.jwk.JWK;
051import com.nimbusds.jose.jwk.OctetSequenceKey;
052import com.nimbusds.jose.jwk.RSAKey;
053
054/**
055 * @author wkim
056 *
057 */
058public class DefaultJWTEncryptionAndDecryptionService implements JWTEncryptionAndDecryptionService {
059
060        /**
061         * Logger for this class
062         */
063        private static final Logger logger = LoggerFactory.getLogger(DefaultJWTEncryptionAndDecryptionService.class);
064
065        // map of identifier to encrypter
066        private Map<String, JWEEncrypter> encrypters = new HashMap<>();
067
068        // map of identifier to decrypter
069        private Map<String, JWEDecrypter> decrypters = new HashMap<>();
070
071        private String defaultEncryptionKeyId;
072
073        private String defaultDecryptionKeyId;
074
075        private JWEAlgorithm defaultAlgorithm;
076
077        // map of identifier to key
078        private Map<String, JWK> keys = new HashMap<>();
079
080        /**
081         * Build this service based on the keys given. All public keys will be used to make encrypters,
082         * all private keys will be used to make decrypters.
083         *
084         * @param keys
085         * @throws NoSuchAlgorithmException
086         * @throws InvalidKeySpecException
087         * @throws JOSEException
088         */
089        public DefaultJWTEncryptionAndDecryptionService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
090                this.keys = keys;
091                buildEncryptersAndDecrypters();
092        }
093
094        /**
095         * Build this service based on the given keystore. All keys must have a key
096         * id ({@code kid}) field in order to be used.
097         *
098         * @param keyStore
099         * @throws NoSuchAlgorithmException
100         * @throws InvalidKeySpecException
101         * @throws JOSEException
102         */
103        public DefaultJWTEncryptionAndDecryptionService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
104
105                // convert all keys in the keystore to a map based on key id
106                for (JWK key : keyStore.getKeys()) {
107                        if (!Strings.isNullOrEmpty(key.getKeyID())) {
108                                this.keys.put(key.getKeyID(), key);
109                        } else {
110                                throw new IllegalArgumentException("Tried to load a key from a keystore without a 'kid' field: " + key);
111                        }
112                }
113
114                buildEncryptersAndDecrypters();
115
116        }
117
118
119        @PostConstruct
120        public void afterPropertiesSet() {
121
122                if (keys == null) {
123                        throw new IllegalArgumentException("Encryption and decryption service must have at least one key configured.");
124                }
125                try {
126                        buildEncryptersAndDecrypters();
127                } catch (NoSuchAlgorithmException e) {
128                        throw new IllegalArgumentException("Encryption and decryption service could not find given algorithm.");
129                } catch (InvalidKeySpecException e) {
130                        throw new IllegalArgumentException("Encryption and decryption service saw an invalid key specification.");
131                } catch (JOSEException e) {
132                        throw new IllegalArgumentException("Encryption and decryption service was unable to process JOSE object.");
133                }
134        }
135
136        public String getDefaultEncryptionKeyId() {
137                if (defaultEncryptionKeyId != null) {
138                        return defaultEncryptionKeyId;
139                } else if (keys.size() == 1) {
140                        // if there's only one key in the map, it's the default
141                        return keys.keySet().iterator().next();
142                } else {
143                        return null;
144                }
145        }
146
147        public void setDefaultEncryptionKeyId(String defaultEncryptionKeyId) {
148                this.defaultEncryptionKeyId = defaultEncryptionKeyId;
149        }
150
151        public String getDefaultDecryptionKeyId() {
152                if (defaultDecryptionKeyId != null) {
153                        return defaultDecryptionKeyId;
154                } else if (keys.size() == 1) {
155                        // if there's only one key in the map, it's the default
156                        return keys.keySet().iterator().next();
157                } else {
158                        return null;
159                }
160        }
161
162        public void setDefaultDecryptionKeyId(String defaultDecryptionKeyId) {
163                this.defaultDecryptionKeyId = defaultDecryptionKeyId;
164        }
165
166        public JWEAlgorithm getDefaultAlgorithm() {
167                return defaultAlgorithm;
168        }
169
170        public void setDefaultAlgorithm(JWEAlgorithm defaultAlgorithm) {
171                this.defaultAlgorithm = defaultAlgorithm;
172        }
173
174        /* (non-Javadoc)
175         * @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#encryptJwt(com.nimbusds.jwt.EncryptedJWT)
176         */
177        @Override
178        public void encryptJwt(JWEObject jwt) {
179                if (getDefaultEncryptionKeyId() == null) {
180                        throw new IllegalStateException("Tried to call default encryption with no default encrypter ID set");
181                }
182
183                JWEEncrypter encrypter = encrypters.get(getDefaultEncryptionKeyId());
184
185                try {
186                        jwt.encrypt(encrypter);
187                } catch (JOSEException e) {
188
189                        logger.error("Failed to encrypt JWT, error was: ", e);
190                }
191
192        }
193
194        /* (non-Javadoc)
195         * @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#decryptJwt(com.nimbusds.jwt.EncryptedJWT)
196         */
197        @Override
198        public void decryptJwt(JWEObject jwt) {
199                if (getDefaultDecryptionKeyId() == null) {
200                        throw new IllegalStateException("Tried to call default decryption with no default decrypter ID set");
201                }
202
203                JWEDecrypter decrypter = decrypters.get(getDefaultDecryptionKeyId());
204
205                try {
206                        jwt.decrypt(decrypter);
207                } catch (JOSEException e) {
208
209                        logger.error("Failed to decrypt JWT, error was: ", e);
210                }
211
212        }
213
214        /**
215         * Builds all the encrypters and decrypters for this service based on the key map.
216         * @throws
217         * @throws InvalidKeySpecException
218         * @throws NoSuchAlgorithmException
219         * @throws JOSEException
220         */
221        private void buildEncryptersAndDecrypters() throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
222
223                for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {
224
225                        String id = jwkEntry.getKey();
226                        JWK jwk = jwkEntry.getValue();
227
228                        if (jwk instanceof RSAKey) {
229                                // build RSA encrypters and decrypters
230
231                                RSAEncrypter encrypter = new RSAEncrypter((RSAKey) jwk); // there should always at least be the public key
232                                encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
233                                encrypters.put(id, encrypter);
234
235                                if (jwk.isPrivate()) { // we can decrypt!
236                                        RSADecrypter decrypter = new RSADecrypter((RSAKey) jwk);
237                                        decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
238                                        decrypters.put(id, decrypter);
239                                } else {
240                                        logger.warn("No private key for key #" + jwk.getKeyID());
241                                }
242                        } else if (jwk instanceof ECKey) {
243
244                                // build EC Encrypters and decrypters
245
246                                ECDHEncrypter encrypter = new ECDHEncrypter((ECKey) jwk);
247                                encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
248                                encrypters.put(id, encrypter);
249
250                                if (jwk.isPrivate()) { // we can decrypt too
251                                        ECDHDecrypter decrypter = new ECDHDecrypter((ECKey) jwk);
252                                        decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
253                                        decrypters.put(id, decrypter);
254                                } else {
255                                        logger.warn("No private key for key # " + jwk.getKeyID());
256                                }
257
258                        } else if (jwk instanceof OctetSequenceKey) {
259                                // build symmetric encrypters and decrypters
260
261                                DirectEncrypter encrypter = new DirectEncrypter((OctetSequenceKey) jwk);
262                                encrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
263                                DirectDecrypter decrypter = new DirectDecrypter((OctetSequenceKey) jwk);
264                                decrypter.getJCAContext().setProvider(BouncyCastleProviderSingleton.getInstance());
265
266                                encrypters.put(id, encrypter);
267                                decrypters.put(id, decrypter);
268
269                        } else {
270                                logger.warn("Unknown key type: " + jwk);
271                        }
272
273                }
274        }
275
276        @Override
277        public Map<String, JWK> getAllPublicKeys() {
278                Map<String, JWK> pubKeys = new HashMap<>();
279
280                // pull out all public keys
281                for (String keyId : keys.keySet()) {
282                        JWK key = keys.get(keyId);
283                        JWK pub = key.toPublicJWK();
284                        if (pub != null) {
285                                pubKeys.put(keyId, pub);
286                        }
287                }
288
289                return pubKeys;
290        }
291
292        @Override
293        public Collection<JWEAlgorithm> getAllEncryptionAlgsSupported() {
294                Set<JWEAlgorithm> algs = new HashSet<>();
295
296                for (JWEEncrypter encrypter : encrypters.values()) {
297                        algs.addAll(encrypter.supportedJWEAlgorithms());
298                }
299
300                for (JWEDecrypter decrypter : decrypters.values()) {
301                        algs.addAll(decrypter.supportedJWEAlgorithms());
302                }
303
304                return algs;
305        }
306
307        /* (non-Javadoc)
308         * @see org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService#getAllEncryptionEncsSupported()
309         */
310        @Override
311        public Collection<EncryptionMethod> getAllEncryptionEncsSupported() {
312                Set<EncryptionMethod> encs = new HashSet<>();
313
314                for (JWEEncrypter encrypter : encrypters.values()) {
315                        encs.addAll(encrypter.supportedEncryptionMethods());
316                }
317
318                for (JWEDecrypter decrypter : decrypters.values()) {
319                        encs.addAll(decrypter.supportedEncryptionMethods());
320                }
321
322                return encs;
323        }
324
325
326}