Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jwt/algorithms.py: 35%

396 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:05 +0000

1from __future__ import annotations 

2 

3import hashlib 

4import hmac 

5import json 

6import sys 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload 

9 

10from .exceptions import InvalidKeyError 

11from .types import HashlibHash, JWKDict 

12from .utils import ( 

13 base64url_decode, 

14 base64url_encode, 

15 der_to_raw_signature, 

16 force_bytes, 

17 from_base64url_uint, 

18 is_pem_format, 

19 is_ssh_key, 

20 raw_to_der_signature, 

21 to_base64url_uint, 

22) 

23 

24if sys.version_info >= (3, 8): 

25 from typing import Literal 

26else: 

27 from typing_extensions import Literal 

28 

29 

30try: 

31 from cryptography.exceptions import InvalidSignature 

32 from cryptography.hazmat.backends import default_backend 

33 from cryptography.hazmat.primitives import hashes 

34 from cryptography.hazmat.primitives.asymmetric import padding 

35 from cryptography.hazmat.primitives.asymmetric.ec import ( 

36 ECDSA, 

37 SECP256K1, 

38 SECP256R1, 

39 SECP384R1, 

40 SECP521R1, 

41 EllipticCurve, 

42 EllipticCurvePrivateKey, 

43 EllipticCurvePrivateNumbers, 

44 EllipticCurvePublicKey, 

45 EllipticCurvePublicNumbers, 

46 ) 

47 from cryptography.hazmat.primitives.asymmetric.ed448 import ( 

48 Ed448PrivateKey, 

49 Ed448PublicKey, 

50 ) 

51 from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

52 Ed25519PrivateKey, 

53 Ed25519PublicKey, 

54 ) 

55 from cryptography.hazmat.primitives.asymmetric.rsa import ( 

56 RSAPrivateKey, 

57 RSAPrivateNumbers, 

58 RSAPublicKey, 

59 RSAPublicNumbers, 

60 rsa_crt_dmp1, 

61 rsa_crt_dmq1, 

62 rsa_crt_iqmp, 

63 rsa_recover_prime_factors, 

64 ) 

65 from cryptography.hazmat.primitives.serialization import ( 

66 Encoding, 

67 NoEncryption, 

68 PrivateFormat, 

69 PublicFormat, 

70 load_pem_private_key, 

71 load_pem_public_key, 

72 load_ssh_public_key, 

73 ) 

74 

75 has_crypto = True 

76except ModuleNotFoundError: 

77 has_crypto = False 

78 

79 

80if TYPE_CHECKING: 

81 # Type aliases for convenience in algorithms method signatures 

82 AllowedRSAKeys = RSAPrivateKey | RSAPublicKey 

83 AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey 

84 AllowedOKPKeys = ( 

85 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey 

86 ) 

87 AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys 

88 AllowedPrivateKeys = ( 

89 RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey 

90 ) 

91 AllowedPublicKeys = ( 

92 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey 

93 ) 

94 

95 

96requires_cryptography = { 

97 "RS256", 

98 "RS384", 

99 "RS512", 

100 "ES256", 

101 "ES256K", 

102 "ES384", 

103 "ES521", 

104 "ES512", 

105 "PS256", 

106 "PS384", 

107 "PS512", 

108 "EdDSA", 

109} 

110 

111 

112def get_default_algorithms() -> dict[str, Algorithm]: 

113 """ 

114 Returns the algorithms that are implemented by the library. 

115 """ 

116 default_algorithms = { 

117 "none": NoneAlgorithm(), 

118 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), 

119 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), 

120 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), 

121 } 

122 

123 if has_crypto: 

124 default_algorithms.update( 

125 { 

126 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), 

127 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), 

128 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), 

129 "ES256": ECAlgorithm(ECAlgorithm.SHA256), 

130 "ES256K": ECAlgorithm(ECAlgorithm.SHA256), 

131 "ES384": ECAlgorithm(ECAlgorithm.SHA384), 

132 "ES521": ECAlgorithm(ECAlgorithm.SHA512), 

133 "ES512": ECAlgorithm( 

134 ECAlgorithm.SHA512 

135 ), # Backward compat for #219 fix 

136 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), 

137 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), 

138 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), 

139 "EdDSA": OKPAlgorithm(), 

140 } 

141 ) 

142 

143 return default_algorithms 

144 

145 

146class Algorithm(ABC): 

147 """ 

148 The interface for an algorithm used to sign and verify tokens. 

149 """ 

150 

