Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import hashlib
4import hmac
5import json
6import os
7from abc import ABC, abstractmethod
8from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
10from .exceptions import InvalidKeyError
11from .types import HashlibHash, JWKDict
12from .utils import (
13 base64url_decode,
14 base64url_encode,
15 der_to_raw_signature,
16 force_bytes,
17 from_base64url_uint,
18 is_pem_format,
19 is_ssh_key,
20 raw_to_der_signature,
21 to_base64url_uint,
22)
24try:
25 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
26 from cryptography.hazmat.backends import default_backend
27 from cryptography.hazmat.primitives import hashes
28 from cryptography.hazmat.primitives.asymmetric import padding
29 from cryptography.hazmat.primitives.asymmetric.ec import (
30 ECDSA,
31 SECP256K1,
32 SECP256R1,
33 SECP384R1,
34 SECP521R1,
35 EllipticCurve,
36 EllipticCurvePrivateKey,
37 EllipticCurvePrivateNumbers,
38 EllipticCurvePublicKey,
39 EllipticCurvePublicNumbers,
40 )
41 from cryptography.hazmat.primitives.asymmetric.ed448 import (
42 Ed448PrivateKey,
43 Ed448PublicKey,
44 )
45 from cryptography.hazmat.primitives.asymmetric.ed25519 import (
46 Ed25519PrivateKey,
47 Ed25519PublicKey,
48 )
49 from cryptography.hazmat.primitives.asymmetric.rsa import (
50 RSAPrivateKey,
51 RSAPrivateNumbers,
52 RSAPublicKey,
53 RSAPublicNumbers,
54 rsa_crt_dmp1,
55 rsa_crt_dmq1,
56 rsa_crt_iqmp,
57 rsa_recover_prime_factors,
58 )
59 from cryptography.hazmat.primitives.serialization import (
60 Encoding,
61 NoEncryption,
62 PrivateFormat,
63 PublicFormat,
64 load_pem_private_key,
65 load_pem_public_key,
66 load_ssh_public_key,
67 )
69 # pyjwt-964: we use these both for type checking below, as well as for validating the key passed in.
70 # in Py >= 3.10, we can replace this with the Union types below
71 ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey)
72 ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey)
73 ALLOWED_OKP_KEY_TYPES = (
74 Ed25519PrivateKey,
75 Ed25519PublicKey,
76 Ed448PrivateKey,
77 Ed448PublicKey,
78 )
79 ALLOWED_KEY_TYPES = (
80 ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES
81 )
82 ALLOWED_PRIVATE_KEY_TYPES = (
83 RSAPrivateKey,
84 EllipticCurvePrivateKey,
85 Ed25519PrivateKey,
86 Ed448PrivateKey,
87 )
88 ALLOWED_PUBLIC_KEY_TYPES = (
89 RSAPublicKey,
90 EllipticCurvePublicKey,
91 Ed25519PublicKey,
92 Ed448PublicKey,
93 )
95 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")):
96 from typing import TypeAlias
98 from cryptography.hazmat.primitives.asymmetric.types import (
99 PrivateKeyTypes,
100 PublicKeyTypes,
101 )
103 # Type aliases for convenience in algorithms method signatures
104 AllowedRSAKeys: TypeAlias = RSAPrivateKey | RSAPublicKey
105 AllowedECKeys: TypeAlias = EllipticCurvePrivateKey | EllipticCurvePublicKey
106 AllowedOKPKeys: TypeAlias = (
107 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
108 )
109 AllowedKeys: TypeAlias = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
110 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed)
111 AllowedPrivateKeys: TypeAlias = (
112 RSAPrivateKey
113 | EllipticCurvePrivateKey
114 | Ed25519PrivateKey
115 | Ed448PrivateKey
116 )
117 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed)
118 AllowedPublicKeys: TypeAlias = (
119 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
120 )
122 has_crypto = True
123except ModuleNotFoundError:
124 has_crypto = False
127requires_cryptography = {
128 "RS256",
129 "RS384",
130 "RS512",
131 "ES256",
132 "ES256K",
133 "ES384",
134 "ES521",
135 "ES512",
136 "PS256",
137 "PS384",
138 "PS512",
139 "EdDSA",
140}
143def get_default_algorithms() -> dict[str, Algorithm]:
144 """
145 Returns the algorithms that are implemented by the library.
146 """
147 default_algorithms: dict[str, Algorithm] = {
148 "none": NoneAlgorithm(),
149 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
150 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
151 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
152 }
154 if has_crypto:
155 default_algorithms.update(
156 {
157 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
158 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
159 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
160 "ES256": ECAlgorithm(ECAlgorithm.SHA256),
161 "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
162 "ES384": ECAlgorithm(ECAlgorithm.SHA384),
163 "ES521": ECAlgorithm(ECAlgorithm.SHA512),
164 "ES512": ECAlgorithm(
165 ECAlgorithm.SHA512
166 ), # Backward compat for #219 fix
167 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
168 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
169 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
170 "EdDSA": OKPAlgorithm(),
171 }
172 )
174 return default_algorithms
177class Algorithm(ABC):
178 """
179 The interface for an algorithm used to sign and verify tokens.
180 """
182 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
183 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None
185 def compute_hash_digest(self, bytestr: bytes) -> bytes:
186 """
187 Compute a hash digest using the specified algorithm's hash algorithm.
189 If there is no hash algorithm, raises a NotImplementedError.
190 """
191 # lookup self.hash_alg if defined in a way that mypy can understand
192 hash_alg = getattr(self, "hash_alg", None)
193 if hash_alg is None:
194 raise NotImplementedError
196 if (
197 has_crypto
198 and isinstance(hash_alg, type)
199 and issubclass(hash_alg, hashes.HashAlgorithm)
200 ):
201 digest = hashes.Hash(hash_alg(), backend=default_backend())
202 digest.update(bytestr)
203 return bytes(digest.finalize())
204 else:
205 return bytes(hash_alg(bytestr).digest())
207 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes):
208 """Check that the key belongs to the right cryptographic family.
210 Note that this method only works when ``cryptography`` is installed.
212 :param key: Potentially a cryptography key
213 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>`
214 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm
215 :raises InvalidKeyError: if the key doesn't match the expected key classes
216 """
217 if not has_crypto or self._crypto_key_types is None:
218 raise ValueError(
219 "This method requires the cryptography library, and should only be used by cryptography-based algorithms."
220 )
222 if not isinstance(key, self._crypto_key_types):
223 valid_classes = (cls.__name__ for cls in self._crypto_key_types)
224 actual_class = key.__class__.__name__
225 self_class = self.__class__.__name__
226 raise InvalidKeyError(
227 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}"
228 )
230 @abstractmethod
231 def prepare_key(self, key: Any) -> Any:
232 """
233 Performs necessary validation and conversions on the key and returns
234 the key value in the proper format for sign() and verify().
235 """
237 @abstractmethod
238 def sign(self, msg: bytes, key: Any) -> bytes:
239 """
240 Returns a digital signature for the specified message
241 using the specified key value.
242 """
244 @abstractmethod
245 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
246 """
247 Verifies that the specified digital signature is valid
248 for the specified message and key values.
249 """
251 @overload
252 @staticmethod
253 @abstractmethod
254 def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
256 @overload
257 @staticmethod
258 @abstractmethod
259 def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
261 @staticmethod
262 @abstractmethod
263 def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
264 """
265 Serializes a given key into a JWK
266 """
268 @staticmethod
269 @abstractmethod
270 def from_jwk(jwk: str | JWKDict) -> Any:
271 """
272 Deserializes a given key from JWK back into a key object
273 """
276class NoneAlgorithm(Algorithm):
277 """
278 Placeholder for use when no signing or verification
279 operations are required.
280 """
282 def prepare_key(self, key: str | None) -> None:
283 if key == "":
284 key = None
286 if key is not None:
287 raise InvalidKeyError('When alg = "none", key value must be None.')
289 return key
291 def sign(self, msg: bytes, key: None) -> bytes:
292 return b""
294 def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
295 return False
297 @staticmethod
298 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
299 raise NotImplementedError()
301 @staticmethod
302 def from_jwk(jwk: str | JWKDict) -> NoReturn:
303 raise NotImplementedError()
306class HMACAlgorithm(Algorithm):
307 """
308 Performs signing and verification operations using HMAC
309 and the specified hash function.
310 """
312 SHA256: ClassVar[HashlibHash] = hashlib.sha256
313 SHA384: ClassVar[HashlibHash] = hashlib.sha384
314 SHA512: ClassVar[HashlibHash] = hashlib.sha512
316 def __init__(self, hash_alg: HashlibHash) -> None:
317 self.hash_alg = hash_alg
319 def prepare_key(self, key: str | bytes) -> bytes:
320 key_bytes = force_bytes(key)
322 if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
323 raise InvalidKeyError(
324 "The specified key is an asymmetric key or x509 certificate and"
325 " should not be used as an HMAC secret."
326 )
328 return key_bytes
330 @overload
331 @staticmethod
332 def to_jwk(
333 key_obj: str | bytes, as_dict: Literal[True]
334 ) -> JWKDict: ... # pragma: no cover
336 @overload
337 @staticmethod
338 def to_jwk(
339 key_obj: str | bytes, as_dict: Literal[False] = False
340 ) -> str: ... # pragma: no cover
342 @staticmethod
343 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
344 jwk = {
345 "k": base64url_encode(force_bytes(key_obj)).decode(),
346 "kty": "oct",
347 }
349 if as_dict:
350 return jwk
351 else:
352 return json.dumps(jwk)
354 @staticmethod
355 def from_jwk(jwk: str | JWKDict) -> bytes:
356 try:
357 if isinstance(jwk, str):
358 obj: JWKDict = json.loads(jwk)
359 elif isinstance(jwk, dict):
360 obj = jwk
361 else:
362 raise ValueError
363 except ValueError:
364 raise InvalidKeyError("Key is not valid JSON") from None
366 if obj.get("kty") != "oct":
367 raise InvalidKeyError("Not an HMAC key")
369 return base64url_decode(obj["k"])
371 def sign(self, msg: bytes, key: bytes) -> bytes:
372 return hmac.new(key, msg, self.hash_alg).digest()
374 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
375 return hmac.compare_digest(sig, self.sign(msg, key))
378if has_crypto:
380 class RSAAlgorithm(Algorithm):
381 """
382 Performs signing and verification operations using
383 RSASSA-PKCS-v1_5 and the specified hash function.
384 """
386 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
387 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
388 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
390 _crypto_key_types = ALLOWED_RSA_KEY_TYPES
392 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
393 self.hash_alg = hash_alg
395 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
396 if isinstance(key, self._crypto_key_types):
397 return key
399 if not isinstance(key, (bytes, str)):
400 raise TypeError("Expecting a PEM-formatted key.")
402 key_bytes = force_bytes(key)
404 try:
405 if key_bytes.startswith(b"ssh-rsa"):
406 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
407 self.check_crypto_key_type(public_key)
408 return cast(RSAPublicKey, public_key)
409 else:
410 private_key: PrivateKeyTypes = load_pem_private_key(
411 key_bytes, password=None
412 )
413 self.check_crypto_key_type(private_key)
414 return cast(RSAPrivateKey, private_key)
415 except ValueError:
416 try:
417 public_key = load_pem_public_key(key_bytes)
418 self.check_crypto_key_type(public_key)
419 return cast(RSAPublicKey, public_key)
420 except (ValueError, UnsupportedAlgorithm):
421 raise InvalidKeyError(
422 "Could not parse the provided public key."
423 ) from None
425 @overload
426 @staticmethod
427 def to_jwk(
428 key_obj: AllowedRSAKeys, as_dict: Literal[True]
429 ) -> JWKDict: ... # pragma: no cover
431 @overload
432 @staticmethod
433 def to_jwk(
434 key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
435 ) -> str: ... # pragma: no cover
437 @staticmethod
438 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
439 obj: dict[str, Any] | None = None
441 if hasattr(key_obj, "private_numbers"):
442 # Private key
443 numbers = key_obj.private_numbers()
445 obj = {
446 "kty": "RSA",
447 "key_ops": ["sign"],
448 "n": to_base64url_uint(numbers.public_numbers.n).decode(),
449 "e": to_base64url_uint(numbers.public_numbers.e).decode(),
450 "d": to_base64url_uint(numbers.d).decode(),
451 "p": to_base64url_uint(numbers.p).decode(),
452 "q": to_base64url_uint(numbers.q).decode(),
453 "dp": to_base64url_uint(numbers.dmp1).decode(),
454 "dq": to_base64url_uint(numbers.dmq1).decode(),
455 "qi": to_base64url_uint(numbers.iqmp).decode(),
456 }
458 elif hasattr(key_obj, "verify"):
459 # Public key
460 numbers = key_obj.public_numbers()
462 obj = {
463 "kty": "RSA",
464 "key_ops": ["verify"],
465 "n": to_base64url_uint(numbers.n).decode(),
466 "e": to_base64url_uint(numbers.e).decode(),
467 }
468 else:
469 raise InvalidKeyError("Not a public or private key")
471 if as_dict:
472 return obj
473 else:
474 return json.dumps(obj)
476 @staticmethod
477 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
478 try:
479 if isinstance(jwk, str):
480 obj = json.loads(jwk)
481 elif isinstance(jwk, dict):
482 obj = jwk
483 else:
484 raise ValueError
485 except ValueError:
486 raise InvalidKeyError("Key is not valid JSON") from None
488 if obj.get("kty") != "RSA":
489 raise InvalidKeyError("Not an RSA key") from None
491 if "d" in obj and "e" in obj and "n" in obj:
492 # Private key
493 if "oth" in obj:
494 raise InvalidKeyError(
495 "Unsupported RSA private key: > 2 primes not supported"
496 )
498 other_props = ["p", "q", "dp", "dq", "qi"]
499 props_found = [prop in obj for prop in other_props]
500 any_props_found = any(props_found)
502 if any_props_found and not all(props_found):
503 raise InvalidKeyError(
504 "RSA key must include all parameters if any are present besides d"
505 ) from None
507 public_numbers = RSAPublicNumbers(
508 from_base64url_uint(obj["e"]),
509 from_base64url_uint(obj["n"]),
510 )
512 if any_props_found:
513 numbers = RSAPrivateNumbers(
514 d=from_base64url_uint(obj["d"]),
515 p=from_base64url_uint(obj["p"]),
516 q=from_base64url_uint(obj["q"]),
517 dmp1=from_base64url_uint(obj["dp"]),
518 dmq1=from_base64url_uint(obj["dq"]),
519 iqmp=from_base64url_uint(obj["qi"]),
520 public_numbers=public_numbers,
521 )
522 else:
523 d = from_base64url_uint(obj["d"])
524 p, q = rsa_recover_prime_factors(
525 public_numbers.n, d, public_numbers.e
526 )
528 numbers = RSAPrivateNumbers(
529 d=d,
530 p=p,
531 q=q,
532 dmp1=rsa_crt_dmp1(d, p),
533 dmq1=rsa_crt_dmq1(d, q),
534 iqmp=rsa_crt_iqmp(p, q),
535 public_numbers=public_numbers,
536 )
538 return numbers.private_key()
539 elif "n" in obj and "e" in obj:
540 # Public key
541 return RSAPublicNumbers(
542 from_base64url_uint(obj["e"]),
543 from_base64url_uint(obj["n"]),
544 ).public_key()
545 else:
546 raise InvalidKeyError("Not a public or private key")
548 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
549 return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
551 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
552 try:
553 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
554 return True
555 except InvalidSignature:
556 return False
558 class ECAlgorithm(Algorithm):
559 """
560 Performs signing and verification operations using
561 ECDSA and the specified hash function
562 """
564 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
565 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
566 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
568 _crypto_key_types = ALLOWED_EC_KEY_TYPES
570 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
571 self.hash_alg = hash_alg
573 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
574 if isinstance(key, self._crypto_key_types):
575 return key
577 if not isinstance(key, (bytes, str)):
578 raise TypeError("Expecting a PEM-formatted key.")
580 key_bytes = force_bytes(key)
582 # Attempt to load key. We don't know if it's
583 # a Signing Key or a Verifying Key, so we try
584 # the Verifying Key first.
585 try:
586 if key_bytes.startswith(b"ecdsa-sha2-"):
587 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
588 else:
589 public_key = load_pem_public_key(key_bytes)
591 # Explicit check the key to prevent confusing errors from cryptography
592 self.check_crypto_key_type(public_key)
593 return cast(EllipticCurvePublicKey, public_key)
594 except ValueError:
595 private_key = load_pem_private_key(key_bytes, password=None)
596 self.check_crypto_key_type(private_key)
597 return cast(EllipticCurvePrivateKey, private_key)
599 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
600 der_sig = key.sign(msg, ECDSA(self.hash_alg()))
602 return der_to_raw_signature(der_sig, key.curve)
604 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
605 try:
606 der_sig = raw_to_der_signature(sig, key.curve)
607 except ValueError:
608 return False
610 try:
611 public_key = (
612 key.public_key()
613 if isinstance(key, EllipticCurvePrivateKey)
614 else key
615 )
616 public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
617 return True
618 except InvalidSignature:
619 return False
621 @overload
622 @staticmethod
623 def to_jwk(
624 key_obj: AllowedECKeys, as_dict: Literal[True]
625 ) -> JWKDict: ... # pragma: no cover
627 @overload
628 @staticmethod
629 def to_jwk(
630 key_obj: AllowedECKeys, as_dict: Literal[False] = False
631 ) -> str: ... # pragma: no cover
633 @staticmethod
634 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
635 if isinstance(key_obj, EllipticCurvePrivateKey):
636 public_numbers = key_obj.public_key().public_numbers()
637 elif isinstance(key_obj, EllipticCurvePublicKey):
638 public_numbers = key_obj.public_numbers()
639 else:
640 raise InvalidKeyError("Not a public or private key")
642 if isinstance(key_obj.curve, SECP256R1):
643 crv = "P-256"
644 elif isinstance(key_obj.curve, SECP384R1):
645 crv = "P-384"
646 elif isinstance(key_obj.curve, SECP521R1):
647 crv = "P-521"
648 elif isinstance(key_obj.curve, SECP256K1):
649 crv = "secp256k1"
650 else:
651 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
653 obj: dict[str, Any] = {
654 "kty": "EC",
655 "crv": crv,
656 "x": to_base64url_uint(
657 public_numbers.x,
658 bit_length=key_obj.curve.key_size,
659 ).decode(),
660 "y": to_base64url_uint(
661 public_numbers.y,
662 bit_length=key_obj.curve.key_size,
663 ).decode(),
664 }
666 if isinstance(key_obj, EllipticCurvePrivateKey):
667 obj["d"] = to_base64url_uint(
668 key_obj.private_numbers().private_value,
669 bit_length=key_obj.curve.key_size,
670 ).decode()
672 if as_dict:
673 return obj
674 else:
675 return json.dumps(obj)
677 @staticmethod
678 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
679 try:
680 if isinstance(jwk, str):
681 obj = json.loads(jwk)
682 elif isinstance(jwk, dict):
683 obj = jwk
684 else:
685 raise ValueError
686 except ValueError:
687 raise InvalidKeyError("Key is not valid JSON") from None
689 if obj.get("kty") != "EC":
690 raise InvalidKeyError("Not an Elliptic curve key") from None
692 if "x" not in obj or "y" not in obj:
693 raise InvalidKeyError("Not an Elliptic curve key") from None
695 x = base64url_decode(obj.get("x"))
696 y = base64url_decode(obj.get("y"))
698 curve = obj.get("crv")
699 curve_obj: EllipticCurve
701 if curve == "P-256":
702 if len(x) == len(y) == 32:
703 curve_obj = SECP256R1()
704 else:
705 raise InvalidKeyError(
706 "Coords should be 32 bytes for curve P-256"
707 ) from None
708 elif curve == "P-384":
709 if len(x) == len(y) == 48:
710 curve_obj = SECP384R1()
711 else:
712 raise InvalidKeyError(
713 "Coords should be 48 bytes for curve P-384"
714 ) from None
715 elif curve == "P-521":
716 if len(x) == len(y) == 66:
717 curve_obj = SECP521R1()
718 else:
719 raise InvalidKeyError(
720 "Coords should be 66 bytes for curve P-521"
721 ) from None
722 elif curve == "secp256k1":
723 if len(x) == len(y) == 32:
724 curve_obj = SECP256K1()
725 else:
726 raise InvalidKeyError(
727 "Coords should be 32 bytes for curve secp256k1"
728 )
729 else:
730 raise InvalidKeyError(f"Invalid curve: {curve}")
732 public_numbers = EllipticCurvePublicNumbers(
733 x=int.from_bytes(x, byteorder="big"),
734 y=int.from_bytes(y, byteorder="big"),
735 curve=curve_obj,
736 )
738 if "d" not in obj:
739 return public_numbers.public_key()
741 d = base64url_decode(obj.get("d"))
742 if len(d) != len(x):
743 raise InvalidKeyError(
744 "D should be {} bytes for curve {}", len(x), curve
745 )
747 return EllipticCurvePrivateNumbers(
748 int.from_bytes(d, byteorder="big"), public_numbers
749 ).private_key()
751 class RSAPSSAlgorithm(RSAAlgorithm):
752 """
753 Performs a signature using RSASSA-PSS with MGF1
754 """
756 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
757 return key.sign(
758 msg,
759 padding.PSS(
760 mgf=padding.MGF1(self.hash_alg()),
761 salt_length=self.hash_alg().digest_size,
762 ),
763 self.hash_alg(),
764 )
766 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
767 try:
768 key.verify(
769 sig,
770 msg,
771 padding.PSS(
772 mgf=padding.MGF1(self.hash_alg()),
773 salt_length=self.hash_alg().digest_size,
774 ),
775 self.hash_alg(),
776 )
777 return True
778 except InvalidSignature:
779 return False
781 class OKPAlgorithm(Algorithm):
782 """
783 Performs signing and verification operations using EdDSA
785 This class requires ``cryptography>=2.6`` to be installed.
786 """
788 _crypto_key_types = ALLOWED_OKP_KEY_TYPES
790 def __init__(self, **kwargs: Any) -> None:
791 pass
793 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
794 if not isinstance(key, (str, bytes)):
795 self.check_crypto_key_type(key)
796 return cast("AllowedOKPKeys", key)
798 key_str = key.decode("utf-8") if isinstance(key, bytes) else key
799 key_bytes = key.encode("utf-8") if isinstance(key, str) else key
801 loaded_key: PublicKeyTypes | PrivateKeyTypes
802 if "-----BEGIN PUBLIC" in key_str:
803 loaded_key = load_pem_public_key(key_bytes)
804 elif "-----BEGIN PRIVATE" in key_str:
805 loaded_key = load_pem_private_key(key_bytes, password=None)
806 elif key_str[0:4] == "ssh-":
807 loaded_key = load_ssh_public_key(key_bytes)
808 else:
809 raise InvalidKeyError("Not a public or private key")
811 # Explicit check the key to prevent confusing errors from cryptography
812 self.check_crypto_key_type(loaded_key)
813 return cast("AllowedOKPKeys", loaded_key)
815 def sign(
816 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
817 ) -> bytes:
818 """
819 Sign a message ``msg`` using the EdDSA private key ``key``
820 :param str|bytes msg: Message to sign
821 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
822 or :class:`.Ed448PrivateKey` isinstance
823 :return bytes signature: The signature, as bytes
824 """
825 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
826 return key.sign(msg_bytes)
828 def verify(
829 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
830 ) -> bool:
831 """
832 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
834 :param str|bytes sig: EdDSA signature to check ``msg`` against
835 :param str|bytes msg: Message to sign
836 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
837 A private or public EdDSA key instance
838 :return bool verified: True if signature is valid, False if not.
839 """
840 try:
841 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
842 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
844 public_key = (
845 key.public_key()
846 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
847 else key
848 )
849 public_key.verify(sig_bytes, msg_bytes)
850 return True # If no exception was raised, the signature is valid.
851 except InvalidSignature:
852 return False
854 @overload
855 @staticmethod
856 def to_jwk(
857 key: AllowedOKPKeys, as_dict: Literal[True]
858 ) -> JWKDict: ... # pragma: no cover
860 @overload
861 @staticmethod
862 def to_jwk(
863 key: AllowedOKPKeys, as_dict: Literal[False] = False
864 ) -> str: ... # pragma: no cover
866 @staticmethod
867 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
868 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
869 x = key.public_bytes(
870 encoding=Encoding.Raw,
871 format=PublicFormat.Raw,
872 )
873 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
875 obj = {
876 "x": base64url_encode(force_bytes(x)).decode(),
877 "kty": "OKP",
878 "crv": crv,
879 }
881 if as_dict:
882 return obj
883 else:
884 return json.dumps(obj)
886 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
887 d = key.private_bytes(
888 encoding=Encoding.Raw,
889 format=PrivateFormat.Raw,
890 encryption_algorithm=NoEncryption(),
891 )
893 x = key.public_key().public_bytes(
894 encoding=Encoding.Raw,
895 format=PublicFormat.Raw,
896 )
898 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
899 obj = {
900 "x": base64url_encode(force_bytes(x)).decode(),
901 "d": base64url_encode(force_bytes(d)).decode(),
902 "kty": "OKP",
903 "crv": crv,
904 }
906 if as_dict:
907 return obj
908 else:
909 return json.dumps(obj)
911 raise InvalidKeyError("Not a public or private key")
913 @staticmethod
914 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
915 try:
916 if isinstance(jwk, str):
917 obj = json.loads(jwk)
918 elif isinstance(jwk, dict):
919 obj = jwk
920 else:
921 raise ValueError
922 except ValueError:
923 raise InvalidKeyError("Key is not valid JSON") from None
925 if obj.get("kty") != "OKP":
926 raise InvalidKeyError("Not an Octet Key Pair")
928 curve = obj.get("crv")
929 if curve != "Ed25519" and curve != "Ed448":
930 raise InvalidKeyError(f"Invalid curve: {curve}")
932 if "x" not in obj:
933 raise InvalidKeyError('OKP should have "x" parameter')
934 x = base64url_decode(obj.get("x"))
936 try:
937 if "d" not in obj:
938 if curve == "Ed25519":
939 return Ed25519PublicKey.from_public_bytes(x)
940 return Ed448PublicKey.from_public_bytes(x)
941 d = base64url_decode(obj.get("d"))
942 if curve == "Ed25519":
943 return Ed25519PrivateKey.from_private_bytes(d)
944 return Ed448PrivateKey.from_private_bytes(d)
945 except ValueError as err:
946 raise InvalidKeyError("Invalid key parameter") from err