Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 18%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

385 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import hmac 

5import json 

6from abc import ABC, abstractmethod 

7from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload 

8 

9from .exceptions import InvalidKeyError 

10from .types import HashlibHash, JWKDict 

11from .utils import ( 

12 base64url_decode, 

13 base64url_encode, 

14 der_to_raw_signature, 

15 force_bytes, 

16 from_base64url_uint, 

17 is_pem_format, 

18 is_ssh_key, 

19 raw_to_der_signature, 

20 to_base64url_uint, 

21) 

22 

23try: 

24 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

25 from cryptography.hazmat.backends import default_backend 

26 from cryptography.hazmat.primitives import hashes 

27 from cryptography.hazmat.primitives.asymmetric import padding 

28 from cryptography.hazmat.primitives.asymmetric.ec import ( 

29 ECDSA, 

30 SECP256K1, 

31 SECP256R1, 

32 SECP384R1, 

33 SECP521R1, 

34 EllipticCurve, 

35 EllipticCurvePrivateKey, 

36 EllipticCurvePrivateNumbers, 

37 EllipticCurvePublicKey, 

38 EllipticCurvePublicNumbers, 

39 ) 

40 from cryptography.hazmat.primitives.asymmetric.ed448 import ( 

41 Ed448PrivateKey, 

42 Ed448PublicKey, 

43 ) 

44 from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

45 Ed25519PrivateKey, 

46 Ed25519PublicKey, 

47 ) 

48 from cryptography.hazmat.primitives.asymmetric.rsa import ( 

49 RSAPrivateKey, 

50 RSAPrivateNumbers, 

51 RSAPublicKey, 

52 RSAPublicNumbers, 

53 rsa_crt_dmp1, 

54 rsa_crt_dmq1, 

55 rsa_crt_iqmp, 

56 rsa_recover_prime_factors, 

57 ) 

58 from cryptography.hazmat.primitives.serialization import ( 

59 Encoding, 

60 NoEncryption, 

61 PrivateFormat, 

62 PublicFormat, 

63 load_pem_private_key, 

64 load_pem_public_key, 

65 load_ssh_public_key, 

66 ) 

67 

68 has_crypto = True 

69except ModuleNotFoundError: 

70 has_crypto = False 

71 

72 

73if TYPE_CHECKING: 

74 # Type aliases for convenience in algorithms method signatures 

75 AllowedRSAKeys = RSAPrivateKey | RSAPublicKey 

76 AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey 

77 AllowedOKPKeys = ( 

78 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey 

79 ) 

80 AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys 

81 AllowedPrivateKeys = ( 

82 RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey 

83 ) 

84 AllowedPublicKeys = ( 

85 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey 

86 ) 

87 

88 

89requires_cryptography = { 

90 "RS256", 

91 "RS384", 

92 "RS512", 

93 "ES256", 

94 "ES256K", 

95 "ES384", 

96 "ES521", 

97 "ES512", 

98 "PS256", 

99 "PS384", 

100 "PS512", 

101 "EdDSA", 

102} 

103 

104 

105def get_default_algorithms() -> dict[str, Algorithm]: 

106 """ 

107 Returns the algorithms that are implemented by the library. 

108 """ 

109 default_algorithms = { 

110 "none": NoneAlgorithm(), 

111 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), 

112 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), 

113 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), 

114 } 

115 

116 if has_crypto: 

117 default_algorithms.update( 

118 { 

119 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), 

120 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), 

121 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), 

122 "ES256": ECAlgorithm(ECAlgorithm.SHA256), 

123 "ES256K": ECAlgorithm(ECAlgorithm.SHA256), 

124 "ES384": ECAlgorithm(ECAlgorithm.SHA384), 

125 "ES521": ECAlgorithm(ECAlgorithm.SHA512), 

126 "ES512": ECAlgorithm( 

127 ECAlgorithm.SHA512 

128 ), # Backward compat for #219 fix 

129 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), 

130 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), 

131 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), 

