Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import hashlib
4import hmac
5import json
6from abc import ABC, abstractmethod
7from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
9from .exceptions import InvalidKeyError
10from .types import HashlibHash, JWKDict
11from .utils import (
12 base64url_decode,
13 base64url_encode,
14 der_to_raw_signature,
15 force_bytes,
16 from_base64url_uint,
17 is_pem_format,
18 is_ssh_key,
19 raw_to_der_signature,
20 to_base64url_uint,
21)
23try:
24 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
25 from cryptography.hazmat.backends import default_backend
26 from cryptography.hazmat.primitives import hashes
27 from cryptography.hazmat.primitives.asymmetric import padding
28 from cryptography.hazmat.primitives.asymmetric.ec import (
29 ECDSA,
30 SECP256K1,
31 SECP256R1,
32 SECP384R1,
33 SECP521R1,
34 EllipticCurve,
35 EllipticCurvePrivateKey,
36 EllipticCurvePrivateNumbers,
37 EllipticCurvePublicKey,
38 EllipticCurvePublicNumbers,
39 )
40 from cryptography.hazmat.primitives.asymmetric.ed448 import (
41 Ed448PrivateKey,
42 Ed448PublicKey,
43 )
44 from cryptography.hazmat.primitives.asymmetric.ed25519 import (
45 Ed25519PrivateKey,
46 Ed25519PublicKey,
47 )
48 from cryptography.hazmat.primitives.asymmetric.rsa import (
49 RSAPrivateKey,
50 RSAPrivateNumbers,
51 RSAPublicKey,
52 RSAPublicNumbers,
53 rsa_crt_dmp1,
54 rsa_crt_dmq1,
55 rsa_crt_iqmp,
56 rsa_recover_prime_factors,
57 )
58 from cryptography.hazmat.primitives.serialization import (
59 Encoding,
60 NoEncryption,
61 PrivateFormat,
62 PublicFormat,
63 load_pem_private_key,
64 load_pem_public_key,
65 load_ssh_public_key,
66 )
68 has_crypto = True
69except ModuleNotFoundError:
70 has_crypto = False
73if TYPE_CHECKING:
74 # Type aliases for convenience in algorithms method signatures
75 AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
76 AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
77 AllowedOKPKeys = (
78 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
79 )
80 AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
81 AllowedPrivateKeys = (
82 RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
83 )
84 AllowedPublicKeys = (
85 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
86 )
89requires_cryptography = {
90 "RS256",
91 "RS384",
92 "RS512",
93 "ES256",
94 "ES256K",
95 "ES384",
96 "ES521",
97 "ES512",
98 "PS256",
99 "PS384",
100 "PS512",
101 "EdDSA",
102}
105def get_default_algorithms() -> dict[str, Algorithm]:
106 """
107 Returns the algorithms that are implemented by the library.
108 """
109 default_algorithms = {
110 "none": NoneAlgorithm(),
111 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
112 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
113 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
114 }
116 if has_crypto:
117 default_algorithms.update(
118 {
119 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
120 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
121 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
122 "ES256": ECAlgorithm(ECAlgorithm.SHA256),
123 "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
124 "ES384": ECAlgorithm(ECAlgorithm.SHA384),
125 "ES521": ECAlgorithm(ECAlgorithm.SHA512),
126 "ES512": ECAlgorithm(
127 ECAlgorithm.SHA512
128 ), # Backward compat for #219 fix
129 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
130 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
131 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
132 "EdDSA": OKPAlgorithm(),
133 }
134 )
136 return default_algorithms
139class Algorithm(ABC):
140 """
141 The interface for an algorithm used to sign and verify tokens.
142 """
144 def compute_hash_digest(self, bytestr: bytes) -> bytes:
145 """
146 Compute a hash digest using the specified algorithm's hash algorithm.
148 If there is no hash algorithm, raises a NotImplementedError.
149 """
150 # lookup self.hash_alg if defined in a way that mypy can understand
151 hash_alg = getattr(self, "hash_alg", None)
152 if hash_alg is None:
153 raise NotImplementedError
155 if (
156 has_crypto
157 and isinstance(hash_alg, type)
158 and issubclass(hash_alg, hashes.HashAlgorithm)
159 ):
160 digest = hashes.Hash(hash_alg(), backend=default_backend())
161 digest.update(bytestr)
162 return bytes(digest.finalize())
163 else:
164 return bytes(hash_alg(bytestr).digest())
166 @abstractmethod
167 def prepare_key(self, key: Any) -> Any:
168 """
169 Performs necessary validation and conversions on the key and returns
170 the key value in the proper format for sign() and verify().
171 """
173 @abstractmethod
174 def sign(self, msg: bytes, key: Any) -> bytes:
175 """
176 Returns a digital signature for the specified message
177 using the specified key value.
178 """
180 @abstractmethod
181 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
182 """
183 Verifies that the specified digital signature is valid
184 for the specified message and key values.
185 """
187 @overload
188 @staticmethod
189 @abstractmethod
190 def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
192 @overload
193 @staticmethod
194 @abstractmethod
195 def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
197 @staticmethod
198 @abstractmethod
199 def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
200 """
201 Serializes a given key into a JWK
202 """
204 @staticmethod
205 @abstractmethod
206 def from_jwk(jwk: str | JWKDict) -> Any:
207 """
208 Deserializes a given key from JWK back into a key object
209 """
212class NoneAlgorithm(Algorithm):
213 """
214 Placeholder for use when no signing or verification
215 operations are required.
216 """
218 def prepare_key(self, key: str | None) -> None:
219 if key == "":
220 key = None
222 if key is not None:
223 raise InvalidKeyError('When alg = "none", key value must be None.')
225 return key
227 def sign(self, msg: bytes, key: None) -> bytes:
228 return b""
230 def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
231 return False
233 @staticmethod
234 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
235 raise NotImplementedError()
237 @staticmethod
238 def from_jwk(jwk: str | JWKDict) -> NoReturn:
239 raise NotImplementedError()
242class HMACAlgorithm(Algorithm):
243 """
244 Performs signing and verification operations using HMAC
245 and the specified hash function.
246 """
248 SHA256: ClassVar[HashlibHash] = hashlib.sha256
249 SHA384: ClassVar[HashlibHash] = hashlib.sha384
250 SHA512: ClassVar[HashlibHash] = hashlib.sha512
252 def __init__(self, hash_alg: HashlibHash) -> None:
253 self.hash_alg = hash_alg
255 def prepare_key(self, key: str | bytes) -> bytes:
256 key_bytes = force_bytes(key)
258 if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
259 raise InvalidKeyError(
260 "The specified key is an asymmetric key or x509 certificate and"
261 " should not be used as an HMAC secret."
262 )
264 return key_bytes
266 @overload
267 @staticmethod
268 def to_jwk(
269 key_obj: str | bytes, as_dict: Literal[True]
270 ) -> JWKDict: ... # pragma: no cover
272 @overload
273 @staticmethod
274 def to_jwk(
275 key_obj: str | bytes, as_dict: Literal[False] = False
276 ) -> str: ... # pragma: no cover
278 @staticmethod
279 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
280 jwk = {
281 "k": base64url_encode(force_bytes(key_obj)).decode(),
282 "kty": "oct",
283 }
285 if as_dict:
286 return jwk
287 else:
288 return json.dumps(jwk)
290 @staticmethod
291 def from_jwk(jwk: str | JWKDict) -> bytes:
292 try:
293 if isinstance(jwk, str):
294 obj: JWKDict = json.loads(jwk)
295 elif isinstance(jwk, dict):
296 obj = jwk
297 else:
298 raise ValueError
299 except ValueError:
300 raise InvalidKeyError("Key is not valid JSON") from None
302 if obj.get("kty") != "oct":
303 raise InvalidKeyError("Not an HMAC key")
305 return base64url_decode(obj["k"])
307 def sign(self, msg: bytes, key: bytes) -> bytes:
308 return hmac.new(key, msg, self.hash_alg).digest()
310 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
311 return hmac.compare_digest(sig, self.sign(msg, key))
314if has_crypto:
316 class RSAAlgorithm(Algorithm):
317 """
318 Performs signing and verification operations using
319 RSASSA-PKCS-v1_5 and the specified hash function.
320 """
322 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
323 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
324 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
326 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
327 self.hash_alg = hash_alg
329 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
330 if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
331 return key
333 if not isinstance(key, (bytes, str)):
334 raise TypeError("Expecting a PEM-formatted key.")
336 key_bytes = force_bytes(key)
338 try:
339 if key_bytes.startswith(b"ssh-rsa"):
340 return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
341 else:
342 return cast(
343 RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
344 )
345 except ValueError:
346 try:
347 return cast(RSAPublicKey, load_pem_public_key(key_bytes))
348 except (ValueError, UnsupportedAlgorithm):
349 raise InvalidKeyError(
350 "Could not parse the provided public key."
351 ) from None
353 @overload
354 @staticmethod
355 def to_jwk(
356 key_obj: AllowedRSAKeys, as_dict: Literal[True]
357 ) -> JWKDict: ... # pragma: no cover
359 @overload
360 @staticmethod
361 def to_jwk(
362 key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
363 ) -> str: ... # pragma: no cover
365 @staticmethod
366 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
367 obj: dict[str, Any] | None = None
369 if hasattr(key_obj, "private_numbers"):
370 # Private key
371 numbers = key_obj.private_numbers()
373 obj = {
374 "kty": "RSA",
375 "key_ops": ["sign"],
376 "n": to_base64url_uint(numbers.public_numbers.n).decode(),
377 "e": to_base64url_uint(numbers.public_numbers.e).decode(),
378 "d": to_base64url_uint(numbers.d).decode(),
379 "p": to_base64url_uint(numbers.p).decode(),
380 "q": to_base64url_uint(numbers.q).decode(),
381 "dp": to_base64url_uint(numbers.dmp1).decode(),
382 "dq": to_base64url_uint(numbers.dmq1).decode(),
383 "qi": to_base64url_uint(numbers.iqmp).decode(),
384 }
386 elif hasattr(key_obj, "verify"):
387 # Public key
388 numbers = key_obj.public_numbers()
390 obj = {
391 "kty": "RSA",
392 "key_ops": ["verify"],
393 "n": to_base64url_uint(numbers.n).decode(),
394 "e": to_base64url_uint(numbers.e).decode(),
395 }
396 else:
397 raise InvalidKeyError("Not a public or private key")
399 if as_dict:
400 return obj
401 else:
402 return json.dumps(obj)
404 @staticmethod
405 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
406 try:
407 if isinstance(jwk, str):
408 obj = json.loads(jwk)
409 elif isinstance(jwk, dict):
410 obj = jwk
411 else:
412 raise ValueError
413 except ValueError:
414 raise InvalidKeyError("Key is not valid JSON") from None
416 if obj.get("kty") != "RSA":
417 raise InvalidKeyError("Not an RSA key") from None
419 if "d" in obj and "e" in obj and "n" in obj:
420 # Private key
421 if "oth" in obj:
422 raise InvalidKeyError(
423 "Unsupported RSA private key: > 2 primes not supported"
424 )
426 other_props = ["p", "q", "dp", "dq", "qi"]
427 props_found = [prop in obj for prop in other_props]
428 any_props_found = any(props_found)
430 if any_props_found and not all(props_found):
431 raise InvalidKeyError(
432 "RSA key must include all parameters if any are present besides d"
433 ) from None
435 public_numbers = RSAPublicNumbers(
436 from_base64url_uint(obj["e"]),
437 from_base64url_uint(obj["n"]),
438 )
440 if any_props_found:
441 numbers = RSAPrivateNumbers(
442 d=from_base64url_uint(obj["d"]),
443 p=from_base64url_uint(obj["p"]),
444 q=from_base64url_uint(obj["q"]),
445 dmp1=from_base64url_uint(obj["dp"]),
446 dmq1=from_base64url_uint(obj["dq"]),
447 iqmp=from_base64url_uint(obj["qi"]),
448 public_numbers=public_numbers,
449 )
450 else:
451 d = from_base64url_uint(obj["d"])
452 p, q = rsa_recover_prime_factors(
453 public_numbers.n, d, public_numbers.e
454 )
456 numbers = RSAPrivateNumbers(
457 d=d,
458 p=p,
459 q=q,
460 dmp1=rsa_crt_dmp1(d, p),
461 dmq1=rsa_crt_dmq1(d, q),
462 iqmp=rsa_crt_iqmp(p, q),
463 public_numbers=public_numbers,
464 )
466 return numbers.private_key()
467 elif "n" in obj and "e" in obj:
468 # Public key
469 return RSAPublicNumbers(
470 from_base64url_uint(obj["e"]),
471 from_base64url_uint(obj["n"]),
472 ).public_key()
473 else:
474 raise InvalidKeyError("Not a public or private key")
476 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
477 return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
479 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
480 try:
481 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
482 return True
483 except InvalidSignature:
484 return False
486 class ECAlgorithm(Algorithm):
487 """
488 Performs signing and verification operations using
489 ECDSA and the specified hash function
490 """
492 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
493 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
494 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
496 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
497 self.hash_alg = hash_alg
499 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
500 if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
501 return key
503 if not isinstance(key, (bytes, str)):
504 raise TypeError("Expecting a PEM-formatted key.")
506 key_bytes = force_bytes(key)
508 # Attempt to load key. We don't know if it's
509 # a Signing Key or a Verifying Key, so we try
510 # the Verifying Key first.
511 try:
512 if key_bytes.startswith(b"ecdsa-sha2-"):
513 crypto_key = load_ssh_public_key(key_bytes)
514 else:
515 crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
516 except ValueError:
517 crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
519 # Explicit check the key to prevent confusing errors from cryptography
520 if not isinstance(
521 crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
522 ):
523 raise InvalidKeyError(
524 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
525 ) from None
527 return crypto_key
529 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
530 der_sig = key.sign(msg, ECDSA(self.hash_alg()))
532 return der_to_raw_signature(der_sig, key.curve)
534 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
535 try:
536 der_sig = raw_to_der_signature(sig, key.curve)
537 except ValueError:
538 return False
540 try:
541 public_key = (
542 key.public_key()
543 if isinstance(key, EllipticCurvePrivateKey)
544 else key
545 )
546 public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
547 return True
548 except InvalidSignature:
549 return False
551 @overload
552 @staticmethod
553 def to_jwk(
554 key_obj: AllowedECKeys, as_dict: Literal[True]
555 ) -> JWKDict: ... # pragma: no cover
557 @overload
558 @staticmethod
559 def to_jwk(
560 key_obj: AllowedECKeys, as_dict: Literal[False] = False
561 ) -> str: ... # pragma: no cover
563 @staticmethod
564 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
565 if isinstance(key_obj, EllipticCurvePrivateKey):
566 public_numbers = key_obj.public_key().public_numbers()
567 elif isinstance(key_obj, EllipticCurvePublicKey):
568 public_numbers = key_obj.public_numbers()
569 else:
570 raise InvalidKeyError("Not a public or private key")
572 if isinstance(key_obj.curve, SECP256R1):
573 crv = "P-256"
574 elif isinstance(key_obj.curve, SECP384R1):
575 crv = "P-384"
576 elif isinstance(key_obj.curve, SECP521R1):
577 crv = "P-521"
578 elif isinstance(key_obj.curve, SECP256K1):
579 crv = "secp256k1"
580 else:
581 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
583 obj: dict[str, Any] = {
584 "kty": "EC",
585 "crv": crv,
586 "x": to_base64url_uint(
587 public_numbers.x,
588 bit_length=key_obj.curve.key_size,
589 ).decode(),
590 "y": to_base64url_uint(
591 public_numbers.y,
592 bit_length=key_obj.curve.key_size,
593 ).decode(),
594 }
596 if isinstance(key_obj, EllipticCurvePrivateKey):
597 obj["d"] = to_base64url_uint(
598 key_obj.private_numbers().private_value,
599 bit_length=key_obj.curve.key_size,
600 ).decode()
602 if as_dict:
603 return obj
604 else:
605 return json.dumps(obj)
607 @staticmethod
608 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
609 try:
610 if isinstance(jwk, str):
611 obj = json.loads(jwk)
612 elif isinstance(jwk, dict):
613 obj = jwk
614 else:
615 raise ValueError
616 except ValueError:
617 raise InvalidKeyError("Key is not valid JSON") from None
619 if obj.get("kty") != "EC":
620 raise InvalidKeyError("Not an Elliptic curve key") from None
622 if "x" not in obj or "y" not in obj:
623 raise InvalidKeyError("Not an Elliptic curve key") from None
625 x = base64url_decode(obj.get("x"))
626 y = base64url_decode(obj.get("y"))
628 curve = obj.get("crv")
629 curve_obj: EllipticCurve
631 if curve == "P-256":
632 if len(x) == len(y) == 32:
633 curve_obj = SECP256R1()
634 else:
635 raise InvalidKeyError(
636 "Coords should be 32 bytes for curve P-256"
637 ) from None
638 elif curve == "P-384":
639 if len(x) == len(y) == 48:
640 curve_obj = SECP384R1()
641 else:
642 raise InvalidKeyError(
643 "Coords should be 48 bytes for curve P-384"
644 ) from None
645 elif curve == "P-521":
646 if len(x) == len(y) == 66:
647 curve_obj = SECP521R1()
648 else:
649 raise InvalidKeyError(
650 "Coords should be 66 bytes for curve P-521"
651 ) from None
652 elif curve == "secp256k1":
653 if len(x) == len(y) == 32:
654 curve_obj = SECP256K1()
655 else:
656 raise InvalidKeyError(
657 "Coords should be 32 bytes for curve secp256k1"
658 )
659 else:
660 raise InvalidKeyError(f"Invalid curve: {curve}")
662 public_numbers = EllipticCurvePublicNumbers(
663 x=int.from_bytes(x, byteorder="big"),
664 y=int.from_bytes(y, byteorder="big"),
665 curve=curve_obj,
666 )
668 if "d" not in obj:
669 return public_numbers.public_key()
671 d = base64url_decode(obj.get("d"))
672 if len(d) != len(x):
673 raise InvalidKeyError(
674 "D should be {} bytes for curve {}", len(x), curve
675 )
677 return EllipticCurvePrivateNumbers(
678 int.from_bytes(d, byteorder="big"), public_numbers
679 ).private_key()
681 class RSAPSSAlgorithm(RSAAlgorithm):
682 """
683 Performs a signature using RSASSA-PSS with MGF1
684 """
686 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
687 return key.sign(
688 msg,
689 padding.PSS(
690 mgf=padding.MGF1(self.hash_alg()),
691 salt_length=self.hash_alg().digest_size,
692 ),
693 self.hash_alg(),
694 )
696 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
697 try:
698 key.verify(
699 sig,
700 msg,
701 padding.PSS(
702 mgf=padding.MGF1(self.hash_alg()),
703 salt_length=self.hash_alg().digest_size,
704 ),
705 self.hash_alg(),
706 )
707 return True
708 except InvalidSignature:
709 return False
711 class OKPAlgorithm(Algorithm):
712 """
713 Performs signing and verification operations using EdDSA
715 This class requires ``cryptography>=2.6`` to be installed.
716 """
718 def __init__(self, **kwargs: Any) -> None:
719 pass
721 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
722 if isinstance(key, (bytes, str)):
723 key_str = key.decode("utf-8") if isinstance(key, bytes) else key
724 key_bytes = key.encode("utf-8") if isinstance(key, str) else key
726 if "-----BEGIN PUBLIC" in key_str:
727 key = load_pem_public_key(key_bytes) # type: ignore[assignment]
728 elif "-----BEGIN PRIVATE" in key_str:
729 key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
730 elif key_str[0:4] == "ssh-":
731 key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
733 # Explicit check the key to prevent confusing errors from cryptography
734 if not isinstance(
735 key,
736 (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
737 ):
738 raise InvalidKeyError(
739 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
740 )
742 return key
744 def sign(
745 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
746 ) -> bytes:
747 """
748 Sign a message ``msg`` using the EdDSA private key ``key``
749 :param str|bytes msg: Message to sign
750 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
751 or :class:`.Ed448PrivateKey` isinstance
752 :return bytes signature: The signature, as bytes
753 """
754 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
755 return key.sign(msg_bytes)
757 def verify(
758 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
759 ) -> bool:
760 """
761 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
763 :param str|bytes sig: EdDSA signature to check ``msg`` against
764 :param str|bytes msg: Message to sign
765 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
766 A private or public EdDSA key instance
767 :return bool verified: True if signature is valid, False if not.
768 """
769 try:
770 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
771 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
773 public_key = (
774 key.public_key()
775 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
776 else key
777 )
778 public_key.verify(sig_bytes, msg_bytes)
779 return True # If no exception was raised, the signature is valid.
780 except InvalidSignature:
781 return False
783 @overload
784 @staticmethod
785 def to_jwk(
786 key: AllowedOKPKeys, as_dict: Literal[True]
787 ) -> JWKDict: ... # pragma: no cover
789 @overload
790 @staticmethod
791 def to_jwk(
792 key: AllowedOKPKeys, as_dict: Literal[False] = False
793 ) -> str: ... # pragma: no cover
795 @staticmethod
796 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
797 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
798 x = key.public_bytes(
799 encoding=Encoding.Raw,
800 format=PublicFormat.Raw,
801 )
802 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
804 obj = {
805 "x": base64url_encode(force_bytes(x)).decode(),
806 "kty": "OKP",
807 "crv": crv,
808 }
810 if as_dict:
811 return obj
812 else:
813 return json.dumps(obj)
815 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
816 d = key.private_bytes(
817 encoding=Encoding.Raw,
818 format=PrivateFormat.Raw,
819 encryption_algorithm=NoEncryption(),
820 )
822 x = key.public_key().public_bytes(
823 encoding=Encoding.Raw,
824 format=PublicFormat.Raw,
825 )
827 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
828 obj = {
829 "x": base64url_encode(force_bytes(x)).decode(),
830 "d": base64url_encode(force_bytes(d)).decode(),
831 "kty": "OKP",
832 "crv": crv,
833 }
835 if as_dict:
836 return obj
837 else:
838 return json.dumps(obj)
840 raise InvalidKeyError("Not a public or private key")
842 @staticmethod
843 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
844 try:
845 if isinstance(jwk, str):
846 obj = json.loads(jwk)
847 elif isinstance(jwk, dict):
848 obj = jwk
849 else:
850 raise ValueError
851 except ValueError:
852 raise InvalidKeyError("Key is not valid JSON") from None
854 if obj.get("kty") != "OKP":
855 raise InvalidKeyError("Not an Octet Key Pair")
857 curve = obj.get("crv")
858 if curve != "Ed25519" and curve != "Ed448":
859 raise InvalidKeyError(f"Invalid curve: {curve}")
861 if "x" not in obj:
862 raise InvalidKeyError('OKP should have "x" parameter')
863 x = base64url_decode(obj.get("x"))
865 try:
866 if "d" not in obj:
867 if curve == "Ed25519":
868 return Ed25519PublicKey.from_public_bytes(x)
869 return Ed448PublicKey.from_public_bytes(x)
870 d = base64url_decode(obj.get("d"))
871 if curve == "Ed25519":
872 return Ed25519PrivateKey.from_private_bytes(d)
873 return Ed448PrivateKey.from_private_bytes(d)
874 except ValueError as err:
875 raise InvalidKeyError("Invalid key parameter") from err