Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import hashlib
4import hmac
5import json
6import os
7from abc import ABC, abstractmethod
8from typing import (
9 TYPE_CHECKING,
10 Any,
11 ClassVar,
12 Literal,
13 NoReturn,
14 Union,
15 cast,
16 overload,
17)
19from .exceptions import InvalidKeyError
20from .types import HashlibHash, JWKDict
21from .utils import (
22 base64url_decode,
23 base64url_encode,
24 der_to_raw_signature,
25 force_bytes,
26 from_base64url_uint,
27 is_pem_format,
28 is_ssh_key,
29 raw_to_der_signature,
30 to_base64url_uint,
31)
33try:
34 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
35 from cryptography.hazmat.backends import default_backend
36 from cryptography.hazmat.primitives import hashes
37 from cryptography.hazmat.primitives.asymmetric import padding
38 from cryptography.hazmat.primitives.asymmetric.ec import (
39 ECDSA,
40 SECP256K1,
41 SECP256R1,
42 SECP384R1,
43 SECP521R1,
44 EllipticCurve,
45 EllipticCurvePrivateKey,
46 EllipticCurvePrivateNumbers,
47 EllipticCurvePublicKey,
48 EllipticCurvePublicNumbers,
49 )
50 from cryptography.hazmat.primitives.asymmetric.ed448 import (
51 Ed448PrivateKey,
52 Ed448PublicKey,
53 )
54 from cryptography.hazmat.primitives.asymmetric.ed25519 import (
55 Ed25519PrivateKey,
56 Ed25519PublicKey,
57 )
58 from cryptography.hazmat.primitives.asymmetric.rsa import (
59 RSAPrivateKey,
60 RSAPrivateNumbers,
61 RSAPublicKey,
62 RSAPublicNumbers,
63 rsa_crt_dmp1,
64 rsa_crt_dmq1,
65 rsa_crt_iqmp,
66 rsa_recover_prime_factors,
67 )
68 from cryptography.hazmat.primitives.serialization import (
69 Encoding,
70 NoEncryption,
71 PrivateFormat,
72 PublicFormat,
73 load_pem_private_key,
74 load_pem_public_key,
75 load_ssh_public_key,
76 )
78 # pyjwt-964: we use these both for type checking below, as well as for validating the key passed in.
79 # in Py >= 3.10, we can replace this with the Union types below
80 ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey)
81 ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey)
82 ALLOWED_OKP_KEY_TYPES = (
83 Ed25519PrivateKey,
84 Ed25519PublicKey,
85 Ed448PrivateKey,
86 Ed448PublicKey,
87 )
88 ALLOWED_KEY_TYPES = (
89 ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES
90 )
91 ALLOWED_PRIVATE_KEY_TYPES = (
92 RSAPrivateKey,
93 EllipticCurvePrivateKey,
94 Ed25519PrivateKey,
95 Ed448PrivateKey,
96 )
97 ALLOWED_PUBLIC_KEY_TYPES = (
98 RSAPublicKey,
99 EllipticCurvePublicKey,
100 Ed25519PublicKey,
101 Ed448PublicKey,
102 )
104 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")):
105 import sys
107 if sys.version_info >= (3, 10):
108 from typing import TypeAlias
109 else:
110 # Python 3.9 and lower
111 from typing_extensions import TypeAlias
113 from cryptography.hazmat.primitives.asymmetric.types import (
114 PrivateKeyTypes,
115 PublicKeyTypes,
116 )
118 # Type aliases for convenience in algorithms method signatures
119 AllowedRSAKeys: TypeAlias = Union[RSAPrivateKey, RSAPublicKey]
120 AllowedECKeys: TypeAlias = Union[
121 EllipticCurvePrivateKey, EllipticCurvePublicKey
122 ]
123 AllowedOKPKeys: TypeAlias = Union[
124 Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey
125 ]
126 AllowedKeys: TypeAlias = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys]
127 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed)
128 AllowedPrivateKeys: TypeAlias = Union[
129 RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey
130 ]
131 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed)
132 AllowedPublicKeys: TypeAlias = Union[
133 RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey
134 ]
136 has_crypto = True
137except ModuleNotFoundError:
138 has_crypto = False
141requires_cryptography = {
142 "RS256",
143 "RS384",
144 "RS512",
145 "ES256",
146 "ES256K",
147 "ES384",
148 "ES521",
149 "ES512",
150 "PS256",
151 "PS384",
152 "PS512",
153 "EdDSA",
154}
157def get_default_algorithms() -> dict[str, Algorithm]:
158 """
159 Returns the algorithms that are implemented by the library.
160 """
161 default_algorithms: dict[str, Algorithm] = {
162 "none": NoneAlgorithm(),
163 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
164 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
165 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
166 }
168 if has_crypto:
169 default_algorithms.update(
170 {
171 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
172 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
173 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
174 "ES256": ECAlgorithm(ECAlgorithm.SHA256),
175 "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
176 "ES384": ECAlgorithm(ECAlgorithm.SHA384),
177 "ES521": ECAlgorithm(ECAlgorithm.SHA512),
178 "ES512": ECAlgorithm(
179 ECAlgorithm.SHA512
180 ), # Backward compat for #219 fix
181 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
182 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
183 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
184 "EdDSA": OKPAlgorithm(),
185 }
186 )
188 return default_algorithms
191class Algorithm(ABC):
192 """
193 The interface for an algorithm used to sign and verify tokens.
194 """
196 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
197 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None
199 def compute_hash_digest(self, bytestr: bytes) -> bytes:
200 """
201 Compute a hash digest using the specified algorithm's hash algorithm.
203 If there is no hash algorithm, raises a NotImplementedError.
204 """
205 # lookup self.hash_alg if defined in a way that mypy can understand
206 hash_alg = getattr(self, "hash_alg", None)
207 if hash_alg is None:
208 raise NotImplementedError
210 if (
211 has_crypto
212 and isinstance(hash_alg, type)
213 and issubclass(hash_alg, hashes.HashAlgorithm)
214 ):
215 digest = hashes.Hash(hash_alg(), backend=default_backend())
216 digest.update(bytestr)
217 return bytes(digest.finalize())
218 else:
219 return bytes(hash_alg(bytestr).digest())
221 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes) -> None:
222 """Check that the key belongs to the right cryptographic family.
224 Note that this method only works when ``cryptography`` is installed.
226 :param key: Potentially a cryptography key
227 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>`
228 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm
229 :raises InvalidKeyError: if the key doesn't match the expected key classes
230 """
231 if not has_crypto or self._crypto_key_types is None:
232 raise ValueError(
233 "This method requires the cryptography library, and should only be used by cryptography-based algorithms."
234 )
236 if not isinstance(key, self._crypto_key_types):
237 valid_classes = (cls.__name__ for cls in self._crypto_key_types)
238 actual_class = key.__class__.__name__
239 self_class = self.__class__.__name__
240 raise InvalidKeyError(
241 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}"
242 )
244 @abstractmethod
245 def prepare_key(self, key: Any) -> Any:
246 """
247 Performs necessary validation and conversions on the key and returns
248 the key value in the proper format for sign() and verify().
249 """
251 @abstractmethod
252 def sign(self, msg: bytes, key: Any) -> bytes:
253 """
254 Returns a digital signature for the specified message
255 using the specified key value.
256 """
258 @abstractmethod
259 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
260 """
261 Verifies that the specified digital signature is valid
262 for the specified message and key values.
263 """
265 @overload
266 @staticmethod
267 @abstractmethod
268 def to_jwk(key_obj: Any, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
270 @overload
271 @staticmethod
272 @abstractmethod
273 def to_jwk(
274 key_obj: Any, as_dict: Literal[False] = False
275 ) -> str: ... # pragma: no cover
277 @staticmethod
278 @abstractmethod
279 def to_jwk(key_obj: Any, as_dict: bool = False) -> JWKDict | str:
280 """
281 Serializes a given key into a JWK
282 """
284 @staticmethod
285 @abstractmethod
286 def from_jwk(jwk: str | JWKDict) -> Any:
287 """
288 Deserializes a given key from JWK back into a key object
289 """
292class NoneAlgorithm(Algorithm):
293 """
294 Placeholder for use when no signing or verification
295 operations are required.
296 """
298 def prepare_key(self, key: str | None) -> None:
299 if key == "":
300 key = None
302 if key is not None:
303 raise InvalidKeyError('When alg = "none", key value must be None.')
305 return key
307 def sign(self, msg: bytes, key: None) -> bytes:
308 return b""
310 def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
311 return False
313 @staticmethod
314 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
315 raise NotImplementedError()
317 @staticmethod
318 def from_jwk(jwk: str | JWKDict) -> NoReturn:
319 raise NotImplementedError()
322class HMACAlgorithm(Algorithm):
323 """
324 Performs signing and verification operations using HMAC
325 and the specified hash function.
326 """
328 SHA256: ClassVar[HashlibHash] = hashlib.sha256
329 SHA384: ClassVar[HashlibHash] = hashlib.sha384
330 SHA512: ClassVar[HashlibHash] = hashlib.sha512
332 def __init__(self, hash_alg: HashlibHash) -> None:
333 self.hash_alg = hash_alg
335 def prepare_key(self, key: str | bytes) -> bytes:
336 key_bytes = force_bytes(key)
338 if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
339 raise InvalidKeyError(
340 "The specified key is an asymmetric key or x509 certificate and"
341 " should not be used as an HMAC secret."
342 )
344 return key_bytes
346 @overload
347 @staticmethod
348 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: ...
350 @overload
351 @staticmethod
352 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: ...
354 @staticmethod
355 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
356 jwk = {
357 "k": base64url_encode(force_bytes(key_obj)).decode(),
358 "kty": "oct",
359 }
361 if as_dict:
362 return jwk
363 else:
364 return json.dumps(jwk)
366 @staticmethod
367 def from_jwk(jwk: str | JWKDict) -> bytes:
368 try:
369 if isinstance(jwk, str):
370 obj: JWKDict = json.loads(jwk)
371 elif isinstance(jwk, dict):
372 obj = jwk
373 else:
374 raise ValueError
375 except ValueError:
376 raise InvalidKeyError("Key is not valid JSON") from None
378 if obj.get("kty") != "oct":
379 raise InvalidKeyError("Not an HMAC key")
381 return base64url_decode(obj["k"])
383 def sign(self, msg: bytes, key: bytes) -> bytes:
384 return hmac.new(key, msg, self.hash_alg).digest()
386 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
387 return hmac.compare_digest(sig, self.sign(msg, key))
390if has_crypto:
392 class RSAAlgorithm(Algorithm):
393 """
394 Performs signing and verification operations using
395 RSASSA-PKCS-v1_5 and the specified hash function.
396 """
398 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
399 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
400 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
402 _crypto_key_types = ALLOWED_RSA_KEY_TYPES
404 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
405 self.hash_alg = hash_alg
407 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
408 if isinstance(key, self._crypto_key_types):
409 return key
411 if not isinstance(key, (bytes, str)):
412 raise TypeError("Expecting a PEM-formatted key.")
414 key_bytes = force_bytes(key)
416 try:
417 if key_bytes.startswith(b"ssh-rsa"):
418 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
419 self.check_crypto_key_type(public_key)
420 return cast(RSAPublicKey, public_key)
421 else:
422 private_key: PrivateKeyTypes = load_pem_private_key(
423 key_bytes, password=None
424 )
425 self.check_crypto_key_type(private_key)
426 return cast(RSAPrivateKey, private_key)
427 except ValueError:
428 try:
429 public_key = load_pem_public_key(key_bytes)
430 self.check_crypto_key_type(public_key)
431 return cast(RSAPublicKey, public_key)
432 except (ValueError, UnsupportedAlgorithm):
433 raise InvalidKeyError(
434 "Could not parse the provided public key."
435 ) from None
437 @overload
438 @staticmethod
439 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: ...
441 @overload
442 @staticmethod
443 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: ...
445 @staticmethod
446 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
447 obj: dict[str, Any] | None = None
449 if hasattr(key_obj, "private_numbers"):
450 # Private key
451 numbers = key_obj.private_numbers()
453 obj = {
454 "kty": "RSA",
455 "key_ops": ["sign"],
456 "n": to_base64url_uint(numbers.public_numbers.n).decode(),
457 "e": to_base64url_uint(numbers.public_numbers.e).decode(),
458 "d": to_base64url_uint(numbers.d).decode(),
459 "p": to_base64url_uint(numbers.p).decode(),
460 "q": to_base64url_uint(numbers.q).decode(),
461 "dp": to_base64url_uint(numbers.dmp1).decode(),
462 "dq": to_base64url_uint(numbers.dmq1).decode(),
463 "qi": to_base64url_uint(numbers.iqmp).decode(),
464 }
466 elif hasattr(key_obj, "verify"):
467 # Public key
468 numbers = key_obj.public_numbers()
470 obj = {
471 "kty": "RSA",
472 "key_ops": ["verify"],
473 "n": to_base64url_uint(numbers.n).decode(),
474 "e": to_base64url_uint(numbers.e).decode(),
475 }
476 else:
477 raise InvalidKeyError("Not a public or private key")
479 if as_dict:
480 return obj
481 else:
482 return json.dumps(obj)
484 @staticmethod
485 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
486 try:
487 if isinstance(jwk, str):
488 obj = json.loads(jwk)
489 elif isinstance(jwk, dict):
490 obj = jwk
491 else:
492 raise ValueError
493 except ValueError:
494 raise InvalidKeyError("Key is not valid JSON") from None
496 if obj.get("kty") != "RSA":
497 raise InvalidKeyError("Not an RSA key") from None
499 if "d" in obj and "e" in obj and "n" in obj:
500 # Private key
501 if "oth" in obj:
502 raise InvalidKeyError(
503 "Unsupported RSA private key: > 2 primes not supported"
504 )
506 other_props = ["p", "q", "dp", "dq", "qi"]
507 props_found = [prop in obj for prop in other_props]
508 any_props_found = any(props_found)
510 if any_props_found and not all(props_found):
511 raise InvalidKeyError(
512 "RSA key must include all parameters if any are present besides d"
513 ) from None
515 public_numbers = RSAPublicNumbers(
516 from_base64url_uint(obj["e"]),
517 from_base64url_uint(obj["n"]),
518 )
520 if any_props_found:
521 numbers = RSAPrivateNumbers(
522 d=from_base64url_uint(obj["d"]),
523 p=from_base64url_uint(obj["p"]),
524 q=from_base64url_uint(obj["q"]),
525 dmp1=from_base64url_uint(obj["dp"]),
526 dmq1=from_base64url_uint(obj["dq"]),
527 iqmp=from_base64url_uint(obj["qi"]),
528 public_numbers=public_numbers,
529 )
530 else:
531 d = from_base64url_uint(obj["d"])
532 p, q = rsa_recover_prime_factors(
533 public_numbers.n, d, public_numbers.e
534 )
536 numbers = RSAPrivateNumbers(
537 d=d,
538 p=p,
539 q=q,
540 dmp1=rsa_crt_dmp1(d, p),
541 dmq1=rsa_crt_dmq1(d, q),
542 iqmp=rsa_crt_iqmp(p, q),
543 public_numbers=public_numbers,
544 )
546 return numbers.private_key()
547 elif "n" in obj and "e" in obj:
548 # Public key
549 return RSAPublicNumbers(
550 from_base64url_uint(obj["e"]),
551 from_base64url_uint(obj["n"]),
552 ).public_key()
553 else:
554 raise InvalidKeyError("Not a public or private key")
556 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
557 signature: bytes = key.sign(msg, padding.PKCS1v15(), self.hash_alg())
558 return signature
560 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
561 try:
562 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
563 return True
564 except InvalidSignature:
565 return False
567 class ECAlgorithm(Algorithm):
568 """
569 Performs signing and verification operations using
570 ECDSA and the specified hash function
571 """
573 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
574 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
575 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
577 _crypto_key_types = ALLOWED_EC_KEY_TYPES
579 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
580 self.hash_alg = hash_alg
582 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
583 if isinstance(key, self._crypto_key_types):
584 return key
586 if not isinstance(key, (bytes, str)):
587 raise TypeError("Expecting a PEM-formatted key.")
589 key_bytes = force_bytes(key)
591 # Attempt to load key. We don't know if it's
592 # a Signing Key or a Verifying Key, so we try
593 # the Verifying Key first.
594 try:
595 if key_bytes.startswith(b"ecdsa-sha2-"):
596 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
597 else:
598 public_key = load_pem_public_key(key_bytes)
600 # Explicit check the key to prevent confusing errors from cryptography
601 self.check_crypto_key_type(public_key)
602 return cast(EllipticCurvePublicKey, public_key)
603 except ValueError:
604 private_key = load_pem_private_key(key_bytes, password=None)
605 self.check_crypto_key_type(private_key)
606 return cast(EllipticCurvePrivateKey, private_key)
608 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
609 der_sig = key.sign(msg, ECDSA(self.hash_alg()))
611 return der_to_raw_signature(der_sig, key.curve)
613 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
614 try:
615 der_sig = raw_to_der_signature(sig, key.curve)
616 except ValueError:
617 return False
619 try:
620 public_key = (
621 key.public_key()
622 if isinstance(key, EllipticCurvePrivateKey)
623 else key
624 )
625 public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
626 return True
627 except InvalidSignature:
628 return False
630 @overload
631 @staticmethod
632 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: ...
634 @overload
635 @staticmethod
636 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: ...
638 @staticmethod
639 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
640 if isinstance(key_obj, EllipticCurvePrivateKey):
641 public_numbers = key_obj.public_key().public_numbers()
642 elif isinstance(key_obj, EllipticCurvePublicKey):
643 public_numbers = key_obj.public_numbers()
644 else:
645 raise InvalidKeyError("Not a public or private key")
647 if isinstance(key_obj.curve, SECP256R1):
648 crv = "P-256"
649 elif isinstance(key_obj.curve, SECP384R1):
650 crv = "P-384"
651 elif isinstance(key_obj.curve, SECP521R1):
652 crv = "P-521"
653 elif isinstance(key_obj.curve, SECP256K1):
654 crv = "secp256k1"
655 else:
656 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
658 obj: dict[str, Any] = {
659 "kty": "EC",
660 "crv": crv,
661 "x": to_base64url_uint(
662 public_numbers.x,
663 bit_length=key_obj.curve.key_size,
664 ).decode(),
665 "y": to_base64url_uint(
666 public_numbers.y,
667 bit_length=key_obj.curve.key_size,
668 ).decode(),
669 }
671 if isinstance(key_obj, EllipticCurvePrivateKey):
672 obj["d"] = to_base64url_uint(
673 key_obj.private_numbers().private_value,
674 bit_length=key_obj.curve.key_size,
675 ).decode()
677 if as_dict:
678 return obj
679 else:
680 return json.dumps(obj)
682 @staticmethod
683 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
684 try:
685 if isinstance(jwk, str):
686 obj = json.loads(jwk)
687 elif isinstance(jwk, dict):
688 obj = jwk
689 else:
690 raise ValueError
691 except ValueError:
692 raise InvalidKeyError("Key is not valid JSON") from None
694 if obj.get("kty") != "EC":
695 raise InvalidKeyError("Not an Elliptic curve key") from None
697 if "x" not in obj or "y" not in obj:
698 raise InvalidKeyError("Not an Elliptic curve key") from None
700 x = base64url_decode(obj.get("x"))
701 y = base64url_decode(obj.get("y"))
703 curve = obj.get("crv")
704 curve_obj: EllipticCurve
706 if curve == "P-256":
707 if len(x) == len(y) == 32:
708 curve_obj = SECP256R1()
709 else:
710 raise InvalidKeyError(
711 "Coords should be 32 bytes for curve P-256"
712 ) from None
713 elif curve == "P-384":
714 if len(x) == len(y) == 48:
715 curve_obj = SECP384R1()
716 else:
717 raise InvalidKeyError(
718 "Coords should be 48 bytes for curve P-384"
719 ) from None
720 elif curve == "P-521":
721 if len(x) == len(y) == 66:
722 curve_obj = SECP521R1()
723 else:
724 raise InvalidKeyError(
725 "Coords should be 66 bytes for curve P-521"
726 ) from None
727 elif curve == "secp256k1":
728 if len(x) == len(y) == 32:
729 curve_obj = SECP256K1()
730 else:
731 raise InvalidKeyError(
732 "Coords should be 32 bytes for curve secp256k1"
733 )
734 else:
735 raise InvalidKeyError(f"Invalid curve: {curve}")
737 public_numbers = EllipticCurvePublicNumbers(
738 x=int.from_bytes(x, byteorder="big"),
739 y=int.from_bytes(y, byteorder="big"),
740 curve=curve_obj,
741 )
743 if "d" not in obj:
744 return public_numbers.public_key()
746 d = base64url_decode(obj.get("d"))
747 if len(d) != len(x):
748 raise InvalidKeyError(
749 "D should be {} bytes for curve {}", len(x), curve
750 )
752 return EllipticCurvePrivateNumbers(
753 int.from_bytes(d, byteorder="big"), public_numbers
754 ).private_key()
756 class RSAPSSAlgorithm(RSAAlgorithm):
757 """
758 Performs a signature using RSASSA-PSS with MGF1
759 """
761 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
762 signature: bytes = key.sign(
763 msg,
764 padding.PSS(
765 mgf=padding.MGF1(self.hash_alg()),
766 salt_length=self.hash_alg().digest_size,
767 ),
768 self.hash_alg(),
769 )
770 return signature
772 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
773 try:
774 key.verify(
775 sig,
776 msg,
777 padding.PSS(
778 mgf=padding.MGF1(self.hash_alg()),
779 salt_length=self.hash_alg().digest_size,
780 ),
781 self.hash_alg(),
782 )
783 return True
784 except InvalidSignature:
785 return False
787 class OKPAlgorithm(Algorithm):
788 """
789 Performs signing and verification operations using EdDSA
791 This class requires ``cryptography>=2.6`` to be installed.
792 """
794 _crypto_key_types = ALLOWED_OKP_KEY_TYPES
796 def __init__(self, **kwargs: Any) -> None:
797 pass
799 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
800 if not isinstance(key, (str, bytes)):
801 self.check_crypto_key_type(key)
802 return key
804 key_str = key.decode("utf-8") if isinstance(key, bytes) else key
805 key_bytes = key.encode("utf-8") if isinstance(key, str) else key
807 loaded_key: PublicKeyTypes | PrivateKeyTypes
808 if "-----BEGIN PUBLIC" in key_str:
809 loaded_key = load_pem_public_key(key_bytes)
810 elif "-----BEGIN PRIVATE" in key_str:
811 loaded_key = load_pem_private_key(key_bytes, password=None)
812 elif key_str[0:4] == "ssh-":
813 loaded_key = load_ssh_public_key(key_bytes)
814 else:
815 raise InvalidKeyError("Not a public or private key")
817 # Explicit check the key to prevent confusing errors from cryptography
818 self.check_crypto_key_type(loaded_key)
819 return cast("AllowedOKPKeys", loaded_key)
821 def sign(
822 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
823 ) -> bytes:
824 """
825 Sign a message ``msg`` using the EdDSA private key ``key``
826 :param str|bytes msg: Message to sign
827 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
828 or :class:`.Ed448PrivateKey` isinstance
829 :return bytes signature: The signature, as bytes
830 """
831 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
832 signature: bytes = key.sign(msg_bytes)
833 return signature
835 def verify(
836 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
837 ) -> bool:
838 """
839 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
841 :param str|bytes sig: EdDSA signature to check ``msg`` against
842 :param str|bytes msg: Message to sign
843 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
844 A private or public EdDSA key instance
845 :return bool verified: True if signature is valid, False if not.
846 """
847 try:
848 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
849 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
851 public_key = (
852 key.public_key()
853 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
854 else key
855 )
856 public_key.verify(sig_bytes, msg_bytes)
857 return True # If no exception was raised, the signature is valid.
858 except InvalidSignature:
859 return False
861 @overload
862 @staticmethod
863 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: ...
865 @overload
866 @staticmethod
867 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: ...
869 @staticmethod
870 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
871 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
872 x = key.public_bytes(
873 encoding=Encoding.Raw,
874 format=PublicFormat.Raw,
875 )
876 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
878 obj = {
879 "x": base64url_encode(force_bytes(x)).decode(),
880 "kty": "OKP",
881 "crv": crv,
882 }
884 if as_dict:
885 return obj
886 else:
887 return json.dumps(obj)
889 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
890 d = key.private_bytes(
891 encoding=Encoding.Raw,
892 format=PrivateFormat.Raw,
893 encryption_algorithm=NoEncryption(),
894 )
896 x = key.public_key().public_bytes(
897 encoding=Encoding.Raw,
898 format=PublicFormat.Raw,
899 )
901 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
902 obj = {
903 "x": base64url_encode(force_bytes(x)).decode(),
904 "d": base64url_encode(force_bytes(d)).decode(),
905 "kty": "OKP",
906 "crv": crv,
907 }
909 if as_dict:
910 return obj
911 else:
912 return json.dumps(obj)
914 raise InvalidKeyError("Not a public or private key")
916 @staticmethod
917 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
918 try:
919 if isinstance(jwk, str):
920 obj = json.loads(jwk)
921 elif isinstance(jwk, dict):
922 obj = jwk
923 else:
924 raise ValueError
925 except ValueError:
926 raise InvalidKeyError("Key is not valid JSON") from None
928 if obj.get("kty") != "OKP":
929 raise InvalidKeyError("Not an Octet Key Pair")
931 curve = obj.get("crv")
932 if curve != "Ed25519" and curve != "Ed448":
933 raise InvalidKeyError(f"Invalid curve: {curve}")
935 if "x" not in obj:
936 raise InvalidKeyError('OKP should have "x" parameter')
937 x = base64url_decode(obj.get("x"))
939 try:
940 if "d" not in obj:
941 if curve == "Ed25519":
942 return Ed25519PublicKey.from_public_bytes(x)
943 return Ed448PublicKey.from_public_bytes(x)
944 d = base64url_decode(obj.get("d"))
945 if curve == "Ed25519":
946 return Ed25519PrivateKey.from_private_bytes(d)
947 return Ed448PrivateKey.from_private_bytes(d)
948 except ValueError as err:
949 raise InvalidKeyError("Invalid key parameter") from err