Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jwt/algorithms.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

385 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import hmac 

5import json 

6from abc import ABC, abstractmethod 

7from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload 

8 

9from .exceptions import InvalidKeyError 

10from .types import HashlibHash, JWKDict 

11from .utils import ( 

12 base64url_decode, 

13 base64url_encode, 

14 der_to_raw_signature, 

15 force_bytes, 

16 from_base64url_uint, 

17 is_pem_format, 

18 is_ssh_key, 

19 raw_to_der_signature, 

20 to_base64url_uint, 

21) 

22 

23try: 

24 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

25 from cryptography.hazmat.backends import default_backend 

26 from cryptography.hazmat.primitives import hashes 

27 from cryptography.hazmat.primitives.asymmetric import padding 

28 from cryptography.hazmat.primitives.asymmetric.ec import ( 

29 ECDSA, 

30 SECP256K1, 

31 SECP256R1, 

32 SECP384R1, 

33 SECP521R1, 

34 EllipticCurve, 

35 EllipticCurvePrivateKey, 

36 EllipticCurvePrivateNumbers, 

37 EllipticCurvePublicKey, 

38 EllipticCurvePublicNumbers, 

39 ) 

40 from cryptography.hazmat.primitives.asymmetric.ed448 import ( 

41 Ed448PrivateKey, 

42 Ed448PublicKey, 

43 ) 

44 from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

45 Ed25519PrivateKey, 

46 Ed25519PublicKey, 

47 ) 

48 from cryptography.hazmat.primitives.asymmetric.rsa import ( 

49 RSAPrivateKey, 

50 RSAPrivateNumbers, 

51 RSAPublicKey, 

52 RSAPublicNumbers, 

53 rsa_crt_dmp1, 

54 rsa_crt_dmq1, 

55 rsa_crt_iqmp, 

56 rsa_recover_prime_factors, 

57 ) 

58 from cryptography.hazmat.primitives.serialization import ( 

59 Encoding, 

60 NoEncryption, 

61 PrivateFormat, 

62 PublicFormat, 

63 load_pem_private_key, 

64 load_pem_public_key, 

65 load_ssh_public_key, 

66 ) 

67 

68 has_crypto = True 

69except ModuleNotFoundError: 

70 has_crypto = False 

71 

72 

73if TYPE_CHECKING: 

74 # Type aliases for convenience in algorithms method signatures 

75 AllowedRSAKeys = RSAPrivateKey | RSAPublicKey 

76 AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey 

77 AllowedOKPKeys = ( 

78 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey 

79 ) 

80 AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys 

81 AllowedPrivateKeys = ( 

82 RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey 

83 ) 

84 AllowedPublicKeys = ( 

85 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey 

86 ) 

87 

88 

89requires_cryptography = { 

90 "RS256", 

91 "RS384", 

92 "RS512", 

93 "ES256", 

94 "ES256K", 

95 "ES384", 

96 "ES521", 

97 "ES512", 

98 "PS256", 

99 "PS384", 

100 "PS512", 

101 "EdDSA", 

102} 

103 

104 

105def get_default_algorithms() -> dict[str, Algorithm]: 

106 """ 

107 Returns the algorithms that are implemented by the library. 

108 """ 

109 default_algorithms = { 

110 "none": NoneAlgorithm(), 

111 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), 

112 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), 

113 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), 

114 } 

115 

116 if has_crypto: 

117 default_algorithms.update( 

118 { 

119 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), 

120 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), 

121 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), 

122 "ES256": ECAlgorithm(ECAlgorithm.SHA256), 

123 "ES256K": ECAlgorithm(ECAlgorithm.SHA256), 

124 "ES384": ECAlgorithm(ECAlgorithm.SHA384), 

125 "ES521": ECAlgorithm(ECAlgorithm.SHA512), 

126 "ES512": ECAlgorithm( 

127 ECAlgorithm.SHA512 

128 ), # Backward compat for #219 fix 