151 def compute_hash_digest(self, bytestr: bytes) -> bytes: 

152 """ 

153 Compute a hash digest using the specified algorithm's hash algorithm. 

154 

155 If there is no hash algorithm, raises a NotImplementedError. 

156 """ 

157 # lookup self.hash_alg if defined in a way that mypy can understand 

158 hash_alg = getattr(self, "hash_alg", None) 

159 if hash_alg is None: 

160 raise NotImplementedError 

161 

162 if ( 

163 has_crypto 

164 and isinstance(hash_alg, type) 

165 and issubclass(hash_alg, hashes.HashAlgorithm) 

166 ): 

167 digest = hashes.Hash(hash_alg(), backend=default_backend()) 

168 digest.update(bytestr) 

169 return bytes(digest.finalize()) 

170 else: 

171 return bytes(hash_alg(bytestr).digest()) 

172 

173 @abstractmethod 

174 def prepare_key(self, key: Any) -> Any: 

175 """ 

176 Performs necessary validation and conversions on the key and returns 

177 the key value in the proper format for sign() and verify(). 

178 """ 

179 

180 @abstractmethod 

181 def sign(self, msg: bytes, key: Any) -> bytes: 

182 """ 

183 Returns a digital signature for the specified message 

184 using the specified key value. 

185 """ 

186 

187 @abstractmethod 

188 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: 

189 """ 

190 Verifies that the specified digital signature is valid 

191 for the specified message and key values. 

192 """ 

193 

194 @overload 

195 @staticmethod 

196 @abstractmethod 

197 def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: 

198 ... # pragma: no cover 

199 

200 @overload 

201 @staticmethod 

202 @abstractmethod 

203 def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: 

204 ... # pragma: no cover 

205 

206 @staticmethod 

207 @abstractmethod 

208 def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]: 

209 """ 

210 Serializes a given key into a JWK 

211 """ 

212 

213 @staticmethod 

214 @abstractmethod 

215 def from_jwk(jwk: str | JWKDict) -> Any: 

216 """ 

217 Deserializes a given key from JWK back into a key object 

218 """ 

219 

220 

221class NoneAlgorithm(Algorithm): 

222 """ 

223 Placeholder for use when no signing or verification 

224 operations are required. 

225 """ 

226 

227 def prepare_key(self, key: str | None) -> None: 

228 if key == "": 

229 key = None 

230 

231 if key is not None: 

232 raise InvalidKeyError('When alg = "none", key value must be None.') 

233 

234 return key 

235 

236 def sign(self, msg: bytes, key: None) -> bytes: 

237 return b"" 

238 

239 def verify(self, msg: bytes, key: None, sig: bytes) -> bool: 

240 return False 

241 

242 @staticmethod 

243 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: 

244 raise NotImplementedError() 

245 

246 @staticmethod 

247 def from_jwk(jwk: str | JWKDict) -> NoReturn: 

248 raise NotImplementedError() 

249 

250 

251class HMACAlgorithm(Algorithm): 

252 """ 

253 Performs signing and verification operations using HMAC 

254 and the specified hash function. 

255 """ 

256 

257 SHA256: ClassVar[HashlibHash] = hashlib.sha256 

258 SHA384: ClassVar[HashlibHash] = hashlib.sha384 

259 SHA512: ClassVar[HashlibHash] = hashlib.sha512 

260 

261 def __init__(self, hash_alg: HashlibHash) -> None: 

262 self.hash_alg = hash_alg 

263 

264 def prepare_key(self, key: str | bytes) -> bytes: 

265 key_bytes = force_bytes(key) 

266 

267 if is_pem_format(key_bytes) or is_ssh_key(key_bytes): 

268 raise InvalidKeyError( 

269 "The specified key is an asymmetric key or x509 certificate and" 

270 " should not be used as an HMAC secret." 

271 ) 

272 

273 return key_bytes 

274 

275 @overload 

276 @staticmethod 

277 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: 

278 ... # pragma: no cover 

279 

280 @overload 

281 @staticmethod 

282 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: 

283 ... # pragma: no cover 

284 

285 @staticmethod 

286 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]: 

287 jwk = { 

288 "k": base64url_encode(force_bytes(key_obj)).decode(), 

289 "kty": "oct", 

290 } 

291 

292 if as_dict: 

293 return jwk 

294 else: 

295 return json.dumps(jwk) 

296 

297 @staticmethod 

298 def from_jwk(jwk: str | JWKDict) -> bytes: 

299 try: 

300 if isinstance(jwk, str): 

