Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

456 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import hmac 

5import json 

6import os 

7import sys 

8from abc import ABC, abstractmethod 

9from typing import ( 

10 TYPE_CHECKING, 

11 Any, 

12 ClassVar, 

13 Literal, 

14 NoReturn, 

15 Union, 

16 cast, 

17 get_args, 

18 overload, 

19) 

20 

21from .exceptions import InvalidKeyError 

22from .types import HashlibHash, JWKDict 

23from .utils import ( 

24 base64url_decode, 

25 base64url_encode, 

26 der_to_raw_signature, 

27 force_bytes, 

28 from_base64url_uint, 

29 is_pem_format, 

30 is_ssh_key, 

31 raw_to_der_signature, 

32 to_base64url_uint, 

33) 

34 

35try: 

36 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

37 from cryptography.hazmat.backends import default_backend 

38 from cryptography.hazmat.primitives import hashes 

39 from cryptography.hazmat.primitives.asymmetric import padding 

40 from cryptography.hazmat.primitives.asymmetric.ec import ( 

41 ECDSA, 

42 SECP256K1, 

43 SECP256R1, 

44 SECP384R1, 

45 SECP521R1, 

46 EllipticCurve, 

47 EllipticCurvePrivateKey, 

48 EllipticCurvePrivateNumbers, 

49 EllipticCurvePublicKey, 

50 EllipticCurvePublicNumbers, 

51 ) 

52 from cryptography.hazmat.primitives.asymmetric.ed448 import ( 

53 Ed448PrivateKey, 

54 Ed448PublicKey, 

55 ) 

56 from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

57 Ed25519PrivateKey, 

58 Ed25519PublicKey, 

59 ) 

60 from cryptography.hazmat.primitives.asymmetric.rsa import ( 

61 RSAPrivateKey, 

62 RSAPrivateNumbers, 

63 RSAPublicKey, 

64 RSAPublicNumbers, 

65 rsa_crt_dmp1, 

66 rsa_crt_dmq1, 

67 rsa_crt_iqmp, 

68 rsa_recover_prime_factors, 

69 ) 

70 from cryptography.hazmat.primitives.serialization import ( 

71 Encoding, 

72 NoEncryption, 

73 PrivateFormat, 

74 PublicFormat, 

75 load_pem_private_key, 

76 load_pem_public_key, 

77 load_ssh_public_key, 

78 ) 

79 

80 if sys.version_info >= (3, 10): 

81 from typing import TypeAlias 

82 else: 

83 # Python 3.9 and lower 

84 from typing_extensions import TypeAlias 

85 

86 # Type aliases for convenience in algorithms method signatures 

87 AllowedRSAKeys: TypeAlias = Union[RSAPrivateKey, RSAPublicKey] 

88 AllowedECKeys: TypeAlias = Union[EllipticCurvePrivateKey, EllipticCurvePublicKey] 

89 AllowedOKPKeys: TypeAlias = Union[ 

90 Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey 

91 ] 

92 AllowedKeys: TypeAlias = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys] 

93 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed) 

94 AllowedPrivateKeys: TypeAlias = Union[ 

95 RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey 

96 ] 

97 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed) 

98 AllowedPublicKeys: TypeAlias = Union[ 

99 RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey 

100 ] 

101 

102 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")): 

103 from cryptography.hazmat.primitives.asymmetric.types import ( 

104 PrivateKeyTypes, 

105 PublicKeyTypes, 

106 ) 

107 

108 has_crypto = True 

109except ModuleNotFoundError: 

110 if sys.version_info >= (3, 11): 

111 from typing import Never 

112 else: 

113 from typing_extensions import Never 

114 

115 AllowedRSAKeys = Never # type: ignore[misc] 

116 AllowedECKeys = Never # type: ignore[misc] 

117 AllowedOKPKeys = Never # type: ignore[misc] 

118 AllowedKeys = Never # type: ignore[misc] 

119 AllowedPrivateKeys = Never # type: ignore[misc] 

120 AllowedPublicKeys = Never # type: ignore[misc] 

121 has_crypto = False 

122 

123 

