Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import hashlib
4import hmac
5import json
6import os
7import sys
8from abc import ABC, abstractmethod
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 ClassVar,
13 Literal,
14 NoReturn,
15 Union,
16 cast,
17 get_args,
18 overload,
19)
21from .exceptions import InvalidKeyError
22from .types import HashlibHash, JWKDict
23from .utils import (
24 base64url_decode,
25 base64url_encode,
26 der_to_raw_signature,
27 force_bytes,
28 from_base64url_uint,
29 is_pem_format,
30 is_ssh_key,
31 raw_to_der_signature,
32 to_base64url_uint,
33)
35try:
36 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
37 from cryptography.hazmat.backends import default_backend
38 from cryptography.hazmat.primitives import hashes
39 from cryptography.hazmat.primitives.asymmetric import padding
40 from cryptography.hazmat.primitives.asymmetric.ec import (
41 ECDSA,
42 SECP256K1,
43 SECP256R1,
44 SECP384R1,
45 SECP521R1,
46 EllipticCurve,
47 EllipticCurvePrivateKey,
48 EllipticCurvePrivateNumbers,
49 EllipticCurvePublicKey,
50 EllipticCurvePublicNumbers,
51 )
52 from cryptography.hazmat.primitives.asymmetric.ed448 import (
53 Ed448PrivateKey,
54 Ed448PublicKey,
55 )
56 from cryptography.hazmat.primitives.asymmetric.ed25519 import (
57 Ed25519PrivateKey,
58 Ed25519PublicKey,
59 )
60 from cryptography.hazmat.primitives.asymmetric.rsa import (
61 RSAPrivateKey,
62 RSAPrivateNumbers,
63 RSAPublicKey,
64 RSAPublicNumbers,
65 rsa_crt_dmp1,
66 rsa_crt_dmq1,
67 rsa_crt_iqmp,
68 rsa_recover_prime_factors,
69 )
70 from cryptography.hazmat.primitives.serialization import (
71 Encoding,
72 NoEncryption,
73 PrivateFormat,
74 PublicFormat,
75 load_pem_private_key,
76 load_pem_public_key,
77 load_ssh_public_key,
78 )
80 if sys.version_info >= (3, 10):
81 from typing import TypeAlias
82 else:
83 # Python 3.9 and lower
84 from typing_extensions import TypeAlias
86 # Type aliases for convenience in algorithms method signatures
87 AllowedRSAKeys: TypeAlias = Union[RSAPrivateKey, RSAPublicKey]
88 AllowedECKeys: TypeAlias = Union[EllipticCurvePrivateKey, EllipticCurvePublicKey]
89 AllowedOKPKeys: TypeAlias = Union[
90 Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey
91 ]
92 AllowedKeys: TypeAlias = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys]
93 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed)
94 AllowedPrivateKeys: TypeAlias = Union[
95 RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey
96 ]
97 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed)
98 AllowedPublicKeys: TypeAlias = Union[
99 RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey
100 ]
102 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")):
103 from cryptography.hazmat.primitives.asymmetric.types import (
104 PrivateKeyTypes,
105 PublicKeyTypes,
106 )
108 has_crypto = True
109except ModuleNotFoundError:
110 if sys.version_info >= (3, 11):
111 from typing import Never
112 else:
113 from typing_extensions import Never
115 AllowedRSAKeys = Never # type: ignore[misc]
116 AllowedECKeys = Never # type: ignore[misc]
117 AllowedOKPKeys = Never # type: ignore[misc]
118 AllowedKeys = Never # type: ignore[misc]
119 AllowedPrivateKeys = Never # type: ignore[misc]
120 AllowedPublicKeys = Never # type: ignore[misc]
121 has_crypto = False
124requires_cryptography = {
125 "RS256",
126 "RS384",
127 "RS512",
128 "ES256",
129 "ES256K",
130 "ES384",
131 "ES521",
132 "ES512",
133 "PS256",
134 "PS384",
135 "PS512",
136 "EdDSA",
137}
140def get_default_algorithms() -> dict[str, Algorithm]:
141 """
142 Returns the algorithms that are implemented by the library.
143 """
144 default_algorithms: dict[str, Algorithm] = {
145 "none": NoneAlgorithm(),
146 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
147 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
148 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
149 }
151 if has_crypto:
152 default_algorithms.update(
153 {
154 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
155 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
156 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
157 "ES256": ECAlgorithm(ECAlgorithm.SHA256, SECP256R1),
158 "ES256K": ECAlgorithm(ECAlgorithm.SHA256, SECP256K1),
159 "ES384": ECAlgorithm(ECAlgorithm.SHA384, SECP384R1),
160 "ES521": ECAlgorithm(ECAlgorithm.SHA512, SECP521R1),
161 "ES512": ECAlgorithm(
162 ECAlgorithm.SHA512, SECP521R1
163 ), # Backward compat for #219 fix
164 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
165 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
166 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
167 "EdDSA": OKPAlgorithm(),
168 }
169 )
171 return default_algorithms
174class Algorithm(ABC):
175 """
176 The interface for an algorithm used to sign and verify tokens.
177 """
179 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
180 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None
182 def compute_hash_digest(self, bytestr: bytes) -> bytes:
183 """
184 Compute a hash digest using the specified algorithm's hash algorithm.
186 If there is no hash algorithm, raises a NotImplementedError.
187 """
188 # lookup self.hash_alg if defined in a way that mypy can understand
189 hash_alg = getattr(self, "hash_alg", None)
190 if hash_alg is None:
191 raise NotImplementedError
193 if (
194 has_crypto
195 and isinstance(hash_alg, type)
196 and issubclass(hash_alg, hashes.HashAlgorithm)
197 ):
198 digest = hashes.Hash(hash_alg(), backend=default_backend())
199 digest.update(bytestr)
200 return bytes(digest.finalize())
201 else:
202 return bytes(hash_alg(bytestr).digest())
204 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes) -> None:
205 """Check that the key belongs to the right cryptographic family.
207 Note that this method only works when ``cryptography`` is installed.
209 :param key: Potentially a cryptography key
210 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>`
211 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm
212 :raises InvalidKeyError: if the key doesn't match the expected key classes
213 """
214 if not has_crypto or self._crypto_key_types is None:
215 raise ValueError(
216 "This method requires the cryptography library, and should only be used by cryptography-based algorithms."
217 )
219 if not isinstance(key, self._crypto_key_types):
220 valid_classes = (cls.__name__ for cls in self._crypto_key_types)
221 actual_class = key.__class__.__name__
222 self_class = self.__class__.__name__
223 raise InvalidKeyError(
224 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}"
225 )
227 @abstractmethod
228 def prepare_key(self, key: Any) -> Any:
229 """
230 Performs necessary validation and conversions on the key and returns
231 the key value in the proper format for sign() and verify().
232 """
234 @abstractmethod
235 def sign(self, msg: bytes, key: Any) -> bytes:
236 """
237 Returns a digital signature for the specified message
238 using the specified key value.
239 """
241 @abstractmethod
242 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
243 """
244 Verifies that the specified digital signature is valid
245 for the specified message and key values.
246 """
248 @overload
249 @staticmethod
250 @abstractmethod
251 def to_jwk(key_obj: Any, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
253 @overload
254 @staticmethod
255 @abstractmethod
256 def to_jwk(
257 key_obj: Any, as_dict: Literal[False] = False
258 ) -> str: ... # pragma: no cover
260 @staticmethod
261 @abstractmethod
262 def to_jwk(key_obj: Any, as_dict: bool = False) -> JWKDict | str:
263 """
264 Serializes a given key into a JWK
265 """
267 @staticmethod
268 @abstractmethod
269 def from_jwk(jwk: str | JWKDict) -> Any:
270 """
271 Deserializes a given key from JWK back into a key object
272 """
274 def check_key_length(self, key: Any) -> str | None:
275 """
276 Return a warning message if the key is below the minimum
277 recommended length for this algorithm, or None if adequate.
278 """
279 return None
282class NoneAlgorithm(Algorithm):
283 """
284 Placeholder for use when no signing or verification
285 operations are required.
286 """
288 def prepare_key(self, key: str | None) -> None:
289 if key == "":
290 key = None
292 if key is not None:
293 raise InvalidKeyError('When alg = "none", key value must be None.')
295 return key
297 def sign(self, msg: bytes, key: None) -> bytes:
298 return b""
300 def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
301 return False
303 @staticmethod
304 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
305 raise NotImplementedError()
307 @staticmethod
308 def from_jwk(jwk: str | JWKDict) -> NoReturn:
309 raise NotImplementedError()
312class HMACAlgorithm(Algorithm):
313 """
314 Performs signing and verification operations using HMAC
315 and the specified hash function.
316 """
318 SHA256: ClassVar[HashlibHash] = hashlib.sha256
319 SHA384: ClassVar[HashlibHash] = hashlib.sha384
320 SHA512: ClassVar[HashlibHash] = hashlib.sha512
322 def __init__(self, hash_alg: HashlibHash) -> None:
323 self.hash_alg = hash_alg
325 def prepare_key(self, key: str | bytes) -> bytes:
326 key_bytes = force_bytes(key)
328 if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
329 raise InvalidKeyError(
330 "The specified key is an asymmetric key or x509 certificate and"
331 " should not be used as an HMAC secret."
332 )
334 return key_bytes
336 @overload
337 @staticmethod
338 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: ...
340 @overload
341 @staticmethod
342 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: ...
344 @staticmethod
345 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
346 jwk = {
347 "k": base64url_encode(force_bytes(key_obj)).decode(),
348 "kty": "oct",
349 }
351 if as_dict:
352 return jwk
353 else:
354 return json.dumps(jwk)
356 @staticmethod
357 def from_jwk(jwk: str | JWKDict) -> bytes:
358 try:
359 if isinstance(jwk, str):
360 obj: JWKDict = json.loads(jwk)
361 elif isinstance(jwk, dict):
362 obj = jwk
363 else:
364 raise ValueError
365 except ValueError:
366 raise InvalidKeyError("Key is not valid JSON") from None
368 if obj.get("kty") != "oct":
369 raise InvalidKeyError("Not an HMAC key")
371 return base64url_decode(obj["k"])
373 def check_key_length(self, key: bytes) -> str | None:
374 min_length = self.hash_alg().digest_size
375 if len(key) < min_length:
376 return (
377 f"The HMAC key is {len(key)} bytes long, which is below "
378 f"the minimum recommended length of {min_length} bytes for "
379 f"{self.hash_alg().name.upper()}. "
380 f"See RFC 7518 Section 3.2."
381 )
382 return None
384 def sign(self, msg: bytes, key: bytes) -> bytes:
385 return hmac.new(key, msg, self.hash_alg).digest()
387 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
388 return hmac.compare_digest(sig, self.sign(msg, key))
391if has_crypto:
393 class RSAAlgorithm(Algorithm):
394 """
395 Performs signing and verification operations using
396 RSASSA-PKCS-v1_5 and the specified hash function.
397 """
399 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
400 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
401 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
403 _crypto_key_types = cast(
404 tuple[type[AllowedKeys], ...],
405 get_args(Union[RSAPrivateKey, RSAPublicKey]),
406 )
407 _MIN_KEY_SIZE: ClassVar[int] = 2048
409 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
410 self.hash_alg = hash_alg
412 def check_key_length(self, key: AllowedRSAKeys) -> str | None:
413 if key.key_size < self._MIN_KEY_SIZE:
414 return (
415 f"The RSA key is {key.key_size} bits long, which is below "
416 f"the minimum recommended size of {self._MIN_KEY_SIZE} bits. "
417 f"See NIST SP 800-131A."
418 )
419 return None
421 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
422 if isinstance(key, self._crypto_key_types):
423 return cast(AllowedRSAKeys, key)
425 if not isinstance(key, (bytes, str)):
426 raise TypeError("Expecting a PEM-formatted key.")
428 key_bytes = force_bytes(key)
430 try:
431 if key_bytes.startswith(b"ssh-rsa"):
432 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
433 self.check_crypto_key_type(public_key)
434 return cast(RSAPublicKey, public_key)
435 else:
436 private_key: PrivateKeyTypes = load_pem_private_key(
437 key_bytes, password=None
438 )
439 self.check_crypto_key_type(private_key)
440 return cast(RSAPrivateKey, private_key)
441 except ValueError:
442 try:
443 public_key = load_pem_public_key(key_bytes)
444 self.check_crypto_key_type(public_key)
445 return cast(RSAPublicKey, public_key)
446 except (ValueError, UnsupportedAlgorithm):
447 raise InvalidKeyError(
448 "Could not parse the provided public key."
449 ) from None
451 @overload
452 @staticmethod
453 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: ...
455 @overload
456 @staticmethod
457 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: ...
459 @staticmethod
460 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
461 obj: dict[str, Any] | None = None
463 if hasattr(key_obj, "private_numbers"):
464 # Private key
465 numbers = key_obj.private_numbers()
467 obj = {
468 "kty": "RSA",
469 "key_ops": ["sign"],
470 "n": to_base64url_uint(numbers.public_numbers.n).decode(),
471 "e": to_base64url_uint(numbers.public_numbers.e).decode(),
472 "d": to_base64url_uint(numbers.d).decode(),
473 "p": to_base64url_uint(numbers.p).decode(),
474 "q": to_base64url_uint(numbers.q).decode(),
475 "dp": to_base64url_uint(numbers.dmp1).decode(),
476 "dq": to_base64url_uint(numbers.dmq1).decode(),
477 "qi": to_base64url_uint(numbers.iqmp).decode(),
478 }
480 elif hasattr(key_obj, "verify"):
481 # Public key
482 numbers = key_obj.public_numbers()
484 obj = {
485 "kty": "RSA",
486 "key_ops": ["verify"],
487 "n": to_base64url_uint(numbers.n).decode(),
488 "e": to_base64url_uint(numbers.e).decode(),
489 }
490 else:
491 raise InvalidKeyError("Not a public or private key")
493 if as_dict:
494 return obj
495 else:
496 return json.dumps(obj)
498 @staticmethod
499 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
500 try:
501 if isinstance(jwk, str):
502 obj = json.loads(jwk)
503 elif isinstance(jwk, dict):
504 obj = jwk
505 else:
506 raise ValueError
507 except ValueError:
508 raise InvalidKeyError("Key is not valid JSON") from None
510 if obj.get("kty") != "RSA":
511 raise InvalidKeyError("Not an RSA key") from None
513 if "d" in obj and "e" in obj and "n" in obj:
514 # Private key
515 if "oth" in obj:
516 raise InvalidKeyError(
517 "Unsupported RSA private key: > 2 primes not supported"
518 )
520 other_props = ["p", "q", "dp", "dq", "qi"]
521 props_found = [prop in obj for prop in other_props]
522 any_props_found = any(props_found)
524 if any_props_found and not all(props_found):
525 raise InvalidKeyError(
526 "RSA key must include all parameters if any are present besides d"
527 ) from None
529 public_numbers = RSAPublicNumbers(
530 from_base64url_uint(obj["e"]),
531 from_base64url_uint(obj["n"]),
532 )
534 if any_props_found:
535 numbers = RSAPrivateNumbers(
536 d=from_base64url_uint(obj["d"]),
537 p=from_base64url_uint(obj["p"]),
538 q=from_base64url_uint(obj["q"]),
539 dmp1=from_base64url_uint(obj["dp"]),
540 dmq1=from_base64url_uint(obj["dq"]),
541 iqmp=from_base64url_uint(obj["qi"]),
542 public_numbers=public_numbers,
543 )
544 else:
545 d = from_base64url_uint(obj["d"])
546 p, q = rsa_recover_prime_factors(
547 public_numbers.n, d, public_numbers.e
548 )
550 numbers = RSAPrivateNumbers(
551 d=d,
552 p=p,
553 q=q,
554 dmp1=rsa_crt_dmp1(d, p),
555 dmq1=rsa_crt_dmq1(d, q),
556 iqmp=rsa_crt_iqmp(p, q),
557 public_numbers=public_numbers,
558 )
560 return numbers.private_key()
561 elif "n" in obj and "e" in obj:
562 # Public key
563 return RSAPublicNumbers(
564 from_base64url_uint(obj["e"]),
565 from_base64url_uint(obj["n"]),
566 ).public_key()
567 else:
568 raise InvalidKeyError("Not a public or private key")
570 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
571 signature: bytes = key.sign(msg, padding.PKCS1v15(), self.hash_alg())
572 return signature
574 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
575 try:
576 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
577 return True
578 except InvalidSignature:
579 return False
581 class ECAlgorithm(Algorithm):
582 """
583 Performs signing and verification operations using
584 ECDSA and the specified hash function
585 """
587 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
588 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
589 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
591 _crypto_key_types = cast(
592 tuple[type[AllowedKeys], ...],
593 get_args(Union[EllipticCurvePrivateKey, EllipticCurvePublicKey]),
594 )
596 def __init__(
597 self,
598 hash_alg: type[hashes.HashAlgorithm],
599 expected_curve: type[EllipticCurve] | None = None,
600 ) -> None:
601 self.hash_alg = hash_alg
602 self.expected_curve = expected_curve
604 def _validate_curve(self, key: AllowedECKeys) -> None:
605 """Validate that the key's curve matches the expected curve."""
606 if self.expected_curve is None:
607 return
609 if not isinstance(key.curve, self.expected_curve):
610 raise InvalidKeyError(
611 f"The key's curve '{key.curve.name}' does not match the expected "
612 f"curve '{self.expected_curve.name}' for this algorithm"
613 )
615 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
616 if isinstance(key, self._crypto_key_types):
617 ec_key = cast(AllowedECKeys, key)
618 self._validate_curve(ec_key)
619 return ec_key
621 if not isinstance(key, (bytes, str)):
622 raise TypeError("Expecting a PEM-formatted key.")
624 key_bytes = force_bytes(key)
626 # Attempt to load key. We don't know if it's
627 # a Signing Key or a Verifying Key, so we try
628 # the Verifying Key first.
629 try:
630 if key_bytes.startswith(b"ecdsa-sha2-"):
631 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
632 else:
633 public_key = load_pem_public_key(key_bytes)
635 # Explicit check the key to prevent confusing errors from cryptography
636 self.check_crypto_key_type(public_key)
637 ec_public_key = cast(EllipticCurvePublicKey, public_key)
638 self._validate_curve(ec_public_key)
639 return ec_public_key
640 except ValueError:
641 private_key = load_pem_private_key(key_bytes, password=None)
642 self.check_crypto_key_type(private_key)
643 ec_private_key = cast(EllipticCurvePrivateKey, private_key)
644 self._validate_curve(ec_private_key)
645 return ec_private_key
647 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
648 der_sig = key.sign(msg, ECDSA(self.hash_alg()))
650 return der_to_raw_signature(der_sig, key.curve)
652 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
653 try:
654 der_sig = raw_to_der_signature(sig, key.curve)
655 except ValueError:
656 return False
658 try:
659 public_key = (
660 key.public_key()
661 if isinstance(key, EllipticCurvePrivateKey)
662 else key
663 )
664 public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
665 return True
666 except InvalidSignature:
667 return False
669 @overload
670 @staticmethod
671 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: ...
673 @overload
674 @staticmethod
675 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: ...
677 @staticmethod
678 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
679 if isinstance(key_obj, EllipticCurvePrivateKey):
680 public_numbers = key_obj.public_key().public_numbers()
681 elif isinstance(key_obj, EllipticCurvePublicKey):
682 public_numbers = key_obj.public_numbers()
683 else:
684 raise InvalidKeyError("Not a public or private key")
686 if isinstance(key_obj.curve, SECP256R1):
687 crv = "P-256"
688 elif isinstance(key_obj.curve, SECP384R1):
689 crv = "P-384"
690 elif isinstance(key_obj.curve, SECP521R1):
691 crv = "P-521"
692 elif isinstance(key_obj.curve, SECP256K1):
693 crv = "secp256k1"
694 else:
695 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
697 obj: dict[str, Any] = {
698 "kty": "EC",
699 "crv": crv,
700 "x": to_base64url_uint(
701 public_numbers.x,
702 bit_length=key_obj.curve.key_size,
703 ).decode(),
704 "y": to_base64url_uint(
705 public_numbers.y,
706 bit_length=key_obj.curve.key_size,
707 ).decode(),
708 }
710 if isinstance(key_obj, EllipticCurvePrivateKey):
711 obj["d"] = to_base64url_uint(
712 key_obj.private_numbers().private_value,
713 bit_length=key_obj.curve.key_size,
714 ).decode()
716 if as_dict:
717 return obj
718 else:
719 return json.dumps(obj)
721 @staticmethod
722 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
723 try:
724 if isinstance(jwk, str):
725 obj = json.loads(jwk)
726 elif isinstance(jwk, dict):
727 obj = jwk
728 else:
729 raise ValueError
730 except ValueError:
731 raise InvalidKeyError("Key is not valid JSON") from None
733 if obj.get("kty") != "EC":
734 raise InvalidKeyError("Not an Elliptic curve key") from None
736 if "x" not in obj or "y" not in obj:
737 raise InvalidKeyError("Not an Elliptic curve key") from None
739 x = base64url_decode(obj.get("x"))
740 y = base64url_decode(obj.get("y"))
742 curve = obj.get("crv")
743 curve_obj: EllipticCurve
745 if curve == "P-256":
746 if len(x) == len(y) == 32:
747 curve_obj = SECP256R1()
748 else:
749 raise InvalidKeyError(
750 "Coords should be 32 bytes for curve P-256"
751 ) from None
752 elif curve == "P-384":
753 if len(x) == len(y) == 48:
754 curve_obj = SECP384R1()
755 else:
756 raise InvalidKeyError(
757 "Coords should be 48 bytes for curve P-384"
758 ) from None
759 elif curve == "P-521":
760 if len(x) == len(y) == 66:
761 curve_obj = SECP521R1()
762 else:
763 raise InvalidKeyError(
764 "Coords should be 66 bytes for curve P-521"
765 ) from None
766 elif curve == "secp256k1":
767 if len(x) == len(y) == 32:
768 curve_obj = SECP256K1()
769 else:
770 raise InvalidKeyError(
771 "Coords should be 32 bytes for curve secp256k1"
772 )
773 else:
774 raise InvalidKeyError(f"Invalid curve: {curve}")
776 public_numbers = EllipticCurvePublicNumbers(
777 x=int.from_bytes(x, byteorder="big"),
778 y=int.from_bytes(y, byteorder="big"),
779 curve=curve_obj,
780 )
782 if "d" not in obj:
783 return public_numbers.public_key()
785 d = base64url_decode(obj.get("d"))
786 if len(d) != len(x):
787 raise InvalidKeyError(
788 "D should be {} bytes for curve {}", len(x), curve
789 )
791 return EllipticCurvePrivateNumbers(
792 int.from_bytes(d, byteorder="big"), public_numbers
793 ).private_key()
795 class RSAPSSAlgorithm(RSAAlgorithm):
796 """
797 Performs a signature using RSASSA-PSS with MGF1
798 """
800 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
801 signature: bytes = key.sign(
802 msg,
803 padding.PSS(
804 mgf=padding.MGF1(self.hash_alg()),
805 salt_length=self.hash_alg().digest_size,
806 ),
807 self.hash_alg(),
808 )
809 return signature
811 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
812 try:
813 key.verify(
814 sig,
815 msg,
816 padding.PSS(
817 mgf=padding.MGF1(self.hash_alg()),
818 salt_length=self.hash_alg().digest_size,
819 ),
820 self.hash_alg(),
821 )
822 return True
823 except InvalidSignature:
824 return False
826 class OKPAlgorithm(Algorithm):
827 """
828 Performs signing and verification operations using EdDSA
830 This class requires ``cryptography>=2.6`` to be installed.
831 """
833 _crypto_key_types = cast(
834 tuple[type[AllowedKeys], ...],
835 get_args(
836 Union[
837 Ed25519PrivateKey,
838 Ed25519PublicKey,
839 Ed448PrivateKey,
840 Ed448PublicKey,
841 ]
842 ),
843 )
845 def __init__(self, **kwargs: Any) -> None:
846 pass
848 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
849 if not isinstance(key, (str, bytes)):
850 self.check_crypto_key_type(key)
851 return key
853 key_str = key.decode("utf-8") if isinstance(key, bytes) else key
854 key_bytes = key.encode("utf-8") if isinstance(key, str) else key
856 loaded_key: PublicKeyTypes | PrivateKeyTypes
857 if "-----BEGIN PUBLIC" in key_str:
858 loaded_key = load_pem_public_key(key_bytes)
859 elif "-----BEGIN PRIVATE" in key_str:
860 loaded_key = load_pem_private_key(key_bytes, password=None)
861 elif key_str[0:4] == "ssh-":
862 loaded_key = load_ssh_public_key(key_bytes)
863 else:
864 raise InvalidKeyError("Not a public or private key")
866 # Explicit check the key to prevent confusing errors from cryptography
867 self.check_crypto_key_type(loaded_key)
868 return cast("AllowedOKPKeys", loaded_key)
870 def sign(
871 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
872 ) -> bytes:
873 """
874 Sign a message ``msg`` using the EdDSA private key ``key``
875 :param str|bytes msg: Message to sign
876 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
877 or :class:`.Ed448PrivateKey` isinstance
878 :return bytes signature: The signature, as bytes
879 """
880 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
881 signature: bytes = key.sign(msg_bytes)
882 return signature
884 def verify(
885 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
886 ) -> bool:
887 """
888 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
890 :param str|bytes sig: EdDSA signature to check ``msg`` against
891 :param str|bytes msg: Message to sign
892 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
893 A private or public EdDSA key instance
894 :return bool verified: True if signature is valid, False if not.
895 """
896 try:
897 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
898 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
900 public_key = (
901 key.public_key()
902 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
903 else key
904 )
905 public_key.verify(sig_bytes, msg_bytes)
906 return True # If no exception was raised, the signature is valid.
907 except InvalidSignature:
908 return False
910 @overload
911 @staticmethod
912 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: ...
914 @overload
915 @staticmethod
916 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: ...
918 @staticmethod
919 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
920 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
921 x = key.public_bytes(
922 encoding=Encoding.Raw,
923 format=PublicFormat.Raw,
924 )
925 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
927 obj = {
928 "x": base64url_encode(force_bytes(x)).decode(),
929 "kty": "OKP",
930 "crv": crv,
931 }
933 if as_dict:
934 return obj
935 else:
936 return json.dumps(obj)
938 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
939 d = key.private_bytes(
940 encoding=Encoding.Raw,
941 format=PrivateFormat.Raw,
942 encryption_algorithm=NoEncryption(),
943 )
945 x = key.public_key().public_bytes(
946 encoding=Encoding.Raw,
947 format=PublicFormat.Raw,
948 )
950 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
951 obj = {
952 "x": base64url_encode(force_bytes(x)).decode(),
953 "d": base64url_encode(force_bytes(d)).decode(),
954 "kty": "OKP",
955 "crv": crv,
956 }
958 if as_dict:
959 return obj
960 else:
961 return json.dumps(obj)
963 raise InvalidKeyError("Not a public or private key")
965 @staticmethod
966 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
967 try:
968 if isinstance(jwk, str):
969 obj = json.loads(jwk)
970 elif isinstance(jwk, dict):
971 obj = jwk
972 else:
973 raise ValueError
974 except ValueError:
975 raise InvalidKeyError("Key is not valid JSON") from None
977 if obj.get("kty") != "OKP":
978 raise InvalidKeyError("Not an Octet Key Pair")
980 curve = obj.get("crv")
981 if curve != "Ed25519" and curve != "Ed448":
982 raise InvalidKeyError(f"Invalid curve: {curve}")
984 if "x" not in obj:
985 raise InvalidKeyError('OKP should have "x" parameter')
986 x = base64url_decode(obj.get("x"))
988 try:
989 if "d" not in obj:
990 if curve == "Ed25519":
991 return Ed25519PublicKey.from_public_bytes(x)
992 return Ed448PublicKey.from_public_bytes(x)
993 d = base64url_decode(obj.get("d"))
994 if curve == "Ed25519":
995 return Ed25519PrivateKey.from_private_bytes(d)
996 return Ed448PrivateKey.from_private_bytes(d)
997 except ValueError as err:
998 raise InvalidKeyError("Invalid key parameter") from err