301 obj: JWKDict = json.loads(jwk) 

302 elif isinstance(jwk, dict): 

303 obj = jwk 

304 else: 

305 raise ValueError 

306 except ValueError: 

307 raise InvalidKeyError("Key is not valid JSON") 

308 

309 if obj.get("kty") != "oct": 

310 raise InvalidKeyError("Not an HMAC key") 

311 

312 return base64url_decode(obj["k"]) 

313 

314 def sign(self, msg: bytes, key: bytes) -> bytes: 

315 return hmac.new(key, msg, self.hash_alg).digest() 

316 

317 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: 

318 return hmac.compare_digest(sig, self.sign(msg, key)) 

319 

320 

321if has_crypto: 

322 

323 class RSAAlgorithm(Algorithm): 

324 """ 

325 Performs signing and verification operations using 

326 RSASSA-PKCS-v1_5 and the specified hash function. 

327 """ 

328 

329 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

330 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

331 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

332 

333 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

334 self.hash_alg = hash_alg 

335 

336 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: 

337 if isinstance(key, (RSAPrivateKey, RSAPublicKey)): 

338 return key 

339 

340 if not isinstance(key, (bytes, str)): 

341 raise TypeError("Expecting a PEM-formatted key.") 

342 

343 key_bytes = force_bytes(key) 

344 

345 try: 

346 if key_bytes.startswith(b"ssh-rsa"): 

347 return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) 

348 else: 

349 return cast( 

350 RSAPrivateKey, load_pem_private_key(key_bytes, password=None) 

351 ) 

352 except ValueError: 

353 return cast(RSAPublicKey, load_pem_public_key(key_bytes)) 

354 

355 @overload 

356 @staticmethod 

357 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: 

358 ... # pragma: no cover 

359 

360 @overload 

361 @staticmethod 

362 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: 

363 ... # pragma: no cover 

364 

365 @staticmethod 

366 def to_jwk( 

367 key_obj: AllowedRSAKeys, as_dict: bool = False 

368 ) -> Union[JWKDict, str]: 

369 obj: dict[str, Any] | None = None 

370 

371 if hasattr(key_obj, "private_numbers"): 

372 # Private key 

373 numbers = key_obj.private_numbers() 

374 

375 obj = { 

376 "kty": "RSA", 

377 "key_ops": ["sign"], 

378 "n": to_base64url_uint(numbers.public_numbers.n).decode(), 

379 "e": to_base64url_uint(numbers.public_numbers.e).decode(), 

380 "d": to_base64url_uint(numbers.d).decode(), 

381 "p": to_base64url_uint(numbers.p).decode(), 

382 "q": to_base64url_uint(numbers.q).decode(), 

383 "dp": to_base64url_uint(numbers.dmp1).decode(), 

384 "dq": to_base64url_uint(numbers.dmq1).decode(), 

385 "qi": to_base64url_uint(numbers.iqmp).decode(), 

386 } 

387 

388 elif hasattr(key_obj, "verify"): 

389 # Public key 

390 numbers = key_obj.public_numbers() 

391 

392 obj = { 

393 "kty": "RSA", 

394 "key_ops": ["verify"], 

395 "n": to_base64url_uint(numbers.n).decode(), 

396 "e": to_base64url_uint(numbers.e).decode(), 

397 } 

398 else: 

399 raise InvalidKeyError("Not a public or private key") 

400 

401 if as_dict: 

402 return obj 

403 else: 

404 return json.dumps(obj) 

405 

406 @staticmethod 

407 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: 

408 try: 

409 if isinstance(jwk, str): 

410 obj = json.loads(jwk) 

411 elif isinstance(jwk, dict): 

412 obj = jwk 

413 else: 

414 raise ValueError 

415 except ValueError: 

416 raise InvalidKeyError("Key is not valid JSON") 

417 

418 if obj.get("kty") != "RSA": 

419 raise InvalidKeyError("Not an RSA key") 

420 

421 if "d" in obj and "e" in obj and "n" in obj: 

422 # Private key 

423 if "oth" in obj: 

424 raise InvalidKeyError( 

425 "Unsupported RSA private key: > 2 primes not supported" 

426 ) 

427 

428 other_props = ["p", "q", "dp", "dq", "qi"] 

429 props_found = [prop in obj for prop in other_props] 

430 any_props_found = any(props_found) 

431 

432 if any_props_found and not all(props_found): 

433 raise InvalidKeyError( 

434 "RSA key must include all parameters if any are present besides d" 

435 ) 