129 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), 

130 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), 

131 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), 

132 "EdDSA": OKPAlgorithm(), 

133 } 

134 ) 

135 

136 return default_algorithms 

137 

138 

139class Algorithm(ABC): 

140 """ 

141 The interface for an algorithm used to sign and verify tokens. 

142 """ 

143 

144 def compute_hash_digest(self, bytestr: bytes) -> bytes: 

145 """ 

146 Compute a hash digest using the specified algorithm's hash algorithm. 

147 

148 If there is no hash algorithm, raises a NotImplementedError. 

149 """ 

150 # lookup self.hash_alg if defined in a way that mypy can understand 

151 hash_alg = getattr(self, "hash_alg", None) 

152 if hash_alg is None: 

153 raise NotImplementedError 

154 

155 if ( 

156 has_crypto 

157 and isinstance(hash_alg, type) 

158 and issubclass(hash_alg, hashes.HashAlgorithm) 

159 ): 

160 digest = hashes.Hash(hash_alg(), backend=default_backend()) 

161 digest.update(bytestr) 

162 return bytes(digest.finalize()) 

163 else: 

164 return bytes(hash_alg(bytestr).digest()) 

165 

166 @abstractmethod 

167 def prepare_key(self, key: Any) -> Any: 

168 """ 

169 Performs necessary validation and conversions on the key and returns 

170 the key value in the proper format for sign() and verify(). 

171 """ 

172 

173 @abstractmethod 

174 def sign(self, msg: bytes, key: Any) -> bytes: 

175 """ 

176 Returns a digital signature for the specified message 

177 using the specified key value. 

178 """ 

179 

180 @abstractmethod 

181 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: 

182 """ 

183 Verifies that the specified digital signature is valid 

184 for the specified message and key values. 

185 """ 

186 

187 @overload 

188 @staticmethod 

189 @abstractmethod 

190 def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover 

191 

192 @overload 

193 @staticmethod 

194 @abstractmethod 

195 def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover 

196 

197 @staticmethod 

198 @abstractmethod 

199 def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str: 

200 """ 

201 Serializes a given key into a JWK 

202 """ 

203 

204 @staticmethod 

205 @abstractmethod 

206 def from_jwk(jwk: str | JWKDict) -> Any: 

207 """ 

208 Deserializes a given key from JWK back into a key object 

209 """ 

210 

211 

212class NoneAlgorithm(Algorithm): 

213 """ 

214 Placeholder for use when no signing or verification 

215 operations are required. 

216 """ 

217 

218 def prepare_key(self, key: str | None) -> None: 

219 if key == "": 

220 key = None 

221 

222 if key is not None: 

223 raise InvalidKeyError('When alg = "none", key value must be None.') 

224 

225 return key 

226 

227 def sign(self, msg: bytes, key: None) -> bytes: 

228 return b"" 

229 

230 def verify(self, msg: bytes, key: None, sig: bytes) -> bool: 

231 return False 

232 

233 @staticmethod 

234 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: 

235 raise NotImplementedError() 

236 

237 @staticmethod 

238 def from_jwk(jwk: str | JWKDict) -> NoReturn: 

239 raise NotImplementedError() 

240 

241 

242class HMACAlgorithm(Algorithm): 

243 """ 

244 Performs signing and verification operations using HMAC 

245 and the specified hash function. 

246 """ 

247 

248 SHA256: ClassVar[HashlibHash] = hashlib.sha256 

249 SHA384: ClassVar[HashlibHash] = hashlib.sha384 

250 SHA512: ClassVar[HashlibHash] = hashlib.sha512 

251 

252 def __init__(self, hash_alg: HashlibHash) -> None: 

253 self.hash_alg = hash_alg 

254 

255 def prepare_key(self, key: str | bytes) -> bytes: 

256 key_bytes = force_bytes(key) 

257 

258 if is_pem_format(key_bytes) or is_ssh_key(key_bytes): 

