Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

415 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import hmac 

5import json 

6import os 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload 

9 

10from .exceptions import InvalidKeyError 

11from .types import HashlibHash, JWKDict 

12from .utils import ( 

13 base64url_decode, 

14 base64url_encode, 

15 der_to_raw_signature, 

16 force_bytes, 

17 from_base64url_uint, 

18 is_pem_format, 

19 is_ssh_key, 

20 raw_to_der_signature, 

21 to_base64url_uint, 

22) 

23 

24try: 

25 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

26 from cryptography.hazmat.backends import default_backend 

27 from cryptography.hazmat.primitives import hashes 

28 from cryptography.hazmat.primitives.asymmetric import padding 

29 from cryptography.hazmat.primitives.asymmetric.ec import ( 

30 ECDSA, 

31 SECP256K1, 

32 SECP256R1, 

33 SECP384R1, 

34 SECP521R1, 

35 EllipticCurve, 

36 EllipticCurvePrivateKey, 

37 EllipticCurvePrivateNumbers, 

38 EllipticCurvePublicKey, 

39 EllipticCurvePublicNumbers, 

40 ) 

41 from cryptography.hazmat.primitives.asymmetric.ed448 import ( 

42 Ed448PrivateKey, 

43 Ed448PublicKey, 

44 ) 

45 from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

46 Ed25519PrivateKey, 

47 Ed25519PublicKey, 

48 ) 

49 from cryptography.hazmat.primitives.asymmetric.rsa import ( 

50 RSAPrivateKey, 

51 RSAPrivateNumbers, 

52 RSAPublicKey, 

53 RSAPublicNumbers, 

54 rsa_crt_dmp1, 

55 rsa_crt_dmq1, 

56 rsa_crt_iqmp, 

57 rsa_recover_prime_factors, 

58 ) 

59 from cryptography.hazmat.primitives.serialization import ( 

60 Encoding, 

61 NoEncryption, 

62 PrivateFormat, 

63 PublicFormat, 

64 load_pem_private_key, 

65 load_pem_public_key, 

66 load_ssh_public_key, 

67 ) 

68 

69 # pyjwt-964: we use these both for type checking below, as well as for validating the key passed in. 

70 # in Py >= 3.10, we can replace this with the Union types below 

71 ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey) 

72 ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey) 

73 ALLOWED_OKP_KEY_TYPES = ( 

74 Ed25519PrivateKey, 

75 Ed25519PublicKey, 

76 Ed448PrivateKey, 

77 Ed448PublicKey, 

78 ) 

79 ALLOWED_KEY_TYPES = ( 

80 ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES 

81 ) 

82 ALLOWED_PRIVATE_KEY_TYPES = ( 

83 RSAPrivateKey, 

84 EllipticCurvePrivateKey, 

85 Ed25519PrivateKey, 

86 Ed448PrivateKey, 

87 ) 

88 ALLOWED_PUBLIC_KEY_TYPES = ( 

89 RSAPublicKey, 

90 EllipticCurvePublicKey, 

91 Ed25519PublicKey, 

92 Ed448PublicKey, 

93 ) 

94 

95 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")): 

96 from typing import TypeAlias 

97 

98 from cryptography.hazmat.primitives.asymmetric.types import ( 

99 PrivateKeyTypes, 

100 PublicKeyTypes, 

101 ) 

102 

103 # Type aliases for convenience in algorithms method signatures 

104 AllowedRSAKeys: TypeAlias = RSAPrivateKey | RSAPublicKey 

105 AllowedECKeys: TypeAlias = EllipticCurvePrivateKey | EllipticCurvePublicKey 

106 AllowedOKPKeys: TypeAlias = ( 

107 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey 

108 ) 

109 AllowedKeys: TypeAlias = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys 

110 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed) 