436 

437 public_numbers = RSAPublicNumbers( 

438 from_base64url_uint(obj["e"]), 

439 from_base64url_uint(obj["n"]), 

440 ) 

441 

442 if any_props_found: 

443 numbers = RSAPrivateNumbers( 

444 d=from_base64url_uint(obj["d"]), 

445 p=from_base64url_uint(obj["p"]), 

446 q=from_base64url_uint(obj["q"]), 

447 dmp1=from_base64url_uint(obj["dp"]), 

448 dmq1=from_base64url_uint(obj["dq"]), 

449 iqmp=from_base64url_uint(obj["qi"]), 

450 public_numbers=public_numbers, 

451 ) 

452 else: 

453 d = from_base64url_uint(obj["d"]) 

454 p, q = rsa_recover_prime_factors( 

455 public_numbers.n, d, public_numbers.e 

456 ) 

457 

458 numbers = RSAPrivateNumbers( 

459 d=d, 

460 p=p, 

461 q=q, 

462 dmp1=rsa_crt_dmp1(d, p), 

463 dmq1=rsa_crt_dmq1(d, q), 

464 iqmp=rsa_crt_iqmp(p, q), 

465 public_numbers=public_numbers, 

466 ) 

467 

468 return numbers.private_key() 

469 elif "n" in obj and "e" in obj: 

470 # Public key 

471 return RSAPublicNumbers( 

472 from_base64url_uint(obj["e"]), 

473 from_base64url_uint(obj["n"]), 

474 ).public_key() 

475 else: 

476 raise InvalidKeyError("Not a public or private key") 

477 

478 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

479 return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) 

480 

481 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

482 try: 

483 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) 

484 return True 

485 except InvalidSignature: 

486 return False 

487 

488 class ECAlgorithm(Algorithm): 

489 """ 

490 Performs signing and verification operations using 

491 ECDSA and the specified hash function 

492 """ 

493 

494 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

495 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

496 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

497 

498 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

499 self.hash_alg = hash_alg 

500 

501 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: 

502 if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): 

503 return key 

504 

505 if not isinstance(key, (bytes, str)): 

506 raise TypeError("Expecting a PEM-formatted key.") 

507 

508 key_bytes = force_bytes(key) 

509 

510 # Attempt to load key. We don't know if it's 

511 # a Signing Key or a Verifying Key, so we try 

512 # the Verifying Key first. 

513 try: 

514 if key_bytes.startswith(b"ecdsa-sha2-"): 

515 crypto_key = load_ssh_public_key(key_bytes) 

516 else: 

517 crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] 

518 except ValueError: 

519 crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] 

520 

521 # Explicit check the key to prevent confusing errors from cryptography 

522 if not isinstance( 

523 crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) 

524 ): 

525 raise InvalidKeyError( 

526 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" 

527 ) 

528 

529 return crypto_key 

530 

531 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: 

532 der_sig = key.sign(msg, ECDSA(self.hash_alg())) 

533 

534 return der_to_raw_signature(der_sig, key.curve) 

535 

536 def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: 

537 try: 

538 der_sig = raw_to_der_signature(sig, key.curve) 

539 except ValueError: 

540 return False 

541 

542 try: 

543 public_key = ( 

544 key.public_key() 

545 if isinstance(key, EllipticCurvePrivateKey) 

546 else key 

547 ) 

548 public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) 

549 return True 

550 except InvalidSignature: 

551 return False 

552 

553 @overload 

554 @staticmethod 

555 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: 

556 ... # pragma: no cover 

557 

558 @overload 

559 @staticmethod 

560 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: 

561 ... # pragma: no cover 

562 

563 @staticmethod 

564 def to_jwk( 

565 key_obj: AllowedECKeys, as_dict: bool = False 

566 ) -> Union[JWKDict, str]: 

567 if isinstance(key_obj, EllipticCurvePrivateKey): 

568 public_numbers = key_obj.public_key().public_numbers() 

569 elif isinstance(key_obj, EllipticCurvePublicKey): 

570 public_numbers = key_obj.public_numbers() 

571 else: 

572 raise InvalidKeyError("Not a public or private key") 

573 

574 if isinstance(key_obj.curve, SECP256R1): 

575 crv = "P-256" 

576 elif isinstance(key_obj.curve, SECP384R1): 

577 crv = "P-384" 

578 elif isinstance(key_obj.curve, SECP521R1): 

579 crv = "P-521" 

580 elif isinstance(key_obj.curve, SECP256K1): 