124requires_cryptography = { 

125 "RS256", 

126 "RS384", 

127 "RS512", 

128 "ES256", 

129 "ES256K", 

130 "ES384", 

131 "ES521", 

132 "ES512", 

133 "PS256", 

134 "PS384", 

135 "PS512", 

136 "EdDSA", 

137} 

138 

139 

140def get_default_algorithms() -> dict[str, Algorithm]: 

141 """ 

142 Returns the algorithms that are implemented by the library. 

143 """ 

144 default_algorithms: dict[str, Algorithm] = { 

145 "none": NoneAlgorithm(), 

146 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), 

147 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), 

148 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), 

149 } 

150 

151 if has_crypto: 

152 default_algorithms.update( 

153 { 

154 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), 

155 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), 

156 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), 

157 "ES256": ECAlgorithm(ECAlgorithm.SHA256, SECP256R1), 

158 "ES256K": ECAlgorithm(ECAlgorithm.SHA256, SECP256K1), 

159 "ES384": ECAlgorithm(ECAlgorithm.SHA384, SECP384R1), 

160 "ES521": ECAlgorithm(ECAlgorithm.SHA512, SECP521R1), 

161 "ES512": ECAlgorithm( 

162 ECAlgorithm.SHA512, SECP521R1 

163 ), # Backward compat for #219 fix 

164 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), 

165 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), 

166 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), 

167 "EdDSA": OKPAlgorithm(), 

168 } 

169 ) 

170 

171 return default_algorithms 

172 

173 

174class Algorithm(ABC): 

175 """ 

176 The interface for an algorithm used to sign and verify tokens. 

177 """ 

178 

179 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family 

180 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None 

181 

182 def compute_hash_digest(self, bytestr: bytes) -> bytes: 

183 """ 

184 Compute a hash digest using the specified algorithm's hash algorithm. 

185 

186 If there is no hash algorithm, raises a NotImplementedError. 

187 """ 

188 # lookup self.hash_alg if defined in a way that mypy can understand 

189 hash_alg = getattr(self, "hash_alg", None) 

190 if hash_alg is None: 

191 raise NotImplementedError 

192 

193 if ( 

194 has_crypto 

195 and isinstance(hash_alg, type) 

196 and issubclass(hash_alg, hashes.HashAlgorithm) 

197 ): 

198 digest = hashes.Hash(hash_alg(), backend=default_backend()) 

199 digest.update(bytestr) 

200 return bytes(digest.finalize()) 

201 else: 

202 return bytes(hash_alg(bytestr).digest()) 

203 

204 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes) -> None: 

205 """Check that the key belongs to the right cryptographic family. 

206 

207 Note that this method only works when ``cryptography`` is installed. 

208 

209 :param key: Potentially a cryptography key 

210 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>` 

211 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm 

212 :raises InvalidKeyError: if the key doesn't match the expected key classes 

213 """ 

214 if not has_crypto or self._crypto_key_types is None: 

215 raise ValueError( 

216 "This method requires the cryptography library, and should only be used by cryptography-based algorithms." 

217 ) 

218 

219 if not isinstance(key, self._crypto_key_types): 

220 valid_classes = (cls.__name__ for cls in self._crypto_key_types) 

221 actual_class = key.__class__.__name__ 

222 self_class = self.__class__.__name__ 

223 raise InvalidKeyError( 

224 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}" 

225 ) 

226 

227 @abstractmethod 

228 def prepare_key(self, key: Any) -> Any: 

229 """ 

230 Performs necessary validation and conversions on the key and returns 

231 the key value in the proper format for sign() and verify(). 

232 """ 

233 

234 @abstractmethod 

235 def sign(self, msg: bytes, key: Any) -> bytes: 

236 """ 

237 Returns a digital signature for the specified message 

238 using the specified key value. 

239 """ 

240 

241 @abstractmethod 

242 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: 

243 """ 

244 Verifies that the specified digital signature is valid 

245 for the specified message and key values. 

246 """ 

247 

248 @overload 

249 @staticmethod 

250 @abstractmethod 