132 "EdDSA": OKPAlgorithm(), 

133 } 

134 ) 

135 

136 return default_algorithms 

137 

138 

139class Algorithm(ABC): 

140 """ 

141 The interface for an algorithm used to sign and verify tokens. 

142 """ 

143 

144 def compute_hash_digest(self, bytestr: bytes) -> bytes: 

145 """ 

146 Compute a hash digest using the specified algorithm's hash algorithm. 

147 

148 If there is no hash algorithm, raises a NotImplementedError. 

149 """ 

150 # lookup self.hash_alg if defined in a way that mypy can understand 

151 hash_alg = getattr(self, "hash_alg", None) 

152 if hash_alg is None: 

153 raise NotImplementedError 

154 

155 if ( 

156 has_crypto 

157 and isinstance(hash_alg, type) 

158 and issubclass(hash_alg, hashes.HashAlgorithm) 

159 ): 

160 digest = hashes.Hash(hash_alg(), backend=default_backend()) 

161 digest.update(bytestr) 

162 return bytes(digest.finalize()) 

163 else: 

164 return bytes(hash_alg(bytestr).digest()) 

165 

166 @abstractmethod 

167 def prepare_key(self, key: Any) -> Any: 

168 """ 

169 Performs necessary validation and conversions on the key and returns 

170 the key value in the proper format for sign() and verify(). 

171 """ 

172 

173 @abstractmethod 

174 def sign(self, msg: bytes, key: Any) -> bytes: 

175 """ 

176 Returns a digital signature for the specified message 

177 using the specified key value. 

178 """ 

179 

180 @abstractmethod 

181 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: 

182 """ 

183 Verifies that the specified digital signature is valid 

184 for the specified message and key values. 

185 """ 

186 

187 @overload 

188 @staticmethod 

189 @abstractmethod 

190 def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover 

191 

192 @overload 

193 @staticmethod 

194 @abstractmethod 

195 def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover 

196 

197 @staticmethod 

198 @abstractmethod 

199 def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str: 

200 """ 

201 Serializes a given key into a JWK 

202 """ 

203 

204 @staticmethod 

205 @abstractmethod 

206 def from_jwk(jwk: str | JWKDict) -> Any: 

207 """ 

208 Deserializes a given key from JWK back into a key object 

209 """ 

210 

211 

212class NoneAlgorithm(Algorithm): 

213 """ 

214 Placeholder for use when no signing or verification 

215 operations are required. 

216 """ 

217 

218 def prepare_key(self, key: str | None) -> None: 

219 if key == "": 

220 key = None 

221 

222 if key is not None: 

223 raise InvalidKeyError('When alg = "none", key value must be None.') 

224 

225 return key 

226 

227 def sign(self, msg: bytes, key: None) -> bytes: 

228 return b"" 

229 

230 def verify(self, msg: bytes, key: None, sig: bytes) -> bool: 

231 return False 

232 

233 @staticmethod 

234 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: 

235 raise NotImplementedError() 

236 

237 @staticmethod 

238 def from_jwk(jwk: str | JWKDict) -> NoReturn: 

239 raise NotImplementedError() 

240 

241 

242class HMACAlgorithm(Algorithm): 

243 """ 

244 Performs signing and verification operations using HMAC 

245 and the specified hash function. 

246 """ 

247 

248 SHA256: ClassVar[HashlibHash] = hashlib.sha256 

249 SHA384: ClassVar[HashlibHash] = hashlib.sha384 

250 SHA512: ClassVar[HashlibHash] = hashlib.sha512 

251 

252 def __init__(self, hash_alg: HashlibHash) -> None: 

253 self.hash_alg = hash_alg 

254 

255 def prepare_key(self, key: str | bytes) -> bytes: 

256 key_bytes = force_bytes(key) 

257 

258 if is_pem_format(key_bytes) or is_ssh_key(key_bytes): 

259 raise InvalidKeyError( 

260 "The specified key is an asymmetric key or x509 certificate and" 

261 " should not be used as an HMAC secret." 

262 ) 

