001/*******************************************************************************
002 * Copyright 2017 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018package org.mitre.jwt.signer.service.impl;
019
020import java.security.NoSuchAlgorithmException;
021import java.security.spec.InvalidKeySpecException;
022import java.util.Collection;
023import java.util.HashMap;
024import java.util.HashSet;
025import java.util.Map;
026import java.util.Set;
027import java.util.UUID;
028
029import org.mitre.jose.keystore.JWKSetKeyStore;
030import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
031import org.slf4j.Logger;
032import org.slf4j.LoggerFactory;
033
034import com.google.common.base.Strings;
035import com.nimbusds.jose.JOSEException;
036import com.nimbusds.jose.JWSAlgorithm;
037import com.nimbusds.jose.JWSSigner;
038import com.nimbusds.jose.JWSVerifier;
039import com.nimbusds.jose.crypto.ECDSASigner;
040import com.nimbusds.jose.crypto.ECDSAVerifier;
041import com.nimbusds.jose.crypto.MACSigner;
042import com.nimbusds.jose.crypto.MACVerifier;
043import com.nimbusds.jose.crypto.RSASSASigner;
044import com.nimbusds.jose.crypto.RSASSAVerifier;
045import com.nimbusds.jose.jwk.ECKey;
046import com.nimbusds.jose.jwk.JWK;
047import com.nimbusds.jose.jwk.OctetSequenceKey;
048import com.nimbusds.jose.jwk.RSAKey;
049import com.nimbusds.jwt.SignedJWT;
050
051public class DefaultJWTSigningAndValidationService implements JWTSigningAndValidationService {
052
053        // map of identifier to signer
054        private Map<String, JWSSigner> signers = new HashMap<>();
055
056        // map of identifier to verifier
057        private Map<String, JWSVerifier> verifiers = new HashMap<>();
058
059        /**
060         * Logger for this class
061         */
062        private static final Logger logger = LoggerFactory.getLogger(DefaultJWTSigningAndValidationService.class);
063
064        private String defaultSignerKeyId;
065
066        private JWSAlgorithm defaultAlgorithm;
067
068        // map of identifier to key
069        private Map<String, JWK> keys = new HashMap<>();
070
071        /**
072         * Build this service based on the keys given. All public keys will be used
073         * to make verifiers, all private keys will be used to make signers.
074         *
075         * @param keys
076         *            A map of key identifier to key
077         *
078         * @throws InvalidKeySpecException
079         *             If the keys in the JWKs are not valid
080         * @throws NoSuchAlgorithmException
081         *             If there is no appropriate algorithm to tie the keys to.
082         */
083        public DefaultJWTSigningAndValidationService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException {
084                this.keys = keys;
085                buildSignersAndVerifiers();
086        }
087
088        /**
089         * Build this service based on the given keystore. All keys must have a key
090         * id ({@code kid}) field in order to be used.
091         *
092         * @param keyStore
093         *            the keystore to load all keys from
094         *
095         * @throws InvalidKeySpecException
096         *             If the keys in the JWKs are not valid
097         * @throws NoSuchAlgorithmException
098         *             If there is no appropriate algorithm to tie the keys to.
099         */
100        public DefaultJWTSigningAndValidationService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException {
101                // convert all keys in the keystore to a map based on key id
102                if (keyStore!= null && keyStore.getJwkSet() != null) {
103                        for (JWK key : keyStore.getKeys()) {
104                                if (!Strings.isNullOrEmpty(key.getKeyID())) {
105                                        // use the key ID that's built into the key itself
106                                        this.keys.put(key.getKeyID(), key);
107                                } else {
108                                        // create a random key id
109                                        String fakeKid = UUID.randomUUID().toString();
110                                        this.keys.put(fakeKid, key);
111                                }
112                        }
113                }
114                buildSignersAndVerifiers();
115        }
116
117
118        /**
119         * @return the defaultSignerKeyId
120         */
121        @Override
122        public String getDefaultSignerKeyId() {
123                return defaultSignerKeyId;
124        }
125
126        /**
127         * @param defaultSignerKeyId the defaultSignerKeyId to set
128         */
129        public void setDefaultSignerKeyId(String defaultSignerId) {
130                this.defaultSignerKeyId = defaultSignerId;
131        }
132
133        /**
134         * @return
135         */
136        @Override
137        public JWSAlgorithm getDefaultSigningAlgorithm() {
138                return defaultAlgorithm;
139        }
140
141        public void setDefaultSigningAlgorithmName(String algName) {
142                defaultAlgorithm = JWSAlgorithm.parse(algName);
143        }
144
145        public String getDefaultSigningAlgorithmName() {
146                if (defaultAlgorithm != null) {
147                        return defaultAlgorithm.getName();
148                } else {
149                        return null;
150                }
151        }
152
153        /**
154         * Build all of the signers and verifiers for this based on the key map.
155         * @throws InvalidKeySpecException If the keys in the JWKs are not valid
156         * @throws NoSuchAlgorithmException If there is no appropriate algorithm to tie the keys to.
157         */
158        private void buildSignersAndVerifiers() throws NoSuchAlgorithmException, InvalidKeySpecException {
159                for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {
160
161                        String id = jwkEntry.getKey();
162                        JWK jwk = jwkEntry.getValue();
163
164                        try {
165                                if (jwk instanceof RSAKey) {
166                                        // build RSA signers & verifiers
167
168                                        if (jwk.isPrivate()) { // only add the signer if there's a private key
169                                                RSASSASigner signer = new RSASSASigner((RSAKey) jwk);
170                                                signers.put(id, signer);
171                                        }
172
173                                        RSASSAVerifier verifier = new RSASSAVerifier((RSAKey) jwk);
174                                        verifiers.put(id, verifier);
175
176                                } else if (jwk instanceof ECKey) {
177                                        // build EC signers & verifiers
178
179                                        if (jwk.isPrivate()) {
180                                                ECDSASigner signer = new ECDSASigner((ECKey) jwk);
181                                                signers.put(id, signer);
182                                        }
183
184                                        ECDSAVerifier verifier = new ECDSAVerifier((ECKey) jwk);
185                                        verifiers.put(id, verifier);
186
187                                } else if (jwk instanceof OctetSequenceKey) {
188                                        // build HMAC signers & verifiers
189
190                                        if (jwk.isPrivate()) { // technically redundant check because all HMAC keys are private
191                                                MACSigner signer = new MACSigner((OctetSequenceKey) jwk);
192                                                signers.put(id, signer);
193                                        }
194
195                                        MACVerifier verifier = new MACVerifier((OctetSequenceKey) jwk);
196                                        verifiers.put(id, verifier);
197
198                                } else {
199                                        logger.warn("Unknown key type: " + jwk);
200                                }
201                        } catch (JOSEException e) {
202                                logger.warn("Exception loading signer/verifier", e);
203                        }
204                }
205
206                if (defaultSignerKeyId == null && keys.size() == 1) {
207                        // if there's only one key, it's the default
208                        setDefaultSignerKeyId(keys.keySet().iterator().next());
209                }
210        }
211
212        /**
213         * Sign a jwt in place using the configured default signer.
214         */
215        @Override
216        public void signJwt(SignedJWT jwt) {
217                if (getDefaultSignerKeyId() == null) {
218                        throw new IllegalStateException("Tried to call default signing with no default signer ID set");
219                }
220
221                JWSSigner signer = signers.get(getDefaultSignerKeyId());
222
223                try {
224                        jwt.sign(signer);
225                } catch (JOSEException e) {
226
227                        logger.error("Failed to sign JWT, error was: ", e);
228                }
229
230        }
231
232        @Override
233        public void signJwt(SignedJWT jwt, JWSAlgorithm alg) {
234
235                JWSSigner signer = null;
236
237                for (JWSSigner s : signers.values()) {
238                        if (s.supportedJWSAlgorithms().contains(alg)) {
239                                signer = s;
240                                break;
241                        }
242                }
243
244                if (signer == null) {
245                        //If we can't find an algorithm that matches, we can't sign
246                        logger.error("No matching algirthm found for alg=" + alg);
247
248                }
249
250                try {
251                        jwt.sign(signer);
252                } catch (JOSEException e) {
253
254                        logger.error("Failed to sign JWT, error was: ", e);
255                }
256
257        }
258
259        @Override
260        public boolean validateSignature(SignedJWT jwt) {
261
262                for (JWSVerifier verifier : verifiers.values()) {
263                        try {
264                                if (jwt.verify(verifier)) {
265                                        return true;
266                                }
267                        } catch (JOSEException e) {
268
269                                logger.error("Failed to validate signature with " + verifier + " error message: " + e.getMessage());
270                        }
271                }
272                return false;
273        }
274
275        @Override
276        public Map<String, JWK> getAllPublicKeys() {
277                Map<String, JWK> pubKeys = new HashMap<>();
278
279                // pull all keys out of the verifiers if we know how
280                for (String keyId : keys.keySet()) {
281                        JWK key = keys.get(keyId);
282                        JWK pub = key.toPublicJWK();
283                        if (pub != null) {
284                                pubKeys.put(keyId, pub);
285                        }
286                }
287
288                return pubKeys;
289        }
290
291        /* (non-Javadoc)
292         * @see org.mitre.jwt.signer.service.JwtSigningAndValidationService#getAllSigningAlgsSupported()
293         */
294        @Override
295        public Collection<JWSAlgorithm> getAllSigningAlgsSupported() {
296
297                Set<JWSAlgorithm> algs = new HashSet<>();
298
299                for (JWSSigner signer : signers.values()) {
300                        algs.addAll(signer.supportedJWSAlgorithms());
301                }
302
303                for (JWSVerifier verifier : verifiers.values()) {
304                        algs.addAll(verifier.supportedJWSAlgorithms());
305                }
306
307                return algs;
308
309        }
310
311}