111 AllowedPrivateKeys: TypeAlias = ( 

112 RSAPrivateKey 

113 | EllipticCurvePrivateKey 

114 | Ed25519PrivateKey 

115 | Ed448PrivateKey 

116 ) 

117 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed) 

118 AllowedPublicKeys: TypeAlias = ( 

119 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey 

120 ) 

121 

122 has_crypto = True 

123except ModuleNotFoundError: 

124 has_crypto = False 

125 

126 

127requires_cryptography = { 

128 "RS256", 

129 "RS384", 

130 "RS512", 

131 "ES256", 

132 "ES256K", 

133 "ES384", 

134 "ES521", 

135 "ES512", 

136 "PS256", 

137 "PS384", 

138 "PS512", 

139 "EdDSA", 

140} 

141 

142 

143def get_default_algorithms() -> dict[str, Algorithm]: 

144 """ 

145 Returns the algorithms that are implemented by the library. 

146 """ 

147 default_algorithms: dict[str, Algorithm] = { 

148 "none": NoneAlgorithm(), 

149 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), 

150 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), 

151 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), 

152 } 

153 

154 if has_crypto: 

155 default_algorithms.update( 

156 { 

157 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), 

158 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), 

159 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), 

160 "ES256": ECAlgorithm(ECAlgorithm.SHA256), 

161 "ES256K": ECAlgorithm(ECAlgorithm.SHA256), 

162 "ES384": ECAlgorithm(ECAlgorithm.SHA384), 

163 "ES521": ECAlgorithm(ECAlgorithm.SHA512), 

164 "ES512": ECAlgorithm( 

165 ECAlgorithm.SHA512 

166 ), # Backward compat for #219 fix 

167 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), 

168 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), 

169 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), 

170 "EdDSA": OKPAlgorithm(), 

171 } 

172 ) 

173 

174 return default_algorithms 

175 

176 

177class Algorithm(ABC): 

178 """ 

179 The interface for an algorithm used to sign and verify tokens. 

180 """ 

181 

182 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family 

183 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None 

184 

185 def compute_hash_digest(self, bytestr: bytes) -> bytes: 

186 """ 

187 Compute a hash digest using the specified algorithm's hash algorithm. 

188 

189 If there is no hash algorithm, raises a NotImplementedError. 

190 """ 

191 # lookup self.hash_alg if defined in a way that mypy can understand 

192 hash_alg = getattr(self, "hash_alg", None) 

193 if hash_alg is None: 

194 raise NotImplementedError 

195 

196 if ( 

197 has_crypto 

198 and isinstance(hash_alg, type) 

199 and issubclass(hash_alg, hashes.HashAlgorithm) 

200 ): 

201 digest = hashes.Hash(hash_alg(), backend=default_backend()) 

202 digest.update(bytestr) 

203 return bytes(digest.finalize()) 

204 else: 

205 return bytes(hash_alg(bytestr).digest()) 

206 

207 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes): 

208 """Check that the key belongs to the right cryptographic family. 

209 

210 Note that this method only works when ``cryptography`` is installed. 

211 

212 :param key: Potentially a cryptography key 

213 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>` 

214 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm 

215 :raises InvalidKeyError: if the key doesn't match the expected key classes 

216 """ 

217 if not has_crypto or self._crypto_key_types is None: 

218 raise ValueError( 

219 "This method requires the cryptography library, and should only be used by cryptography-based algorithms." 

220 ) 

221 

222 if not isinstance(key, self._crypto_key_types): 

223 valid_classes = (cls.__name__ for cls in self._crypto_key_types) 

224 actual_class = key.__class__.__name__ 

225 self_class = self.__class__.__name__ 

226 raise InvalidKeyError( 

227 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}" 

228 ) 

229 

230 @abstractmethod 

231 def prepare_key(self, key: Any) -> Any: 

232 """ 

233 Performs necessary validation and conversions on the key and returns 

234 the key value in the proper format for sign() and verify(). 

235 """ 

236 