263 

264 return key_bytes 

265 

266 @overload 

267 @staticmethod 

268 def to_jwk( 

269 key_obj: str | bytes, as_dict: Literal[True] 

270 ) -> JWKDict: ... # pragma: no cover 

271 

272 @overload 

273 @staticmethod 

274 def to_jwk( 

275 key_obj: str | bytes, as_dict: Literal[False] = False 

276 ) -> str: ... # pragma: no cover 

277 

278 @staticmethod 

279 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: 

280 jwk = { 

281 "k": base64url_encode(force_bytes(key_obj)).decode(), 

282 "kty": "oct", 

283 } 

284 

285 if as_dict: 

286 return jwk 

287 else: 

288 return json.dumps(jwk) 

289 

290 @staticmethod 

291 def from_jwk(jwk: str | JWKDict) -> bytes: 

292 try: 

293 if isinstance(jwk, str): 

294 obj: JWKDict = json.loads(jwk) 

295 elif isinstance(jwk, dict): 

296 obj = jwk 

297 else: 

298 raise ValueError 

299 except ValueError: 

300 raise InvalidKeyError("Key is not valid JSON") from None 

301 

302 if obj.get("kty") != "oct": 

303 raise InvalidKeyError("Not an HMAC key") 

304 

305 return base64url_decode(obj["k"]) 

306 

307 def sign(self, msg: bytes, key: bytes) -> bytes: 

308 return hmac.new(key, msg, self.hash_alg).digest() 

309 

310 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: 

311 return hmac.compare_digest(sig, self.sign(msg, key)) 

312 

313 

314if has_crypto: 

315 

316 class RSAAlgorithm(Algorithm): 

317 """ 

318 Performs signing and verification operations using 

319 RSASSA-PKCS-v1_5 and the specified hash function. 

320 """ 

321 

322 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

323 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

324 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

325 

326 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

327 self.hash_alg = hash_alg 

328 

329 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: 

330 if isinstance(key, (RSAPrivateKey, RSAPublicKey)): 

331 return key 

332 

333 if not isinstance(key, (bytes, str)): 

334 raise TypeError("Expecting a PEM-formatted key.") 

335 

336 key_bytes = force_bytes(key) 

337 

338 try: 

339 if key_bytes.startswith(b"ssh-rsa"): 

340 return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) 

341 else: 

342 return cast( 

343 RSAPrivateKey, load_pem_private_key(key_bytes, password=None) 

344 ) 

345 except ValueError: 

346 try: 

347 return cast(RSAPublicKey, load_pem_public_key(key_bytes)) 

348 except (ValueError, UnsupportedAlgorithm): 

349 raise InvalidKeyError( 

350 "Could not parse the provided public key." 

351 ) from None 

352 

353 @overload 

354 @staticmethod 

355 def to_jwk( 

356 key_obj: AllowedRSAKeys, as_dict: Literal[True] 

357 ) -> JWKDict: ... # pragma: no cover 

358 

359 @overload 

360 @staticmethod 

361 def to_jwk( 

362 key_obj: AllowedRSAKeys, as_dict: Literal[False] = False 

363 ) -> str: ... # pragma: no cover 

364 

365 @staticmethod 

366 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: 

367 obj: dict[str, Any] | None = None 

368 

369 if hasattr(key_obj, "private_numbers"): 

370 # Private key 

371 numbers = key_obj.private_numbers() 

372 

373 obj = { 

374 "kty": "RSA", 

375 "key_ops": ["sign"], 

376 "n": to_base64url_uint(numbers.public_numbers.n).decode(), 

377 "e": to_base64url_uint(numbers.public_numbers.e).decode(), 

378 "d": to_base64url_uint(numbers.d).decode(), 

379 "p": to_base64url_uint(numbers.p).decode(), 

380 "q": to_base64url_uint(numbers.q).decode(), 

381 "dp": to_base64url_uint(numbers.dmp1).decode(), 

382 "dq": to_base64url_uint(numbers.dmq1).decode(), 

383 "qi": to_base64url_uint(numbers.iqmp).decode(), 

384 } 