259 raise InvalidKeyError( 

260 "The specified key is an asymmetric key or x509 certificate and" 

261 " should not be used as an HMAC secret." 

262 ) 

263 

264 return key_bytes 

265 

266 @overload 

267 @staticmethod 

268 def to_jwk( 

269 key_obj: str | bytes, as_dict: Literal[True] 

270 ) -> JWKDict: ... # pragma: no cover 

271 

272 @overload 

273 @staticmethod 

274 def to_jwk( 

275 key_obj: str | bytes, as_dict: Literal[False] = False 

276 ) -> str: ... # pragma: no cover 

277 

278 @staticmethod 

279 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: 

280 jwk = { 

281 "k": base64url_encode(force_bytes(key_obj)).decode(), 

282 "kty": "oct", 

283 } 

284 

285 if as_dict: 

286 return jwk 

287 else: 

288 return json.dumps(jwk) 

289 

290 @staticmethod 

291 def from_jwk(jwk: str | JWKDict) -> bytes: 

292 try: 

293 if isinstance(jwk, str): 

294 obj: JWKDict = json.loads(jwk) 

295 elif isinstance(jwk, dict): 

296 obj = jwk 

297 else: 

298 raise ValueError 

299 except ValueError: 

300 raise InvalidKeyError("Key is not valid JSON") 

301 

302 if obj.get("kty") != "oct": 

303 raise InvalidKeyError("Not an HMAC key") 

304 

305 return base64url_decode(obj["k"]) 

306 

307 def sign(self, msg: bytes, key: bytes) -> bytes: 

308 return hmac.new(key, msg, self.hash_alg).digest() 

309 

310 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: 

311 return hmac.compare_digest(sig, self.sign(msg, key)) 

312 

313 

314if has_crypto: 

315 

316 class RSAAlgorithm(Algorithm): 

317 """ 

318 Performs signing and verification operations using 

319 RSASSA-PKCS-v1_5 and the specified hash function. 

320 """ 

321 

322 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

323 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

324 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

325 

326 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

327 self.hash_alg = hash_alg 

328 

329 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: 

330 if isinstance(key, (RSAPrivateKey, RSAPublicKey)): 

331 return key 

332 

333 if not isinstance(key, (bytes, str)): 

334 raise TypeError("Expecting a PEM-formatted key.") 

335 

336 key_bytes = force_bytes(key) 

337 

338 try: 

339 if key_bytes.startswith(b"ssh-rsa"): 

340 return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) 

341 else: 

342 return cast( 

343 RSAPrivateKey, load_pem_private_key(key_bytes, password=None) 

344 ) 

345 except ValueError: 

346 try: 

347 return cast(RSAPublicKey, load_pem_public_key(key_bytes)) 

348 except (ValueError, UnsupportedAlgorithm): 

349 raise InvalidKeyError("Could not parse the provided public key.") 

350 

351 @overload 

352 @staticmethod 

353 def to_jwk( 

354 key_obj: AllowedRSAKeys, as_dict: Literal[True] 

355 ) -> JWKDict: ... # pragma: no cover 

356 

357 @overload 

358 @staticmethod 

359 def to_jwk( 

360 key_obj: AllowedRSAKeys, as_dict: Literal[False] = False 

361 ) -> str: ... # pragma: no cover 

362 

363 @staticmethod 

364 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: 

365 obj: dict[str, Any] | None = None 

366 

367 if hasattr(key_obj, "private_numbers"): 

368 # Private key 

369 numbers = key_obj.private_numbers() 

370 

371 obj = { 

372 "kty": "RSA", 

373 "key_ops": ["sign"], 

374 "n": to_base64url_uint(numbers.public_numbers.n).decode(), 

375 "e": to_base64url_uint(numbers.public_numbers.e).decode(), 

376 "d": to_base64url_uint(numbers.d).decode(), 

377 "p": to_base64url_uint(numbers.p).decode(), 

378 "q": to_base64url_uint(numbers.q).decode(), 

379 "dp": to_base64url_uint(numbers.dmp1).decode(), 

380 "dq": to_base64url_uint(numbers.dmq1).decode(), 

381 "qi": to_base64url_uint(numbers.iqmp).decode(), 

382 } 