251 def to_jwk(key_obj: Any, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover 

252 

253 @overload 

254 @staticmethod 

255 @abstractmethod 

256 def to_jwk( 

257 key_obj: Any, as_dict: Literal[False] = False 

258 ) -> str: ... # pragma: no cover 

259 

260 @staticmethod 

261 @abstractmethod 

262 def to_jwk(key_obj: Any, as_dict: bool = False) -> JWKDict | str: 

263 """ 

264 Serializes a given key into a JWK 

265 """ 

266 

267 @staticmethod 

268 @abstractmethod 

269 def from_jwk(jwk: str | JWKDict) -> Any: 

270 """ 

271 Deserializes a given key from JWK back into a key object 

272 """ 

273 

274 def check_key_length(self, key: Any) -> str | None: 

275 """ 

276 Return a warning message if the key is below the minimum 

277 recommended length for this algorithm, or None if adequate. 

278 """ 

279 return None 

280 

281 

282class NoneAlgorithm(Algorithm): 

283 """ 

284 Placeholder for use when no signing or verification 

285 operations are required. 

286 """ 

287 

288 def prepare_key(self, key: str | None) -> None: 

289 if key == "": 

290 key = None 

291 

292 if key is not None: 

293 raise InvalidKeyError('When alg = "none", key value must be None.') 

294 

295 return key 

296 

297 def sign(self, msg: bytes, key: None) -> bytes: 

298 return b"" 

299 

300 def verify(self, msg: bytes, key: None, sig: bytes) -> bool: 

301 return False 

302 

303 @staticmethod 

304 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: 

305 raise NotImplementedError() 

306 

307 @staticmethod 

308 def from_jwk(jwk: str | JWKDict) -> NoReturn: 

309 raise NotImplementedError() 

310 

311 

312class HMACAlgorithm(Algorithm): 

313 """ 

314 Performs signing and verification operations using HMAC 

315 and the specified hash function. 

316 """ 

317 

318 SHA256: ClassVar[HashlibHash] = hashlib.sha256 

319 SHA384: ClassVar[HashlibHash] = hashlib.sha384 

320 SHA512: ClassVar[HashlibHash] = hashlib.sha512 

321 

322 def __init__(self, hash_alg: HashlibHash) -> None: 

323 self.hash_alg = hash_alg 

324 

325 def prepare_key(self, key: str | bytes) -> bytes: 

326 key_bytes = force_bytes(key) 

327 

328 if is_pem_format(key_bytes) or is_ssh_key(key_bytes): 

329 raise InvalidKeyError( 

330 "The specified key is an asymmetric key or x509 certificate and" 

331 " should not be used as an HMAC secret." 

332 ) 

333 

334 return key_bytes 

335 

336 @overload 

337 @staticmethod 

338 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: ... 

339 

340 @overload 

341 @staticmethod 

342 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: ... 

343 

344 @staticmethod 

345 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: 

346 jwk = { 

347 "k": base64url_encode(force_bytes(key_obj)).decode(), 

348 "kty": "oct", 

349 } 

350 

351 if as_dict: 

352 return jwk 

353 else: 

354 return json.dumps(jwk) 

355 

356 @staticmethod 

357 def from_jwk(jwk: str | JWKDict) -> bytes: 

358 try: 

359 if isinstance(jwk, str): 

360 obj: JWKDict = json.loads(jwk) 

361 elif isinstance(jwk, dict): 

362 obj = jwk 

363 else: 

364 raise ValueError 

365 except ValueError: 

366 raise InvalidKeyError("Key is not valid JSON") from None 

367 

368 if obj.get("kty") != "oct": 

369 raise InvalidKeyError("Not an HMAC key") 

370 

371 return base64url_decode(obj["k"]) 

372 

373 def check_key_length(self, key: bytes) -> str | None: 

374 min_length = self.hash_alg().digest_size 

375 if len(key) < min_length: 

376 return ( 

377 f"The HMAC key is {len(key)} bytes long, which is below " 

378 f"the minimum recommended length of {min_length} bytes for " 

379 f"{self.hash_alg().name.upper()}. " 

380 f"See RFC 7518 Section 3.2." 

381 ) 

382 return None 

383 

384 def sign(self, msg: bytes, key: bytes) -> bytes: 

385 return hmac.new(key, msg, self.hash_alg).digest() 

386 

387 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: 

388 return hmac.compare_digest(sig, self.sign(msg, key)) 

389 

390 

391if has_crypto: 

392 

393 class RSAAlgorithm(Algorithm): 

394 """ 

395 Performs signing and verification operations using 

396 RSASSA-PKCS-v1_5 and the specified hash function. 

397 """ 

398 

399 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

400 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

401 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

402 

403 _crypto_key_types = cast( 

404 tuple[type[AllowedKeys], ...], 

405 get_args(Union[RSAPrivateKey, RSAPublicKey]), 

406 ) 

407 _MIN_KEY_SIZE: ClassVar[int] = 2048 

408 

409 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

410 self.hash_alg = hash_alg 

411 

412 def check_key_length(self, key: AllowedRSAKeys) -> str | None: 

413 if key.key_size < self._MIN_KEY_SIZE: 

414 return ( 

415 f"The RSA key is {key.key_size} bits long, which is below " 

416 f"the minimum recommended size of {self._MIN_KEY_SIZE} bits. " 

417 f"See NIST SP 800-131A." 

418 ) 

419 return None 

420 

421 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: 

422 if isinstance(key, self._crypto_key_types): 

423 return cast(AllowedRSAKeys, key) 

424 

425 if not isinstance(key, (bytes, str)): 

426 raise TypeError("Expecting a PEM-formatted key.") 

427 

428 key_bytes = force_bytes(key) 

429 

430 try: 

431 if key_bytes.startswith(b"ssh-rsa"): 

432 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

433 self.check_crypto_key_type(public_key) 

434 return cast(RSAPublicKey, public_key) 

435 else: 

436 private_key: PrivateKeyTypes = load_pem_private_key( 

437 key_bytes, password=None 

438 ) 

439 self.check_crypto_key_type(private_key) 

440 return cast(RSAPrivateKey, private_key) 

441 except ValueError: 

442 try: 

443 public_key = load_pem_public_key(key_bytes) 

444 self.check_crypto_key_type(public_key) 

445 return cast(RSAPublicKey, public_key) 

446 except (ValueError, UnsupportedAlgorithm): 

447 raise InvalidKeyError( 

448 "Could not parse the provided public key." 

449 ) from None 

450 

451 @overload 

452 @staticmethod 

453 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: ... 

454 

455 @overload 

456 @staticmethod 

457 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: ... 

458 

459 @staticmethod 

460 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: 

461 obj: dict[str, Any] | None = None 

462 

463 if hasattr(key_obj, "private_numbers"): 

464 # Private key 

465 numbers = key_obj.private_numbers() 

466 

467 obj = { 

468 "kty": "RSA", 

469 "key_ops": ["sign"], 

470 "n": to_base64url_uint(numbers.public_numbers.n).decode(), 

471 "e": to_base64url_uint(numbers.public_numbers.e).decode(), 

472 "d": to_base64url_uint(numbers.d).decode(), 

473 "p": to_base64url_uint(numbers.p).decode(), 

474 "q": to_base64url_uint(numbers.q).decode(), 

475 "dp": to_base64url_uint(numbers.dmp1).decode(), 

476 "dq": to_base64url_uint(numbers.dmq1).decode(), 

477 "qi": to_base64url_uint(numbers.iqmp).decode(), 

478 } 

479 

480 elif hasattr(key_obj, "verify"): 

481 # Public key 

482 numbers = key_obj.public_numbers() 

483 

484 obj = { 

485 "kty": "RSA", 

486 "key_ops": ["verify"], 

487 "n": to_base64url_uint(numbers.n).decode(), 

488 "e": to_base64url_uint(numbers.e).decode(), 

489 } 

490 else: 

491 raise InvalidKeyError("Not a public or private key") 

492 

493 if as_dict: 

494 return obj 

495 else: 

496 return json.dumps(obj) 

497 

498 @staticmethod 

499 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: 

500 try: 

501 if isinstance(jwk, str): 

502 obj = json.loads(jwk) 

503 elif isinstance(jwk, dict): 

504 obj = jwk 

505 else: 

506 raise ValueError 

507 except ValueError: 

508 raise InvalidKeyError("Key is not valid JSON") from None 

509 

510 if obj.get("kty") != "RSA": 

511 raise InvalidKeyError("Not an RSA key") from None 

512 

513 if "d" in obj and "e" in obj and "n" in obj: 

514 # Private key 

515 if "oth" in obj: 

516 raise InvalidKeyError( 

517 "Unsupported RSA private key: > 2 primes not supported" 

518 ) 

519 

520 other_props = ["p", "q", "dp", "dq", "qi"] 

521 props_found = [prop in obj for prop in other_props] 

522 any_props_found = any(props_found) 

523 

524 if any_props_found and not all(props_found): 

525 raise InvalidKeyError( 

526 "RSA key must include all parameters if any are present besides d" 

527 ) from None 

528 

529 public_numbers = RSAPublicNumbers( 

530 from_base64url_uint(obj["e"]), 

531 from_base64url_uint(obj["n"]), 

532 ) 

533 

534 if any_props_found: 

535 numbers = RSAPrivateNumbers( 

536 d=from_base64url_uint(obj["d"]), 

537 p=from_base64url_uint(obj["p"]), 

538 q=from_base64url_uint(obj["q"]), 

539 dmp1=from_base64url_uint(obj["dp"]), 

540 dmq1=from_base64url_uint(obj["dq"]), 

541 iqmp=from_base64url_uint(obj["qi"]), 

542 public_numbers=public_numbers, 

543 ) 

544 else: 

545 d = from_base64url_uint(obj["d"]) 

546 p, q = rsa_recover_prime_factors( 

547 public_numbers.n, d, public_numbers.e 

548 ) 

549 

550 numbers = RSAPrivateNumbers( 

551 d=d, 

552 p=p, 

553 q=q, 

554 dmp1=rsa_crt_dmp1(d, p), 

555 dmq1=rsa_crt_dmq1(d, q), 

556 iqmp=rsa_crt_iqmp(p, q), 

557 public_numbers=public_numbers, 

558 ) 

559 

560 return numbers.private_key() 

561 elif "n" in obj and "e" in obj: 

562 # Public key 

563 return RSAPublicNumbers( 

564 from_base64url_uint(obj["e"]), 

565 from_base64url_uint(obj["n"]), 

566 ).public_key() 

567 else: 

568 raise InvalidKeyError("Not a public or private key") 

569 

570 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

571 signature: bytes = key.sign(msg, padding.PKCS1v15(), self.hash_alg()) 

572 return signature 

573 

574 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

575 try: 

576 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) 