237 @abstractmethod 

238 def sign(self, msg: bytes, key: Any) -> bytes: 

239 """ 

240 Returns a digital signature for the specified message 

241 using the specified key value. 

242 """ 

243 

244 @abstractmethod 

245 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: 

246 """ 

247 Verifies that the specified digital signature is valid 

248 for the specified message and key values. 

249 """ 

250 

251 @overload 

252 @staticmethod 

253 @abstractmethod 

254 def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover 

255 

256 @overload 

257 @staticmethod 

258 @abstractmethod 

259 def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover 

260 

261 @staticmethod 

262 @abstractmethod 

263 def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str: 

264 """ 

265 Serializes a given key into a JWK 

266 """ 

267 

268 @staticmethod 

269 @abstractmethod 

270 def from_jwk(jwk: str | JWKDict) -> Any: 

271 """ 

272 Deserializes a given key from JWK back into a key object 

273 """ 

274 

275 

276class NoneAlgorithm(Algorithm): 

277 """ 

278 Placeholder for use when no signing or verification 

279 operations are required. 

280 """ 

281 

282 def prepare_key(self, key: str | None) -> None: 

283 if key == "": 

284 key = None 

285 

286 if key is not None: 

287 raise InvalidKeyError('When alg = "none", key value must be None.') 

288 

289 return key 

290 

291 def sign(self, msg: bytes, key: None) -> bytes: 

292 return b"" 

293 

294 def verify(self, msg: bytes, key: None, sig: bytes) -> bool: 

295 return False 

296 

297 @staticmethod 

298 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: 

299 raise NotImplementedError() 

300 

301 @staticmethod 

302 def from_jwk(jwk: str | JWKDict) -> NoReturn: 

303 raise NotImplementedError() 

304 

305 

306class HMACAlgorithm(Algorithm): 

307 """ 

308 Performs signing and verification operations using HMAC 

309 and the specified hash function. 

310 """ 

311 

312 SHA256: ClassVar[HashlibHash] = hashlib.sha256 

313 SHA384: ClassVar[HashlibHash] = hashlib.sha384 

314 SHA512: ClassVar[HashlibHash] = hashlib.sha512 

315 

316 def __init__(self, hash_alg: HashlibHash) -> None: 

317 self.hash_alg = hash_alg 

318 

319 def prepare_key(self, key: str | bytes) -> bytes: 

320 key_bytes = force_bytes(key) 

321 

322 if is_pem_format(key_bytes) or is_ssh_key(key_bytes): 

323 raise InvalidKeyError( 

324 "The specified key is an asymmetric key or x509 certificate and" 

325 " should not be used as an HMAC secret." 

326 ) 

327 

328 return key_bytes 

329 

330 @overload 

331 @staticmethod 

332 def to_jwk( 

333 key_obj: str | bytes, as_dict: Literal[True] 

334 ) -> JWKDict: ... # pragma: no cover 

335 

336 @overload 

337 @staticmethod 

338 def to_jwk( 

339 key_obj: str | bytes, as_dict: Literal[False] = False 

340 ) -> str: ... # pragma: no cover 

341 

342 @staticmethod 

343 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: 

344 jwk = { 

345 "k": base64url_encode(force_bytes(key_obj)).decode(), 

346 "kty": "oct", 

347 } 

348 

349 if as_dict: 

350 return jwk 

351 else: 

352 return json.dumps(jwk) 

353 

354 @staticmethod 

355 def from_jwk(jwk: str | JWKDict) -> bytes: 

356 try: 

357 if isinstance(jwk, str): 

358 obj: JWKDict = json.loads(jwk) 

359 elif isinstance(jwk, dict): 

360 obj = jwk 

361 else: 

362 raise ValueError 

363 except ValueError: 

364 raise InvalidKeyError("Key is not valid JSON") from None 

365 

366 if obj.get("kty") != "oct": 

367 raise InvalidKeyError("Not an HMAC key") 

