Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

466 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import hmac 

5import json 

6import os 

7import sys 

8from abc import ABC, abstractmethod 

9from typing import ( 

10 TYPE_CHECKING, 

11 Any, 

12 ClassVar, 

13 Literal, 

14 NoReturn, 

15 Union, 

16 cast, 

17 get_args, 

18 overload, 

19) 

20 

21from .exceptions import InvalidKeyError 

22from .types import HashlibHash, JWKDict 

23from .utils import ( 

24 base64url_decode, 

25 base64url_encode, 

26 der_to_raw_signature, 

27 force_bytes, 

28 from_base64url_uint, 

29 is_pem_format, 

30 is_ssh_key, 

31 raw_to_der_signature, 

32 to_base64url_uint, 

33) 

34 

35try: 

36 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

37 from cryptography.hazmat.backends import default_backend 

38 from cryptography.hazmat.primitives import hashes 

39 from cryptography.hazmat.primitives.asymmetric import padding 

40 from cryptography.hazmat.primitives.asymmetric.ec import ( 

41 ECDSA, 

42 SECP256K1, 

43 SECP256R1, 

44 SECP384R1, 

45 SECP521R1, 

46 EllipticCurve, 

47 EllipticCurvePrivateKey, 

48 EllipticCurvePrivateNumbers, 

49 EllipticCurvePublicKey, 

50 EllipticCurvePublicNumbers, 

51 ) 

52 from cryptography.hazmat.primitives.asymmetric.ed448 import ( 

53 Ed448PrivateKey, 

54 Ed448PublicKey, 

55 ) 

56 from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

57 Ed25519PrivateKey, 

58 Ed25519PublicKey, 

59 ) 

60 from cryptography.hazmat.primitives.asymmetric.rsa import ( 

61 RSAPrivateKey, 

62 RSAPrivateNumbers, 

63 RSAPublicKey, 

64 RSAPublicNumbers, 

65 rsa_crt_dmp1, 

66 rsa_crt_dmq1, 

67 rsa_crt_iqmp, 

68 rsa_recover_prime_factors, 

69 ) 

70 from cryptography.hazmat.primitives.serialization import ( 

71 Encoding, 

72 NoEncryption, 

73 PrivateFormat, 

74 PublicFormat, 

75 load_pem_private_key, 

76 load_pem_public_key, 

77 load_ssh_public_key, 

78 ) 

79 

80 if sys.version_info >= (3, 10): 

81 from typing import TypeAlias 

82 else: 

83 # Python 3.9 and lower 

84 from typing_extensions import TypeAlias 

85 

86 # Type aliases for convenience in algorithms method signatures 

87 AllowedRSAKeys: TypeAlias = Union[RSAPrivateKey, RSAPublicKey] 

88 AllowedECKeys: TypeAlias = Union[EllipticCurvePrivateKey, EllipticCurvePublicKey] 

89 AllowedOKPKeys: TypeAlias = Union[ 

90 Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey 

91 ] 

92 AllowedKeys: TypeAlias = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys] 

93 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed) 

94 AllowedPrivateKeys: TypeAlias = Union[ 

95 RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey 

96 ] 

97 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed) 

98 AllowedPublicKeys: TypeAlias = Union[ 

99 RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey 

100 ] 

101 

102 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")): 

103 from cryptography.hazmat.primitives.asymmetric.types import ( 

104 PrivateKeyTypes, 

105 PublicKeyTypes, 

106 ) 

107 

108 has_crypto = True 

109except ModuleNotFoundError: 

110 if sys.version_info >= (3, 11): 

111 from typing import Never 

112 else: 

113 from typing_extensions import Never 

114 

115 AllowedRSAKeys = Never # type: ignore[misc] 

116 AllowedECKeys = Never # type: ignore[misc] 

117 AllowedOKPKeys = Never # type: ignore[misc] 

118 AllowedKeys = Never # type: ignore[misc] 

119 AllowedPrivateKeys = Never # type: ignore[misc] 

120 AllowedPublicKeys = Never # type: ignore[misc] 

121 has_crypto = False 

122 

123 

124requires_cryptography = { 

125 "RS256", 

126 "RS384", 

127 "RS512", 

128 "ES256", 

129 "ES256K", 

130 "ES384", 

131 "ES521", 

132 "ES512", 

133 "PS256", 

134 "PS384", 

135 "PS512", 

136 "EdDSA", 

137} 

