Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import hashlib
4import hmac
5import json
6import os
7from abc import ABC, abstractmethod
8from typing import (
9 TYPE_CHECKING,
10 Any,
11 ClassVar,
12 Literal,
13 NoReturn,
14 Union,
15 cast,
16 overload,
17)
19from .exceptions import InvalidKeyError
20from .types import HashlibHash, JWKDict
21from .utils import (
22 base64url_decode,
23 base64url_encode,
24 der_to_raw_signature,
25 force_bytes,
26 from_base64url_uint,
27 is_pem_format,
28 is_ssh_key,
29 raw_to_der_signature,
30 to_base64url_uint,
31)
33try:
34 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
35 from cryptography.hazmat.backends import default_backend
36 from cryptography.hazmat.primitives import hashes
37 from cryptography.hazmat.primitives.asymmetric import padding
38 from cryptography.hazmat.primitives.asymmetric.ec import (
39 ECDSA,
40 SECP256K1,
41 SECP256R1,
42 SECP384R1,
43 SECP521R1,
44 EllipticCurve,
45 EllipticCurvePrivateKey,
46 EllipticCurvePrivateNumbers,
47 EllipticCurvePublicKey,
48 EllipticCurvePublicNumbers,
49 )
50 from cryptography.hazmat.primitives.asymmetric.ed448 import (
51 Ed448PrivateKey,
52 Ed448PublicKey,
53 )
54 from cryptography.hazmat.primitives.asymmetric.ed25519 import (
55 Ed25519PrivateKey,
56 Ed25519PublicKey,
57 )
58 from cryptography.hazmat.primitives.asymmetric.rsa import (
59 RSAPrivateKey,
60 RSAPrivateNumbers,
61 RSAPublicKey,
62 RSAPublicNumbers,
63 rsa_crt_dmp1,
64 rsa_crt_dmq1,
65 rsa_crt_iqmp,
66 rsa_recover_prime_factors,
67 )
68 from cryptography.hazmat.primitives.serialization import (
69 Encoding,
70 NoEncryption,
71 PrivateFormat,
72 PublicFormat,
73 load_pem_private_key,
74 load_pem_public_key,
75 load_ssh_public_key,
76 )
78 # pyjwt-964: we use these both for type checking below, as well as for validating the key passed in.
79 # in Py >= 3.10, we can replace this with the Union types below
80 ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey)
81 ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey)
82 ALLOWED_OKP_KEY_TYPES = (
83 Ed25519PrivateKey,
84 Ed25519PublicKey,
85 Ed448PrivateKey,
86 Ed448PublicKey,
87 )
88 ALLOWED_KEY_TYPES = (
89 ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES
90 )
91 ALLOWED_PRIVATE_KEY_TYPES = (
92 RSAPrivateKey,
93 EllipticCurvePrivateKey,
94 Ed25519PrivateKey,
95 Ed448PrivateKey,
96 )
97 ALLOWED_PUBLIC_KEY_TYPES = (
98 RSAPublicKey,
99 EllipticCurvePublicKey,
100 Ed25519PublicKey,
101 Ed448PublicKey,
102 )
104 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")):
105 import sys
107 if sys.version_info >= (3, 10):
108 from typing import TypeAlias
109 else:
110 # Python 3.9 and lower
111 from typing_extensions import TypeAlias
113 from cryptography.hazmat.primitives.asymmetric.types import (
114 PrivateKeyTypes,
115 PublicKeyTypes,
116 )
118 # Type aliases for convenience in algorithms method signatures
119 AllowedRSAKeys: TypeAlias = Union[RSAPrivateKey, RSAPublicKey]
120 AllowedECKeys: TypeAlias = Union[
121 EllipticCurvePrivateKey, EllipticCurvePublicKey
122 ]
123 AllowedOKPKeys: TypeAlias = Union[
124 Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey
125 ]
126 AllowedKeys: TypeAlias = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys]
127 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed)
128 AllowedPrivateKeys: TypeAlias = Union[
129 RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey
130 ]
131 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed)
132 AllowedPublicKeys: TypeAlias = Union[
133 RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey
134 ]
136 has_crypto = True
137except ModuleNotFoundError:
138 has_crypto = False
141requires_cryptography = {
142 "RS256",
143 "RS384",
144 "RS512",
145 "ES256",
146 "ES256K",
147 "ES384",
148 "ES521",
149 "ES512",
150 "PS256",
151 "PS384",
152 "PS512",
153 "EdDSA",
154}
157def get_default_algorithms() -> dict[str, Algorithm]:
158 """
159 Returns the algorithms that are implemented by the library.
160 """
161 default_algorithms: dict[str, Algorithm] = {
162 "none": NoneAlgorithm(),
163 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
164 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
165 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
166 }
168 if has_crypto:
169 default_algorithms.update(
170 {
171 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
172 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
173 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
174 "ES256": ECAlgorithm(ECAlgorithm.SHA256, SECP256R1),
175 "ES256K": ECAlgorithm(ECAlgorithm.SHA256, SECP256K1),
176 "ES384": ECAlgorithm(ECAlgorithm.SHA384, SECP384R1),
177 "ES521": ECAlgorithm(ECAlgorithm.SHA512, SECP521R1),
178 "ES512": ECAlgorithm(
179 ECAlgorithm.SHA512, SECP521R1
180 ), # Backward compat for #219 fix
181 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
182 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
183 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
184 "EdDSA": OKPAlgorithm(),
185 }
186 )
188 return default_algorithms
191class Algorithm(ABC):
192 """
193 The interface for an algorithm used to sign and verify tokens.
194 """
196 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
197 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None
199 def compute_hash_digest(self, bytestr: bytes) -> bytes:
200 """
201 Compute a hash digest using the specified algorithm's hash algorithm.
203 If there is no hash algorithm, raises a NotImplementedError.
204 """
205 # lookup self.hash_alg if defined in a way that mypy can understand
206 hash_alg = getattr(self, "hash_alg", None)
207 if hash_alg is None:
208 raise NotImplementedError
210 if (
211 has_crypto
212 and isinstance(hash_alg, type)
213 and issubclass(hash_alg, hashes.HashAlgorithm)
214 ):
215 digest = hashes.Hash(hash_alg(), backend=default_backend())
216 digest.update(bytestr)
217 return bytes(digest.finalize())
218 else:
219 return bytes(hash_alg(bytestr).digest())
221 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes) -> None:
222 """Check that the key belongs to the right cryptographic family.
224 Note that this method only works when ``cryptography`` is installed.
226 :param key: Potentially a cryptography key
227 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>`
228 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm
229 :raises InvalidKeyError: if the key doesn't match the expected key classes
230 """
231 if not has_crypto or self._crypto_key_types is None:
232 raise ValueError(
233 "This method requires the cryptography library, and should only be used by cryptography-based algorithms."
234 )
236 if not isinstance(key, self._crypto_key_types):
237 valid_classes = (cls.__name__ for cls in self._crypto_key_types)
238 actual_class = key.__class__.__name__
239 self_class = self.__class__.__name__
240 raise InvalidKeyError(
241 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}"
242 )
244 @abstractmethod
245 def prepare_key(self, key: Any) -> Any:
246 """
247 Performs necessary validation and conversions on the key and returns
248 the key value in the proper format for sign() and verify().
249 """
251 @abstractmethod
252 def sign(self, msg: bytes, key: Any) -> bytes:
253 """
254 Returns a digital signature for the specified message
255 using the specified key value.
256 """
258 @abstractmethod
259 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
260 """
261 Verifies that the specified digital signature is valid
262 for the specified message and key values.
263 """
265 @overload
266 @staticmethod
267 @abstractmethod
268 def to_jwk(key_obj: Any, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
270 @overload
271 @staticmethod
272 @abstractmethod
273 def to_jwk(
274 key_obj: Any, as_dict: Literal[False] = False
275 ) -> str: ... # pragma: no cover
277 @staticmethod
278 @abstractmethod
279 def to_jwk(key_obj: Any, as_dict: bool = False) -> JWKDict | str:
280 """
281 Serializes a given key into a JWK
282 """
284 @staticmethod
285 @abstractmethod
286 def from_jwk(jwk: str | JWKDict) -> Any:
287 """
288 Deserializes a given key from JWK back into a key object
289 """
291 def check_key_length(self, key: Any) -> str | None:
292 """
293 Return a warning message if the key is below the minimum
294 recommended length for this algorithm, or None if adequate.
295 """
296 return None
299class NoneAlgorithm(Algorithm):
300 """
301 Placeholder for use when no signing or verification
302 operations are required.
303 """
305 def prepare_key(self, key: str | None) -> None:
306 if key == "":
307 key = None
309 if key is not None:
310 raise InvalidKeyError('When alg = "none", key value must be None.')
312 return key
314 def sign(self, msg: bytes, key: None) -> bytes:
315 return b""
317 def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
318 return False
320 @staticmethod
321 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
322 raise NotImplementedError()
324 @staticmethod
325 def from_jwk(jwk: str | JWKDict) -> NoReturn:
326 raise NotImplementedError()
329class HMACAlgorithm(Algorithm):
330 """
331 Performs signing and verification operations using HMAC
332 and the specified hash function.
333 """
335 SHA256: ClassVar[HashlibHash] = hashlib.sha256
336 SHA384: ClassVar[HashlibHash] = hashlib.sha384
337 SHA512: ClassVar[HashlibHash] = hashlib.sha512
339 def __init__(self, hash_alg: HashlibHash) -> None:
340 self.hash_alg = hash_alg
342 def prepare_key(self, key: str | bytes) -> bytes:
343 key_bytes = force_bytes(key)
345 if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
346 raise InvalidKeyError(
347 "The specified key is an asymmetric key or x509 certificate and"
348 " should not be used as an HMAC secret."
349 )
351 return key_bytes
353 @overload
354 @staticmethod
355 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: ...
357 @overload
358 @staticmethod
359 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: ...
361 @staticmethod
362 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
363 jwk = {
364 "k": base64url_encode(force_bytes(key_obj)).decode(),
365 "kty": "oct",
366 }
368 if as_dict:
369 return jwk
370 else:
371 return json.dumps(jwk)
373 @staticmethod
374 def from_jwk(jwk: str | JWKDict) -> bytes:
375 try:
376 if isinstance(jwk, str):
377 obj: JWKDict = json.loads(jwk)
378 elif isinstance(jwk, dict):
379 obj = jwk
380 else:
381 raise ValueError
382 except ValueError:
383 raise InvalidKeyError("Key is not valid JSON") from None
385 if obj.get("kty") != "oct":
386 raise InvalidKeyError("Not an HMAC key")
388 return base64url_decode(obj["k"])
390 def check_key_length(self, key: bytes) -> str | None:
391 min_length = self.hash_alg().digest_size
392 if len(key) < min_length:
393 return (
394 f"The HMAC key is {len(key)} bytes long, which is below "
395 f"the minimum recommended length of {min_length} bytes for "
396 f"{self.hash_alg().name.upper()}. "
397 f"See RFC 7518 Section 3.2."
398 )
399 return None
401 def sign(self, msg: bytes, key: bytes) -> bytes:
402 return hmac.new(key, msg, self.hash_alg).digest()
404 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
405 return hmac.compare_digest(sig, self.sign(msg, key))
408if has_crypto:
410 class RSAAlgorithm(Algorithm):
411 """
412 Performs signing and verification operations using
413 RSASSA-PKCS-v1_5 and the specified hash function.
414 """
416 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
417 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
418 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
420 _crypto_key_types = ALLOWED_RSA_KEY_TYPES
421 _MIN_KEY_SIZE: ClassVar[int] = 2048
423 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
424 self.hash_alg = hash_alg
426 def check_key_length(self, key: AllowedRSAKeys) -> str | None:
427 if key.key_size < self._MIN_KEY_SIZE:
428 return (
429 f"The RSA key is {key.key_size} bits long, which is below "
430 f"the minimum recommended size of {self._MIN_KEY_SIZE} bits. "
431 f"See NIST SP 800-131A."
432 )
433 return None
435 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
436 if isinstance(key, self._crypto_key_types):
437 return key
439 if not isinstance(key, (bytes, str)):
440 raise TypeError("Expecting a PEM-formatted key.")
442 key_bytes = force_bytes(key)
444 try:
445 if key_bytes.startswith(b"ssh-rsa"):
446 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
447 self.check_crypto_key_type(public_key)
448 return cast(RSAPublicKey, public_key)
449 else:
450 private_key: PrivateKeyTypes = load_pem_private_key(
451 key_bytes, password=None
452 )
453 self.check_crypto_key_type(private_key)
454 return cast(RSAPrivateKey, private_key)
455 except ValueError:
456 try:
457 public_key = load_pem_public_key(key_bytes)
458 self.check_crypto_key_type(public_key)
459 return cast(RSAPublicKey, public_key)
460 except (ValueError, UnsupportedAlgorithm):
461 raise InvalidKeyError(
462 "Could not parse the provided public key."
463 ) from None
465 @overload
466 @staticmethod
467 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: ...
469 @overload
470 @staticmethod
471 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: ...
473 @staticmethod
474 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
475 obj: dict[str, Any] | None = None
477 if hasattr(key_obj, "private_numbers"):
478 # Private key
479 numbers = key_obj.private_numbers()
481 obj = {
482 "kty": "RSA",
483 "key_ops": ["sign"],
484 "n": to_base64url_uint(numbers.public_numbers.n).decode(),
485 "e": to_base64url_uint(numbers.public_numbers.e).decode(),
486 "d": to_base64url_uint(numbers.d).decode(),
487 "p": to_base64url_uint(numbers.p).decode(),
488 "q": to_base64url_uint(numbers.q).decode(),
489 "dp": to_base64url_uint(numbers.dmp1).decode(),
490 "dq": to_base64url_uint(numbers.dmq1).decode(),
491 "qi": to_base64url_uint(numbers.iqmp).decode(),
492 }
494 elif hasattr(key_obj, "verify"):
495 # Public key
496 numbers = key_obj.public_numbers()
498 obj = {
499 "kty": "RSA",
500 "key_ops": ["verify"],
501 "n": to_base64url_uint(numbers.n).decode(),
502 "e": to_base64url_uint(numbers.e).decode(),
503 }
504 else:
505 raise InvalidKeyError("Not a public or private key")
507 if as_dict:
508 return obj
509 else:
510 return json.dumps(obj)
512 @staticmethod
513 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
514 try:
515 if isinstance(jwk, str):
516 obj = json.loads(jwk)
517 elif isinstance(jwk, dict):
518 obj = jwk
519 else:
520 raise ValueError
521 except ValueError:
522 raise InvalidKeyError("Key is not valid JSON") from None
524 if obj.get("kty") != "RSA":
525 raise InvalidKeyError("Not an RSA key") from None
527 if "d" in obj and "e" in obj and "n" in obj:
528 # Private key
529 if "oth" in obj:
530 raise InvalidKeyError(
531 "Unsupported RSA private key: > 2 primes not supported"
532 )
534 other_props = ["p", "q", "dp", "dq", "qi"]
535 props_found = [prop in obj for prop in other_props]
536 any_props_found = any(props_found)
538 if any_props_found and not all(props_found):
539 raise InvalidKeyError(
540 "RSA key must include all parameters if any are present besides d"
541 ) from None
543 public_numbers = RSAPublicNumbers(
544 from_base64url_uint(obj["e"]),
545 from_base64url_uint(obj["n"]),
546 )
548 if any_props_found:
549 numbers = RSAPrivateNumbers(
550 d=from_base64url_uint(obj["d"]),
551 p=from_base64url_uint(obj["p"]),
552 q=from_base64url_uint(obj["q"]),
553 dmp1=from_base64url_uint(obj["dp"]),
554 dmq1=from_base64url_uint(obj["dq"]),
555 iqmp=from_base64url_uint(obj["qi"]),
556 public_numbers=public_numbers,
557 )
558 else:
559 d = from_base64url_uint(obj["d"])
560 p, q = rsa_recover_prime_factors(
561 public_numbers.n, d, public_numbers.e
562 )
564 numbers = RSAPrivateNumbers(
565 d=d,
566 p=p,
567 q=q,
568 dmp1=rsa_crt_dmp1(d, p),
569 dmq1=rsa_crt_dmq1(d, q),
570 iqmp=rsa_crt_iqmp(p, q),
571 public_numbers=public_numbers,
572 )
574 return numbers.private_key()
575 elif "n" in obj and "e" in obj:
576 # Public key
577 return RSAPublicNumbers(
578 from_base64url_uint(obj["e"]),
579 from_base64url_uint(obj["n"]),
580 ).public_key()
581 else:
582 raise InvalidKeyError("Not a public or private key")
584 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
585 signature: bytes = key.sign(msg, padding.PKCS1v15(), self.hash_alg())
586 return signature
588 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
589 try:
590 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
591 return True
592 except InvalidSignature:
593 return False
595 class ECAlgorithm(Algorithm):
596 """
597 Performs signing and verification operations using
598 ECDSA and the specified hash function
599 """
601 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
602 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
603 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
605 _crypto_key_types = ALLOWED_EC_KEY_TYPES
607 def __init__(
608 self,
609 hash_alg: type[hashes.HashAlgorithm],
610 expected_curve: type[EllipticCurve] | None = None,
611 ) -> None:
612 self.hash_alg = hash_alg
613 self.expected_curve = expected_curve
615 def _validate_curve(self, key: AllowedECKeys) -> None:
616 """Validate that the key's curve matches the expected curve."""
617 if self.expected_curve is None:
618 return
620 if not isinstance(key.curve, self.expected_curve):
621 raise InvalidKeyError(
622 f"The key's curve '{key.curve.name}' does not match the expected "
623 f"curve '{self.expected_curve.name}' for this algorithm"
624 )
626 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
627 if isinstance(key, self._crypto_key_types):
628 self._validate_curve(key)
629 return key
631 if not isinstance(key, (bytes, str)):
632 raise TypeError("Expecting a PEM-formatted key.")
634 key_bytes = force_bytes(key)
636 # Attempt to load key. We don't know if it's
637 # a Signing Key or a Verifying Key, so we try
638 # the Verifying Key first.
639 try:
640 if key_bytes.startswith(b"ecdsa-sha2-"):
641 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
642 else:
643 public_key = load_pem_public_key(key_bytes)
645 # Explicit check the key to prevent confusing errors from cryptography
646 self.check_crypto_key_type(public_key)
647 ec_public_key = cast(EllipticCurvePublicKey, public_key)
648 self._validate_curve(ec_public_key)
649 return ec_public_key
650 except ValueError:
651 private_key = load_pem_private_key(key_bytes, password=None)
652 self.check_crypto_key_type(private_key)
653 ec_private_key = cast(EllipticCurvePrivateKey, private_key)
654 self._validate_curve(ec_private_key)
655 return ec_private_key
657 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
658 der_sig = key.sign(msg, ECDSA(self.hash_alg()))
660 return der_to_raw_signature(der_sig, key.curve)
662 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
663 try:
664 der_sig = raw_to_der_signature(sig, key.curve)
665 except ValueError:
666 return False
668 try:
669 public_key = (
670 key.public_key()
671 if isinstance(key, EllipticCurvePrivateKey)
672 else key
673 )
674 public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
675 return True
676 except InvalidSignature:
677 return False
679 @overload
680 @staticmethod
681 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: ...
683 @overload
684 @staticmethod
685 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: ...
687 @staticmethod
688 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
689 if isinstance(key_obj, EllipticCurvePrivateKey):
690 public_numbers = key_obj.public_key().public_numbers()
691 elif isinstance(key_obj, EllipticCurvePublicKey):
692 public_numbers = key_obj.public_numbers()
693 else:
694 raise InvalidKeyError("Not a public or private key")
696 if isinstance(key_obj.curve, SECP256R1):
697 crv = "P-256"
698 elif isinstance(key_obj.curve, SECP384R1):
699 crv = "P-384"
700 elif isinstance(key_obj.curve, SECP521R1):
701 crv = "P-521"
702 elif isinstance(key_obj.curve, SECP256K1):
703 crv = "secp256k1"
704 else:
705 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
707 obj: dict[str, Any] = {
708 "kty": "EC",
709 "crv": crv,
710 "x": to_base64url_uint(
711 public_numbers.x,
712 bit_length=key_obj.curve.key_size,
713 ).decode(),
714 "y": to_base64url_uint(
715 public_numbers.y,
716 bit_length=key_obj.curve.key_size,
717 ).decode(),
718 }
720 if isinstance(key_obj, EllipticCurvePrivateKey):
721 obj["d"] = to_base64url_uint(
722 key_obj.private_numbers().private_value,
723 bit_length=key_obj.curve.key_size,
724 ).decode()
726 if as_dict:
727 return obj
728 else:
729 return json.dumps(obj)
731 @staticmethod
732 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
733 try:
734 if isinstance(jwk, str):
735 obj = json.loads(jwk)
736 elif isinstance(jwk, dict):
737 obj = jwk
738 else:
739 raise ValueError
740 except ValueError:
741 raise InvalidKeyError("Key is not valid JSON") from None
743 if obj.get("kty") != "EC":
744 raise InvalidKeyError("Not an Elliptic curve key") from None
746 if "x" not in obj or "y" not in obj:
747 raise InvalidKeyError("Not an Elliptic curve key") from None
749 x = base64url_decode(obj.get("x"))
750 y = base64url_decode(obj.get("y"))
752 curve = obj.get("crv")
753 curve_obj: EllipticCurve
755 if curve == "P-256":
756 if len(x) == len(y) == 32:
757 curve_obj = SECP256R1()
758 else:
759 raise InvalidKeyError(
760 "Coords should be 32 bytes for curve P-256"
761 ) from None
762 elif curve == "P-384":
763 if len(x) == len(y) == 48:
764 curve_obj = SECP384R1()
765 else:
766 raise InvalidKeyError(
767 "Coords should be 48 bytes for curve P-384"
768 ) from None
769 elif curve == "P-521":
770 if len(x) == len(y) == 66:
771 curve_obj = SECP521R1()
772 else:
773 raise InvalidKeyError(
774 "Coords should be 66 bytes for curve P-521"
775 ) from None
776 elif curve == "secp256k1":
777 if len(x) == len(y) == 32:
778 curve_obj = SECP256K1()
779 else:
780 raise InvalidKeyError(
781 "Coords should be 32 bytes for curve secp256k1"
782 )
783 else:
784 raise InvalidKeyError(f"Invalid curve: {curve}")
786 public_numbers = EllipticCurvePublicNumbers(
787 x=int.from_bytes(x, byteorder="big"),
788 y=int.from_bytes(y, byteorder="big"),
789 curve=curve_obj,
790 )
792 if "d" not in obj:
793 return public_numbers.public_key()
795 d = base64url_decode(obj.get("d"))
796 if len(d) != len(x):
797 raise InvalidKeyError(
798 "D should be {} bytes for curve {}", len(x), curve
799 )
801 return EllipticCurvePrivateNumbers(
802 int.from_bytes(d, byteorder="big"), public_numbers
803 ).private_key()
805 class RSAPSSAlgorithm(RSAAlgorithm):
806 """
807 Performs a signature using RSASSA-PSS with MGF1
808 """
810 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
811 signature: bytes = key.sign(
812 msg,
813 padding.PSS(
814 mgf=padding.MGF1(self.hash_alg()),
815 salt_length=self.hash_alg().digest_size,
816 ),
817 self.hash_alg(),
818 )
819 return signature
821 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
822 try:
823 key.verify(
824 sig,
825 msg,
826 padding.PSS(
827 mgf=padding.MGF1(self.hash_alg()),
828 salt_length=self.hash_alg().digest_size,
829 ),
830 self.hash_alg(),
831 )
832 return True
833 except InvalidSignature:
834 return False
836 class OKPAlgorithm(Algorithm):
837 """
838 Performs signing and verification operations using EdDSA
840 This class requires ``cryptography>=2.6`` to be installed.
841 """
843 _crypto_key_types = ALLOWED_OKP_KEY_TYPES
845 def __init__(self, **kwargs: Any) -> None:
846 pass
848 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
849 if not isinstance(key, (str, bytes)):
850 self.check_crypto_key_type(key)
851 return key
853 key_str = key.decode("utf-8") if isinstance(key, bytes) else key
854 key_bytes = key.encode("utf-8") if isinstance(key, str) else key
856 loaded_key: PublicKeyTypes | PrivateKeyTypes
857 if "-----BEGIN PUBLIC" in key_str:
858 loaded_key = load_pem_public_key(key_bytes)
859 elif "-----BEGIN PRIVATE" in key_str:
860 loaded_key = load_pem_private_key(key_bytes, password=None)
861 elif key_str[0:4] == "ssh-":
862 loaded_key = load_ssh_public_key(key_bytes)
863 else:
864 raise InvalidKeyError("Not a public or private key")
866 # Explicit check the key to prevent confusing errors from cryptography
867 self.check_crypto_key_type(loaded_key)
868 return cast("AllowedOKPKeys", loaded_key)
870 def sign(
871 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
872 ) -> bytes:
873 """
874 Sign a message ``msg`` using the EdDSA private key ``key``
875 :param str|bytes msg: Message to sign
876 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
877 or :class:`.Ed448PrivateKey` isinstance
878 :return bytes signature: The signature, as bytes
879 """
880 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
881 signature: bytes = key.sign(msg_bytes)
882 return signature
884 def verify(
885 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
886 ) -> bool:
887 """
888 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
890 :param str|bytes sig: EdDSA signature to check ``msg`` against
891 :param str|bytes msg: Message to sign
892 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
893 A private or public EdDSA key instance
894 :return bool verified: True if signature is valid, False if not.
895 """
896 try:
897 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
898 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
900 public_key = (
901 key.public_key()
902 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
903 else key
904 )
905 public_key.verify(sig_bytes, msg_bytes)
906 return True # If no exception was raised, the signature is valid.
907 except InvalidSignature:
908 return False
910 @overload
911 @staticmethod
912 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: ...
914 @overload
915 @staticmethod
916 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: ...
918 @staticmethod
919 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
920 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
921 x = key.public_bytes(
922 encoding=Encoding.Raw,
923 format=PublicFormat.Raw,
924 )
925 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
927 obj = {
928 "x": base64url_encode(force_bytes(x)).decode(),
929 "kty": "OKP",
930 "crv": crv,
931 }
933 if as_dict:
934 return obj
935 else:
936 return json.dumps(obj)
938 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
939 d = key.private_bytes(
940 encoding=Encoding.Raw,
941 format=PrivateFormat.Raw,
942 encryption_algorithm=NoEncryption(),
943 )
945 x = key.public_key().public_bytes(
946 encoding=Encoding.Raw,
947 format=PublicFormat.Raw,
948 )
950 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
951 obj = {
952 "x": base64url_encode(force_bytes(x)).decode(),
953 "d": base64url_encode(force_bytes(d)).decode(),
954 "kty": "OKP",
955 "crv": crv,
956 }
958 if as_dict:
959 return obj
960 else:
961 return json.dumps(obj)
963 raise InvalidKeyError("Not a public or private key")
965 @staticmethod
966 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
967 try:
968 if isinstance(jwk, str):
969 obj = json.loads(jwk)
970 elif isinstance(jwk, dict):
971 obj = jwk
972 else:
973 raise ValueError
974 except ValueError:
975 raise InvalidKeyError("Key is not valid JSON") from None
977 if obj.get("kty") != "OKP":
978 raise InvalidKeyError("Not an Octet Key Pair")
980 curve = obj.get("crv")
981 if curve != "Ed25519" and curve != "Ed448":
982 raise InvalidKeyError(f"Invalid curve: {curve}")
984 if "x" not in obj:
985 raise InvalidKeyError('OKP should have "x" parameter')
986 x = base64url_decode(obj.get("x"))
988 try:
989 if "d" not in obj:
990 if curve == "Ed25519":
991 return Ed25519PublicKey.from_public_bytes(x)
992 return Ed448PublicKey.from_public_bytes(x)
993 d = base64url_decode(obj.get("d"))
994 if curve == "Ed25519":
995 return Ed25519PrivateKey.from_private_bytes(d)
996 return Ed448PrivateKey.from_private_bytes(d)
997 except ValueError as err:
998 raise InvalidKeyError("Invalid key parameter") from err