385 

386 elif hasattr(key_obj, "verify"): 

387 # Public key 

388 numbers = key_obj.public_numbers() 

389 

390 obj = { 

391 "kty": "RSA", 

392 "key_ops": ["verify"], 

393 "n": to_base64url_uint(numbers.n).decode(), 

394 "e": to_base64url_uint(numbers.e).decode(), 

395 } 

396 else: 

397 raise InvalidKeyError("Not a public or private key") 

398 

399 if as_dict: 

400 return obj 

401 else: 

402 return json.dumps(obj) 

403 

404 @staticmethod 

405 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: 

406 try: 

407 if isinstance(jwk, str): 

408 obj = json.loads(jwk) 

409 elif isinstance(jwk, dict): 

410 obj = jwk 

411 else: 

412 raise ValueError 

413 except ValueError: 

414 raise InvalidKeyError("Key is not valid JSON") from None 

415 

416 if obj.get("kty") != "RSA": 

417 raise InvalidKeyError("Not an RSA key") from None 

418 

419 if "d" in obj and "e" in obj and "n" in obj: 

420 # Private key 

421 if "oth" in obj: 

422 raise InvalidKeyError( 

423 "Unsupported RSA private key: > 2 primes not supported" 

424 ) 

425 

426 other_props = ["p", "q", "dp", "dq", "qi"] 

427 props_found = [prop in obj for prop in other_props] 

428 any_props_found = any(props_found) 

429 

430 if any_props_found and not all(props_found): 

431 raise InvalidKeyError( 

432 "RSA key must include all parameters if any are present besides d" 

433 ) from None 

434 

435 public_numbers = RSAPublicNumbers( 

436 from_base64url_uint(obj["e"]), 

437 from_base64url_uint(obj["n"]), 

438 ) 

439 

440 if any_props_found: 

441 numbers = RSAPrivateNumbers( 

442 d=from_base64url_uint(obj["d"]), 

443 p=from_base64url_uint(obj["p"]), 

444 q=from_base64url_uint(obj["q"]), 

445 dmp1=from_base64url_uint(obj["dp"]), 

446 dmq1=from_base64url_uint(obj["dq"]), 

447 iqmp=from_base64url_uint(obj["qi"]), 

448 public_numbers=public_numbers, 

449 ) 

450 else: 

451 d = from_base64url_uint(obj["d"]) 

452 p, q = rsa_recover_prime_factors( 

453 public_numbers.n, d, public_numbers.e 

454 ) 

455 

456 numbers = RSAPrivateNumbers( 

457 d=d, 

458 p=p, 

459 q=q, 

460 dmp1=rsa_crt_dmp1(d, p), 

461 dmq1=rsa_crt_dmq1(d, q), 

462 iqmp=rsa_crt_iqmp(p, q), 

463 public_numbers=public_numbers, 

464 ) 

465 

466 return numbers.private_key() 

467 elif "n" in obj and "e" in obj: 

468 # Public key 

469 return RSAPublicNumbers( 

470 from_base64url_uint(obj["e"]), 

471 from_base64url_uint(obj["n"]), 

472 ).public_key() 

473 else: 

474 raise InvalidKeyError("Not a public or private key") 

475 

476 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

477 return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) 

478 

479 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

480 try: 

481 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) 

482 return True 

483 except InvalidSignature: 

484 return False 

485 

486 class ECAlgorithm(Algorithm): 

487 """ 

488 Performs signing and verification operations using 

489 ECDSA and the specified hash function 

490 """ 

491 

492 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

493 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

494 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

495 

496 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

497 self.hash_alg = hash_alg 

498 

499 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: 

500 if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): 

501 return key 

502 

503 if not isinstance(key, (bytes, str)): 

504 raise TypeError("Expecting a PEM-formatted key.") 

505 