368 

369 return base64url_decode(obj["k"]) 

370 

371 def sign(self, msg: bytes, key: bytes) -> bytes: 

372 return hmac.new(key, msg, self.hash_alg).digest() 

373 

374 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: 

375 return hmac.compare_digest(sig, self.sign(msg, key)) 

376 

377 

378if has_crypto: 

379 

380 class RSAAlgorithm(Algorithm): 

381 """ 

382 Performs signing and verification operations using 

383 RSASSA-PKCS-v1_5 and the specified hash function. 

384 """ 

385 

386 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

387 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

388 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

389 

390 _crypto_key_types = ALLOWED_RSA_KEY_TYPES 

391 

392 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

393 self.hash_alg = hash_alg 

394 

395 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: 

396 if isinstance(key, self._crypto_key_types): 

397 return key 

398 

399 if not isinstance(key, (bytes, str)): 

400 raise TypeError("Expecting a PEM-formatted key.") 

401 

402 key_bytes = force_bytes(key) 

403 

404 try: 

405 if key_bytes.startswith(b"ssh-rsa"): 

406 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

407 self.check_crypto_key_type(public_key) 

408 return cast(RSAPublicKey, public_key) 

409 else: 

410 private_key: PrivateKeyTypes = load_pem_private_key( 

411 key_bytes, password=None 

412 ) 

413 self.check_crypto_key_type(private_key) 

414 return cast(RSAPrivateKey, private_key) 

415 except ValueError: 

416 try: 

417 public_key = load_pem_public_key(key_bytes) 

418 self.check_crypto_key_type(public_key) 

419 return cast(RSAPublicKey, public_key) 

420 except (ValueError, UnsupportedAlgorithm): 

421 raise InvalidKeyError( 

422 "Could not parse the provided public key." 

423 ) from None 

424 

425 @overload 

426 @staticmethod 

427 def to_jwk( 

428 key_obj: AllowedRSAKeys, as_dict: Literal[True] 

429 ) -> JWKDict: ... # pragma: no cover 

430 

431 @overload 

432 @staticmethod 

433 def to_jwk( 

434 key_obj: AllowedRSAKeys, as_dict: Literal[False] = False 

435 ) -> str: ... # pragma: no cover 

436 

437 @staticmethod 

438 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: 

439 obj: dict[str, Any] | None = None 

440 

441 if hasattr(key_obj, "private_numbers"): 

442 # Private key 

443 numbers = key_obj.private_numbers() 

444 

445 obj = { 

446 "kty": "RSA", 

447 "key_ops": ["sign"], 

448 "n": to_base64url_uint(numbers.public_numbers.n).decode(), 

449 "e": to_base64url_uint(numbers.public_numbers.e).decode(), 

450 "d": to_base64url_uint(numbers.d).decode(), 

451 "p": to_base64url_uint(numbers.p).decode(), 

452 "q": to_base64url_uint(numbers.q).decode(), 

453 "dp": to_base64url_uint(numbers.dmp1).decode(), 

454 "dq": to_base64url_uint(numbers.dmq1).decode(), 

455 "qi": to_base64url_uint(numbers.iqmp).decode(), 

456 } 

457 

458 elif hasattr(key_obj, "verify"): 

459 # Public key 

460 numbers = key_obj.public_numbers() 

461 

462 obj = { 

463 "kty": "RSA", 

464 "key_ops": ["verify"], 

465 "n": to_base64url_uint(numbers.n).decode(), 

466 "e": to_base64url_uint(numbers.e).decode(), 

467 } 

468 else: 

469 raise InvalidKeyError("Not a public or private key") 

470 

471 if as_dict: 

472 return obj 

473 else: 

474 return json.dumps(obj) 

475 

476 @staticmethod 

477 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: 

478 try: 

479 if isinstance(jwk, str): 

480 obj = json.loads(jwk) 