138 

139 

140def get_default_algorithms() -> dict[str, Algorithm]: 

141 """ 

142 Returns the algorithms that are implemented by the library. 

143 """ 

144 default_algorithms: dict[str, Algorithm] = { 

145 "none": NoneAlgorithm(), 

146 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), 

147 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), 

148 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), 

149 } 

150 

151 if has_crypto: 

152 default_algorithms.update( 

153 { 

154 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), 

155 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), 

156 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), 

157 "ES256": ECAlgorithm(ECAlgorithm.SHA256, SECP256R1), 

158 "ES256K": ECAlgorithm(ECAlgorithm.SHA256, SECP256K1), 

159 "ES384": ECAlgorithm(ECAlgorithm.SHA384, SECP384R1), 

160 "ES521": ECAlgorithm(ECAlgorithm.SHA512, SECP521R1), 

161 "ES512": ECAlgorithm( 

162 ECAlgorithm.SHA512, SECP521R1 

163 ), # Backward compat for #219 fix 

164 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), 

165 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), 

166 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), 

167 "EdDSA": OKPAlgorithm(), 

168 } 

169 ) 

170 

171 return default_algorithms 

172 

173 

174class Algorithm(ABC): 

175 """ 

176 The interface for an algorithm used to sign and verify tokens. 

177 """ 

178 

179 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family 

180 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None 

181 

182 def compute_hash_digest(self, bytestr: bytes) -> bytes: 

183 """ 

184 Compute a hash digest using the specified algorithm's hash algorithm. 

185 

186 If there is no hash algorithm, raises a NotImplementedError. 

187 """ 

188 # lookup self.hash_alg if defined in a way that mypy can understand 

189 hash_alg = getattr(self, "hash_alg", None) 

190 if hash_alg is None: 

191 raise NotImplementedError 

192 

193 if ( 

194 has_crypto 

195 and isinstance(hash_alg, type) 

196 and issubclass(hash_alg, hashes.HashAlgorithm) 

197 ): 

198 digest = hashes.Hash(hash_alg(), backend=default_backend()) 

199 digest.update(bytestr) 

200 return bytes(digest.finalize()) 

201 else: 

202 return bytes(hash_alg(bytestr).digest()) 

203 

204 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes) -> None: 

205 """Check that the key belongs to the right cryptographic family. 

206 

207 Note that this method only works when ``cryptography`` is installed. 

208 

209 :param key: Potentially a cryptography key 

210 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>` 

211 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm 

212 :raises InvalidKeyError: if the key doesn't match the expected key classes 

213 """ 

214 if not has_crypto or self._crypto_key_types is None: 

215 raise ValueError( 

216 "This method requires the cryptography library, and should only be used by cryptography-based algorithms." 

217 ) 

218 

219 if not isinstance(key, self._crypto_key_types): 

220 valid_classes = (cls.__name__ for cls in self._crypto_key_types) 

221 actual_class = key.__class__.__name__ 

222 self_class = self.__class__.__name__ 

223 raise InvalidKeyError( 

224 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}" 

225 ) 

226 

227 @abstractmethod 

228 def prepare_key(self, key: Any) -> Any: 

229 """ 

230 Performs necessary validation and conversions on the key and returns 

231 the key value in the proper format for sign() and verify(). 

232 """ 

233 

234 @abstractmethod 

235 def sign(self, msg: bytes, key: Any) -> bytes: 

236 """ 

237 Returns a digital signature for the specified message 

238 using the specified key value. 

239 """ 

240 

241 @abstractmethod 

242 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: 

243 """ 

244 Verifies that the specified digital signature is valid 

245 for the specified message and key values. 

246 """ 

247 

248 @overload 

249 @staticmethod 

250 @abstractmethod 