506 key_bytes = force_bytes(key) 

507 

508 # Attempt to load key. We don't know if it's 

509 # a Signing Key or a Verifying Key, so we try 

510 # the Verifying Key first. 

511 try: 

512 if key_bytes.startswith(b"ecdsa-sha2-"): 

513 crypto_key = load_ssh_public_key(key_bytes) 

514 else: 

515 crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] 

516 except ValueError: 

517 crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] 

518 

519 # Explicit check the key to prevent confusing errors from cryptography 

520 if not isinstance( 

521 crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) 

522 ): 

523 raise InvalidKeyError( 

524 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" 

525 ) from None 

526 

527 return crypto_key 

528 

529 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: 

530 der_sig = key.sign(msg, ECDSA(self.hash_alg())) 

531 

532 return der_to_raw_signature(der_sig, key.curve) 

533 

534 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: 

535 try: 

536 der_sig = raw_to_der_signature(sig, key.curve) 

537 except ValueError: 

538 return False 

539 

540 try: 

541 public_key = ( 

542 key.public_key() 

543 if isinstance(key, EllipticCurvePrivateKey) 

544 else key 

545 ) 

546 public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) 

547 return True 

548 except InvalidSignature: 

549 return False 

550 

551 @overload 

552 @staticmethod 

553 def to_jwk( 

554 key_obj: AllowedECKeys, as_dict: Literal[True] 

555 ) -> JWKDict: ... # pragma: no cover 

556 

557 @overload 

558 @staticmethod 

559 def to_jwk( 

560 key_obj: AllowedECKeys, as_dict: Literal[False] = False 

561 ) -> str: ... # pragma: no cover 

562 

563 @staticmethod 

564 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: 

565 if isinstance(key_obj, EllipticCurvePrivateKey): 

566 public_numbers = key_obj.public_key().public_numbers() 

567 elif isinstance(key_obj, EllipticCurvePublicKey): 

568 public_numbers = key_obj.public_numbers() 

569 else: 

570 raise InvalidKeyError("Not a public or private key") 

571 

572 if isinstance(key_obj.curve, SECP256R1): 

573 crv = "P-256" 

574 elif isinstance(key_obj.curve, SECP384R1): 

575 crv = "P-384" 

576 elif isinstance(key_obj.curve, SECP521R1): 

577 crv = "P-521" 

578 elif isinstance(key_obj.curve, SECP256K1): 

579 crv = "secp256k1" 

580 else: 

581 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") 

582 

583 obj: dict[str, Any] = { 

584 "kty": "EC", 

585 "crv": crv, 

586 "x": to_base64url_uint( 

587 public_numbers.x, 

588 bit_length=key_obj.curve.key_size, 

589 ).decode(), 

590 "y": to_base64url_uint( 

591 public_numbers.y, 

592 bit_length=key_obj.curve.key_size, 

593 ).decode(), 

594 } 

595 

596 if isinstance(key_obj, EllipticCurvePrivateKey): 

597 obj["d"] = to_base64url_uint( 

598 key_obj.private_numbers().private_value, 

599 bit_length=key_obj.curve.key_size, 

600 ).decode() 

601 

602 if as_dict: 

603 return obj 

604 else: 

605 return json.dumps(obj) 

606 

607 @staticmethod 

608 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: 

609 try: 

610 if isinstance(jwk, str): 

611 obj = json.loads(jwk) 

612 elif isinstance(jwk, dict): 

613 obj = jwk 

614 else: 

615 raise ValueError 

616 except ValueError: 

617 raise InvalidKeyError("Key is not valid JSON") from None 

618 

619 if obj.get("kty") != "EC": 

620 raise InvalidKeyError("Not an Elliptic curve key") from None 

621 

622 if "x" not in obj or "y" not in obj: 

623 raise InvalidKeyError("Not an Elliptic curve key") from None 

624 

625 x = base64url_decode(obj.get("x")) 

626 y = base64url_decode(obj.get("y")) 

627 