481 elif isinstance(jwk, dict): 

482 obj = jwk 

483 else: 

484 raise ValueError 

485 except ValueError: 

486 raise InvalidKeyError("Key is not valid JSON") from None 

487 

488 if obj.get("kty") != "RSA": 

489 raise InvalidKeyError("Not an RSA key") from None 

490 

491 if "d" in obj and "e" in obj and "n" in obj: 

492 # Private key 

493 if "oth" in obj: 

494 raise InvalidKeyError( 

495 "Unsupported RSA private key: > 2 primes not supported" 

496 ) 

497 

498 other_props = ["p", "q", "dp", "dq", "qi"] 

499 props_found = [prop in obj for prop in other_props] 

500 any_props_found = any(props_found) 

501 

502 if any_props_found and not all(props_found): 

503 raise InvalidKeyError( 

504 "RSA key must include all parameters if any are present besides d" 

505 ) from None 

506 

507 public_numbers = RSAPublicNumbers( 

508 from_base64url_uint(obj["e"]), 

509 from_base64url_uint(obj["n"]), 

510 ) 

511 

512 if any_props_found: 

513 numbers = RSAPrivateNumbers( 

514 d=from_base64url_uint(obj["d"]), 

515 p=from_base64url_uint(obj["p"]), 

516 q=from_base64url_uint(obj["q"]), 

517 dmp1=from_base64url_uint(obj["dp"]), 

518 dmq1=from_base64url_uint(obj["dq"]), 

519 iqmp=from_base64url_uint(obj["qi"]), 

520 public_numbers=public_numbers, 

521 ) 

522 else: 

523 d = from_base64url_uint(obj["d"]) 

524 p, q = rsa_recover_prime_factors( 

525 public_numbers.n, d, public_numbers.e 

526 ) 

527 

528 numbers = RSAPrivateNumbers( 

529 d=d, 

530 p=p, 

531 q=q, 

532 dmp1=rsa_crt_dmp1(d, p), 

533 dmq1=rsa_crt_dmq1(d, q), 

534 iqmp=rsa_crt_iqmp(p, q), 

535 public_numbers=public_numbers, 

536 ) 

537 

538 return numbers.private_key() 

539 elif "n" in obj and "e" in obj: 

540 # Public key 

541 return RSAPublicNumbers( 

542 from_base64url_uint(obj["e"]), 

543 from_base64url_uint(obj["n"]), 

544 ).public_key() 

545 else: 

546 raise InvalidKeyError("Not a public or private key") 

547 

548 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

549 return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) 

550 

551 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

552 try: 

553 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) 

554 return True 

555 except InvalidSignature: 

556 return False 

557 

558 class ECAlgorithm(Algorithm): 

559 """ 

560 Performs signing and verification operations using 

561 ECDSA and the specified hash function 

562 """ 

563 

564 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

565 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

566 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

567 

568 _crypto_key_types = ALLOWED_EC_KEY_TYPES 

569 

570 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

571 self.hash_alg = hash_alg 

572 

573 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: 

574 if isinstance(key, self._crypto_key_types): 

575 return key 

576 

577 if not isinstance(key, (bytes, str)): 

578 raise TypeError("Expecting a PEM-formatted key.") 

579 

580 key_bytes = force_bytes(key) 

581 

582 # Attempt to load key. We don't know if it's 

583 # a Signing Key or a Verifying Key, so we try 

584 # the Verifying Key first. 

585 try: 

586 if key_bytes.startswith(b"ecdsa-sha2-"): 

587 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

588 else: 

589 public_key = load_pem_public_key(key_bytes) 

590 

591 # Explicit check the key to prevent confusing errors from cryptography 

592 self.check_crypto_key_type(public_key) 

593 return cast(EllipticCurvePublicKey, public_key) 

594 except ValueError: 

595 private_key = load_pem_private_key(key_bytes, password=None) 