383 

384 elif hasattr(key_obj, "verify"): 

385 # Public key 

386 numbers = key_obj.public_numbers() 

387 

388 obj = { 

389 "kty": "RSA", 

390 "key_ops": ["verify"], 

391 "n": to_base64url_uint(numbers.n).decode(), 

392 "e": to_base64url_uint(numbers.e).decode(), 

393 } 

394 else: 

395 raise InvalidKeyError("Not a public or private key") 

396 

397 if as_dict: 

398 return obj 

399 else: 

400 return json.dumps(obj) 

401 

402 @staticmethod 

403 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: 

404 try: 

405 if isinstance(jwk, str): 

406 obj = json.loads(jwk) 

407 elif isinstance(jwk, dict): 

408 obj = jwk 

409 else: 

410 raise ValueError 

411 except ValueError: 

412 raise InvalidKeyError("Key is not valid JSON") 

413 

414 if obj.get("kty") != "RSA": 

415 raise InvalidKeyError("Not an RSA key") 

416 

417 if "d" in obj and "e" in obj and "n" in obj: 

418 # Private key 

419 if "oth" in obj: 

420 raise InvalidKeyError( 

421 "Unsupported RSA private key: > 2 primes not supported" 

422 ) 

423 

424 other_props = ["p", "q", "dp", "dq", "qi"] 

425 props_found = [prop in obj for prop in other_props] 

426 any_props_found = any(props_found) 

427 

428 if any_props_found and not all(props_found): 

429 raise InvalidKeyError( 

430 "RSA key must include all parameters if any are present besides d" 

431 ) 

432 

433 public_numbers = RSAPublicNumbers( 

434 from_base64url_uint(obj["e"]), 

435 from_base64url_uint(obj["n"]), 

436 ) 

437 

438 if any_props_found: 

439 numbers = RSAPrivateNumbers( 

440 d=from_base64url_uint(obj["d"]), 

441 p=from_base64url_uint(obj["p"]), 

442 q=from_base64url_uint(obj["q"]), 

443 dmp1=from_base64url_uint(obj["dp"]), 

444 dmq1=from_base64url_uint(obj["dq"]), 

445 iqmp=from_base64url_uint(obj["qi"]), 

446 public_numbers=public_numbers, 

447 ) 

448 else: 

449 d = from_base64url_uint(obj["d"]) 

450 p, q = rsa_recover_prime_factors( 

451 public_numbers.n, d, public_numbers.e 

452 ) 

453 

454 numbers = RSAPrivateNumbers( 

455 d=d, 

456 p=p, 

457 q=q, 

458 dmp1=rsa_crt_dmp1(d, p), 

459 dmq1=rsa_crt_dmq1(d, q), 

460 iqmp=rsa_crt_iqmp(p, q), 

461 public_numbers=public_numbers, 

462 ) 

463 

464 return numbers.private_key() 

465 elif "n" in obj and "e" in obj: 

466 # Public key 

467 return RSAPublicNumbers( 

468 from_base64url_uint(obj["e"]), 

469 from_base64url_uint(obj["n"]), 

470 ).public_key() 

471 else: 

472 raise InvalidKeyError("Not a public or private key") 

473 

474 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

475 return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) 

476 

477 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

478 try: 

479 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) 

480 return True 

481 except InvalidSignature: 

482 return False 

483 

484 class ECAlgorithm(Algorithm): 

485 """ 

486 Performs signing and verification operations using 

487 ECDSA and the specified hash function 

488 """ 

489 

490 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

491 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

492 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

493 

494 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

495 self.hash_alg = hash_alg 

496 

497 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: 

