Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jwt/algorithms.py: 32%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import hashlib
4import hmac
5import json
6from abc import ABC, abstractmethod
7from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
9from .exceptions import InvalidKeyError
10from .types import HashlibHash, JWKDict
11from .utils import (
12 base64url_decode,
13 base64url_encode,
14 der_to_raw_signature,
15 force_bytes,
16 from_base64url_uint,
17 is_pem_format,
18 is_ssh_key,
19 raw_to_der_signature,
20 to_base64url_uint,
21)
23try:
24 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
25 from cryptography.hazmat.backends import default_backend
26 from cryptography.hazmat.primitives import hashes
27 from cryptography.hazmat.primitives.asymmetric import padding
28 from cryptography.hazmat.primitives.asymmetric.ec import (
29 ECDSA,
30 SECP256K1,
31 SECP256R1,
32 SECP384R1,
33 SECP521R1,
34 EllipticCurve,
35 EllipticCurvePrivateKey,
36 EllipticCurvePrivateNumbers,
37 EllipticCurvePublicKey,
38 EllipticCurvePublicNumbers,
39 )
40 from cryptography.hazmat.primitives.asymmetric.ed448 import (
41 Ed448PrivateKey,
42 Ed448PublicKey,
43 )
44 from cryptography.hazmat.primitives.asymmetric.ed25519 import (
45 Ed25519PrivateKey,
46 Ed25519PublicKey,
47 )
48 from cryptography.hazmat.primitives.asymmetric.rsa import (
49 RSAPrivateKey,
50 RSAPrivateNumbers,
51 RSAPublicKey,
52 RSAPublicNumbers,
53 rsa_crt_dmp1,
54 rsa_crt_dmq1,
55 rsa_crt_iqmp,
56 rsa_recover_prime_factors,
57 )
58 from cryptography.hazmat.primitives.serialization import (
59 Encoding,
60 NoEncryption,
61 PrivateFormat,
62 PublicFormat,
63 load_pem_private_key,
64 load_pem_public_key,
65 load_ssh_public_key,
66 )
68 has_crypto = True
69except ModuleNotFoundError:
70 has_crypto = False
73if TYPE_CHECKING:
74 # Type aliases for convenience in algorithms method signatures
75 AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
76 AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
77 AllowedOKPKeys = (
78 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
79 )
80 AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
81 AllowedPrivateKeys = (
82 RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
83 )
84 AllowedPublicKeys = (
85 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
86 )
89requires_cryptography = {
90 "RS256",
91 "RS384",
92 "RS512",
93 "ES256",
94 "ES256K",
95 "ES384",
96 "ES521",
97 "ES512",
98 "PS256",
99 "PS384",
100 "PS512",
101 "EdDSA",
102}
105def get_default_algorithms() -> dict[str, Algorithm]:
106 """
107 Returns the algorithms that are implemented by the library.
108 """
109 default_algorithms = {
110 "none": NoneAlgorithm(),
111 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
112 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
113 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
114 }
116 if has_crypto:
117 default_algorithms.update(
118 {
119 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
120 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
121 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
122 "ES256": ECAlgorithm(ECAlgorithm.SHA256),
123 "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
124 "ES384": ECAlgorithm(ECAlgorithm.SHA384),
125 "ES521": ECAlgorithm(ECAlgorithm.SHA512),
126 "ES512": ECAlgorithm(
127 ECAlgorithm.SHA512
128 ), # Backward compat for #219 fix
129 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
130 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
131 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
132 "EdDSA": OKPAlgorithm(),
133 }
134 )
136 return default_algorithms
139class Algorithm(ABC):
140 """
141 The interface for an algorithm used to sign and verify tokens.
142 """
144 def compute_hash_digest(self, bytestr: bytes) -> bytes:
145 """
146 Compute a hash digest using the specified algorithm's hash algorithm.
148 If there is no hash algorithm, raises a NotImplementedError.
149 """
150 # lookup self.hash_alg if defined in a way that mypy can understand
151 hash_alg = getattr(self, "hash_alg", None)
152 if hash_alg is None:
153 raise NotImplementedError
155 if (
156 has_crypto
157 and isinstance(hash_alg, type)
158 and issubclass(hash_alg, hashes.HashAlgorithm)
159 ):
160 digest = hashes.Hash(hash_alg(), backend=default_backend())
161 digest.update(bytestr)
162 return bytes(digest.finalize())
163 else:
164 return bytes(hash_alg(bytestr).digest())
166 @abstractmethod
167 def prepare_key(self, key: Any) -> Any:
168 """
169 Performs necessary validation and conversions on the key and returns
170 the key value in the proper format for sign() and verify().
171 """
173 @abstractmethod
174 def sign(self, msg: bytes, key: Any) -> bytes:
175 """
176 Returns a digital signature for the specified message
177 using the specified key value.
178 """
180 @abstractmethod
181 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
182 """
183 Verifies that the specified digital signature is valid
184 for the specified message and key values.
185 """
187 @overload
188 @staticmethod
189 @abstractmethod
190 def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
192 @overload
193 @staticmethod
194 @abstractmethod
195 def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
197 @staticmethod
198 @abstractmethod
199 def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
200 """
201 Serializes a given key into a JWK
202 """
204 @staticmethod
205 @abstractmethod
206 def from_jwk(jwk: str | JWKDict) -> Any:
207 """
208 Deserializes a given key from JWK back into a key object
209 """
212class NoneAlgorithm(Algorithm):
213 """
214 Placeholder for use when no signing or verification
215 operations are required.
216 """
218 def prepare_key(self, key: str | None) -> None:
219 if key == "":
220 key = None
222 if key is not None:
223 raise InvalidKeyError('When alg = "none", key value must be None.')
225 return key
227 def sign(self, msg: bytes, key: None) -> bytes:
228 return b""
230 def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
231 return False
233 @staticmethod
234 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
235 raise NotImplementedError()
237 @staticmethod
238 def from_jwk(jwk: str | JWKDict) -> NoReturn:
239 raise NotImplementedError()
242class HMACAlgorithm(Algorithm):
243 """
244 Performs signing and verification operations using HMAC
245 and the specified hash function.
246 """
248 SHA256: ClassVar[HashlibHash] = hashlib.sha256
249 SHA384: ClassVar[HashlibHash] = hashlib.sha384
250 SHA512: ClassVar[HashlibHash] = hashlib.sha512
252 def __init__(self, hash_alg: HashlibHash) -> None:
253 self.hash_alg = hash_alg
255 def prepare_key(self, key: str | bytes) -> bytes:
256 key_bytes = force_bytes(key)
258 if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
259 raise InvalidKeyError(
260 "The specified key is an asymmetric key or x509 certificate and"
261 " should not be used as an HMAC secret."
262 )
264 return key_bytes
266 @overload
267 @staticmethod
268 def to_jwk(
269 key_obj: str | bytes, as_dict: Literal[True]
270 ) -> JWKDict: ... # pragma: no cover
272 @overload
273 @staticmethod
274 def to_jwk(
275 key_obj: str | bytes, as_dict: Literal[False] = False
276 ) -> str: ... # pragma: no cover
278 @staticmethod
279 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
280 jwk = {
281 "k": base64url_encode(force_bytes(key_obj)).decode(),
282 "kty": "oct",
283 }
285 if as_dict:
286 return jwk
287 else:
288 return json.dumps(jwk)
290 @staticmethod
291 def from_jwk(jwk: str | JWKDict) -> bytes:
292 try:
293 if isinstance(jwk, str):
294 obj: JWKDict = json.loads(jwk)
295 elif isinstance(jwk, dict):
296 obj = jwk
297 else:
298 raise ValueError
299 except ValueError:
300 raise InvalidKeyError("Key is not valid JSON")
302 if obj.get("kty") != "oct":
303 raise InvalidKeyError("Not an HMAC key")
305 return base64url_decode(obj["k"])
307 def sign(self, msg: bytes, key: bytes) -> bytes:
308 return hmac.new(key, msg, self.hash_alg).digest()
310 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
311 return hmac.compare_digest(sig, self.sign(msg, key))
314if has_crypto:
316 class RSAAlgorithm(Algorithm):
317 """
318 Performs signing and verification operations using
319 RSASSA-PKCS-v1_5 and the specified hash function.
320 """
322 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
323 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
324 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
326 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
327 self.hash_alg = hash_alg
329 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
330 if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
331 return key
333 if not isinstance(key, (bytes, str)):
334 raise TypeError("Expecting a PEM-formatted key.")
336 key_bytes = force_bytes(key)
338 try:
339 if key_bytes.startswith(b"ssh-rsa"):
340 return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
341 else:
342 return cast(
343 RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
344 )
345 except ValueError:
346 try:
347 return cast(RSAPublicKey, load_pem_public_key(key_bytes))
348 except (ValueError, UnsupportedAlgorithm):
349 raise InvalidKeyError("Could not parse the provided public key.")
351 @overload
352 @staticmethod
353 def to_jwk(
354 key_obj: AllowedRSAKeys, as_dict: Literal[True]
355 ) -> JWKDict: ... # pragma: no cover
357 @overload
358 @staticmethod
359 def to_jwk(
360 key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
361 ) -> str: ... # pragma: no cover
363 @staticmethod
364 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
365 obj: dict[str, Any] | None = None
367 if hasattr(key_obj, "private_numbers"):
368 # Private key
369 numbers = key_obj.private_numbers()
371 obj = {
372 "kty": "RSA",
373 "key_ops": ["sign"],
374 "n": to_base64url_uint(numbers.public_numbers.n).decode(),
375 "e": to_base64url_uint(numbers.public_numbers.e).decode(),
376 "d": to_base64url_uint(numbers.d).decode(),
377 "p": to_base64url_uint(numbers.p).decode(),
378 "q": to_base64url_uint(numbers.q).decode(),
379 "dp": to_base64url_uint(numbers.dmp1).decode(),
380 "dq": to_base64url_uint(numbers.dmq1).decode(),
381 "qi": to_base64url_uint(numbers.iqmp).decode(),
382 }
384 elif hasattr(key_obj, "verify"):
385 # Public key
386 numbers = key_obj.public_numbers()
388 obj = {
389 "kty": "RSA",
390 "key_ops": ["verify"],
391 "n": to_base64url_uint(numbers.n).decode(),
392 "e": to_base64url_uint(numbers.e).decode(),
393 }
394 else:
395 raise InvalidKeyError("Not a public or private key")
397 if as_dict:
398 return obj
399 else:
400 return json.dumps(obj)
402 @staticmethod
403 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
404 try:
405 if isinstance(jwk, str):
406 obj = json.loads(jwk)
407 elif isinstance(jwk, dict):
408 obj = jwk
409 else:
410 raise ValueError
411 except ValueError:
412 raise InvalidKeyError("Key is not valid JSON")
414 if obj.get("kty") != "RSA":
415 raise InvalidKeyError("Not an RSA key")
417 if "d" in obj and "e" in obj and "n" in obj:
418 # Private key
419 if "oth" in obj:
420 raise InvalidKeyError(
421 "Unsupported RSA private key: > 2 primes not supported"
422 )
424 other_props = ["p", "q", "dp", "dq", "qi"]
425 props_found = [prop in obj for prop in other_props]
426 any_props_found = any(props_found)
428 if any_props_found and not all(props_found):
429 raise InvalidKeyError(
430 "RSA key must include all parameters if any are present besides d"
431 )
433 public_numbers = RSAPublicNumbers(
434 from_base64url_uint(obj["e"]),
435 from_base64url_uint(obj["n"]),
436 )
438 if any_props_found:
439 numbers = RSAPrivateNumbers(
440 d=from_base64url_uint(obj["d"]),
441 p=from_base64url_uint(obj["p"]),
442 q=from_base64url_uint(obj["q"]),
443 dmp1=from_base64url_uint(obj["dp"]),
444 dmq1=from_base64url_uint(obj["dq"]),
445 iqmp=from_base64url_uint(obj["qi"]),
446 public_numbers=public_numbers,
447 )
448 else:
449 d = from_base64url_uint(obj["d"])
450 p, q = rsa_recover_prime_factors(
451 public_numbers.n, d, public_numbers.e
452 )
454 numbers = RSAPrivateNumbers(
455 d=d,
456 p=p,
457 q=q,
458 dmp1=rsa_crt_dmp1(d, p),
459 dmq1=rsa_crt_dmq1(d, q),
460 iqmp=rsa_crt_iqmp(p, q),
461 public_numbers=public_numbers,
462 )
464 return numbers.private_key()
465 elif "n" in obj and "e" in obj:
466 # Public key
467 return RSAPublicNumbers(
468 from_base64url_uint(obj["e"]),
469 from_base64url_uint(obj["n"]),
470 ).public_key()
471 else:
472 raise InvalidKeyError("Not a public or private key")
474 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
475 return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
477 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
478 try:
479 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
480 return True
481 except InvalidSignature:
482 return False
484 class ECAlgorithm(Algorithm):
485 """
486 Performs signing and verification operations using
487 ECDSA and the specified hash function
488 """
490 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
491 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
492 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
494 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
495 self.hash_alg = hash_alg
497 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
498 if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
499 return key
501 if not isinstance(key, (bytes, str)):
502 raise TypeError("Expecting a PEM-formatted key.")
504 key_bytes = force_bytes(key)
506 # Attempt to load key. We don't know if it's
507 # a Signing Key or a Verifying Key, so we try
508 # the Verifying Key first.
509 try:
510 if key_bytes.startswith(b"ecdsa-sha2-"):
511 crypto_key = load_ssh_public_key(key_bytes)
512 else:
513 crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
514 except ValueError:
515 crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
517 # Explicit check the key to prevent confusing errors from cryptography
518 if not isinstance(
519 crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
520 ):
521 raise InvalidKeyError(
522 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
523 )
525 return crypto_key
527 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
528 der_sig = key.sign(msg, ECDSA(self.hash_alg()))
530 return der_to_raw_signature(der_sig, key.curve)
532 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
533 try:
534 der_sig = raw_to_der_signature(sig, key.curve)
535 except ValueError:
536 return False
538 try:
539 public_key = (
540 key.public_key()
541 if isinstance(key, EllipticCurvePrivateKey)
542 else key
543 )
544 public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
545 return True
546 except InvalidSignature:
547 return False
549 @overload
550 @staticmethod
551 def to_jwk(
552 key_obj: AllowedECKeys, as_dict: Literal[True]
553 ) -> JWKDict: ... # pragma: no cover
555 @overload
556 @staticmethod
557 def to_jwk(
558 key_obj: AllowedECKeys, as_dict: Literal[False] = False
559 ) -> str: ... # pragma: no cover
561 @staticmethod
562 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
563 if isinstance(key_obj, EllipticCurvePrivateKey):
564 public_numbers = key_obj.public_key().public_numbers()
565 elif isinstance(key_obj, EllipticCurvePublicKey):
566 public_numbers = key_obj.public_numbers()
567 else:
568 raise InvalidKeyError("Not a public or private key")
570 if isinstance(key_obj.curve, SECP256R1):
571 crv = "P-256"
572 elif isinstance(key_obj.curve, SECP384R1):
573 crv = "P-384"
574 elif isinstance(key_obj.curve, SECP521R1):
575 crv = "P-521"
576 elif isinstance(key_obj.curve, SECP256K1):
577 crv = "secp256k1"
578 else:
579 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
581 obj: dict[str, Any] = {
582 "kty": "EC",
583 "crv": crv,
584 "x": to_base64url_uint(public_numbers.x).decode(),
585 "y": to_base64url_uint(public_numbers.y).decode(),
586 }
588 if isinstance(key_obj, EllipticCurvePrivateKey):
589 obj["d"] = to_base64url_uint(
590 key_obj.private_numbers().private_value
591 ).decode()
593 if as_dict:
594 return obj
595 else:
596 return json.dumps(obj)
598 @staticmethod
599 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
600 try:
601 if isinstance(jwk, str):
602 obj = json.loads(jwk)
603 elif isinstance(jwk, dict):
604 obj = jwk
605 else:
606 raise ValueError
607 except ValueError:
608 raise InvalidKeyError("Key is not valid JSON")
610 if obj.get("kty") != "EC":
611 raise InvalidKeyError("Not an Elliptic curve key")
613 if "x" not in obj or "y" not in obj:
614 raise InvalidKeyError("Not an Elliptic curve key")
616 x = base64url_decode(obj.get("x"))
617 y = base64url_decode(obj.get("y"))
619 curve = obj.get("crv")
620 curve_obj: EllipticCurve
622 if curve == "P-256":
623 if len(x) == len(y) == 32:
624 curve_obj = SECP256R1()
625 else:
626 raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
627 elif curve == "P-384":
628 if len(x) == len(y) == 48:
629 curve_obj = SECP384R1()
630 else:
631 raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
632 elif curve == "P-521":
633 if len(x) == len(y) == 66:
634 curve_obj = SECP521R1()
635 else:
636 raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
637 elif curve == "secp256k1":
638 if len(x) == len(y) == 32:
639 curve_obj = SECP256K1()
640 else:
641 raise InvalidKeyError(
642 "Coords should be 32 bytes for curve secp256k1"
643 )
644 else:
645 raise InvalidKeyError(f"Invalid curve: {curve}")
647 public_numbers = EllipticCurvePublicNumbers(
648 x=int.from_bytes(x, byteorder="big"),
649 y=int.from_bytes(y, byteorder="big"),
650 curve=curve_obj,
651 )
653 if "d" not in obj:
654 return public_numbers.public_key()
656 d = base64url_decode(obj.get("d"))
657 if len(d) != len(x):
658 raise InvalidKeyError(
659 "D should be {} bytes for curve {}", len(x), curve
660 )
662 return EllipticCurvePrivateNumbers(
663 int.from_bytes(d, byteorder="big"), public_numbers
664 ).private_key()
666 class RSAPSSAlgorithm(RSAAlgorithm):
667 """
668 Performs a signature using RSASSA-PSS with MGF1
669 """
671 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
672 return key.sign(
673 msg,
674 padding.PSS(
675 mgf=padding.MGF1(self.hash_alg()),
676 salt_length=self.hash_alg().digest_size,
677 ),
678 self.hash_alg(),
679 )
681 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
682 try:
683 key.verify(
684 sig,
685 msg,
686 padding.PSS(
687 mgf=padding.MGF1(self.hash_alg()),
688 salt_length=self.hash_alg().digest_size,
689 ),
690 self.hash_alg(),
691 )
692 return True
693 except InvalidSignature:
694 return False
696 class OKPAlgorithm(Algorithm):
697 """
698 Performs signing and verification operations using EdDSA
700 This class requires ``cryptography>=2.6`` to be installed.
701 """
703 def __init__(self, **kwargs: Any) -> None:
704 pass
706 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
707 if isinstance(key, (bytes, str)):
708 key_str = key.decode("utf-8") if isinstance(key, bytes) else key
709 key_bytes = key.encode("utf-8") if isinstance(key, str) else key
711 if "-----BEGIN PUBLIC" in key_str:
712 key = load_pem_public_key(key_bytes) # type: ignore[assignment]
713 elif "-----BEGIN PRIVATE" in key_str:
714 key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
715 elif key_str[0:4] == "ssh-":
716 key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
718 # Explicit check the key to prevent confusing errors from cryptography
719 if not isinstance(
720 key,
721 (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
722 ):
723 raise InvalidKeyError(
724 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
725 )
727 return key
729 def sign(
730 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
731 ) -> bytes:
732 """
733 Sign a message ``msg`` using the EdDSA private key ``key``
734 :param str|bytes msg: Message to sign
735 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
736 or :class:`.Ed448PrivateKey` isinstance
737 :return bytes signature: The signature, as bytes
738 """
739 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
740 return key.sign(msg_bytes)
742 def verify(
743 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
744 ) -> bool:
745 """
746 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
748 :param str|bytes sig: EdDSA signature to check ``msg`` against
749 :param str|bytes msg: Message to sign
750 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
751 A private or public EdDSA key instance
752 :return bool verified: True if signature is valid, False if not.
753 """
754 try:
755 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
756 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
758 public_key = (
759 key.public_key()
760 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
761 else key
762 )
763 public_key.verify(sig_bytes, msg_bytes)
764 return True # If no exception was raised, the signature is valid.
765 except InvalidSignature:
766 return False
768 @overload
769 @staticmethod
770 def to_jwk(
771 key: AllowedOKPKeys, as_dict: Literal[True]
772 ) -> JWKDict: ... # pragma: no cover
774 @overload
775 @staticmethod
776 def to_jwk(
777 key: AllowedOKPKeys, as_dict: Literal[False] = False
778 ) -> str: ... # pragma: no cover
780 @staticmethod
781 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
782 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
783 x = key.public_bytes(
784 encoding=Encoding.Raw,
785 format=PublicFormat.Raw,
786 )
787 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
789 obj = {
790 "x": base64url_encode(force_bytes(x)).decode(),
791 "kty": "OKP",
792 "crv": crv,
793 }
795 if as_dict:
796 return obj
797 else:
798 return json.dumps(obj)
800 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
801 d = key.private_bytes(
802 encoding=Encoding.Raw,
803 format=PrivateFormat.Raw,
804 encryption_algorithm=NoEncryption(),
805 )
807 x = key.public_key().public_bytes(
808 encoding=Encoding.Raw,
809 format=PublicFormat.Raw,
810 )
812 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
813 obj = {
814 "x": base64url_encode(force_bytes(x)).decode(),
815 "d": base64url_encode(force_bytes(d)).decode(),
816 "kty": "OKP",
817 "crv": crv,
818 }
820 if as_dict:
821 return obj
822 else:
823 return json.dumps(obj)
825 raise InvalidKeyError("Not a public or private key")
827 @staticmethod
828 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
829 try:
830 if isinstance(jwk, str):
831 obj = json.loads(jwk)
832 elif isinstance(jwk, dict):
833 obj = jwk
834 else:
835 raise ValueError
836 except ValueError:
837 raise InvalidKeyError("Key is not valid JSON")
839 if obj.get("kty") != "OKP":
840 raise InvalidKeyError("Not an Octet Key Pair")
842 curve = obj.get("crv")
843 if curve != "Ed25519" and curve != "Ed448":
844 raise InvalidKeyError(f"Invalid curve: {curve}")
846 if "x" not in obj:
847 raise InvalidKeyError('OKP should have "x" parameter')
848 x = base64url_decode(obj.get("x"))
850 try:
851 if "d" not in obj:
852 if curve == "Ed25519":
853 return Ed25519PublicKey.from_public_bytes(x)
854 return Ed448PublicKey.from_public_bytes(x)
855 d = base64url_decode(obj.get("d"))
856 if curve == "Ed25519":
857 return Ed25519PrivateKey.from_private_bytes(d)
858 return Ed448PrivateKey.from_private_bytes(d)
859 except ValueError as err:
860 raise InvalidKeyError("Invalid key parameter") from err