596 self.check_crypto_key_type(private_key) 

597 return cast(EllipticCurvePrivateKey, private_key) 

598 

599 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: 

600 der_sig = key.sign(msg, ECDSA(self.hash_alg())) 

601 

602 return der_to_raw_signature(der_sig, key.curve) 

603 

604 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: 

605 try: 

606 der_sig = raw_to_der_signature(sig, key.curve) 

607 except ValueError: 

608 return False 

609 

610 try: 

611 public_key = ( 

612 key.public_key() 

613 if isinstance(key, EllipticCurvePrivateKey) 

614 else key 

615 ) 

616 public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) 

617 return True 

618 except InvalidSignature: 

619 return False 

620 

621 @overload 

622 @staticmethod 

623 def to_jwk( 

624 key_obj: AllowedECKeys, as_dict: Literal[True] 

625 ) -> JWKDict: ... # pragma: no cover 

626 

627 @overload 

628 @staticmethod 

629 def to_jwk( 

630 key_obj: AllowedECKeys, as_dict: Literal[False] = False 

631 ) -> str: ... # pragma: no cover 

632 

633 @staticmethod 

634 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: 

635 if isinstance(key_obj, EllipticCurvePrivateKey): 

636 public_numbers = key_obj.public_key().public_numbers() 

637 elif isinstance(key_obj, EllipticCurvePublicKey): 

638 public_numbers = key_obj.public_numbers() 

639 else: 

640 raise InvalidKeyError("Not a public or private key") 

641 

642 if isinstance(key_obj.curve, SECP256R1): 

643 crv = "P-256" 

644 elif isinstance(key_obj.curve, SECP384R1): 

645 crv = "P-384" 

646 elif isinstance(key_obj.curve, SECP521R1): 

647 crv = "P-521" 

648 elif isinstance(key_obj.curve, SECP256K1): 

649 crv = "secp256k1" 

650 else: 

651 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") 

652 

653 obj: dict[str, Any] = { 

654 "kty": "EC", 

655 "crv": crv, 

656 "x": to_base64url_uint( 

657 public_numbers.x, 

658 bit_length=key_obj.curve.key_size, 

659 ).decode(), 

660 "y": to_base64url_uint( 

661 public_numbers.y, 

662 bit_length=key_obj.curve.key_size, 

663 ).decode(), 

664 } 

665 

666 if isinstance(key_obj, EllipticCurvePrivateKey): 

667 obj["d"] = to_base64url_uint( 

668 key_obj.private_numbers().private_value, 

669 bit_length=key_obj.curve.key_size, 

670 ).decode() 

671 

672 if as_dict: 

673 return obj 

674 else: 

675 return json.dumps(obj) 

676 

677 @staticmethod 

678 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: 

679 try: 

680 if isinstance(jwk, str): 

681 obj = json.loads(jwk) 

682 elif isinstance(jwk, dict): 

683 obj = jwk 

684 else: 

685 raise ValueError 

686 except ValueError: 

687 raise InvalidKeyError("Key is not valid JSON") from None 

688 

689 if obj.get("kty") != "EC": 

690 raise InvalidKeyError("Not an Elliptic curve key") from None 

691 

692 if "x" not in obj or "y" not in obj: 

693 raise InvalidKeyError("Not an Elliptic curve key") from None 

694 

695 x = base64url_decode(obj.get("x")) 

696 y = base64url_decode(obj.get("y")) 

697 

698 curve = obj.get("crv") 

699 curve_obj: EllipticCurve 

700 

701 if curve == "P-256": 

702 if len(x) == len(y) == 32: 

703 curve_obj = SECP256R1() 

704 else: 

705 raise InvalidKeyError( 

706 "Coords should be 32 bytes for curve P-256" 

707 ) from None 

708 elif curve == "P-384": 

709 if len(x) == len(y) == 48: 

710 curve_obj = SECP384R1() 

711 else: 