251 def to_jwk(key_obj: Any, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover 

252 

253 @overload 

254 @staticmethod 

255 @abstractmethod 

256 def to_jwk( 

257 key_obj: Any, as_dict: Literal[False] = False 

258 ) -> str: ... # pragma: no cover 

259 

260 @staticmethod 

261 @abstractmethod 

262 def to_jwk(key_obj: Any, as_dict: bool = False) -> JWKDict | str: 

263 """ 

264 Serializes a given key into a JWK 

265 """ 

266 

267 @staticmethod 

268 @abstractmethod 

269 def from_jwk(jwk: str | JWKDict) -> Any: 

270 """ 

271 Deserializes a given key from JWK back into a key object 

272 """ 

273 

274 def check_key_length(self, key: Any) -> str | None: 

275 """ 

276 Return a warning message if the key is below the minimum 

277 recommended length for this algorithm, or None if adequate. 

278 """ 

279 return None 

280 

281 

282class NoneAlgorithm(Algorithm): 

283 """ 

284 Placeholder for use when no signing or verification 

285 operations are required. 

286 """ 

287 

288 def prepare_key(self, key: str | None) -> None: 

289 if key == "": 

290 key = None 

291 

292 if key is not None: 

293 raise InvalidKeyError('When alg = "none", key value must be None.') 

294 

295 return key 

296 

297 def sign(self, msg: bytes, key: None) -> bytes: 

298 return b"" 

299 

300 def verify(self, msg: bytes, key: None, sig: bytes) -> bool: 

301 return False 

302 

303 @staticmethod 

304 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: 

305 raise NotImplementedError() 

306 

307 @staticmethod 

308 def from_jwk(jwk: str | JWKDict) -> NoReturn: 

309 raise NotImplementedError() 

310 

311 

312class HMACAlgorithm(Algorithm): 

313 """ 

314 Performs signing and verification operations using HMAC 

315 and the specified hash function. 

316 """ 

317 

318 SHA256: ClassVar[HashlibHash] = hashlib.sha256 

319 SHA384: ClassVar[HashlibHash] = hashlib.sha384 

320 SHA512: ClassVar[HashlibHash] = hashlib.sha512 

321 

322 def __init__(self, hash_alg: HashlibHash) -> None: 

323 self.hash_alg = hash_alg 

324 

325 def prepare_key(self, key: str | bytes) -> bytes: 

326 key_bytes = force_bytes(key) 

327 

328 if len(key_bytes) == 0: 

329 raise InvalidKeyError("HMAC key must not be empty.") 

330 

331 if is_pem_format(key_bytes) or is_ssh_key(key_bytes): 

332 raise InvalidKeyError( 

333 "The specified key is an asymmetric key or x509 certificate and" 

334 " should not be used as an HMAC secret." 

335 ) 

336 

337 # Defense against algorithm-confusion attacks: an attacker with 

338 # control over the token header can force this code path by setting 

339 # alg=HS*, and HMACAlgorithm is the only algorithm that accepts 

340 # arbitrary bytes as a valid secret. Other algorithms reject 

341 # non-key-shaped input naturally. Even a symmetric (kty=oct) JWK 

342 # should be loaded via PyJWK / from_jwk rather than fed as raw JSON 

343 # bytes (whose contents are not the secret material). 

344 stripped = key_bytes.lstrip() 

345 if stripped.startswith(b"{"): 

346 try: 

347 jwk_obj = json.loads(key_bytes) 

348 except ValueError: 

349 jwk_obj = None 

350 if isinstance(jwk_obj, dict) and "kty" in jwk_obj: 

351 raise InvalidKeyError( 

352 "The specified key looks like a JWK and should not be " 

353 "used directly as an HMAC secret. Load it via " 

354 "PyJWK / HMACAlgorithm.from_jwk first." 

355 ) 

356 

357 return key_bytes 

358 

359 @overload 

360 @staticmethod 

361 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: ... 

362 

363 @overload 

364 @staticmethod 

365 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: ... 

366 

367 @staticmethod 

368 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: 

369 jwk = { 

370 "k": base64url_encode(force_bytes(key_obj)).decode(), 

371 "kty": "oct", 

372 } 

373 

374 if as_dict: 

375 return jwk 

376 else: 

377 return json.dumps(jwk) 

378 

379 @staticmethod 

380 def from_jwk(jwk: str | JWKDict) -> bytes: 

381 try: 

382 if isinstance(jwk, str): 

383 obj: JWKDict = json.loads(jwk) 

384 elif isinstance(jwk, dict): 

385 obj = jwk 

386 else: 

387 raise ValueError 

388 except ValueError: 

389 raise InvalidKeyError("Key is not valid JSON") from None 

390 

391 if obj.get("kty") != "oct": 

392 raise InvalidKeyError("Not an HMAC key") 

393 

394 return base64url_decode(obj["k"]) 

395 

396 def check_key_length(self, key: bytes) -> str | None: 

397 min_length = self.hash_alg().digest_size 

398 if len(key) < min_length: 

399 return ( 

400 f"The HMAC key is {len(key)} bytes long, which is below " 

401 f"the minimum recommended length of {min_length} bytes for " 

402 f"{self.hash_alg().name.upper()}. " 

403 f"See RFC 7518 Section 3.2." 

404 ) 

405 return None 

406 

407 def sign(self, msg: bytes, key: bytes) -> bytes: 

408 return hmac.new(key, msg, self.hash_alg).digest() 

409 

410 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: 

411 return hmac.compare_digest(sig, self.sign(msg, key)) 

412 

413 

414if has_crypto: 

415 

416 class RSAAlgorithm(Algorithm): 

417 """ 

418 Performs signing and verification operations using 

419 RSASSA-PKCS-v1_5 and the specified hash function. 

420 """ 

421 

422 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

423 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

424 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

425 

426 _crypto_key_types = cast( 

427 tuple[type[AllowedKeys], ...], 

428 get_args(Union[RSAPrivateKey, RSAPublicKey]), 

429 ) 

430 _MIN_KEY_SIZE: ClassVar[int] = 2048 

431 

432 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

433 self.hash_alg = hash_alg 

434 

435 def check_key_length(self, key: AllowedRSAKeys) -> str | None: 

436 if key.key_size < self._MIN_KEY_SIZE: 

437 return ( 

438 f"The RSA key is {key.key_size} bits long, which is below " 

439 f"the minimum recommended size of {self._MIN_KEY_SIZE} bits. " 

440 f"See NIST SP 800-131A." 

441 ) 

442 return None 

443 

444 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: 

445 if isinstance(key, self._crypto_key_types): 

446 # Cast is required for type narrowing on Python 3.9's mypy 

447 # but redundant on newer mypy versions; suppress both 

448 # diagnostics so the line works across all supported envs. 

449 return cast(AllowedRSAKeys, key) # type: ignore[redundant-cast,unused-ignore] 

450 

451 if not isinstance(key, (bytes, str)): 

452 raise TypeError("Expecting a PEM-formatted key.") 

453 

454 key_bytes = force_bytes(key) 

455 

456 try: 

457 if key_bytes.startswith(b"ssh-rsa"): 

458 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

459 self.check_crypto_key_type(public_key) 

460 return cast(RSAPublicKey, public_key) 

461 else: 

462 private_key: PrivateKeyTypes = load_pem_private_key( 

463 key_bytes, password=None 

464 ) 

465 self.check_crypto_key_type(private_key) 

466 return cast(RSAPrivateKey, private_key) 

467 except ValueError: 

468 try: 

469 public_key = load_pem_public_key(key_bytes) 

470 self.check_crypto_key_type(public_key) 

471 return cast(RSAPublicKey, public_key) 

472 except (ValueError, UnsupportedAlgorithm): 

473 raise InvalidKeyError( 

474 "Could not parse the provided public key." 

475 ) from None 

476 

477 @overload 

478 @staticmethod 

479 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: ... 

480 

481 @overload 

482 @staticmethod 

483 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: ... 

484 

485 @staticmethod 

486 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: 

487 obj: dict[str, Any] | None = None 

488 

489 if hasattr(key_obj, "private_numbers"): 

490 # Private key 

491 numbers = key_obj.private_numbers() 

492 

493 obj = { 

494 "kty": "RSA", 

495 "key_ops": ["sign"], 

496 "n": to_base64url_uint(numbers.public_numbers.n).decode(), 

497 "e": to_base64url_uint(numbers.public_numbers.e).decode(), 

498 "d": to_base64url_uint(numbers.d).decode(), 

499 "p": to_base64url_uint(numbers.p).decode(), 

500 "q": to_base64url_uint(numbers.q).decode(), 

501 "dp": to_base64url_uint(numbers.dmp1).decode(), 

502 "dq": to_base64url_uint(numbers.dmq1).decode(), 

503 "qi": to_base64url_uint(numbers.iqmp).decode(), 

504 } 

505 

506 elif hasattr(key_obj, "verify"): 

507 # Public key 

508 numbers = key_obj.public_numbers() 

509 

510 obj = { 

511 "kty": "RSA", 

512 "key_ops": ["verify"], 

513 "n": to_base64url_uint(numbers.n).decode(), 

514 "e": to_base64url_uint(numbers.e).decode(), 

515 } 

516 else: 

517 raise InvalidKeyError("Not a public or private key") 

518 

519 if as_dict: 

520 return obj 

521 else: 

522 return json.dumps(obj) 

523 

524 @staticmethod 

525 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: 

526 try: 

527 if isinstance(jwk, str): 

528 obj = json.loads(jwk) 

529 elif isinstance(jwk, dict): 

530 obj = jwk 

531 else: 

532 raise ValueError 

533 except ValueError: 

534 raise InvalidKeyError("Key is not valid JSON") from None 

535 

536 if obj.get("kty") != "RSA": 

537 raise InvalidKeyError("Not an RSA key") from None 

538 

539 if "d" in obj and "e" in obj and "n" in obj: 

540 # Private key 

541 if "oth" in obj: 

542 raise InvalidKeyError( 

543 "Unsupported RSA private key: > 2 primes not supported" 

544 ) 

545 

546 other_props = ["p", "q", "dp", "dq", "qi"] 

547 props_found = [prop in obj for prop in other_props] 

548 any_props_found = any(props_found) 

549 

550 if any_props_found and not all(props_found): 

551 raise InvalidKeyError( 

552 "RSA key must include all parameters if any are present besides d" 

553 ) from None 

554 

555 public_numbers = RSAPublicNumbers( 

556 from_base64url_uint(obj["e"]), 

557 from_base64url_uint(obj["n"]), 

558 ) 

559 

560 if any_props_found: 

561 numbers = RSAPrivateNumbers( 

562 d=from_base64url_uint(obj["d"]), 

563 p=from_base64url_uint(obj["p"]), 

564 q=from_base64url_uint(obj["q"]), 

565 dmp1=from_base64url_uint(obj["dp"]), 

566 dmq1=from_base64url_uint(obj["dq"]), 

567 iqmp=from_base64url_uint(obj["qi"]), 

568 public_numbers=public_numbers, 

569 ) 

570 else: 

571 d = from_base64url_uint(obj["d"]) 

572 p, q = rsa_recover_prime_factors( 

573 public_numbers.n, d, public_numbers.e 

574 ) 

575 

576 numbers = RSAPrivateNumbers( 

577 d=d, 

578 p=p, 

579 q=q, 

580 dmp1=rsa_crt_dmp1(d, p), 

581 dmq1=rsa_crt_dmq1(d, q), 

582 iqmp=rsa_crt_iqmp(p, q), 

583 public_numbers=public_numbers, 

584 ) 

585 

586 return numbers.private_key() 

587 elif "n" in obj and "e" in obj: 

588 # Public key 

589 return RSAPublicNumbers( 

590 from_base64url_uint(obj["e"]), 

591 from_base64url_uint(obj["n"]), 

592 ).public_key() 

593 else: 

594 raise InvalidKeyError("Not a public or private key") 

595 

596 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

597 signature: bytes = key.sign(msg, padding.PKCS1v15(), self.hash_alg()) 

598 return signature 

599 

600 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

601 try: 

602 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) 