577 return True 

578 except InvalidSignature: 

579 return False 

580 

581 class ECAlgorithm(Algorithm): 

582 """ 

583 Performs signing and verification operations using 

584 ECDSA and the specified hash function 

585 """ 

586 

587 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

588 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

589 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

590 

591 _crypto_key_types = cast( 

592 tuple[type[AllowedKeys], ...], 

593 get_args(Union[EllipticCurvePrivateKey, EllipticCurvePublicKey]), 

594 ) 

595 

596 def __init__( 

597 self, 

598 hash_alg: type[hashes.HashAlgorithm], 

599 expected_curve: type[EllipticCurve] | None = None, 

600 ) -> None: 

601 self.hash_alg = hash_alg 

602 self.expected_curve = expected_curve 

603 

604 def _validate_curve(self, key: AllowedECKeys) -> None: 

605 """Validate that the key's curve matches the expected curve.""" 

606 if self.expected_curve is None: 

607 return 

608 

609 if not isinstance(key.curve, self.expected_curve): 

610 raise InvalidKeyError( 

611 f"The key's curve '{key.curve.name}' does not match the expected " 

612 f"curve '{self.expected_curve.name}' for this algorithm" 

613 ) 

614 

615 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: 

616 if isinstance(key, self._crypto_key_types): 