712 raise InvalidKeyError( 

713 "Coords should be 48 bytes for curve P-384" 

714 ) from None 

715 elif curve == "P-521": 

716 if len(x) == len(y) == 66: 

717 curve_obj = SECP521R1() 

718 else: 

719 raise InvalidKeyError( 

720 "Coords should be 66 bytes for curve P-521" 

721 ) from None 

722 elif curve == "secp256k1": 

723 if len(x) == len(y) == 32: 

724 curve_obj = SECP256K1() 

725 else: 

726 raise InvalidKeyError( 

727 "Coords should be 32 bytes for curve secp256k1" 

728 ) 

729 else: 

730 raise InvalidKeyError(f"Invalid curve: {curve}") 

731 

732 public_numbers = EllipticCurvePublicNumbers( 

733 x=int.from_bytes(x, byteorder="big"), 

734 y=int.from_bytes(y, byteorder="big"), 

735 curve=curve_obj, 

736 ) 

737 

738 if "d" not in obj: 

739 return public_numbers.public_key() 

740 

741 d = base64url_decode(obj.get("d")) 

742 if len(d) != len(x): 

743 raise InvalidKeyError( 

744 "D should be {} bytes for curve {}", len(x), curve 

745 ) 

746 

747 return EllipticCurvePrivateNumbers( 

748 int.from_bytes(d, byteorder="big"), public_numbers 

749 ).private_key() 

750 

751 class RSAPSSAlgorithm(RSAAlgorithm): 

752 """ 

753 Performs a signature using RSASSA-PSS with MGF1 

754 """ 

755 

756 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

757 return key.sign( 

758 msg, 

759 padding.PSS( 

760 mgf=padding.MGF1(self.hash_alg()), 

761 salt_length=self.hash_alg().digest_size, 

762 ), 

763 self.hash_alg(), 

764 ) 

765 

766 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

767 try: 

768 key.verify( 

769 sig, 

770 msg, 

771 padding.PSS( 

772 mgf=padding.MGF1(self.hash_alg()), 

773 salt_length=self.hash_alg().digest_size, 

774 ), 

775 self.hash_alg(), 

776 ) 

777 return True 

778 except InvalidSignature: 

779 return False 

780 

781 class OKPAlgorithm(Algorithm): 

782 """ 

783 Performs signing and verification operations using EdDSA 

784 

785 This class requires ``cryptography>=2.6`` to be installed. 

786 """ 

787 

788 _crypto_key_types = ALLOWED_OKP_KEY_TYPES 

789 

790 def __init__(self, **kwargs: Any) -> None: 

791 pass 

792 

793 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: 

794 if not isinstance(key, (str, bytes)): 

795 self.check_crypto_key_type(key) 

796 return cast("AllowedOKPKeys", key) 

797 

798 key_str = key.decode("utf-8") if isinstance(key, bytes) else key 

799 key_bytes = key.encode("utf-8") if isinstance(key, str) else key 

800 

801 loaded_key: PublicKeyTypes | PrivateKeyTypes 

802 if "-----BEGIN PUBLIC" in key_str: 

803 loaded_key = load_pem_public_key(key_bytes) 

804 elif "-----BEGIN PRIVATE" in key_str: 

805 loaded_key = load_pem_private_key(key_bytes, password=None) 

806 elif key_str[0:4] == "ssh-": 

807 loaded_key = load_ssh_public_key(key_bytes) 

808 else: 

809 raise InvalidKeyError("Not a public or private key") 

810 

811 # Explicit check the key to prevent confusing errors from cryptography 

812 self.check_crypto_key_type(loaded_key) 

813 return cast("AllowedOKPKeys", loaded_key) 

814 

815 def sign( 

816 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey 

817 ) -> bytes: 

818 """ 

819 Sign a message ``msg`` using the EdDSA private key ``key`` 

820 :param str|bytes msg: Message to sign 

821 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` 

822 or :class:`.Ed448PrivateKey` isinstance 

823 :return bytes signature: The signature, as bytes 

824 """ 