603 return True 

604 except InvalidSignature: 

605 return False 

606 

607 class ECAlgorithm(Algorithm): 

608 """ 

609 Performs signing and verification operations using 

610 ECDSA and the specified hash function 

611 """ 

612 

613 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

614 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

615 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

616 

617 _crypto_key_types = cast( 

618 tuple[type[AllowedKeys], ...], 

619 get_args(Union[EllipticCurvePrivateKey, EllipticCurvePublicKey]), 

620 ) 

621 

622 def __init__( 

623 self, 

624 hash_alg: type[hashes.HashAlgorithm], 

625 expected_curve: type[EllipticCurve] | None = None, 

626 ) -> None: 

627 self.hash_alg = hash_alg 

628 self.expected_curve = expected_curve 

629 

630 def _validate_curve(self, key: AllowedECKeys) -> None: 

631 """Validate that the key's curve matches the expected curve.""" 

632 if self.expected_curve is None: 

633 return 

634 

635 if not isinstance(key.curve, self.expected_curve): 

636 raise InvalidKeyError( 

637 f"The key's curve '{key.curve.name}' does not match the expected " 

638 f"curve '{self.expected_curve.name}' for this algorithm" 

639 ) 

640 

641 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: 

642 if isinstance(key, self._crypto_key_types): 