617 ec_key = cast(AllowedECKeys, key) 

618 self._validate_curve(ec_key) 

619 return ec_key 

620 

621 if not isinstance(key, (bytes, str)): 

622 raise TypeError("Expecting a PEM-formatted key.") 

623 

624 key_bytes = force_bytes(key) 

625 

626 # Attempt to load key. We don't know if it's 

627 # a Signing Key or a Verifying Key, so we try 

628 # the Verifying Key first. 

629 try: 

630 if key_bytes.startswith(b"ecdsa-sha2-"): 

631 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

632 else: 

633 public_key = load_pem_public_key(key_bytes) 

634 

635 # Explicit check the key to prevent confusing errors from cryptography 

636 self.check_crypto_key_type(public_key) 

637 ec_public_key = cast(EllipticCurvePublicKey, public_key) 

638 self._validate_curve(ec_public_key) 

639 return ec_public_key 

640 except ValueError: 

641 private_key = load_pem_private_key(key_bytes, password=None) 

642 self.check_crypto_key_type(private_key) 

643 ec_private_key = cast(EllipticCurvePrivateKey, private_key) 

644 self._validate_curve(ec_private_key) 

645 return ec_private_key 

646 

647 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: 

648 der_sig = key.sign(msg, ECDSA(self.hash_alg())) 