498 if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): 

499 return key 

500 

501 if not isinstance(key, (bytes, str)): 

502 raise TypeError("Expecting a PEM-formatted key.") 

503 

504 key_bytes = force_bytes(key) 

505 

506 # Attempt to load key. We don't know if it's 

507 # a Signing Key or a Verifying Key, so we try 

508 # the Verifying Key first. 

509 try: 

510 if key_bytes.startswith(b"ecdsa-sha2-"): 

511 crypto_key = load_ssh_public_key(key_bytes) 

512 else: 

513 crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] 

514 except ValueError: 

515 crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] 

516 

517 # Explicit check the key to prevent confusing errors from cryptography 

518 if not isinstance( 

519 crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) 

520 ): 

521 raise InvalidKeyError( 

522 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" 

523 ) 

524 

525 return crypto_key 

526 

527 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: 

528 der_sig = key.sign(msg, ECDSA(self.hash_alg())) 

529 

530 return der_to_raw_signature(der_sig, key.curve) 

531 

532 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: 

533 try: 

534 der_sig = raw_to_der_signature(sig, key.curve) 

535 except ValueError: 

536 return False 

537 

538 try: 

539 public_key = ( 

540 key.public_key() 

541 if isinstance(key, EllipticCurvePrivateKey) 

542 else key 

543 ) 

544 public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) 

545 return True 

546 except InvalidSignature: 

547 return False 

548 

549 @overload 

550 @staticmethod 

551 def to_jwk( 

552 key_obj: AllowedECKeys, as_dict: Literal[True] 

553 ) -> JWKDict: ... # pragma: no cover 

554 

555 @overload 

556 @staticmethod 

557 def to_jwk( 

558 key_obj: AllowedECKeys, as_dict: Literal[False] = False 

559 ) -> str: ... # pragma: no cover 

560 

561 @staticmethod 

562 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: 

563 if isinstance(key_obj, EllipticCurvePrivateKey): 

564 public_numbers = key_obj.public_key().public_numbers() 

565 elif isinstance(key_obj, EllipticCurvePublicKey): 

566 public_numbers = key_obj.public_numbers() 

567 else: 

568 raise InvalidKeyError("Not a public or private key") 

569 

570 if isinstance(key_obj.curve, SECP256R1): 

571 crv = "P-256" 

572 elif isinstance(key_obj.curve, SECP384R1): 

573 crv = "P-384" 

574 elif isinstance(key_obj.curve, SECP521R1): 

575 crv = "P-521" 

576 elif isinstance(key_obj.curve, SECP256K1): 

577 crv = "secp256k1" 

578 else: 

579 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") 

580 

581 obj: dict[str, Any] = { 

582 "kty": "EC", 

583 "crv": crv, 

584 "x": to_base64url_uint(public_numbers.x).decode(), 

585 "y": to_base64url_uint(public_numbers.y).decode(), 

586 } 

587 

588 if isinstance(key_obj, EllipticCurvePrivateKey): 

589 obj["d"] = to_base64url_uint( 

590 key_obj.private_numbers().private_value 

591 ).decode() 

592 

593 if as_dict: 

594 return obj 

595 else: 

596 return json.dumps(obj) 

597 

598 @staticmethod 

599 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: 

600 try: 

601 if isinstance(jwk, str): 

602 obj = json.loads(jwk) 

603 elif isinstance(jwk, dict): 

604 obj = jwk 

605 else: 

606 raise ValueError 

607 except ValueError: 

608 raise InvalidKeyError("Key is not valid JSON") 

609 

610 if obj.get("kty") != "EC": 

611 raise InvalidKeyError("Not an Elliptic curve key") 

612 

613 if "x" not in obj or "y" not in obj: 

614 raise InvalidKeyError("Not an Elliptic curve key") 

615 

616 x = base64url_decode(obj.get("x")) 

617 y = base64url_decode(obj.get("y")) 

618 

