Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jose/backends/cryptography_backend.py: 1%
367 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:16 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:16 +0000
1import math
2import warnings
4from cryptography.exceptions import InvalidSignature, InvalidTag
5from cryptography.hazmat.backends import default_backend
6from cryptography.hazmat.bindings.openssl.binding import Binding
7from cryptography.hazmat.primitives import hashes, hmac, serialization
8from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa
9from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
10from cryptography.hazmat.primitives.ciphers import Cipher, aead, algorithms, modes
11from cryptography.hazmat.primitives.keywrap import InvalidUnwrap, aes_key_unwrap, aes_key_wrap
12from cryptography.hazmat.primitives.padding import PKCS7
13from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
14from cryptography.utils import int_to_bytes
15from cryptography.x509 import load_pem_x509_certificate
17from ..constants import ALGORITHMS
18from ..exceptions import JWEError, JWKError
19from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64
20from .base import Key
22_binding = None
25def get_random_bytes(num_bytes):
26 """
27 Get random bytes
29 Currently, Cryptography returns OS random bytes. If you want OpenSSL
30 generated random bytes, you'll have to switch the RAND engine after
31 initializing the OpenSSL backend
32 Args:
33 num_bytes (int): Number of random bytes to generate and return
34 Returns:
35 bytes: Random bytes
36 """
37 global _binding
39 if _binding is None:
40 _binding = Binding()
42 buf = _binding.ffi.new("char[]", num_bytes)
43 _binding.lib.RAND_bytes(buf, num_bytes)
44 rand_bytes = _binding.ffi.buffer(buf, num_bytes)[:]
45 return rand_bytes
48class CryptographyECKey(Key):
49 SHA256 = hashes.SHA256
50 SHA384 = hashes.SHA384
51 SHA512 = hashes.SHA512
53 def __init__(self, key, algorithm, cryptography_backend=default_backend):
54 if algorithm not in ALGORITHMS.EC:
55 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
57 self.hash_alg = {
58 ALGORITHMS.ES256: self.SHA256,
59 ALGORITHMS.ES384: self.SHA384,
60 ALGORITHMS.ES512: self.SHA512,
61 }.get(algorithm)
62 self._algorithm = algorithm
64 self.cryptography_backend = cryptography_backend
66 if hasattr(key, "public_bytes") or hasattr(key, "private_bytes"):
67 self.prepared_key = key
68 return
70 if hasattr(key, "to_pem"):
71 # convert to PEM and let cryptography below load it as PEM
72 key = key.to_pem().decode("utf-8")
74 if isinstance(key, dict):
75 self.prepared_key = self._process_jwk(key)
76 return
78 if isinstance(key, str):
79 key = key.encode("utf-8")
81 if isinstance(key, bytes):
82 # Attempt to load key. We don't know if it's
83 # a Public Key or a Private Key, so we try
84 # the Public Key first.
85 try:
86 try:
87 key = load_pem_public_key(key, self.cryptography_backend())
88 except ValueError:
89 key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
90 except Exception as e:
91 raise JWKError(e)
93 self.prepared_key = key
94 return
96 raise JWKError("Unable to parse an ECKey from key: %s" % key)
98 def _process_jwk(self, jwk_dict):
99 if not jwk_dict.get("kty") == "EC":
100 raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
102 if not all(k in jwk_dict for k in ["x", "y", "crv"]):
103 raise JWKError("Mandatory parameters are missing")
105 x = base64_to_long(jwk_dict.get("x"))
106 y = base64_to_long(jwk_dict.get("y"))
107 curve = {
108 "P-256": ec.SECP256R1,
109 "P-384": ec.SECP384R1,
110 "P-521": ec.SECP521R1,
111 }[jwk_dict["crv"]]
113 public = ec.EllipticCurvePublicNumbers(x, y, curve())
115 if "d" in jwk_dict:
116 d = base64_to_long(jwk_dict.get("d"))
117 private = ec.EllipticCurvePrivateNumbers(d, public)
119 return private.private_key(self.cryptography_backend())
120 else:
121 return public.public_key(self.cryptography_backend())
123 def _sig_component_length(self):
124 """Determine the correct serialization length for an encoded signature component.
126 This is the number of bytes required to encode the maximum key value.
127 """
128 return int(math.ceil(self.prepared_key.key_size / 8.0))
130 def _der_to_raw(self, der_signature):
131 """Convert signature from DER encoding to RAW encoding."""
132 r, s = decode_dss_signature(der_signature)
133 component_length = self._sig_component_length()
134 return int_to_bytes(r, component_length) + int_to_bytes(s, component_length)
136 def _raw_to_der(self, raw_signature):
137 """Convert signature from RAW encoding to DER encoding."""
138 component_length = self._sig_component_length()
139 if len(raw_signature) != int(2 * component_length):
140 raise ValueError("Invalid signature")
142 r_bytes = raw_signature[:component_length]
143 s_bytes = raw_signature[component_length:]
144 r = int.from_bytes(r_bytes, "big")
145 s = int.from_bytes(s_bytes, "big")
146 return encode_dss_signature(r, s)
148 def sign(self, msg):
149 if self.hash_alg.digest_size * 8 > self.prepared_key.curve.key_size:
150 raise TypeError(
151 "this curve (%s) is too short "
152 "for your digest (%d)" % (self.prepared_key.curve.name, 8 * self.hash_alg.digest_size)
153 )
154 signature = self.prepared_key.sign(msg, ec.ECDSA(self.hash_alg()))
155 return self._der_to_raw(signature)
157 def verify(self, msg, sig):
158 try:
159 signature = self._raw_to_der(sig)
160 self.prepared_key.verify(signature, msg, ec.ECDSA(self.hash_alg()))
161 return True
162 except Exception:
163 return False
165 def is_public(self):
166 return hasattr(self.prepared_key, "public_bytes")
168 def public_key(self):
169 if self.is_public():
170 return self
171 return self.__class__(self.prepared_key.public_key(), self._algorithm)
173 def to_pem(self):
174 if self.is_public():
175 pem = self.prepared_key.public_bytes(
176 encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
177 )
178 return pem
179 pem = self.prepared_key.private_bytes(
180 encoding=serialization.Encoding.PEM,
181 format=serialization.PrivateFormat.TraditionalOpenSSL,
182 encryption_algorithm=serialization.NoEncryption(),
183 )
184 return pem
186 def to_dict(self):
187 if not self.is_public():
188 public_key = self.prepared_key.public_key()
189 else:
190 public_key = self.prepared_key
192 crv = {
193 "secp256r1": "P-256",
194 "secp384r1": "P-384",
195 "secp521r1": "P-521",
196 }[self.prepared_key.curve.name]
198 # Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
199 # RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
200 # points must be encoded as octed-strings of this length.
201 key_size = (self.prepared_key.curve.key_size + 7) // 8
203 data = {
204 "alg": self._algorithm,
205 "kty": "EC",
206 "crv": crv,
207 "x": long_to_base64(public_key.public_numbers().x, size=key_size).decode("ASCII"),
208 "y": long_to_base64(public_key.public_numbers().y, size=key_size).decode("ASCII"),
209 }
211 if not self.is_public():
212 private_value = self.prepared_key.private_numbers().private_value
213 data["d"] = long_to_base64(private_value, size=key_size).decode("ASCII")
215 return data
218class CryptographyRSAKey(Key):
219 SHA256 = hashes.SHA256
220 SHA384 = hashes.SHA384
221 SHA512 = hashes.SHA512
223 RSA1_5 = padding.PKCS1v15()
224 RSA_OAEP = padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)
225 RSA_OAEP_256 = padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)
227 def __init__(self, key, algorithm, cryptography_backend=default_backend):
228 if algorithm not in ALGORITHMS.RSA:
229 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
231 self.hash_alg = {
232 ALGORITHMS.RS256: self.SHA256,
233 ALGORITHMS.RS384: self.SHA384,
234 ALGORITHMS.RS512: self.SHA512,
235 }.get(algorithm)
236 self._algorithm = algorithm
238 self.padding = {
239 ALGORITHMS.RSA1_5: self.RSA1_5,
240 ALGORITHMS.RSA_OAEP: self.RSA_OAEP,
241 ALGORITHMS.RSA_OAEP_256: self.RSA_OAEP_256,
242 }.get(algorithm)
244 self.cryptography_backend = cryptography_backend
246 # if it conforms to RSAPublicKey interface
247 if hasattr(key, "public_bytes") and hasattr(key, "public_numbers"):
248 self.prepared_key = key
249 return
251 if isinstance(key, dict):
252 self.prepared_key = self._process_jwk(key)
253 return
255 if isinstance(key, str):
256 key = key.encode("utf-8")
258 if isinstance(key, bytes):
259 try:
260 if key.startswith(b"-----BEGIN CERTIFICATE-----"):
261 self._process_cert(key)
262 return
264 try:
265 self.prepared_key = load_pem_public_key(key, self.cryptography_backend())
266 except ValueError:
267 self.prepared_key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
268 except Exception as e:
269 raise JWKError(e)
270 return
272 raise JWKError("Unable to parse an RSA_JWK from key: %s" % key)
274 def _process_jwk(self, jwk_dict):
275 if not jwk_dict.get("kty") == "RSA":
276 raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
278 e = base64_to_long(jwk_dict.get("e", 256))
279 n = base64_to_long(jwk_dict.get("n"))
280 public = rsa.RSAPublicNumbers(e, n)
282 if "d" not in jwk_dict:
283 return public.public_key(self.cryptography_backend())
284 else:
285 # This is a private key.
286 d = base64_to_long(jwk_dict.get("d"))
288 extra_params = ["p", "q", "dp", "dq", "qi"]
290 if any(k in jwk_dict for k in extra_params):
291 # Precomputed private key parameters are available.
292 if not all(k in jwk_dict for k in extra_params):
293 # These values must be present when 'p' is according to
294 # Section 6.3.2 of RFC7518, so if they are not we raise
295 # an error.
296 raise JWKError("Precomputed private key parameters are incomplete.")
298 p = base64_to_long(jwk_dict["p"])
299 q = base64_to_long(jwk_dict["q"])
300 dp = base64_to_long(jwk_dict["dp"])
301 dq = base64_to_long(jwk_dict["dq"])
302 qi = base64_to_long(jwk_dict["qi"])
303 else:
304 # The precomputed private key parameters are not available,
305 # so we use cryptography's API to fill them in.
306 p, q = rsa.rsa_recover_prime_factors(n, e, d)
307 dp = rsa.rsa_crt_dmp1(d, p)
308 dq = rsa.rsa_crt_dmq1(d, q)
309 qi = rsa.rsa_crt_iqmp(p, q)
311 private = rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, public)
313 return private.private_key(self.cryptography_backend())
315 def _process_cert(self, key):
316 key = load_pem_x509_certificate(key, self.cryptography_backend())
317 self.prepared_key = key.public_key()
319 def sign(self, msg):
320 try:
321 signature = self.prepared_key.sign(msg, padding.PKCS1v15(), self.hash_alg())
322 except Exception as e:
323 raise JWKError(e)
324 return signature
326 def verify(self, msg, sig):
327 if not self.is_public():
328 warnings.warn("Attempting to verify a message with a private key. " "This is not recommended.")
330 try:
331 self.public_key().prepared_key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
332 return True
333 except InvalidSignature:
334 return False
336 def is_public(self):
337 return hasattr(self.prepared_key, "public_bytes")
339 def public_key(self):
340 if self.is_public():
341 return self
342 return self.__class__(self.prepared_key.public_key(), self._algorithm)
344 def to_pem(self, pem_format="PKCS8"):
345 if self.is_public():
346 if pem_format == "PKCS8":
347 fmt = serialization.PublicFormat.SubjectPublicKeyInfo
348 elif pem_format == "PKCS1":
349 fmt = serialization.PublicFormat.PKCS1
350 else:
351 raise ValueError("Invalid format specified: %r" % pem_format)
352 pem = self.prepared_key.public_bytes(encoding=serialization.Encoding.PEM, format=fmt)
353 return pem
355 if pem_format == "PKCS8":
356 fmt = serialization.PrivateFormat.PKCS8
357 elif pem_format == "PKCS1":
358 fmt = serialization.PrivateFormat.TraditionalOpenSSL
359 else:
360 raise ValueError("Invalid format specified: %r" % pem_format)
362 return self.prepared_key.private_bytes(
363 encoding=serialization.Encoding.PEM, format=fmt, encryption_algorithm=serialization.NoEncryption()
364 )
366 def to_dict(self):
367 if not self.is_public():
368 public_key = self.prepared_key.public_key()
369 else:
370 public_key = self.prepared_key
372 data = {
373 "alg": self._algorithm,
374 "kty": "RSA",
375 "n": long_to_base64(public_key.public_numbers().n).decode("ASCII"),
376 "e": long_to_base64(public_key.public_numbers().e).decode("ASCII"),
377 }
379 if not self.is_public():
380 data.update(
381 {
382 "d": long_to_base64(self.prepared_key.private_numbers().d).decode("ASCII"),
383 "p": long_to_base64(self.prepared_key.private_numbers().p).decode("ASCII"),
384 "q": long_to_base64(self.prepared_key.private_numbers().q).decode("ASCII"),
385 "dp": long_to_base64(self.prepared_key.private_numbers().dmp1).decode("ASCII"),
386 "dq": long_to_base64(self.prepared_key.private_numbers().dmq1).decode("ASCII"),
387 "qi": long_to_base64(self.prepared_key.private_numbers().iqmp).decode("ASCII"),
388 }
389 )
391 return data
393 def wrap_key(self, key_data):
394 try:
395 wrapped_key = self.prepared_key.encrypt(key_data, self.padding)
396 except Exception as e:
397 raise JWEError(e)
399 return wrapped_key
401 def unwrap_key(self, wrapped_key):
402 try:
403 unwrapped_key = self.prepared_key.decrypt(wrapped_key, self.padding)
404 return unwrapped_key
405 except Exception as e:
406 raise JWEError(e)
409class CryptographyAESKey(Key):
410 KEY_128 = (ALGORITHMS.A128GCM, ALGORITHMS.A128GCMKW, ALGORITHMS.A128KW, ALGORITHMS.A128CBC)
411 KEY_192 = (ALGORITHMS.A192GCM, ALGORITHMS.A192GCMKW, ALGORITHMS.A192KW, ALGORITHMS.A192CBC)
412 KEY_256 = (
413 ALGORITHMS.A256GCM,
414 ALGORITHMS.A256GCMKW,
415 ALGORITHMS.A256KW,
416 ALGORITHMS.A128CBC_HS256,
417 ALGORITHMS.A256CBC,
418 )
419 KEY_384 = (ALGORITHMS.A192CBC_HS384,)
420 KEY_512 = (ALGORITHMS.A256CBC_HS512,)
422 AES_KW_ALGS = (ALGORITHMS.A128KW, ALGORITHMS.A192KW, ALGORITHMS.A256KW)
424 MODES = {
425 ALGORITHMS.A128GCM: modes.GCM,
426 ALGORITHMS.A192GCM: modes.GCM,
427 ALGORITHMS.A256GCM: modes.GCM,
428 ALGORITHMS.A128CBC_HS256: modes.CBC,
429 ALGORITHMS.A192CBC_HS384: modes.CBC,
430 ALGORITHMS.A256CBC_HS512: modes.CBC,
431 ALGORITHMS.A128CBC: modes.CBC,
432 ALGORITHMS.A192CBC: modes.CBC,
433 ALGORITHMS.A256CBC: modes.CBC,
434 ALGORITHMS.A128GCMKW: modes.GCM,
435 ALGORITHMS.A192GCMKW: modes.GCM,
436 ALGORITHMS.A256GCMKW: modes.GCM,
437 ALGORITHMS.A128KW: None,
438 ALGORITHMS.A192KW: None,
439 ALGORITHMS.A256KW: None,
440 }
442 def __init__(self, key, algorithm):
443 if algorithm not in ALGORITHMS.AES:
444 raise JWKError("%s is not a valid AES algorithm" % algorithm)
445 if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO):
446 raise JWKError("%s is not a supported algorithm" % algorithm)
448 self._algorithm = algorithm
449 self._mode = self.MODES.get(self._algorithm)
451 if algorithm in self.KEY_128 and len(key) != 16:
452 raise JWKError(f"Key must be 128 bit for alg {algorithm}")
453 elif algorithm in self.KEY_192 and len(key) != 24:
454 raise JWKError(f"Key must be 192 bit for alg {algorithm}")
455 elif algorithm in self.KEY_256 and len(key) != 32:
456 raise JWKError(f"Key must be 256 bit for alg {algorithm}")
457 elif algorithm in self.KEY_384 and len(key) != 48:
458 raise JWKError(f"Key must be 384 bit for alg {algorithm}")
459 elif algorithm in self.KEY_512 and len(key) != 64:
460 raise JWKError(f"Key must be 512 bit for alg {algorithm}")
462 self._key = key
464 def to_dict(self):
465 data = {"alg": self._algorithm, "kty": "oct", "k": base64url_encode(self._key)}
466 return data
468 def encrypt(self, plain_text, aad=None):
469 plain_text = ensure_binary(plain_text)
470 try:
471 iv = get_random_bytes(algorithms.AES.block_size // 8)
472 mode = self._mode(iv)
473 if mode.name == "GCM":
474 cipher = aead.AESGCM(self._key)
475 cipher_text_and_tag = cipher.encrypt(iv, plain_text, aad)
476 cipher_text = cipher_text_and_tag[: len(cipher_text_and_tag) - 16]
477 auth_tag = cipher_text_and_tag[-16:]
478 else:
479 cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
480 encryptor = cipher.encryptor()
481 padder = PKCS7(algorithms.AES.block_size).padder()
482 padded_data = padder.update(plain_text)
483 padded_data += padder.finalize()
484 cipher_text = encryptor.update(padded_data) + encryptor.finalize()
485 auth_tag = None
486 return iv, cipher_text, auth_tag
487 except Exception as e:
488 raise JWEError(e)
490 def decrypt(self, cipher_text, iv=None, aad=None, tag=None):
491 cipher_text = ensure_binary(cipher_text)
492 try:
493 iv = ensure_binary(iv)
494 mode = self._mode(iv)
495 if mode.name == "GCM":
496 if tag is None:
497 raise ValueError("tag cannot be None")
498 cipher = aead.AESGCM(self._key)
499 cipher_text_and_tag = cipher_text + tag
500 try:
501 plain_text = cipher.decrypt(iv, cipher_text_and_tag, aad)
502 except InvalidTag:
503 raise JWEError("Invalid JWE Auth Tag")
504 else:
505 cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
506 decryptor = cipher.decryptor()
507 padded_plain_text = decryptor.update(cipher_text)
508 padded_plain_text += decryptor.finalize()
509 unpadder = PKCS7(algorithms.AES.block_size).unpadder()
510 plain_text = unpadder.update(padded_plain_text)
511 plain_text += unpadder.finalize()
513 return plain_text
514 except Exception as e:
515 raise JWEError(e)
517 def wrap_key(self, key_data):
518 key_data = ensure_binary(key_data)
519 cipher_text = aes_key_wrap(self._key, key_data, default_backend())
520 return cipher_text # IV, cipher text, auth tag
522 def unwrap_key(self, wrapped_key):
523 wrapped_key = ensure_binary(wrapped_key)
524 try:
525 plain_text = aes_key_unwrap(self._key, wrapped_key, default_backend())
526 except InvalidUnwrap as cause:
527 raise JWEError(cause)
528 return plain_text
531class CryptographyHMACKey(Key):
532 """
533 Performs signing and verification operations using HMAC
534 and the specified hash function.
535 """
537 ALG_MAP = {ALGORITHMS.HS256: hashes.SHA256(), ALGORITHMS.HS384: hashes.SHA384(), ALGORITHMS.HS512: hashes.SHA512()}
539 def __init__(self, key, algorithm):
540 if algorithm not in ALGORITHMS.HMAC:
541 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
542 self._algorithm = algorithm
543 self._hash_alg = self.ALG_MAP.get(algorithm)
545 if isinstance(key, dict):
546 self.prepared_key = self._process_jwk(key)
547 return
549 if not isinstance(key, str) and not isinstance(key, bytes):
550 raise JWKError("Expecting a string- or bytes-formatted key.")
552 if isinstance(key, str):
553 key = key.encode("utf-8")
555 invalid_strings = [
556 b"-----BEGIN PUBLIC KEY-----",
557 b"-----BEGIN RSA PUBLIC KEY-----",
558 b"-----BEGIN CERTIFICATE-----",
559 b"ssh-rsa",
560 ]
562 if any(string_value in key for string_value in invalid_strings):
563 raise JWKError(
564 "The specified key is an asymmetric key or x509 certificate and"
565 " should not be used as an HMAC secret."
566 )
568 self.prepared_key = key
570 def _process_jwk(self, jwk_dict):
571 if not jwk_dict.get("kty") == "oct":
572 raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
574 k = jwk_dict.get("k")
575 k = k.encode("utf-8")
576 k = bytes(k)
577 k = base64url_decode(k)
579 return k
581 def to_dict(self):
582 return {
583 "alg": self._algorithm,
584 "kty": "oct",
585 "k": base64url_encode(self.prepared_key).decode("ASCII"),
586 }
588 def sign(self, msg):
589 msg = ensure_binary(msg)
590 h = hmac.HMAC(self.prepared_key, self._hash_alg, backend=default_backend())
591 h.update(msg)
592 signature = h.finalize()
593 return signature
595 def verify(self, msg, sig):
596 msg = ensure_binary(msg)
597 sig = ensure_binary(sig)
598 h = hmac.HMAC(self.prepared_key, self._hash_alg, backend=default_backend())
599 h.update(msg)
600 try:
601 h.verify(sig)
602 verified = True
603 except InvalidSignature:
604 verified = False
605 return verified