581 crv = "secp256k1" 

582 else: 

583 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") 

584 

585 obj: dict[str, Any] = { 

586 "kty": "EC", 

587 "crv": crv, 

588 "x": to_base64url_uint(public_numbers.x).decode(), 

589 "y": to_base64url_uint(public_numbers.y).decode(), 

590 } 

591 

592 if isinstance(key_obj, EllipticCurvePrivateKey): 

593 obj["d"] = to_base64url_uint( 

594 key_obj.private_numbers().private_value 

595 ).decode() 

596 

597 if as_dict: 

598 return obj 

599 else: 

600 return json.dumps(obj) 

601 

602 @staticmethod 

603 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: 

604 try: 

605 if isinstance(jwk, str): 

606 obj = json.loads(jwk) 

607 elif isinstance(jwk, dict): 

608 obj = jwk 

609 else: 

610 raise ValueError 

611 except ValueError: 

612 raise InvalidKeyError("Key is not valid JSON") 

613 

614 if obj.get("kty") != "EC": 

615 raise InvalidKeyError("Not an Elliptic curve key") 

616 

617 if "x" not in obj or "y" not in obj: 

618 raise InvalidKeyError("Not an Elliptic curve key") 

619 

620 x = base64url_decode(obj.get("x")) 

621 y = base64url_decode(obj.get("y")) 

622 

623 curve = obj.get("crv") 

624 curve_obj: EllipticCurve 

625 

626 if curve == "P-256": 

627 if len(x) == len(y) == 32: 

628 curve_obj = SECP256R1() 

629 else: 

630 raise InvalidKeyError("Coords should be 32 bytes for curve P-256") 

631 elif curve == "P-384": 

632 if len(x) == len(y) == 48: 

633 curve_obj = SECP384R1() 

634 else: 

635 raise InvalidKeyError("Coords should be 48 bytes for curve P-384") 

636 elif curve == "P-521": 

637 if len(x) == len(y) == 66: 

638 curve_obj = SECP521R1() 

639 else: 

640 raise InvalidKeyError("Coords should be 66 bytes for curve P-521") 

641 elif curve == "secp256k1": 

642 if len(x) == len(y) == 32: 

643 curve_obj = SECP256K1() 

644 else: 

645 raise InvalidKeyError( 

646 "Coords should be 32 bytes for curve secp256k1" 

647 ) 

648 else: 

649 raise InvalidKeyError(f"Invalid curve: {curve}") 

650 

651 public_numbers = EllipticCurvePublicNumbers( 

652 x=int.from_bytes(x, byteorder="big"), 

653 y=int.from_bytes(y, byteorder="big"), 

654 curve=curve_obj, 

655 ) 

656 

657 if "d" not in obj: 

658 return public_numbers.public_key() 

659 

660 d = base64url_decode(obj.get("d")) 

661 if len(d) != len(x): 

662 raise InvalidKeyError( 

663 "D should be {} bytes for curve {}", len(x), curve 

664 ) 

665 

666 return EllipticCurvePrivateNumbers( 

667 int.from_bytes(d, byteorder="big"), public_numbers 

668 ).private_key() 

669 

670 class RSAPSSAlgorithm(RSAAlgorithm): 

671 """ 

672 Performs a signature using RSASSA-PSS with MGF1 

673 """ 

674 

675 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

676 return key.sign( 

677 msg, 

678 padding.PSS( 

679 mgf=padding.MGF1(self.hash_alg()), 

680 salt_length=self.hash_alg().digest_size, 

681 ), 

682 self.hash_alg(), 

683 ) 

684 

685 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

686 try: 

687 key.verify( 

688 sig, 

689 msg, 

690 padding.PSS( 

691 mgf=padding.MGF1(self.hash_alg()), 

692 salt_length=self.hash_alg().digest_size, 

693 ), 

694 self.hash_alg(), 

695 ) 

696 return True 

697 except InvalidSignature: 

698 return False 

699 

700 class OKPAlgorithm(Algorithm): 

701 """ 

702 Performs signing and verification operations using EdDSA 

703 

704 This class requires ``cryptography>=2.6`` to be installed. 

705 """ 

706 

707 def __init__(self, **kwargs: Any) -> None: 

708 pass 

709 

710 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: 

711 if isinstance(key, (bytes, str)): 

712 key_str = key.decode("utf-8") if isinstance(key, bytes) else key 

713 key_bytes = key.encode("utf-8") if isinstance(key, str) else key 

714 