649 

650 return der_to_raw_signature(der_sig, key.curve) 

651 

652 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: 

653 try: 

654 der_sig = raw_to_der_signature(sig, key.curve) 

655 except ValueError: 

656 return False 

657 

658 try: 

659 public_key = ( 

660 key.public_key() 

661 if isinstance(key, EllipticCurvePrivateKey) 

662 else key 

663 ) 

664 public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) 

665 return True 

666 except InvalidSignature: 

667 return False 

668 

669 @overload 

670 @staticmethod 

671 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: ... 

672 

673 @overload 

674 @staticmethod 

675 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: ... 

676 

677 @staticmethod 

678 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: 

679 if isinstance(key_obj, EllipticCurvePrivateKey): 

680 public_numbers = key_obj.public_key().public_numbers() 

681 elif isinstance(key_obj, EllipticCurvePublicKey): 

682 public_numbers = key_obj.public_numbers() 

683 else: 

684 raise InvalidKeyError("Not a public or private key") 

685 

686 if isinstance(key_obj.curve, SECP256R1): 

687 crv = "P-256" 

688 elif isinstance(key_obj.curve, SECP384R1): 

689 crv = "P-384" 

690 elif isinstance(key_obj.curve, SECP521R1): 

691 crv = "P-521" 

692 elif isinstance(key_obj.curve, SECP256K1): 

693 crv = "secp256k1" 

694 else: 

695 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") 

696 

697 obj: dict[str, Any] = { 

698 "kty": "EC", 

699 "crv": crv, 

700 "x": to_base64url_uint( 

701 public_numbers.x, 

702 bit_length=key_obj.curve.key_size, 

703 ).decode(), 

704 "y": to_base64url_uint( 

705 public_numbers.y, 

706 bit_length=key_obj.curve.key_size, 

707 ).decode(), 

708 } 

709 

710 if isinstance(key_obj, EllipticCurvePrivateKey): 

711 obj["d"] = to_base64url_uint( 

712 key_obj.private_numbers().private_value, 

713 bit_length=key_obj.curve.key_size, 

714 ).decode() 

715 

716 if as_dict: 

717 return obj 

718 else: 

719 return json.dumps(obj) 

720 

721 @staticmethod 

722 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: 

723 try: 

724 if isinstance(jwk, str): 

725 obj = json.loads(jwk) 

726 elif isinstance(jwk, dict): 

727 obj = jwk 

728 else: 

729 raise ValueError 

730 except ValueError: 

731 raise InvalidKeyError("Key is not valid JSON") from None 

732 

733 if obj.get("kty") != "EC": 

734 raise InvalidKeyError("Not an Elliptic curve key") from None 

735 

736 if "x" not in obj or "y" not in obj: 

737 raise InvalidKeyError("Not an Elliptic curve key") from None 

738 

739 x = base64url_decode(obj.get("x")) 

740 y = base64url_decode(obj.get("y")) 

741 

742 curve = obj.get("crv") 