643 # See note in RSAAlgorithm.prepare_key. 

644 ec_key = cast(AllowedECKeys, key) # type: ignore[redundant-cast,unused-ignore] 

645 self._validate_curve(ec_key) 

646 return ec_key 

647 

648 if not isinstance(key, (bytes, str)): 

649 raise TypeError("Expecting a PEM-formatted key.") 

650 

651 key_bytes = force_bytes(key) 

652 

653 # Attempt to load key. We don't know if it's 

654 # a Signing Key or a Verifying Key, so we try 

655 # the Verifying Key first. 

656 try: 

657 if key_bytes.startswith(b"ecdsa-sha2-"): 

658 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

659 else: 

660 public_key = load_pem_public_key(key_bytes) 

661 

662 # Explicit check the key to prevent confusing errors from cryptography 

663 self.check_crypto_key_type(public_key) 

664 ec_public_key = cast(EllipticCurvePublicKey, public_key) 

665 self._validate_curve(ec_public_key) 

666 return ec_public_key 

667 except ValueError: 

668 private_key = load_pem_private_key(key_bytes, password=None) 

669 self.check_crypto_key_type(private_key) 

670 ec_private_key = cast(EllipticCurvePrivateKey, private_key) 

671 self._validate_curve(ec_private_key) 

672 return ec_private_key 