715 if "-----BEGIN PUBLIC" in key_str: 

716 key = load_pem_public_key(key_bytes) # type: ignore[assignment] 

717 elif "-----BEGIN PRIVATE" in key_str: 

718 key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] 

719 elif key_str[0:4] == "ssh-": 

720 key = load_ssh_public_key(key_bytes) # type: ignore[assignment] 

721 

722 # Explicit check the key to prevent confusing errors from cryptography 

723 if not isinstance( 

724 key, 

725 (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey), 

726 ): 

727 raise InvalidKeyError( 

728 "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms" 

729 ) 

730 

731 return key 

732 

733 def sign( 

734 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey 

735 ) -> bytes: 

736 """ 

737 Sign a message ``msg`` using the EdDSA private key ``key`` 

738 :param str|bytes msg: Message to sign 

739 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` 

740 or :class:`.Ed448PrivateKey` isinstance 

741 :return bytes signature: The signature, as bytes 

742 """ 

743 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

744 return key.sign(msg_bytes) 

745 

746 def verify( 

747 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes 

748 ) -> bool: 

749 """ 

750 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` 

751 

752 :param str|bytes sig: EdDSA signature to check ``msg`` against 

753 :param str|bytes msg: Message to sign 

754 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: 

755 A private or public EdDSA key instance 

756 :return bool verified: True if signature is valid, False if not. 

757 """ 

758 try: 

759 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

760 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig 

761 

762 public_key = ( 

763 key.public_key() 

764 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) 

765 else key 

766 ) 

767 public_key.verify(sig_bytes, msg_bytes) 

768 return True # If no exception was raised, the signature is valid. 

769 except InvalidSignature: 

770 return False 

771 

772 @overload 

773 @staticmethod 

774 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: 

775 ... # pragma: no cover 

776 

777 @overload 

778 @staticmethod 

779 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: 

780 ... # pragma: no cover 

781 

782 @staticmethod 

783 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]: 

784 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): 

785 x = key.public_bytes( 

786 encoding=Encoding.Raw, 

787 format=PublicFormat.Raw, 

788 ) 

789 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" 

790 

791 obj = { 

792 "x": base64url_encode(force_bytes(x)).decode(), 

793 "kty": "OKP", 

794 "crv": crv, 

795 } 

796 

797 if as_dict: 

798 return obj 

799 else: 

800 return json.dumps(obj) 

801 

802 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): 

803 d = key.private_bytes( 

804 encoding=Encoding.Raw, 

805 format=PrivateFormat.Raw, 

806 encryption_algorithm=NoEncryption(), 

807 ) 

808 

809 x = key.public_key().public_bytes( 

810 encoding=Encoding.Raw, 

811 format=PublicFormat.Raw, 

812 ) 

813 

814 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" 

815 obj = { 

816 "x": base64url_encode(force_bytes(x)).decode(), 

817 "d": base64url_encode(force_bytes(d)).decode(), 

818 "kty": "OKP", 

819 "crv": crv, 

820 } 

821 

822 if as_dict: 

823 return obj 

824 else: 

825 return json.dumps(obj) 

826 

827 raise InvalidKeyError("Not a public or private key") 

828 

829 @staticmethod 

830 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: 

831 try: 

832 if isinstance(jwk, str): 

833 obj = json.loads(jwk) 

834 elif isinstance(jwk, dict): 

835 obj = jwk 

836 else: 

837 raise ValueError 

838 except ValueError: 

839 raise InvalidKeyError("Key is not valid JSON") 

840 

841 if obj.get("kty") != "OKP": 

842 raise InvalidKeyError("Not an Octet Key Pair") 

843 

844 curve = obj.get("crv") 

845 if curve != "Ed25519" and curve != "Ed448": 

846 raise InvalidKeyError(f"Invalid curve: {curve}") 

847 

848 if "x" not in obj: 

849 raise InvalidKeyError('OKP should have "x" parameter') 

850 x = base64url_decode(obj.get("x")) 

851 

852 try: 

853 if "d" not in obj: 

854 if curve == "Ed25519": 

855 return Ed25519PublicKey.from_public_bytes(x) 

856 return Ed448PublicKey.from_public_bytes(x) 

857 d = base64url_decode(obj.get("d")) 

858 if curve == "Ed25519": 

859 return Ed25519PrivateKey.from_private_bytes(d) 

860 return Ed448PrivateKey.from_private_bytes(d) 

861 except ValueError as err: 

862 raise InvalidKeyError("Invalid key parameter") from err