743 curve_obj: EllipticCurve 

744 

745 if curve == "P-256": 

746 if len(x) == len(y) == 32: 

747 curve_obj = SECP256R1() 

748 else: 

749 raise InvalidKeyError( 

750 "Coords should be 32 bytes for curve P-256" 

751 ) from None 

752 elif curve == "P-384": 

753 if len(x) == len(y) == 48: 

754 curve_obj = SECP384R1() 

755 else: 

756 raise InvalidKeyError( 

757 "Coords should be 48 bytes for curve P-384" 

758 ) from None 

759 elif curve == "P-521": 

760 if len(x) == len(y) == 66: 

761 curve_obj = SECP521R1() 

762 else: 

763 raise InvalidKeyError( 

764 "Coords should be 66 bytes for curve P-521" 

765 ) from None 

766 elif curve == "secp256k1": 

767 if len(x) == len(y) == 32: 

768 curve_obj = SECP256K1() 

769 else: 

770 raise InvalidKeyError( 

771 "Coords should be 32 bytes for curve secp256k1" 

772 ) 

773 else: 

774 raise InvalidKeyError(f"Invalid curve: {curve}") 

775 

776 public_numbers = EllipticCurvePublicNumbers( 

777 x=int.from_bytes(x, byteorder="big"), 

778 y=int.from_bytes(y, byteorder="big"), 

779 curve=curve_obj, 

780 ) 

781 

782 if "d" not in obj: 

783 return public_numbers.public_key() 

784 

785 d = base64url_decode(obj.get("d")) 

786 if len(d) != len(x): 

787 raise InvalidKeyError( 

788 "D should be {} bytes for curve {}", len(x), curve 

789 ) 

790 

791 return EllipticCurvePrivateNumbers( 

792 int.from_bytes(d, byteorder="big"), public_numbers 

793 ).private_key() 

794 

795 class RSAPSSAlgorithm(RSAAlgorithm): 

796 """ 

797 Performs a signature using RSASSA-PSS with MGF1 

798 """ 

799 

800 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

801 signature: bytes = key.sign( 

802 msg, 

803 padding.PSS( 

804 mgf=padding.MGF1(self.hash_alg()), 

805 salt_length=self.hash_alg().digest_size, 

806 ), 

807 self.hash_alg(), 

808 ) 

809 return signature 

810 

811 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

812 try: 

813 key.verify( 

814 sig, 

815 msg, 

816 padding.PSS( 

817 mgf=padding.MGF1(self.hash_alg()), 

818 salt_length=self.hash_alg().digest_size, 

819 ), 

820 self.hash_alg(), 

821 ) 

822 return True 

823 except InvalidSignature: 

824 return False 

825 

826 class OKPAlgorithm(Algorithm): 

827 """ 

828 Performs signing and verification operations using EdDSA 

829 

830 This class requires ``cryptography>=2.6`` to be installed. 

831 """ 

832 

833 _crypto_key_types = cast( 

834 tuple[type[AllowedKeys], ...], 

835 get_args( 

836 Union[ 

837 Ed25519PrivateKey, 

838 Ed25519PublicKey, 

839 Ed448PrivateKey, 

840 Ed448PublicKey, 

841 ] 

842 ), 

843 ) 

844 

845 def __init__(self, **kwargs: Any) -> None: 

846 pass 

847 

848 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: 

849 if not isinstance(key, (str, bytes)): 

850 self.check_crypto_key_type(key) 

851 return key 

852 

853 key_str = key.decode("utf-8") if isinstance(key, bytes) else key 

854 key_bytes = key.encode("utf-8") if isinstance(key, str) else key 

855 

856 loaded_key: PublicKeyTypes | PrivateKeyTypes 

857 if "-----BEGIN PUBLIC" in key_str: 

858 loaded_key = load_pem_public_key(key_bytes) 

859 elif "-----BEGIN PRIVATE" in key_str: 

860 loaded_key = load_pem_private_key(key_bytes, password=None) 

861 elif key_str[0:4] == "ssh-": 

862 loaded_key = load_ssh_public_key(key_bytes) 

863 else: 

