1import math
2import warnings
3
4from cryptography.exceptions import InvalidSignature, InvalidTag
5from cryptography.hazmat.backends import default_backend
6from cryptography.hazmat.primitives import hashes, hmac, serialization
7from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa
8from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
9from cryptography.hazmat.primitives.ciphers import Cipher, aead, algorithms, modes
10from cryptography.hazmat.primitives.keywrap import InvalidUnwrap, aes_key_unwrap, aes_key_wrap
11from cryptography.hazmat.primitives.padding import PKCS7
12from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
13from cryptography.utils import int_to_bytes
14from cryptography.x509 import load_pem_x509_certificate
15
16from ..constants import ALGORITHMS
17from ..exceptions import JWEError, JWKError
18from ..utils import (
19 base64_to_long,
20 base64url_decode,
21 base64url_encode,
22 ensure_binary,
23 is_pem_format,
24 is_ssh_key,
25 long_to_base64,
26)
27from . import get_random_bytes
28from .base import Key
29
30_binding = None
31
32
33class CryptographyECKey(Key):
34 SHA256 = hashes.SHA256
35 SHA384 = hashes.SHA384
36 SHA512 = hashes.SHA512
37
38 def __init__(self, key, algorithm, cryptography_backend=default_backend):
39 if algorithm not in ALGORITHMS.EC:
40 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
41
42 self.hash_alg = {
43 ALGORITHMS.ES256: self.SHA256,
44 ALGORITHMS.ES384: self.SHA384,
45 ALGORITHMS.ES512: self.SHA512,
46 }.get(algorithm)
47 self._algorithm = algorithm
48
49 self.cryptography_backend = cryptography_backend
50
51 if hasattr(key, "public_bytes") or hasattr(key, "private_bytes"):
52 self.prepared_key = key
53 return
54
55 if hasattr(key, "to_pem"):
56 # convert to PEM and let cryptography below load it as PEM
57 key = key.to_pem().decode("utf-8")
58
59 if isinstance(key, dict):
60 self.prepared_key = self._process_jwk(key)
61 return
62
63 if isinstance(key, str):
64 key = key.encode("utf-8")
65
66 if isinstance(key, bytes):
67 # Attempt to load key. We don't know if it's
68 # a Public Key or a Private Key, so we try
69 # the Public Key first.
70 try:
71 try:
72 key = load_pem_public_key(key, self.cryptography_backend())
73 except ValueError:
74 key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
75 except Exception as e:
76 raise JWKError(e)
77
78 self.prepared_key = key
79 return
80
81 raise JWKError("Unable to parse an ECKey from key: %s" % key)
82
83 def _process_jwk(self, jwk_dict):
84 if not jwk_dict.get("kty") == "EC":
85 raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
86
87 if not all(k in jwk_dict for k in ["x", "y", "crv"]):
88 raise JWKError("Mandatory parameters are missing")
89
90 x = base64_to_long(jwk_dict.get("x"))
91 y = base64_to_long(jwk_dict.get("y"))
92 curve = {
93 "P-256": ec.SECP256R1,
94 "P-384": ec.SECP384R1,
95 "P-521": ec.SECP521R1,
96 }[jwk_dict["crv"]]
97
98 public = ec.EllipticCurvePublicNumbers(x, y, curve())
99
100 if "d" in jwk_dict:
101 d = base64_to_long(jwk_dict.get("d"))
102 private = ec.EllipticCurvePrivateNumbers(d, public)
103
104 return private.private_key(self.cryptography_backend())
105 else:
106 return public.public_key(self.cryptography_backend())
107
108 def _sig_component_length(self):
109 """Determine the correct serialization length for an encoded signature component.
110
111 This is the number of bytes required to encode the maximum key value.
112 """
113 return int(math.ceil(self.prepared_key.key_size / 8.0))
114
115 def _der_to_raw(self, der_signature):
116 """Convert signature from DER encoding to RAW encoding."""
117 r, s = decode_dss_signature(der_signature)
118 component_length = self._sig_component_length()
119 return int_to_bytes(r, component_length) + int_to_bytes(s, component_length)
120
121 def _raw_to_der(self, raw_signature):
122 """Convert signature from RAW encoding to DER encoding."""
123 component_length = self._sig_component_length()
124 if len(raw_signature) != int(2 * component_length):
125 raise ValueError("Invalid signature")
126
127 r_bytes = raw_signature[:component_length]
128 s_bytes = raw_signature[component_length:]
129 r = int.from_bytes(r_bytes, "big")
130 s = int.from_bytes(s_bytes, "big")
131 return encode_dss_signature(r, s)
132
133 def sign(self, msg):
134 if self.hash_alg.digest_size * 8 > self.prepared_key.curve.key_size:
135 raise TypeError(
136 "this curve (%s) is too short "
137 "for your digest (%d)" % (self.prepared_key.curve.name, 8 * self.hash_alg.digest_size)
138 )
139 signature = self.prepared_key.sign(msg, ec.ECDSA(self.hash_alg()))
140 return self._der_to_raw(signature)
141
142 def verify(self, msg, sig):
143 try:
144 signature = self._raw_to_der(sig)
145 self.prepared_key.verify(signature, msg, ec.ECDSA(self.hash_alg()))
146 return True
147 except Exception:
148 return False
149
150 def is_public(self):
151 return hasattr(self.prepared_key, "public_bytes")
152
153 def public_key(self):
154 if self.is_public():
155 return self
156 return self.__class__(self.prepared_key.public_key(), self._algorithm)
157
158 def to_pem(self):
159 if self.is_public():
160 pem = self.prepared_key.public_bytes(
161 encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
162 )
163 return pem
164 pem = self.prepared_key.private_bytes(
165 encoding=serialization.Encoding.PEM,
166 format=serialization.PrivateFormat.TraditionalOpenSSL,
167 encryption_algorithm=serialization.NoEncryption(),
168 )
169 return pem
170
171 def to_dict(self):
172 if not self.is_public():
173 public_key = self.prepared_key.public_key()
174 else:
175 public_key = self.prepared_key
176
177 crv = {
178 "secp256r1": "P-256",
179 "secp384r1": "P-384",
180 "secp521r1": "P-521",
181 }[self.prepared_key.curve.name]
182
183 # Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
184 # RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
185 # points must be encoded as octed-strings of this length.
186 key_size = (self.prepared_key.curve.key_size + 7) // 8
187
188 data = {
189 "alg": self._algorithm,
190 "kty": "EC",
191 "crv": crv,
192 "x": long_to_base64(public_key.public_numbers().x, size=key_size).decode("ASCII"),
193 "y": long_to_base64(public_key.public_numbers().y, size=key_size).decode("ASCII"),
194 }
195
196 if not self.is_public():
197 private_value = self.prepared_key.private_numbers().private_value
198 data["d"] = long_to_base64(private_value, size=key_size).decode("ASCII")
199
200 return data
201
202
203class CryptographyRSAKey(Key):
204 SHA256 = hashes.SHA256
205 SHA384 = hashes.SHA384
206 SHA512 = hashes.SHA512
207
208 RSA1_5 = padding.PKCS1v15()
209 RSA_OAEP = padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)
210 RSA_OAEP_256 = padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)
211
212 def __init__(self, key, algorithm, cryptography_backend=default_backend):
213 if algorithm not in ALGORITHMS.RSA:
214 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
215
216 self.hash_alg = {
217 ALGORITHMS.RS256: self.SHA256,
218 ALGORITHMS.RS384: self.SHA384,
219 ALGORITHMS.RS512: self.SHA512,
220 }.get(algorithm)
221 self._algorithm = algorithm
222
223 self.padding = {
224 ALGORITHMS.RSA1_5: self.RSA1_5,
225 ALGORITHMS.RSA_OAEP: self.RSA_OAEP,
226 ALGORITHMS.RSA_OAEP_256: self.RSA_OAEP_256,
227 }.get(algorithm)
228
229 self.cryptography_backend = cryptography_backend
230
231 # if it conforms to RSAPublicKey or RSAPrivateKey interface
232 if (hasattr(key, "public_bytes") and hasattr(key, "public_numbers")) or hasattr(key, "private_bytes"):
233 self.prepared_key = key
234 return
235
236 if isinstance(key, dict):
237 self.prepared_key = self._process_jwk(key)
238 return
239
240 if isinstance(key, str):
241 key = key.encode("utf-8")
242
243 if isinstance(key, bytes):
244 try:
245 if key.startswith(b"-----BEGIN CERTIFICATE-----"):
246 self._process_cert(key)
247 return
248
249 try:
250 self.prepared_key = load_pem_public_key(key, self.cryptography_backend())
251 except ValueError:
252 self.prepared_key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
253 except Exception as e:
254 raise JWKError(e)
255 return
256
257 raise JWKError("Unable to parse an RSA_JWK from key: %s" % key)
258
259 def _process_jwk(self, jwk_dict):
260 if not jwk_dict.get("kty") == "RSA":
261 raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
262
263 e = base64_to_long(jwk_dict.get("e", 256))
264 n = base64_to_long(jwk_dict.get("n"))
265 public = rsa.RSAPublicNumbers(e, n)
266
267 if "d" not in jwk_dict:
268 return public.public_key(self.cryptography_backend())
269 else:
270 # This is a private key.
271 d = base64_to_long(jwk_dict.get("d"))
272
273 extra_params = ["p", "q", "dp", "dq", "qi"]
274
275 if any(k in jwk_dict for k in extra_params):
276 # Precomputed private key parameters are available.
277 if not all(k in jwk_dict for k in extra_params):
278 # These values must be present when 'p' is according to
279 # Section 6.3.2 of RFC7518, so if they are not we raise
280 # an error.
281 raise JWKError("Precomputed private key parameters are incomplete.")
282
283 p = base64_to_long(jwk_dict["p"])
284 q = base64_to_long(jwk_dict["q"])
285 dp = base64_to_long(jwk_dict["dp"])
286 dq = base64_to_long(jwk_dict["dq"])
287 qi = base64_to_long(jwk_dict["qi"])
288 else:
289 # The precomputed private key parameters are not available,
290 # so we use cryptography's API to fill them in.
291 p, q = rsa.rsa_recover_prime_factors(n, e, d)
292 dp = rsa.rsa_crt_dmp1(d, p)
293 dq = rsa.rsa_crt_dmq1(d, q)
294 qi = rsa.rsa_crt_iqmp(p, q)
295
296 private = rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, public)
297
298 return private.private_key(self.cryptography_backend())
299
300 def _process_cert(self, key):
301 key = load_pem_x509_certificate(key, self.cryptography_backend())
302 self.prepared_key = key.public_key()
303
304 def sign(self, msg):
305 try:
306 signature = self.prepared_key.sign(msg, padding.PKCS1v15(), self.hash_alg())
307 except Exception as e:
308 raise JWKError(e)
309 return signature
310
311 def verify(self, msg, sig):
312 if not self.is_public():
313 warnings.warn("Attempting to verify a message with a private key. " "This is not recommended.")
314
315 try:
316 self.public_key().prepared_key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
317 return True
318 except InvalidSignature:
319 return False
320
321 def is_public(self):
322 return hasattr(self.prepared_key, "public_bytes")
323
324 def public_key(self):
325 if self.is_public():
326 return self
327 return self.__class__(self.prepared_key.public_key(), self._algorithm)
328
329 def to_pem(self, pem_format="PKCS8"):
330 if self.is_public():
331 if pem_format == "PKCS8":
332 fmt = serialization.PublicFormat.SubjectPublicKeyInfo
333 elif pem_format == "PKCS1":
334 fmt = serialization.PublicFormat.PKCS1
335 else:
336 raise ValueError("Invalid format specified: %r" % pem_format)
337 pem = self.prepared_key.public_bytes(encoding=serialization.Encoding.PEM, format=fmt)
338 return pem
339
340 if pem_format == "PKCS8":
341 fmt = serialization.PrivateFormat.PKCS8
342 elif pem_format == "PKCS1":
343 fmt = serialization.PrivateFormat.TraditionalOpenSSL
344 else:
345 raise ValueError("Invalid format specified: %r" % pem_format)
346
347 return self.prepared_key.private_bytes(
348 encoding=serialization.Encoding.PEM, format=fmt, encryption_algorithm=serialization.NoEncryption()
349 )
350
351 def to_dict(self):
352 if not self.is_public():
353 public_key = self.prepared_key.public_key()
354 else:
355 public_key = self.prepared_key
356
357 data = {
358 "alg": self._algorithm,
359 "kty": "RSA",
360 "n": long_to_base64(public_key.public_numbers().n).decode("ASCII"),
361 "e": long_to_base64(public_key.public_numbers().e).decode("ASCII"),
362 }
363
364 if not self.is_public():
365 data.update(
366 {
367 "d": long_to_base64(self.prepared_key.private_numbers().d).decode("ASCII"),
368 "p": long_to_base64(self.prepared_key.private_numbers().p).decode("ASCII"),
369 "q": long_to_base64(self.prepared_key.private_numbers().q).decode("ASCII"),
370 "dp": long_to_base64(self.prepared_key.private_numbers().dmp1).decode("ASCII"),
371 "dq": long_to_base64(self.prepared_key.private_numbers().dmq1).decode("ASCII"),
372 "qi": long_to_base64(self.prepared_key.private_numbers().iqmp).decode("ASCII"),
373 }
374 )
375
376 return data
377
378 def wrap_key(self, key_data):
379 try:
380 wrapped_key = self.prepared_key.encrypt(key_data, self.padding)
381 except Exception as e:
382 raise JWEError(e)
383
384 return wrapped_key
385
386 def unwrap_key(self, wrapped_key):
387 try:
388 unwrapped_key = self.prepared_key.decrypt(wrapped_key, self.padding)
389 return unwrapped_key
390 except Exception as e:
391 raise JWEError(e)
392
393
394class CryptographyAESKey(Key):
395 KEY_128 = (ALGORITHMS.A128GCM, ALGORITHMS.A128GCMKW, ALGORITHMS.A128KW, ALGORITHMS.A128CBC)
396 KEY_192 = (ALGORITHMS.A192GCM, ALGORITHMS.A192GCMKW, ALGORITHMS.A192KW, ALGORITHMS.A192CBC)
397 KEY_256 = (
398 ALGORITHMS.A256GCM,
399 ALGORITHMS.A256GCMKW,
400 ALGORITHMS.A256KW,
401 ALGORITHMS.A128CBC_HS256,
402 ALGORITHMS.A256CBC,
403 )
404 KEY_384 = (ALGORITHMS.A192CBC_HS384,)
405 KEY_512 = (ALGORITHMS.A256CBC_HS512,)
406
407 AES_KW_ALGS = (ALGORITHMS.A128KW, ALGORITHMS.A192KW, ALGORITHMS.A256KW)
408
409 MODES = {
410 ALGORITHMS.A128GCM: modes.GCM,
411 ALGORITHMS.A192GCM: modes.GCM,
412 ALGORITHMS.A256GCM: modes.GCM,
413 ALGORITHMS.A128CBC_HS256: modes.CBC,
414 ALGORITHMS.A192CBC_HS384: modes.CBC,
415 ALGORITHMS.A256CBC_HS512: modes.CBC,
416 ALGORITHMS.A128CBC: modes.CBC,
417 ALGORITHMS.A192CBC: modes.CBC,
418 ALGORITHMS.A256CBC: modes.CBC,
419 ALGORITHMS.A128GCMKW: modes.GCM,
420 ALGORITHMS.A192GCMKW: modes.GCM,
421 ALGORITHMS.A256GCMKW: modes.GCM,
422 ALGORITHMS.A128KW: None,
423 ALGORITHMS.A192KW: None,
424 ALGORITHMS.A256KW: None,
425 }
426
427 IV_BYTE_LENGTH_MODE_MAP = {"CBC": algorithms.AES.block_size // 8, "GCM": 96 // 8}
428
429 def __init__(self, key, algorithm):
430 if algorithm not in ALGORITHMS.AES:
431 raise JWKError("%s is not a valid AES algorithm" % algorithm)
432 if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO):
433 raise JWKError("%s is not a supported algorithm" % algorithm)
434
435 self._algorithm = algorithm
436 self._mode = self.MODES.get(self._algorithm)
437
438 if algorithm in self.KEY_128 and len(key) != 16:
439 raise JWKError(f"Key must be 128 bit for alg {algorithm}")
440 elif algorithm in self.KEY_192 and len(key) != 24:
441 raise JWKError(f"Key must be 192 bit for alg {algorithm}")
442 elif algorithm in self.KEY_256 and len(key) != 32:
443 raise JWKError(f"Key must be 256 bit for alg {algorithm}")
444 elif algorithm in self.KEY_384 and len(key) != 48:
445 raise JWKError(f"Key must be 384 bit for alg {algorithm}")
446 elif algorithm in self.KEY_512 and len(key) != 64:
447 raise JWKError(f"Key must be 512 bit for alg {algorithm}")
448
449 self._key = key
450
451 def to_dict(self):
452 data = {"alg": self._algorithm, "kty": "oct", "k": base64url_encode(self._key)}
453 return data
454
455 def encrypt(self, plain_text, aad=None):
456 plain_text = ensure_binary(plain_text)
457 try:
458 iv_byte_length = self.IV_BYTE_LENGTH_MODE_MAP.get(self._mode.name, algorithms.AES.block_size)
459 iv = get_random_bytes(iv_byte_length)
460 mode = self._mode(iv)
461 if mode.name == "GCM":
462 cipher = aead.AESGCM(self._key)
463 cipher_text_and_tag = cipher.encrypt(iv, plain_text, aad)
464 cipher_text = cipher_text_and_tag[: len(cipher_text_and_tag) - 16]
465 auth_tag = cipher_text_and_tag[-16:]
466 else:
467 cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
468 encryptor = cipher.encryptor()
469 padder = PKCS7(algorithms.AES.block_size).padder()
470 padded_data = padder.update(plain_text)
471 padded_data += padder.finalize()
472 cipher_text = encryptor.update(padded_data) + encryptor.finalize()
473 auth_tag = None
474 return iv, cipher_text, auth_tag
475 except Exception as e:
476 raise JWEError(e)
477
478 def decrypt(self, cipher_text, iv=None, aad=None, tag=None):
479 cipher_text = ensure_binary(cipher_text)
480 try:
481 iv = ensure_binary(iv)
482 mode = self._mode(iv)
483 if mode.name == "GCM":
484 if tag is None:
485 raise ValueError("tag cannot be None")
486 cipher = aead.AESGCM(self._key)
487 cipher_text_and_tag = cipher_text + tag
488 try:
489 plain_text = cipher.decrypt(iv, cipher_text_and_tag, aad)
490 except InvalidTag:
491 raise JWEError("Invalid JWE Auth Tag")
492 else:
493 cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
494 decryptor = cipher.decryptor()
495 padded_plain_text = decryptor.update(cipher_text)
496 padded_plain_text += decryptor.finalize()
497 unpadder = PKCS7(algorithms.AES.block_size).unpadder()
498 plain_text = unpadder.update(padded_plain_text)
499 plain_text += unpadder.finalize()
500
501 return plain_text
502 except Exception as e:
503 raise JWEError(e)
504
505 def wrap_key(self, key_data):
506 key_data = ensure_binary(key_data)
507 cipher_text = aes_key_wrap(self._key, key_data, default_backend())
508 return cipher_text # IV, cipher text, auth tag
509
510 def unwrap_key(self, wrapped_key):
511 wrapped_key = ensure_binary(wrapped_key)
512 try:
513 plain_text = aes_key_unwrap(self._key, wrapped_key, default_backend())
514 except InvalidUnwrap as cause:
515 raise JWEError(cause)
516 return plain_text
517
518
519class CryptographyHMACKey(Key):
520 """
521 Performs signing and verification operations using HMAC
522 and the specified hash function.
523 """
524
525 ALG_MAP = {ALGORITHMS.HS256: hashes.SHA256(), ALGORITHMS.HS384: hashes.SHA384(), ALGORITHMS.HS512: hashes.SHA512()}
526
527 def __init__(self, key, algorithm):
528 if algorithm not in ALGORITHMS.HMAC:
529 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
530 self._algorithm = algorithm
531 self._hash_alg = self.ALG_MAP.get(algorithm)
532
533 if isinstance(key, dict):
534 self.prepared_key = self._process_jwk(key)
535 return
536
537 if not isinstance(key, str) and not isinstance(key, bytes):
538 raise JWKError("Expecting a string- or bytes-formatted key.")
539
540 if isinstance(key, str):
541 key = key.encode("utf-8")
542
543 if is_pem_format(key) or is_ssh_key(key):
544 raise JWKError(
545 "The specified key is an asymmetric key or x509 certificate and"
546 " should not be used as an HMAC secret."
547 )
548
549 self.prepared_key = key
550
551 def _process_jwk(self, jwk_dict):
552 if not jwk_dict.get("kty") == "oct":
553 raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
554
555 k = jwk_dict.get("k")
556 k = k.encode("utf-8")
557 k = bytes(k)
558 k = base64url_decode(k)
559
560 return k
561
562 def to_dict(self):
563 return {
564 "alg": self._algorithm,
565 "kty": "oct",
566 "k": base64url_encode(self.prepared_key).decode("ASCII"),
567 }
568
569 def sign(self, msg):
570 msg = ensure_binary(msg)
571 h = hmac.HMAC(self.prepared_key, self._hash_alg, backend=default_backend())
572 h.update(msg)
573 signature = h.finalize()
574 return signature
575
576 def verify(self, msg, sig):
577 msg = ensure_binary(msg)
578 sig = ensure_binary(sig)
579 h = hmac.HMAC(self.prepared_key, self._hash_alg, backend=default_backend())
580 h.update(msg)
581 try:
582 h.verify(sig)
583 verified = True
584 except InvalidSignature:
585 verified = False
586 return verified