673 

674 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: 

675 der_sig = key.sign(msg, ECDSA(self.hash_alg())) 

676 

677 return der_to_raw_signature(der_sig, key.curve) 

678 

679 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: 

680 try: 

681 der_sig = raw_to_der_signature(sig, key.curve) 

682 except ValueError: 

683 return False 

684 

685 try: 

686 public_key = ( 

687 key.public_key() 

688 if isinstance(key, EllipticCurvePrivateKey) 

689 else key 

690 ) 

691 public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) 

692 return True 

693 except InvalidSignature: 

694 return False 

695 

696 @overload 

697 @staticmethod 

698 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: ... 

699 

700 @overload 

701 @staticmethod 

702 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: ... 

703 

704 @staticmethod 

705 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: 

706 if isinstance(key_obj, EllipticCurvePrivateKey): 

707 public_numbers = key_obj.public_key().public_numbers() 

708 elif isinstance(key_obj, EllipticCurvePublicKey): 

709 public_numbers = key_obj.public_numbers() 

710 else: 

711 raise InvalidKeyError("Not a public or private key") 

712 

713 if isinstance(key_obj.curve, SECP256R1): 

714 crv = "P-256" 

715 elif isinstance(key_obj.curve, SECP384R1): 

716 crv = "P-384" 

717 elif isinstance(key_obj.curve, SECP521R1): 

718 crv = "P-521" 

719 elif isinstance(key_obj.curve, SECP256K1): 

720 crv = "secp256k1" 

721 else: 

722 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") 

723 

724 obj: dict[str, Any] = { 

725 "kty": "EC", 

726 "crv": crv, 

727 "x": to_base64url_uint( 

728 public_numbers.x, 

729 bit_length=key_obj.curve.key_size, 

730 ).decode(), 

731 "y": to_base64url_uint( 

732 public_numbers.y, 

733 bit_length=key_obj.curve.key_size, 

734 ).decode(), 

735 } 

736 

737 if isinstance(key_obj, EllipticCurvePrivateKey): 

738 obj["d"] = to_base64url_uint( 

739 key_obj.private_numbers().private_value, 

740 bit_length=key_obj.curve.key_size, 

741 ).decode() 

742 

743 if as_dict: 

744 return obj 

745 else: 

746 return json.dumps(obj) 

747 

748 @staticmethod 

749 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: 

750 try: 

751 if isinstance(jwk, str): 

752 obj = json.loads(jwk) 

753 elif isinstance(jwk, dict): 

754 obj = jwk 

755 else: 

756 raise ValueError 

757 except ValueError: 

758 raise InvalidKeyError("Key is not valid JSON") from None 

759 

760 if obj.get("kty") != "EC": 

761 raise InvalidKeyError("Not an Elliptic curve key") from None 