628 curve = obj.get("crv") 

629 curve_obj: EllipticCurve 

630 

631 if curve == "P-256": 

632 if len(x) == len(y) == 32: 

633 curve_obj = SECP256R1() 

634 else: 

635 raise InvalidKeyError( 

636 "Coords should be 32 bytes for curve P-256" 

637 ) from None 

638 elif curve == "P-384": 

639 if len(x) == len(y) == 48: 

640 curve_obj = SECP384R1() 

641 else: 

642 raise InvalidKeyError( 

643 "Coords should be 48 bytes for curve P-384" 

644 ) from None 

645 elif curve == "P-521": 

646 if len(x) == len(y) == 66: 

647 curve_obj = SECP521R1() 

648 else: 

649 raise InvalidKeyError( 

650 "Coords should be 66 bytes for curve P-521" 

651 ) from None 

652 elif curve == "secp256k1": 

653 if len(x) == len(y) == 32: 

654 curve_obj = SECP256K1() 

655 else: 

656 raise InvalidKeyError( 

657 "Coords should be 32 bytes for curve secp256k1" 

658 ) 

659 else: 

660 raise InvalidKeyError(f"Invalid curve: {curve}") 

661 

662 public_numbers = EllipticCurvePublicNumbers( 

663 x=int.from_bytes(x, byteorder="big"), 

664 y=int.from_bytes(y, byteorder="big"), 

665 curve=curve_obj, 

666 ) 

667 

668 if "d" not in obj: 

669 return public_numbers.public_key() 

670 

671 d = base64url_decode(obj.get("d")) 

672 if len(d) != len(x): 

673 raise InvalidKeyError( 

674 "D should be {} bytes for curve {}", len(x), curve 

675 ) 

676 

677 return EllipticCurvePrivateNumbers( 

678 int.from_bytes(d, byteorder="big"), public_numbers 

679 ).private_key() 

680 

681 class RSAPSSAlgorithm(RSAAlgorithm): 

682 """ 

683 Performs a signature using RSASSA-PSS with MGF1 

684 """ 

685 

686 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

687 return key.sign( 

688 msg, 

689 padding.PSS( 

690 mgf=padding.MGF1(self.hash_alg()), 

691 salt_length=self.hash_alg().digest_size, 

692 ), 

693 self.hash_alg(), 

694 ) 

695 

696 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

697 try: 

698 key.verify( 

699 sig, 

700 msg, 

701 padding.PSS( 

702 mgf=padding.MGF1(self.hash_alg()), 

703 salt_length=self.hash_alg().digest_size, 

704 ), 

705 self.hash_alg(), 

706 ) 

707 return True 

708 except InvalidSignature: 

709 return False 

710 

711 class OKPAlgorithm(Algorithm): 

712 """ 

713 Performs signing and verification operations using EdDSA 

714 

715 This class requires ``cryptography>=2.6`` to be installed. 

716 """ 

717 

718 def __init__(self, **kwargs: Any) -> None: 

719 pass 

720 

721 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: 

722 if isinstance(key, (bytes, str)): 

723 key_str = key.decode("utf-8") if isinstance(key, bytes) else key 

724 key_bytes = key.encode("utf-8") if isinstance(key, str) else key 

725 

726 if "-----BEGIN PUBLIC" in key_str: 

727 key = load_pem_public_key(key_bytes) # type: ignore[assignment] 

728 elif "-----BEGIN PRIVATE" in key_str: 

729 key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] 

730 elif key_str[0:4] == "ssh-": 

731 key = load_ssh_public_key(key_bytes) # type: ignore[assignment] 

732 

733 # Explicit check the key to prevent confusing errors from cryptography 

734 if not isinstance( 

735 key, 

736 (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey), 

737 ): 

738 raise InvalidKeyError( 

739 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms" 

740 ) 

741 

742 return key 

743 

744 def sign( 

745 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey 

746 ) -> bytes: 