619 curve = obj.get("crv") 

620 curve_obj: EllipticCurve 

621 

622 if curve == "P-256": 

623 if len(x) == len(y) == 32: 

624 curve_obj = SECP256R1() 

625 else: 

626 raise InvalidKeyError("Coords should be 32 bytes for curve P-256") 

627 elif curve == "P-384": 

628 if len(x) == len(y) == 48: 

629 curve_obj = SECP384R1() 

630 else: 

631 raise InvalidKeyError("Coords should be 48 bytes for curve P-384") 

632 elif curve == "P-521": 

633 if len(x) == len(y) == 66: 

634 curve_obj = SECP521R1() 

635 else: 

636 raise InvalidKeyError("Coords should be 66 bytes for curve P-521") 

637 elif curve == "secp256k1": 

638 if len(x) == len(y) == 32: 

639 curve_obj = SECP256K1() 

640 else: 

641 raise InvalidKeyError( 

642 "Coords should be 32 bytes for curve secp256k1" 

643 ) 

644 else: 

645 raise InvalidKeyError(f"Invalid curve: {curve}") 

646 

647 public_numbers = EllipticCurvePublicNumbers( 

648 x=int.from_bytes(x, byteorder="big"), 

649 y=int.from_bytes(y, byteorder="big"), 

650 curve=curve_obj, 

651 ) 

652 

653 if "d" not in obj: 

654 return public_numbers.public_key() 

655 

656 d = base64url_decode(obj.get("d")) 

657 if len(d) != len(x): 

658 raise InvalidKeyError( 

659 "D should be {} bytes for curve {}", len(x), curve 

660 ) 

661 

662 return EllipticCurvePrivateNumbers( 

663 int.from_bytes(d, byteorder="big"), public_numbers 

664 ).private_key() 

665 

666 class RSAPSSAlgorithm(RSAAlgorithm): 

667 """ 

668 Performs a signature using RSASSA-PSS with MGF1 

669 """ 

670 

671 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

672 return key.sign( 

673 msg, 

674 padding.PSS( 

675 mgf=padding.MGF1(self.hash_alg()), 

676 salt_length=self.hash_alg().digest_size, 

677 ), 

678 self.hash_alg(), 

679 ) 

680 

681 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

682 try: 

683 key.verify( 

684 sig, 

685 msg, 

686 padding.PSS( 

687 mgf=padding.MGF1(self.hash_alg()), 

688 salt_length=self.hash_alg().digest_size, 

689 ), 

690 self.hash_alg(), 

691 ) 

692 return True 

693 except InvalidSignature: 

694 return False 

695 

696 class OKPAlgorithm(Algorithm): 

697 """ 

698 Performs signing and verification operations using EdDSA 

699 

700 This class requires ``cryptography>=2.6`` to be installed. 

701 """ 

702 

703 def __init__(self, **kwargs: Any) -> None: 

704 pass 

705 

706 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: 

707 if isinstance(key, (bytes, str)): 

708 key_str = key.decode("utf-8") if isinstance(key, bytes) else key 

709 key_bytes = key.encode("utf-8") if isinstance(key, str) else key 

710 

711 if "-----BEGIN PUBLIC" in key_str: 

712 key = load_pem_public_key(key_bytes) # type: ignore[assignment] 

713 elif "-----BEGIN PRIVATE" in key_str: 

714 key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] 

715 elif key_str[0:4] == "ssh-": 

716 key = load_ssh_public_key(key_bytes) # type: ignore[assignment] 

717 

718 # Explicit check the key to prevent confusing errors from cryptography 

719 if not isinstance( 

720 key, 

721 (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey), 

722 ): 

723 raise InvalidKeyError( 

724 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms" 

725 ) 

726 

727 return key 

728 

729 def sign( 

730 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey 

731 ) -> bytes: 

