Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 20%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import hashlib
4import hmac
5import json
6import os
7import sys
8from abc import ABC, abstractmethod
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 ClassVar,
13 Literal,
14 NoReturn,
15 Union,
16 cast,
17 get_args,
18 overload,
19)
21from .exceptions import InvalidKeyError
22from .types import HashlibHash, JWKDict
23from .utils import (
24 base64url_decode,
25 base64url_encode,
26 der_to_raw_signature,
27 force_bytes,
28 from_base64url_uint,
29 is_pem_format,
30 is_ssh_key,
31 raw_to_der_signature,
32 to_base64url_uint,
33)
35try:
36 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
37 from cryptography.hazmat.backends import default_backend
38 from cryptography.hazmat.primitives import hashes
39 from cryptography.hazmat.primitives.asymmetric import padding
40 from cryptography.hazmat.primitives.asymmetric.ec import (
41 ECDSA,
42 SECP256K1,
43 SECP256R1,
44 SECP384R1,
45 SECP521R1,
46 EllipticCurve,
47 EllipticCurvePrivateKey,
48 EllipticCurvePrivateNumbers,
49 EllipticCurvePublicKey,
50 EllipticCurvePublicNumbers,
51 )
52 from cryptography.hazmat.primitives.asymmetric.ed448 import (
53 Ed448PrivateKey,
54 Ed448PublicKey,
55 )
56 from cryptography.hazmat.primitives.asymmetric.ed25519 import (
57 Ed25519PrivateKey,
58 Ed25519PublicKey,
59 )
60 from cryptography.hazmat.primitives.asymmetric.rsa import (
61 RSAPrivateKey,
62 RSAPrivateNumbers,
63 RSAPublicKey,
64 RSAPublicNumbers,
65 rsa_crt_dmp1,
66 rsa_crt_dmq1,
67 rsa_crt_iqmp,
68 rsa_recover_prime_factors,
69 )
70 from cryptography.hazmat.primitives.serialization import (
71 Encoding,
72 NoEncryption,
73 PrivateFormat,
74 PublicFormat,
75 load_pem_private_key,
76 load_pem_public_key,
77 load_ssh_public_key,
78 )
80 if sys.version_info >= (3, 10):
81 from typing import TypeAlias
82 else:
83 # Python 3.9 and lower
84 from typing_extensions import TypeAlias
86 # Type aliases for convenience in algorithms method signatures
87 AllowedRSAKeys: TypeAlias = Union[RSAPrivateKey, RSAPublicKey]
88 AllowedECKeys: TypeAlias = Union[EllipticCurvePrivateKey, EllipticCurvePublicKey]
89 AllowedOKPKeys: TypeAlias = Union[
90 Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey
91 ]
92 AllowedKeys: TypeAlias = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys]
93 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed)
94 AllowedPrivateKeys: TypeAlias = Union[
95 RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey
96 ]
97 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed)
98 AllowedPublicKeys: TypeAlias = Union[
99 RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey
100 ]
102 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")):
103 from cryptography.hazmat.primitives.asymmetric.types import (
104 PrivateKeyTypes,
105 PublicKeyTypes,
106 )
108 has_crypto = True
109except ModuleNotFoundError:
110 if sys.version_info >= (3, 11):
111 from typing import Never
112 else:
113 from typing_extensions import Never
115 AllowedRSAKeys = Never # type: ignore[misc]
116 AllowedECKeys = Never # type: ignore[misc]
117 AllowedOKPKeys = Never # type: ignore[misc]
118 AllowedKeys = Never # type: ignore[misc]
119 AllowedPrivateKeys = Never # type: ignore[misc]
120 AllowedPublicKeys = Never # type: ignore[misc]
121 has_crypto = False
124requires_cryptography = {
125 "RS256",
126 "RS384",
127 "RS512",
128 "ES256",
129 "ES256K",
130 "ES384",
131 "ES521",
132 "ES512",
133 "PS256",
134 "PS384",
135 "PS512",
136 "EdDSA",
137}
140def get_default_algorithms() -> dict[str, Algorithm]:
141 """
142 Returns the algorithms that are implemented by the library.
143 """
144 default_algorithms: dict[str, Algorithm] = {
145 "none": NoneAlgorithm(),
146 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
147 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
148 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
149 }
151 if has_crypto:
152 default_algorithms.update(
153 {
154 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
155 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
156 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
157 "ES256": ECAlgorithm(ECAlgorithm.SHA256, SECP256R1),
158 "ES256K": ECAlgorithm(ECAlgorithm.SHA256, SECP256K1),
159 "ES384": ECAlgorithm(ECAlgorithm.SHA384, SECP384R1),
160 "ES521": ECAlgorithm(ECAlgorithm.SHA512, SECP521R1),
161 "ES512": ECAlgorithm(
162 ECAlgorithm.SHA512, SECP521R1
163 ), # Backward compat for #219 fix
164 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
165 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
166 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
167 "EdDSA": OKPAlgorithm(),
168 }
169 )
171 return default_algorithms
174class Algorithm(ABC):
175 """
176 The interface for an algorithm used to sign and verify tokens.
177 """
179 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
180 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None
182 def compute_hash_digest(self, bytestr: bytes) -> bytes:
183 """
184 Compute a hash digest using the specified algorithm's hash algorithm.
186 If there is no hash algorithm, raises a NotImplementedError.
187 """
188 # lookup self.hash_alg if defined in a way that mypy can understand
189 hash_alg = getattr(self, "hash_alg", None)
190 if hash_alg is None:
191 raise NotImplementedError
193 if (
194 has_crypto
195 and isinstance(hash_alg, type)
196 and issubclass(hash_alg, hashes.HashAlgorithm)
197 ):
198 digest = hashes.Hash(hash_alg(), backend=default_backend())
199 digest.update(bytestr)
200 return bytes(digest.finalize())
201 else:
202 return bytes(hash_alg(bytestr).digest())
204 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes) -> None:
205 """Check that the key belongs to the right cryptographic family.
207 Note that this method only works when ``cryptography`` is installed.
209 :param key: Potentially a cryptography key
210 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>`
211 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm
212 :raises InvalidKeyError: if the key doesn't match the expected key classes
213 """
214 if not has_crypto or self._crypto_key_types is None:
215 raise ValueError(
216 "This method requires the cryptography library, and should only be used by cryptography-based algorithms."
217 )
219 if not isinstance(key, self._crypto_key_types):
220 valid_classes = (cls.__name__ for cls in self._crypto_key_types)
221 actual_class = key.__class__.__name__
222 self_class = self.__class__.__name__
223 raise InvalidKeyError(
224 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}"
225 )
227 @abstractmethod
228 def prepare_key(self, key: Any) -> Any:
229 """
230 Performs necessary validation and conversions on the key and returns
231 the key value in the proper format for sign() and verify().
232 """
234 @abstractmethod
235 def sign(self, msg: bytes, key: Any) -> bytes:
236 """
237 Returns a digital signature for the specified message
238 using the specified key value.
239 """
241 @abstractmethod
242 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
243 """
244 Verifies that the specified digital signature is valid
245 for the specified message and key values.
246 """
248 @overload
249 @staticmethod
250 @abstractmethod
251 def to_jwk(key_obj: Any, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
253 @overload
254 @staticmethod
255 @abstractmethod
256 def to_jwk(
257 key_obj: Any, as_dict: Literal[False] = False
258 ) -> str: ... # pragma: no cover
260 @staticmethod
261 @abstractmethod
262 def to_jwk(key_obj: Any, as_dict: bool = False) -> JWKDict | str:
263 """
264 Serializes a given key into a JWK
265 """
267 @staticmethod
268 @abstractmethod
269 def from_jwk(jwk: str | JWKDict) -> Any:
270 """
271 Deserializes a given key from JWK back into a key object
272 """
274 def check_key_length(self, key: Any) -> str | None:
275 """
276 Return a warning message if the key is below the minimum
277 recommended length for this algorithm, or None if adequate.
278 """
279 return None
282class NoneAlgorithm(Algorithm):
283 """
284 Placeholder for use when no signing or verification
285 operations are required.
286 """
288 def prepare_key(self, key: str | None) -> None:
289 if key == "":
290 key = None
292 if key is not None:
293 raise InvalidKeyError('When alg = "none", key value must be None.')
295 return key
297 def sign(self, msg: bytes, key: None) -> bytes:
298 return b""
300 def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
301 return False
303 @staticmethod
304 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
305 raise NotImplementedError()
307 @staticmethod
308 def from_jwk(jwk: str | JWKDict) -> NoReturn:
309 raise NotImplementedError()
312class HMACAlgorithm(Algorithm):
313 """
314 Performs signing and verification operations using HMAC
315 and the specified hash function.
316 """
318 SHA256: ClassVar[HashlibHash] = hashlib.sha256
319 SHA384: ClassVar[HashlibHash] = hashlib.sha384
320 SHA512: ClassVar[HashlibHash] = hashlib.sha512
322 def __init__(self, hash_alg: HashlibHash) -> None:
323 self.hash_alg = hash_alg
325 def prepare_key(self, key: str | bytes) -> bytes:
326 key_bytes = force_bytes(key)
328 if len(key_bytes) == 0:
329 raise InvalidKeyError("HMAC key must not be empty.")
331 if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
332 raise InvalidKeyError(
333 "The specified key is an asymmetric key or x509 certificate and"
334 " should not be used as an HMAC secret."
335 )
337 # Defense against algorithm-confusion attacks: an attacker with
338 # control over the token header can force this code path by setting
339 # alg=HS*, and HMACAlgorithm is the only algorithm that accepts
340 # arbitrary bytes as a valid secret. Other algorithms reject
341 # non-key-shaped input naturally. Even a symmetric (kty=oct) JWK
342 # should be loaded via PyJWK / from_jwk rather than fed as raw JSON
343 # bytes (whose contents are not the secret material).
344 stripped = key_bytes.lstrip()
345 if stripped.startswith(b"{"):
346 try:
347 jwk_obj = json.loads(key_bytes)
348 except ValueError:
349 jwk_obj = None
350 if isinstance(jwk_obj, dict) and "kty" in jwk_obj:
351 raise InvalidKeyError(
352 "The specified key looks like a JWK and should not be "
353 "used directly as an HMAC secret. Load it via "
354 "PyJWK / HMACAlgorithm.from_jwk first."
355 )
357 return key_bytes
359 @overload
360 @staticmethod
361 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: ...
363 @overload
364 @staticmethod
365 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: ...
367 @staticmethod
368 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
369 jwk = {
370 "k": base64url_encode(force_bytes(key_obj)).decode(),
371 "kty": "oct",
372 }
374 if as_dict:
375 return jwk
376 else:
377 return json.dumps(jwk)
379 @staticmethod
380 def from_jwk(jwk: str | JWKDict) -> bytes:
381 try:
382 if isinstance(jwk, str):
383 obj: JWKDict = json.loads(jwk)
384 elif isinstance(jwk, dict):
385 obj = jwk
386 else:
387 raise ValueError
388 except ValueError:
389 raise InvalidKeyError("Key is not valid JSON") from None
391 if obj.get("kty") != "oct":
392 raise InvalidKeyError("Not an HMAC key")
394 return base64url_decode(obj["k"])
396 def check_key_length(self, key: bytes) -> str | None:
397 min_length = self.hash_alg().digest_size
398 if len(key) < min_length:
399 return (
400 f"The HMAC key is {len(key)} bytes long, which is below "
401 f"the minimum recommended length of {min_length} bytes for "
402 f"{self.hash_alg().name.upper()}. "
403 f"See RFC 7518 Section 3.2."
404 )
405 return None
407 def sign(self, msg: bytes, key: bytes) -> bytes:
408 return hmac.new(key, msg, self.hash_alg).digest()
410 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
411 return hmac.compare_digest(sig, self.sign(msg, key))
414if has_crypto:
416 class RSAAlgorithm(Algorithm):
417 """
418 Performs signing and verification operations using
419 RSASSA-PKCS-v1_5 and the specified hash function.
420 """
422 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
423 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
424 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
426 _crypto_key_types = cast(
427 tuple[type[AllowedKeys], ...],
428 get_args(Union[RSAPrivateKey, RSAPublicKey]),
429 )
430 _MIN_KEY_SIZE: ClassVar[int] = 2048
432 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
433 self.hash_alg = hash_alg
435 def check_key_length(self, key: AllowedRSAKeys) -> str | None:
436 if key.key_size < self._MIN_KEY_SIZE:
437 return (
438 f"The RSA key is {key.key_size} bits long, which is below "
439 f"the minimum recommended size of {self._MIN_KEY_SIZE} bits. "
440 f"See NIST SP 800-131A."
441 )
442 return None
444 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
445 if isinstance(key, self._crypto_key_types):
446 # Cast is required for type narrowing on Python 3.9's mypy
447 # but redundant on newer mypy versions; suppress both
448 # diagnostics so the line works across all supported envs.
449 return cast(AllowedRSAKeys, key) # type: ignore[redundant-cast,unused-ignore]
451 if not isinstance(key, (bytes, str)):
452 raise TypeError("Expecting a PEM-formatted key.")
454 key_bytes = force_bytes(key)
456 try:
457 if key_bytes.startswith(b"ssh-rsa"):
458 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
459 self.check_crypto_key_type(public_key)
460 return cast(RSAPublicKey, public_key)
461 else:
462 private_key: PrivateKeyTypes = load_pem_private_key(
463 key_bytes, password=None
464 )
465 self.check_crypto_key_type(private_key)
466 return cast(RSAPrivateKey, private_key)
467 except ValueError:
468 try:
469 public_key = load_pem_public_key(key_bytes)
470 self.check_crypto_key_type(public_key)
471 return cast(RSAPublicKey, public_key)
472 except (ValueError, UnsupportedAlgorithm):
473 raise InvalidKeyError(
474 "Could not parse the provided public key."
475 ) from None
477 @overload
478 @staticmethod
479 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: ...
481 @overload
482 @staticmethod
483 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: ...
485 @staticmethod
486 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
487 obj: dict[str, Any] | None = None
489 if hasattr(key_obj, "private_numbers"):
490 # Private key
491 numbers = key_obj.private_numbers()
493 obj = {
494 "kty": "RSA",
495 "key_ops": ["sign"],
496 "n": to_base64url_uint(numbers.public_numbers.n).decode(),
497 "e": to_base64url_uint(numbers.public_numbers.e).decode(),
498 "d": to_base64url_uint(numbers.d).decode(),
499 "p": to_base64url_uint(numbers.p).decode(),
500 "q": to_base64url_uint(numbers.q).decode(),
501 "dp": to_base64url_uint(numbers.dmp1).decode(),
502 "dq": to_base64url_uint(numbers.dmq1).decode(),
503 "qi": to_base64url_uint(numbers.iqmp).decode(),
504 }
506 elif hasattr(key_obj, "verify"):
507 # Public key
508 numbers = key_obj.public_numbers()
510 obj = {
511 "kty": "RSA",
512 "key_ops": ["verify"],
513 "n": to_base64url_uint(numbers.n).decode(),
514 "e": to_base64url_uint(numbers.e).decode(),
515 }
516 else:
517 raise InvalidKeyError("Not a public or private key")
519 if as_dict:
520 return obj
521 else:
522 return json.dumps(obj)
524 @staticmethod
525 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
526 try:
527 if isinstance(jwk, str):
528 obj = json.loads(jwk)
529 elif isinstance(jwk, dict):
530 obj = jwk
531 else:
532 raise ValueError
533 except ValueError:
534 raise InvalidKeyError("Key is not valid JSON") from None
536 if obj.get("kty") != "RSA":
537 raise InvalidKeyError("Not an RSA key") from None
539 if "d" in obj and "e" in obj and "n" in obj:
540 # Private key
541 if "oth" in obj:
542 raise InvalidKeyError(
543 "Unsupported RSA private key: > 2 primes not supported"
544 )
546 other_props = ["p", "q", "dp", "dq", "qi"]
547 props_found = [prop in obj for prop in other_props]
548 any_props_found = any(props_found)
550 if any_props_found and not all(props_found):
551 raise InvalidKeyError(
552 "RSA key must include all parameters if any are present besides d"
553 ) from None
555 public_numbers = RSAPublicNumbers(
556 from_base64url_uint(obj["e"]),
557 from_base64url_uint(obj["n"]),
558 )
560 if any_props_found:
561 numbers = RSAPrivateNumbers(
562 d=from_base64url_uint(obj["d"]),
563 p=from_base64url_uint(obj["p"]),
564 q=from_base64url_uint(obj["q"]),
565 dmp1=from_base64url_uint(obj["dp"]),
566 dmq1=from_base64url_uint(obj["dq"]),
567 iqmp=from_base64url_uint(obj["qi"]),
568 public_numbers=public_numbers,
569 )
570 else:
571 d = from_base64url_uint(obj["d"])
572 p, q = rsa_recover_prime_factors(
573 public_numbers.n, d, public_numbers.e
574 )
576 numbers = RSAPrivateNumbers(
577 d=d,
578 p=p,
579 q=q,
580 dmp1=rsa_crt_dmp1(d, p),
581 dmq1=rsa_crt_dmq1(d, q),
582 iqmp=rsa_crt_iqmp(p, q),
583 public_numbers=public_numbers,
584 )
586 return numbers.private_key()
587 elif "n" in obj and "e" in obj:
588 # Public key
589 return RSAPublicNumbers(
590 from_base64url_uint(obj["e"]),
591 from_base64url_uint(obj["n"]),
592 ).public_key()
593 else:
594 raise InvalidKeyError("Not a public or private key")
596 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
597 signature: bytes = key.sign(msg, padding.PKCS1v15(), self.hash_alg())
598 return signature
600 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
601 try:
602 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
603 return True
604 except InvalidSignature:
605 return False
607 class ECAlgorithm(Algorithm):
608 """
609 Performs signing and verification operations using
610 ECDSA and the specified hash function
611 """
613 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
614 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
615 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
617 _crypto_key_types = cast(
618 tuple[type[AllowedKeys], ...],
619 get_args(Union[EllipticCurvePrivateKey, EllipticCurvePublicKey]),
620 )
622 def __init__(
623 self,
624 hash_alg: type[hashes.HashAlgorithm],
625 expected_curve: type[EllipticCurve] | None = None,
626 ) -> None:
627 self.hash_alg = hash_alg
628 self.expected_curve = expected_curve
630 def _validate_curve(self, key: AllowedECKeys) -> None:
631 """Validate that the key's curve matches the expected curve."""
632 if self.expected_curve is None:
633 return
635 if not isinstance(key.curve, self.expected_curve):
636 raise InvalidKeyError(
637 f"The key's curve '{key.curve.name}' does not match the expected "
638 f"curve '{self.expected_curve.name}' for this algorithm"
639 )
641 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
642 if isinstance(key, self._crypto_key_types):
643 # See note in RSAAlgorithm.prepare_key.
644 ec_key = cast(AllowedECKeys, key) # type: ignore[redundant-cast,unused-ignore]
645 self._validate_curve(ec_key)
646 return ec_key
648 if not isinstance(key, (bytes, str)):
649 raise TypeError("Expecting a PEM-formatted key.")
651 key_bytes = force_bytes(key)
653 # Attempt to load key. We don't know if it's
654 # a Signing Key or a Verifying Key, so we try
655 # the Verifying Key first.
656 try:
657 if key_bytes.startswith(b"ecdsa-sha2-"):
658 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
659 else:
660 public_key = load_pem_public_key(key_bytes)
662 # Explicit check the key to prevent confusing errors from cryptography
663 self.check_crypto_key_type(public_key)
664 ec_public_key = cast(EllipticCurvePublicKey, public_key)
665 self._validate_curve(ec_public_key)
666 return ec_public_key
667 except ValueError:
668 private_key = load_pem_private_key(key_bytes, password=None)
669 self.check_crypto_key_type(private_key)
670 ec_private_key = cast(EllipticCurvePrivateKey, private_key)
671 self._validate_curve(ec_private_key)
672 return ec_private_key
674 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
675 der_sig = key.sign(msg, ECDSA(self.hash_alg()))
677 return der_to_raw_signature(der_sig, key.curve)
679 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
680 try:
681 der_sig = raw_to_der_signature(sig, key.curve)
682 except ValueError:
683 return False
685 try:
686 public_key = (
687 key.public_key()
688 if isinstance(key, EllipticCurvePrivateKey)
689 else key
690 )
691 public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
692 return True
693 except InvalidSignature:
694 return False
696 @overload
697 @staticmethod
698 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: ...
700 @overload
701 @staticmethod
702 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: ...
704 @staticmethod
705 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
706 if isinstance(key_obj, EllipticCurvePrivateKey):
707 public_numbers = key_obj.public_key().public_numbers()
708 elif isinstance(key_obj, EllipticCurvePublicKey):
709 public_numbers = key_obj.public_numbers()
710 else:
711 raise InvalidKeyError("Not a public or private key")
713 if isinstance(key_obj.curve, SECP256R1):
714 crv = "P-256"
715 elif isinstance(key_obj.curve, SECP384R1):
716 crv = "P-384"
717 elif isinstance(key_obj.curve, SECP521R1):
718 crv = "P-521"
719 elif isinstance(key_obj.curve, SECP256K1):
720 crv = "secp256k1"
721 else:
722 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
724 obj: dict[str, Any] = {
725 "kty": "EC",
726 "crv": crv,
727 "x": to_base64url_uint(
728 public_numbers.x,
729 bit_length=key_obj.curve.key_size,
730 ).decode(),
731 "y": to_base64url_uint(
732 public_numbers.y,
733 bit_length=key_obj.curve.key_size,
734 ).decode(),
735 }
737 if isinstance(key_obj, EllipticCurvePrivateKey):
738 obj["d"] = to_base64url_uint(
739 key_obj.private_numbers().private_value,
740 bit_length=key_obj.curve.key_size,
741 ).decode()
743 if as_dict:
744 return obj
745 else:
746 return json.dumps(obj)
748 @staticmethod
749 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
750 try:
751 if isinstance(jwk, str):
752 obj = json.loads(jwk)
753 elif isinstance(jwk, dict):
754 obj = jwk
755 else:
756 raise ValueError
757 except ValueError:
758 raise InvalidKeyError("Key is not valid JSON") from None
760 if obj.get("kty") != "EC":
761 raise InvalidKeyError("Not an Elliptic curve key") from None
763 if "x" not in obj or "y" not in obj:
764 raise InvalidKeyError("Not an Elliptic curve key") from None
766 x = base64url_decode(obj.get("x"))
767 y = base64url_decode(obj.get("y"))
769 curve = obj.get("crv")
770 curve_obj: EllipticCurve
772 if curve == "P-256":
773 if len(x) == len(y) == 32:
774 curve_obj = SECP256R1()
775 else:
776 raise InvalidKeyError(
777 "Coords should be 32 bytes for curve P-256"
778 ) from None
779 elif curve == "P-384":
780 if len(x) == len(y) == 48:
781 curve_obj = SECP384R1()
782 else:
783 raise InvalidKeyError(
784 "Coords should be 48 bytes for curve P-384"
785 ) from None
786 elif curve == "P-521":
787 if len(x) == len(y) == 66:
788 curve_obj = SECP521R1()
789 else:
790 raise InvalidKeyError(
791 "Coords should be 66 bytes for curve P-521"
792 ) from None
793 elif curve == "secp256k1":
794 if len(x) == len(y) == 32:
795 curve_obj = SECP256K1()
796 else:
797 raise InvalidKeyError(
798 "Coords should be 32 bytes for curve secp256k1"
799 )
800 else:
801 raise InvalidKeyError(f"Invalid curve: {curve}")
803 public_numbers = EllipticCurvePublicNumbers(
804 x=int.from_bytes(x, byteorder="big"),
805 y=int.from_bytes(y, byteorder="big"),
806 curve=curve_obj,
807 )
809 if "d" not in obj:
810 return public_numbers.public_key()
812 d = base64url_decode(obj.get("d"))
813 if len(d) != len(x):
814 raise InvalidKeyError(
815 "D should be {} bytes for curve {}", len(x), curve
816 )
818 return EllipticCurvePrivateNumbers(
819 int.from_bytes(d, byteorder="big"), public_numbers
820 ).private_key()
822 class RSAPSSAlgorithm(RSAAlgorithm):
823 """
824 Performs a signature using RSASSA-PSS with MGF1
825 """
827 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
828 signature: bytes = key.sign(
829 msg,
830 padding.PSS(
831 mgf=padding.MGF1(self.hash_alg()),
832 salt_length=self.hash_alg().digest_size,
833 ),
834 self.hash_alg(),
835 )
836 return signature
838 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
839 try:
840 key.verify(
841 sig,
842 msg,
843 padding.PSS(
844 mgf=padding.MGF1(self.hash_alg()),
845 salt_length=self.hash_alg().digest_size,
846 ),
847 self.hash_alg(),
848 )
849 return True
850 except InvalidSignature:
851 return False
853 class OKPAlgorithm(Algorithm):
854 """
855 Performs signing and verification operations using EdDSA
857 This class requires ``cryptography>=2.6`` to be installed.
858 """
860 _crypto_key_types = cast(
861 tuple[type[AllowedKeys], ...],
862 get_args(
863 Union[
864 Ed25519PrivateKey,
865 Ed25519PublicKey,
866 Ed448PrivateKey,
867 Ed448PublicKey,
868 ]
869 ),
870 )
872 def __init__(self, **kwargs: Any) -> None:
873 pass
875 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
876 if not isinstance(key, (str, bytes)):
877 self.check_crypto_key_type(key)
878 return key
880 key_str = key.decode("utf-8") if isinstance(key, bytes) else key
881 key_bytes = key.encode("utf-8") if isinstance(key, str) else key
883 loaded_key: PublicKeyTypes | PrivateKeyTypes
884 if "-----BEGIN PUBLIC" in key_str:
885 loaded_key = load_pem_public_key(key_bytes)
886 elif "-----BEGIN PRIVATE" in key_str:
887 loaded_key = load_pem_private_key(key_bytes, password=None)
888 elif key_str[0:4] == "ssh-":
889 loaded_key = load_ssh_public_key(key_bytes)
890 else:
891 raise InvalidKeyError("Not a public or private key")
893 # Explicit check the key to prevent confusing errors from cryptography
894 self.check_crypto_key_type(loaded_key)
895 return cast("AllowedOKPKeys", loaded_key)
897 def sign(
898 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
899 ) -> bytes:
900 """
901 Sign a message ``msg`` using the EdDSA private key ``key``
902 :param str|bytes msg: Message to sign
903 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
904 or :class:`.Ed448PrivateKey` isinstance
905 :return bytes signature: The signature, as bytes
906 """
907 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
908 signature: bytes = key.sign(msg_bytes)
909 return signature
911 def verify(
912 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
913 ) -> bool:
914 """
915 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
917 :param str|bytes sig: EdDSA signature to check ``msg`` against
918 :param str|bytes msg: Message to sign
919 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
920 A private or public EdDSA key instance
921 :return bool verified: True if signature is valid, False if not.
922 """
923 try:
924 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
925 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
927 public_key = (
928 key.public_key()
929 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
930 else key
931 )
932 public_key.verify(sig_bytes, msg_bytes)
933 return True # If no exception was raised, the signature is valid.
934 except InvalidSignature:
935 return False
937 @overload
938 @staticmethod
939 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: ...
941 @overload
942 @staticmethod
943 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: ...
945 @staticmethod
946 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
947 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
948 x = key.public_bytes(
949 encoding=Encoding.Raw,
950 format=PublicFormat.Raw,
951 )
952 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
954 obj = {
955 "x": base64url_encode(force_bytes(x)).decode(),
956 "kty": "OKP",
957 "crv": crv,
958 }
960 if as_dict:
961 return obj
962 else:
963 return json.dumps(obj)
965 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
966 d = key.private_bytes(
967 encoding=Encoding.Raw,
968 format=PrivateFormat.Raw,
969 encryption_algorithm=NoEncryption(),
970 )
972 x = key.public_key().public_bytes(
973 encoding=Encoding.Raw,
974 format=PublicFormat.Raw,
975 )
977 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
978 obj = {
979 "x": base64url_encode(force_bytes(x)).decode(),
980 "d": base64url_encode(force_bytes(d)).decode(),
981 "kty": "OKP",
982 "crv": crv,
983 }
985 if as_dict:
986 return obj
987 else:
988 return json.dumps(obj)
990 raise InvalidKeyError("Not a public or private key")
992 @staticmethod
993 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
994 try:
995 if isinstance(jwk, str):
996 obj = json.loads(jwk)
997 elif isinstance(jwk, dict):
998 obj = jwk
999 else:
1000 raise ValueError
1001 except ValueError:
1002 raise InvalidKeyError("Key is not valid JSON") from None
1004 if obj.get("kty") != "OKP":
1005 raise InvalidKeyError("Not an Octet Key Pair")
1007 curve = obj.get("crv")
1008 if curve != "Ed25519" and curve != "Ed448":
1009 raise InvalidKeyError(f"Invalid curve: {curve}")
1011 if "x" not in obj:
1012 raise InvalidKeyError('OKP should have "x" parameter')
1013 x = base64url_decode(obj.get("x"))
1015 try:
1016 if "d" not in obj:
1017 if curve == "Ed25519":
1018 return Ed25519PublicKey.from_public_bytes(x)
1019 return Ed448PublicKey.from_public_bytes(x)
1020 d = base64url_decode(obj.get("d"))
1021 if curve == "Ed25519":
1022 return Ed25519PrivateKey.from_private_bytes(d)
1023 return Ed448PrivateKey.from_private_bytes(d)
1024 except ValueError as err:
1025 raise InvalidKeyError("Invalid key parameter") from err