747 """ 

748 Sign a message ``msg`` using the EdDSA private key ``key`` 

749 :param str|bytes msg: Message to sign 

750 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` 

751 or :class:`.Ed448PrivateKey` isinstance 

752 :return bytes signature: The signature, as bytes 

753 """ 

754 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

755 return key.sign(msg_bytes) 

756 

757 def verify( 

758 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes 

759 ) -> bool: 

760 """ 

761 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` 

762 

763 :param str|bytes sig: EdDSA signature to check ``msg`` against 

764 :param str|bytes msg: Message to sign 

765 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: 

766 A private or public EdDSA key instance 

767 :return bool verified: True if signature is valid, False if not. 

768 """ 

769 try: 

770 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

771 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig 

772 

773 public_key = ( 

774 key.public_key() 

775 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) 

776 else key 

777 ) 

778 public_key.verify(sig_bytes, msg_bytes) 

779 return True # If no exception was raised, the signature is valid. 

780 except InvalidSignature: 

781 return False 

782 

783 @overload 

784 @staticmethod 

785 def to_jwk( 

786 key: AllowedOKPKeys, as_dict: Literal[True] 

787 ) -> JWKDict: ... # pragma: no cover 

788 

789 @overload 

790 @staticmethod 

791 def to_jwk( 

792 key: AllowedOKPKeys, as_dict: Literal[False] = False 

793 ) -> str: ... # pragma: no cover 

794 

795 @staticmethod 

796 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: 

797 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): 

798 x = key.public_bytes( 

799 encoding=Encoding.Raw, 

800 format=PublicFormat.Raw, 

801 ) 

802 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" 

803 

804 obj = { 

805 "x": base64url_encode(force_bytes(x)).decode(), 

806 "kty": "OKP", 

807 "crv": crv, 

808 } 

809 

810 if as_dict: 

811 return obj 

812 else: 

813 return json.dumps(obj) 

814 

815 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): 

816 d = key.private_bytes( 

817 encoding=Encoding.Raw, 

818 format=PrivateFormat.Raw, 

819 encryption_algorithm=NoEncryption(), 

820 ) 

821 

822 x = key.public_key().public_bytes( 

823 encoding=Encoding.Raw, 

824 format=PublicFormat.Raw, 

825 ) 

826 

827 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" 

828 obj = { 

829 "x": base64url_encode(force_bytes(x)).decode(), 

830 "d": base64url_encode(force_bytes(d)).decode(), 

831 "kty": "OKP", 

832 "crv": crv, 

833 } 

834 

835 if as_dict: 

836 return obj 

837 else: 

838 return json.dumps(obj) 

839 

840 raise InvalidKeyError("Not a public or private key") 

841 

842 @staticmethod 

843 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: 

844 try: 

845 if isinstance(jwk, str): 

846 obj = json.loads(jwk) 

847 elif isinstance(jwk, dict): 

848 obj = jwk 

849 else: 

850 raise ValueError 

851 except ValueError: 

852 raise InvalidKeyError("Key is not valid JSON") from None 

853 

854 if obj.get("kty") != "OKP": 

855 raise InvalidKeyError("Not an Octet Key Pair") 

856 

857 curve = obj.get("crv") 

858 if curve != "Ed25519" and curve != "Ed448": 

859 raise InvalidKeyError(f"Invalid curve: {curve}") 

860 

861 if "x" not in obj: 

862 raise InvalidKeyError('OKP should have "x" parameter') 

863 x = base64url_decode(obj.get("x")) 

864 

865 try: 

866 if "d" not in obj: 

867 if curve == "Ed25519": 

868 return Ed25519PublicKey.from_public_bytes(x) 

869 return Ed448PublicKey.from_public_bytes(x) 

870 d = base64url_decode(obj.get("d")) 

871 if curve == "Ed25519": 

872 return Ed25519PrivateKey.from_private_bytes(d) 

873 return Ed448PrivateKey.from_private_bytes(d) 

874 except ValueError as err: 

875 raise InvalidKeyError("Invalid key parameter") from err