732 """ 

733 Sign a message ``msg`` using the EdDSA private key ``key`` 

734 :param str|bytes msg: Message to sign 

735 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` 

736 or :class:`.Ed448PrivateKey` isinstance 

737 :return bytes signature: The signature, as bytes 

738 """ 

739 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

740 return key.sign(msg_bytes) 

741 

742 def verify( 

743 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes 

744 ) -> bool: 

745 """ 

746 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` 

747 

748 :param str|bytes sig: EdDSA signature to check ``msg`` against 

749 :param str|bytes msg: Message to sign 

750 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: 

751 A private or public EdDSA key instance 

752 :return bool verified: True if signature is valid, False if not. 

753 """ 

754 try: 

755 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

756 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig 

757 

758 public_key = ( 

759 key.public_key() 

760 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) 

761 else key 

762 ) 

763 public_key.verify(sig_bytes, msg_bytes) 

764 return True # If no exception was raised, the signature is valid. 

765 except InvalidSignature: 

766 return False 

767 

768 @overload 

769 @staticmethod 

770 def to_jwk( 

771 key: AllowedOKPKeys, as_dict: Literal[True] 

772 ) -> JWKDict: ... # pragma: no cover 

773 

774 @overload 

775 @staticmethod 

776 def to_jwk( 

777 key: AllowedOKPKeys, as_dict: Literal[False] = False 

778 ) -> str: ... # pragma: no cover 

779 

780 @staticmethod 

781 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: 

782 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): 

783 x = key.public_bytes( 

784 encoding=Encoding.Raw, 

785 format=PublicFormat.Raw, 

786 ) 

787 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" 

788 

789 obj = { 

790 "x": base64url_encode(force_bytes(x)).decode(), 

791 "kty": "OKP", 

792 "crv": crv, 

793 } 

794 

795 if as_dict: 

796 return obj 

797 else: 

798 return json.dumps(obj) 

799 

800 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): 

801 d = key.private_bytes( 

802 encoding=Encoding.Raw, 

803 format=PrivateFormat.Raw, 

804 encryption_algorithm=NoEncryption(), 

805 ) 

806 

807 x = key.public_key().public_bytes( 

808 encoding=Encoding.Raw, 

809 format=PublicFormat.Raw, 

810 ) 

811 

812 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" 

813 obj = { 

814 "x": base64url_encode(force_bytes(x)).decode(), 

815 "d": base64url_encode(force_bytes(d)).decode(), 

816 "kty": "OKP", 

817 "crv": crv, 

818 } 

819 

820 if as_dict: 

821 return obj 

822 else: 

823 return json.dumps(obj) 

824 

825 raise InvalidKeyError("Not a public or private key") 

826 

827 @staticmethod 

828 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: 

829 try: 

830 if isinstance(jwk, str): 

831 obj = json.loads(jwk) 

832 elif isinstance(jwk, dict): 

833 obj = jwk 

834 else: 

835 raise ValueError 

836 except ValueError: 

837 raise InvalidKeyError("Key is not valid JSON") 

838 

839 if obj.get("kty") != "OKP": 

840 raise InvalidKeyError("Not an Octet Key Pair") 

841 

842 curve = obj.get("crv") 

843 if curve != "Ed25519" and curve != "Ed448": 

844 raise InvalidKeyError(f"Invalid curve: {curve}") 

845 

846 if "x" not in obj: 

847 raise InvalidKeyError('OKP should have "x" parameter') 

848 x = base64url_decode(obj.get("x")) 

849 

850 try: 

851 if "d" not in obj: 

852 if curve == "Ed25519": 

853 return Ed25519PublicKey.from_public_bytes(x) 

854 return Ed448PublicKey.from_public_bytes(x) 

855 d = base64url_decode(obj.get("d")) 

856 if curve == "Ed25519": 

857 return Ed25519PrivateKey.from_private_bytes(d) 

858 return Ed448PrivateKey.from_private_bytes(d) 

859 except ValueError as err: 

860 raise InvalidKeyError("Invalid key parameter") from err