762 

763 if "x" not in obj or "y" not in obj: 

764 raise InvalidKeyError("Not an Elliptic curve key") from None 

765 

766 x = base64url_decode(obj.get("x")) 

767 y = base64url_decode(obj.get("y")) 

768 

769 curve = obj.get("crv") 

770 curve_obj: EllipticCurve 

771 

772 if curve == "P-256": 

773 if len(x) == len(y) == 32: 

774 curve_obj = SECP256R1() 

775 else: 

776 raise InvalidKeyError( 

777 "Coords should be 32 bytes for curve P-256" 

778 ) from None 

779 elif curve == "P-384": 

780 if len(x) == len(y) == 48: 

781 curve_obj = SECP384R1() 

782 else: 

783 raise InvalidKeyError( 

784 "Coords should be 48 bytes for curve P-384" 

785 ) from None 

786 elif curve == "P-521": 

787 if len(x) == len(y) == 66: 

788 curve_obj = SECP521R1() 

789 else: 

790 raise InvalidKeyError( 

791 "Coords should be 66 bytes for curve P-521" 

792 ) from None 

793 elif curve == "secp256k1": 

794 if len(x) == len(y) == 32: 

795 curve_obj = SECP256K1() 

796 else: 

797 raise InvalidKeyError( 

798 "Coords should be 32 bytes for curve secp256k1" 

799 ) 

800 else: 

801 raise InvalidKeyError(f"Invalid curve: {curve}") 

802 

803 public_numbers = EllipticCurvePublicNumbers( 

804 x=int.from_bytes(x, byteorder="big"), 

805 y=int.from_bytes(y, byteorder="big"), 

806 curve=curve_obj, 

807 ) 

808 

809 if "d" not in obj: 

810 return public_numbers.public_key() 

811 

812 d = base64url_decode(obj.get("d")) 

813 if len(d) != len(x): 

814 raise InvalidKeyError( 

815 "D should be {} bytes for curve {}", len(x), curve 

816 ) 

817 

818 return EllipticCurvePrivateNumbers( 

819 int.from_bytes(d, byteorder="big"), public_numbers 

820 ).private_key() 

821 

822 class RSAPSSAlgorithm(RSAAlgorithm): 

823 """ 

824 Performs a signature using RSASSA-PSS with MGF1 

825 """ 

826 

827 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

828 signature: bytes = key.sign( 

829 msg, 

830 padding.PSS( 

831 mgf=padding.MGF1(self.hash_alg()), 

832 salt_length=self.hash_alg().digest_size, 

833 ), 

834 self.hash_alg(), 

835 ) 

836 return signature 

837 

838 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

839 try: 

840 key.verify( 

841 sig, 

842 msg, 

843 padding.PSS( 

844 mgf=padding.MGF1(self.hash_alg()), 

845 salt_length=self.hash_alg().digest_size, 

846 ), 

847 self.hash_alg(), 

848 ) 

849 return True 

850 except InvalidSignature: 

851 return False 

852 

853 class OKPAlgorithm(Algorithm): 

854 """ 

855 Performs signing and verification operations using EdDSA 

856 

857 This class requires ``cryptography>=2.6`` to be installed. 

858 """ 

859 

860 _crypto_key_types = cast( 

861 tuple[type[AllowedKeys], ...], 

862 get_args( 

863 Union[ 

864 Ed25519PrivateKey, 

865 Ed25519PublicKey, 

866 Ed448PrivateKey, 

867 Ed448PublicKey, 

868 ] 

869 ), 

870 ) 

871 

872 def __init__(self, **kwargs: Any) -> None: 

873 pass 

874 

875 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: 

876 if not isinstance(key, (str, bytes)): 

877 self.check_crypto_key_type(key) 

878 return key 

879 

880 key_str = key.decode("utf-8") if isinstance(key, bytes) else key 

881 key_bytes = key.encode("utf-8") if isinstance(key, str) else key 

882 

883 loaded_key: PublicKeyTypes | PrivateKeyTypes 

884 if "-----BEGIN PUBLIC" in key_str: 

885 loaded_key = load_pem_public_key(key_bytes) 

886 elif "-----BEGIN PRIVATE" in key_str: 

887 loaded_key = load_pem_private_key(key_bytes, password=None) 