864 raise InvalidKeyError("Not a public or private key") 

865 

866 # Explicit check the key to prevent confusing errors from cryptography 

867 self.check_crypto_key_type(loaded_key) 

868 return cast("AllowedOKPKeys", loaded_key) 

869 

870 def sign( 

871 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey 

872 ) -> bytes: 

873 """ 

874 Sign a message ``msg`` using the EdDSA private key ``key`` 

875 :param str|bytes msg: Message to sign 

876 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` 

877 or :class:`.Ed448PrivateKey` isinstance 

878 :return bytes signature: The signature, as bytes 

879 """ 

880 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

881 signature: bytes = key.sign(msg_bytes) 

882 return signature 

883 

884 def verify( 

885 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes 

886 ) -> bool: 

887 """ 

888 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` 

889 

890 :param str|bytes sig: EdDSA signature to check ``msg`` against 

891 :param str|bytes msg: Message to sign 

892 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: 

893 A private or public EdDSA key instance 

894 :return bool verified: True if signature is valid, False if not. 

895 """ 

896 try: 

897 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

898 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig 

899 

900 public_key = ( 

901 key.public_key() 

902 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) 

903 else key 

904 ) 

905 public_key.verify(sig_bytes, msg_bytes) 

906 return True # If no exception was raised, the signature is valid. 

907 except InvalidSignature: 

908 return False 

909 

910 @overload 

911 @staticmethod 

912 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: ... 

913 

914 @overload 

915 @staticmethod 

916 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: ... 

917 

918 @staticmethod 

919 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: 

920 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): 

921 x = key.public_bytes( 

922 encoding=Encoding.Raw, 

923 format=PublicFormat.Raw, 

924 ) 

925 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" 

926 

927 obj = { 

928 "x": base64url_encode(force_bytes(x)).decode(), 

929 "kty": "OKP", 

930 "crv": crv, 

931 } 

932 

933 if as_dict: 

934 return obj 

935 else: 

936 return json.dumps(obj) 

937 

938 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): 

939 d = key.private_bytes( 

940 encoding=Encoding.Raw, 

941 format=PrivateFormat.Raw, 

942 encryption_algorithm=NoEncryption(), 

943 ) 

944 

945 x = key.public_key().public_bytes( 

946 encoding=Encoding.Raw, 

947 format=PublicFormat.Raw, 

948 ) 

949 

950 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" 

951 obj = { 

952 "x": base64url_encode(force_bytes(x)).decode(), 

953 "d": base64url_encode(force_bytes(d)).decode(), 

954 "kty": "OKP", 

955 "crv": crv, 

956 } 

957 

958 if as_dict: 

959 return obj 

960 else: 

961 return json.dumps(obj) 

962 

963 raise InvalidKeyError("Not a public or private key") 

964 

965 @staticmethod 

966 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: 

967 try: 

968 if isinstance(jwk, str): 

969 obj = json.loads(jwk) 

970 elif isinstance(jwk, dict): 

971 obj = jwk 

972 else: 

973 raise ValueError 

974 except ValueError: 

975 raise InvalidKeyError("Key is not valid JSON") from None 

976 

977 if obj.get("kty") != "OKP": 

978 raise InvalidKeyError("Not an Octet Key Pair") 

979 

980 curve = obj.get("crv") 

981 if curve != "Ed25519" and curve != "Ed448": 

982 raise InvalidKeyError(f"Invalid curve: {curve}") 

983 

984 if "x" not in obj: 

985 raise InvalidKeyError('OKP should have "x" parameter') 

986 x = base64url_decode(obj.get("x")) 

987 

988 try: 

989 if "d" not in obj: 

990 if curve == "Ed25519": 

991 return Ed25519PublicKey.from_public_bytes(x) 

992 return Ed448PublicKey.from_public_bytes(x) 

993 d = base64url_decode(obj.get("d")) 

994 if curve == "Ed25519": 

995 return Ed25519PrivateKey.from_private_bytes(d) 

996 return Ed448PrivateKey.from_private_bytes(d) 

997 except ValueError as err: 

998 raise InvalidKeyError("Invalid key parameter") from err