825 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

826 return key.sign(msg_bytes) 

827 

828 def verify( 

829 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes 

830 ) -> bool: 

831 """ 

832 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` 

833 

834 :param str|bytes sig: EdDSA signature to check ``msg`` against 

835 :param str|bytes msg: Message to sign 

836 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: 

837 A private or public EdDSA key instance 

838 :return bool verified: True if signature is valid, False if not. 

839 """ 

840 try: 

841 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

842 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig 

843 

844 public_key = ( 

845 key.public_key() 

846 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) 

847 else key 

848 ) 

849 public_key.verify(sig_bytes, msg_bytes) 

850 return True # If no exception was raised, the signature is valid. 

851 except InvalidSignature: 

852 return False 

853 

854 @overload 

855 @staticmethod 

856 def to_jwk( 

857 key: AllowedOKPKeys, as_dict: Literal[True] 

858 ) -> JWKDict: ... # pragma: no cover 

859 

860 @overload 

861 @staticmethod 

862 def to_jwk( 

863 key: AllowedOKPKeys, as_dict: Literal[False] = False 

864 ) -> str: ... # pragma: no cover 

865 

866 @staticmethod 

867 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: 

868 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): 

869 x = key.public_bytes( 

870 encoding=Encoding.Raw, 

871 format=PublicFormat.Raw, 

872 ) 

873 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" 

874 

875 obj = { 

876 "x": base64url_encode(force_bytes(x)).decode(), 

877 "kty": "OKP", 

878 "crv": crv, 

879 } 

880 

881 if as_dict: 

882 return obj 

883 else: 

884 return json.dumps(obj) 

885 

886 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): 

887 d = key.private_bytes( 

888 encoding=Encoding.Raw, 

889 format=PrivateFormat.Raw, 

890 encryption_algorithm=NoEncryption(), 

891 ) 

892 

893 x = key.public_key().public_bytes( 

894 encoding=Encoding.Raw, 

895 format=PublicFormat.Raw, 

896 ) 

897 

898 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" 

899 obj = { 

900 "x": base64url_encode(force_bytes(x)).decode(), 

901 "d": base64url_encode(force_bytes(d)).decode(), 

902 "kty": "OKP", 

903 "crv": crv, 

904 } 

905 

906 if as_dict: 

907 return obj 

908 else: 

909 return json.dumps(obj) 

910 

911 raise InvalidKeyError("Not a public or private key") 

912 

913 @staticmethod 

914 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: 

915 try: 

916 if isinstance(jwk, str): 

917 obj = json.loads(jwk) 

918 elif isinstance(jwk, dict): 

919 obj = jwk 

920 else: 

921 raise ValueError 

922 except ValueError: 

923 raise InvalidKeyError("Key is not valid JSON") from None 

924 

925 if obj.get("kty") != "OKP": 

926 raise InvalidKeyError("Not an Octet Key Pair") 

927 

928 curve = obj.get("crv") 

929 if curve != "Ed25519" and curve != "Ed448": 

930 raise InvalidKeyError(f"Invalid curve: {curve}") 

931 

932 if "x" not in obj: 

933 raise InvalidKeyError('OKP should have "x" parameter') 

934 x = base64url_decode(obj.get("x")) 

935 

936 try: 

937 if "d" not in obj: 

938 if curve == "Ed25519": 

939 return Ed25519PublicKey.from_public_bytes(x) 

940 return Ed448PublicKey.from_public_bytes(x) 

941 d = base64url_decode(obj.get("d")) 

942 if curve == "Ed25519": 

943 return Ed25519PrivateKey.from_private_bytes(d) 

944 return Ed448PrivateKey.from_private_bytes(d) 

945 except ValueError as err: 

946 raise InvalidKeyError("Invalid key parameter") from err