888 elif key_str[0:4] == "ssh-": 

889 loaded_key = load_ssh_public_key(key_bytes) 

890 else: 

891 raise InvalidKeyError("Not a public or private key") 

892 

893 # Explicit check the key to prevent confusing errors from cryptography 

894 self.check_crypto_key_type(loaded_key) 

895 return cast("AllowedOKPKeys", loaded_key) 

896 

897 def sign( 

898 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey 

899 ) -> bytes: 

900 """ 

901 Sign a message ``msg`` using the EdDSA private key ``key`` 

902 :param str|bytes msg: Message to sign 

903 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` 

904 or :class:`.Ed448PrivateKey` isinstance 

905 :return bytes signature: The signature, as bytes 

906 """ 

907 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

908 signature: bytes = key.sign(msg_bytes) 

909 return signature 

910 

911 def verify( 

912 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes 

913 ) -> bool: 

914 """ 

915 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` 

916 

917 :param str|bytes sig: EdDSA signature to check ``msg`` against 

918 :param str|bytes msg: Message to sign 

919 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: 

920 A private or public EdDSA key instance 

921 :return bool verified: True if signature is valid, False if not. 

922 """ 

923 try: 

924 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

925 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig 

926 

927 public_key = ( 

928 key.public_key() 

929 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) 

930 else key 

931 ) 

932 public_key.verify(sig_bytes, msg_bytes) 

933 return True # If no exception was raised, the signature is valid. 

934 except InvalidSignature: 

935 return False 

936 

937 @overload 

938 @staticmethod 

939 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: ... 

940 

941 @overload 

942 @staticmethod 

943 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: ... 

944 

945 @staticmethod 

946 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: 

947 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): 

948 x = key.public_bytes( 

949 encoding=Encoding.Raw, 

950 format=PublicFormat.Raw, 

951 ) 

952 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" 

953 

954 obj = { 

955 "x": base64url_encode(force_bytes(x)).decode(), 

956 "kty": "OKP", 

957 "crv": crv, 

958 } 

959 

960 if as_dict: 

961 return obj 

962 else: 

963 return json.dumps(obj) 

964 

965 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): 

966 d = key.private_bytes( 

967 encoding=Encoding.Raw, 

968 format=PrivateFormat.Raw, 

969 encryption_algorithm=NoEncryption(), 

970 ) 

971 

972 x = key.public_key().public_bytes( 

973 encoding=Encoding.Raw, 

974 format=PublicFormat.Raw, 

975 ) 

976 

977 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" 

978 obj = { 

979 "x": base64url_encode(force_bytes(x)).decode(), 

980 "d": base64url_encode(force_bytes(d)).decode(), 

981 "kty": "OKP", 

982 "crv": crv, 

983 } 

984 

985 if as_dict: 

986 return obj 

987 else: 

988 return json.dumps(obj) 

989 

990 raise InvalidKeyError("Not a public or private key") 

991 

992 @staticmethod 

993 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: 

994 try: 

995 if isinstance(jwk, str): 

996 obj = json.loads(jwk) 

997 elif isinstance(jwk, dict): 

998 obj = jwk 

999 else: 

1000 raise ValueError 

1001 except ValueError: 

1002 raise InvalidKeyError("Key is not valid JSON") from None 

1003 

1004 if obj.get("kty") != "OKP": 

1005 raise InvalidKeyError("Not an Octet Key Pair") 

1006 

1007 curve = obj.get("crv") 

1008 if curve != "Ed25519" and curve != "Ed448": 

1009 raise InvalidKeyError(f"Invalid curve: {curve}") 

1010 

1011 if "x" not in obj: 

1012 raise InvalidKeyError('OKP should have "x" parameter') 

1013 x = base64url_decode(obj.get("x")) 

1014 

1015 try: 

1016 if "d" not in obj: 

1017 if curve == "Ed25519": 

1018 return Ed25519PublicKey.from_public_bytes(x) 

1019 return Ed448PublicKey.from_public_bytes(x) 

1020 d = base64url_decode(obj.get("d")) 

1021 if curve == "Ed25519": 

1022 return Ed25519PrivateKey.from_private_bytes(d) 

1023 return Ed448PrivateKey.from_private_bytes(d) 

1024 except ValueError as err: 

1025 raise InvalidKeyError("Invalid key parameter") from err