Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/algorithms.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

452 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import hmac 

5import json 

6import os 

7from abc import ABC, abstractmethod 

8from typing import ( 

9 TYPE_CHECKING, 

10 Any, 

11 ClassVar, 

12 Literal, 

13 NoReturn, 

14 Union, 

15 cast, 

16 overload, 

17) 

18 

19from .exceptions import InvalidKeyError 

20from .types import HashlibHash, JWKDict 

21from .utils import ( 

22 base64url_decode, 

23 base64url_encode, 

24 der_to_raw_signature, 

25 force_bytes, 

26 from_base64url_uint, 

27 is_pem_format, 

28 is_ssh_key, 

29 raw_to_der_signature, 

30 to_base64url_uint, 

31) 

32 

33try: 

34 from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

35 from cryptography.hazmat.backends import default_backend 

36 from cryptography.hazmat.primitives import hashes 

37 from cryptography.hazmat.primitives.asymmetric import padding 

38 from cryptography.hazmat.primitives.asymmetric.ec import ( 

39 ECDSA, 

40 SECP256K1, 

41 SECP256R1, 

42 SECP384R1, 

43 SECP521R1, 

44 EllipticCurve, 

45 EllipticCurvePrivateKey, 

46 EllipticCurvePrivateNumbers, 

47 EllipticCurvePublicKey, 

48 EllipticCurvePublicNumbers, 

49 ) 

50 from cryptography.hazmat.primitives.asymmetric.ed448 import ( 

51 Ed448PrivateKey, 

52 Ed448PublicKey, 

53 ) 

54 from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

55 Ed25519PrivateKey, 

56 Ed25519PublicKey, 

57 ) 

58 from cryptography.hazmat.primitives.asymmetric.rsa import ( 

59 RSAPrivateKey, 

60 RSAPrivateNumbers, 

61 RSAPublicKey, 

62 RSAPublicNumbers, 

63 rsa_crt_dmp1, 

64 rsa_crt_dmq1, 

65 rsa_crt_iqmp, 

66 rsa_recover_prime_factors, 

67 ) 

68 from cryptography.hazmat.primitives.serialization import ( 

69 Encoding, 

70 NoEncryption, 

71 PrivateFormat, 

72 PublicFormat, 

73 load_pem_private_key, 

74 load_pem_public_key, 

75 load_ssh_public_key, 

76 ) 

77 

78 # pyjwt-964: we use these both for type checking below, as well as for validating the key passed in. 

79 # in Py >= 3.10, we can replace this with the Union types below 

80 ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey) 

81 ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey) 

82 ALLOWED_OKP_KEY_TYPES = ( 

83 Ed25519PrivateKey, 

84 Ed25519PublicKey, 

85 Ed448PrivateKey, 

86 Ed448PublicKey, 

87 ) 

88 ALLOWED_KEY_TYPES = ( 

89 ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES 

90 ) 

91 ALLOWED_PRIVATE_KEY_TYPES = ( 

92 RSAPrivateKey, 

93 EllipticCurvePrivateKey, 

94 Ed25519PrivateKey, 

95 Ed448PrivateKey, 

96 ) 

97 ALLOWED_PUBLIC_KEY_TYPES = ( 

98 RSAPublicKey, 

99 EllipticCurvePublicKey, 

100 Ed25519PublicKey, 

101 Ed448PublicKey, 

102 ) 

103 

104 if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")): 

105 import sys 

106 

107 if sys.version_info >= (3, 10): 

108 from typing import TypeAlias 

109 else: 

110 # Python 3.9 and lower 

111 from typing_extensions import TypeAlias 

112 

113 from cryptography.hazmat.primitives.asymmetric.types import ( 

114 PrivateKeyTypes, 

115 PublicKeyTypes, 

116 ) 

117 

118 # Type aliases for convenience in algorithms method signatures 

119 AllowedRSAKeys: TypeAlias = Union[RSAPrivateKey, RSAPublicKey] 

120 AllowedECKeys: TypeAlias = Union[ 

121 EllipticCurvePrivateKey, EllipticCurvePublicKey 

122 ] 

123 AllowedOKPKeys: TypeAlias = Union[ 

124 Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey 

125 ] 

126 AllowedKeys: TypeAlias = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys] 

127 #: Type alias for allowed ``cryptography`` private keys (requires ``cryptography`` to be installed) 

128 AllowedPrivateKeys: TypeAlias = Union[ 

129 RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey 

130 ] 

131 #: Type alias for allowed ``cryptography`` public keys (requires ``cryptography`` to be installed) 

132 AllowedPublicKeys: TypeAlias = Union[ 

133 RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey 

134 ] 

135 

136 has_crypto = True 

137except ModuleNotFoundError: 

138 has_crypto = False 

139 

140 

141requires_cryptography = { 

142 "RS256", 

143 "RS384", 

144 "RS512", 

145 "ES256", 

146 "ES256K", 

147 "ES384", 

148 "ES521", 

149 "ES512", 

150 "PS256", 

151 "PS384", 

152 "PS512", 

153 "EdDSA", 

154} 

155 

156 

157def get_default_algorithms() -> dict[str, Algorithm]: 

158 """ 

159 Returns the algorithms that are implemented by the library. 

160 """ 

161 default_algorithms: dict[str, Algorithm] = { 

162 "none": NoneAlgorithm(), 

163 "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), 

164 "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), 

165 "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), 

166 } 

167 

168 if has_crypto: 

169 default_algorithms.update( 

170 { 

171 "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), 

172 "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), 

173 "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), 

174 "ES256": ECAlgorithm(ECAlgorithm.SHA256, SECP256R1), 

175 "ES256K": ECAlgorithm(ECAlgorithm.SHA256, SECP256K1), 

176 "ES384": ECAlgorithm(ECAlgorithm.SHA384, SECP384R1), 

177 "ES521": ECAlgorithm(ECAlgorithm.SHA512, SECP521R1), 

178 "ES512": ECAlgorithm( 

179 ECAlgorithm.SHA512, SECP521R1 

180 ), # Backward compat for #219 fix 

181 "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), 

182 "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), 

183 "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), 

184 "EdDSA": OKPAlgorithm(), 

185 } 

186 ) 

187 

188 return default_algorithms 

189 

190 

191class Algorithm(ABC): 

192 """ 

193 The interface for an algorithm used to sign and verify tokens. 

194 """ 

195 

196 # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family 

197 _crypto_key_types: tuple[type[AllowedKeys], ...] | None = None 

198 

199 def compute_hash_digest(self, bytestr: bytes) -> bytes: 

200 """ 

201 Compute a hash digest using the specified algorithm's hash algorithm. 

202 

203 If there is no hash algorithm, raises a NotImplementedError. 

204 """ 

205 # lookup self.hash_alg if defined in a way that mypy can understand 

206 hash_alg = getattr(self, "hash_alg", None) 

207 if hash_alg is None: 

208 raise NotImplementedError 

209 

210 if ( 

211 has_crypto 

212 and isinstance(hash_alg, type) 

213 and issubclass(hash_alg, hashes.HashAlgorithm) 

214 ): 

215 digest = hashes.Hash(hash_alg(), backend=default_backend()) 

216 digest.update(bytestr) 

217 return bytes(digest.finalize()) 

218 else: 

219 return bytes(hash_alg(bytestr).digest()) 

220 

221 def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes) -> None: 

222 """Check that the key belongs to the right cryptographic family. 

223 

224 Note that this method only works when ``cryptography`` is installed. 

225 

226 :param key: Potentially a cryptography key 

227 :type key: :py:data:`PublicKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PublicKeyTypes>` | :py:data:`PrivateKeyTypes <cryptography.hazmat.primitives.asymmetric.types.PrivateKeyTypes>` 

228 :raises ValueError: if ``cryptography`` is not installed, or this method is called by a non-cryptography algorithm 

229 :raises InvalidKeyError: if the key doesn't match the expected key classes 

230 """ 

231 if not has_crypto or self._crypto_key_types is None: 

232 raise ValueError( 

233 "This method requires the cryptography library, and should only be used by cryptography-based algorithms." 

234 ) 

235 

236 if not isinstance(key, self._crypto_key_types): 

237 valid_classes = (cls.__name__ for cls in self._crypto_key_types) 

238 actual_class = key.__class__.__name__ 

239 self_class = self.__class__.__name__ 

240 raise InvalidKeyError( 

241 f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}" 

242 ) 

243 

244 @abstractmethod 

245 def prepare_key(self, key: Any) -> Any: 

246 """ 

247 Performs necessary validation and conversions on the key and returns 

248 the key value in the proper format for sign() and verify(). 

249 """ 

250 

251 @abstractmethod 

252 def sign(self, msg: bytes, key: Any) -> bytes: 

253 """ 

254 Returns a digital signature for the specified message 

255 using the specified key value. 

256 """ 

257 

258 @abstractmethod 

259 def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: 

260 """ 

261 Verifies that the specified digital signature is valid 

262 for the specified message and key values. 

263 """ 

264 

265 @overload 

266 @staticmethod 

267 @abstractmethod 

268 def to_jwk(key_obj: Any, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover 

269 

270 @overload 

271 @staticmethod 

272 @abstractmethod 

273 def to_jwk( 

274 key_obj: Any, as_dict: Literal[False] = False 

275 ) -> str: ... # pragma: no cover 

276 

277 @staticmethod 

278 @abstractmethod 

279 def to_jwk(key_obj: Any, as_dict: bool = False) -> JWKDict | str: 

280 """ 

281 Serializes a given key into a JWK 

282 """ 

283 

284 @staticmethod 

285 @abstractmethod 

286 def from_jwk(jwk: str | JWKDict) -> Any: 

287 """ 

288 Deserializes a given key from JWK back into a key object 

289 """ 

290 

291 def check_key_length(self, key: Any) -> str | None: 

292 """ 

293 Return a warning message if the key is below the minimum 

294 recommended length for this algorithm, or None if adequate. 

295 """ 

296 return None 

297 

298 

299class NoneAlgorithm(Algorithm): 

300 """ 

301 Placeholder for use when no signing or verification 

302 operations are required. 

303 """ 

304 

305 def prepare_key(self, key: str | None) -> None: 

306 if key == "": 

307 key = None 

308 

309 if key is not None: 

310 raise InvalidKeyError('When alg = "none", key value must be None.') 

311 

312 return key 

313 

314 def sign(self, msg: bytes, key: None) -> bytes: 

315 return b"" 

316 

317 def verify(self, msg: bytes, key: None, sig: bytes) -> bool: 

318 return False 

319 

320 @staticmethod 

321 def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: 

322 raise NotImplementedError() 

323 

324 @staticmethod 

325 def from_jwk(jwk: str | JWKDict) -> NoReturn: 

326 raise NotImplementedError() 

327 

328 

329class HMACAlgorithm(Algorithm): 

330 """ 

331 Performs signing and verification operations using HMAC 

332 and the specified hash function. 

333 """ 

334 

335 SHA256: ClassVar[HashlibHash] = hashlib.sha256 

336 SHA384: ClassVar[HashlibHash] = hashlib.sha384 

337 SHA512: ClassVar[HashlibHash] = hashlib.sha512 

338 

339 def __init__(self, hash_alg: HashlibHash) -> None: 

340 self.hash_alg = hash_alg 

341 

342 def prepare_key(self, key: str | bytes) -> bytes: 

343 key_bytes = force_bytes(key) 

344 

345 if is_pem_format(key_bytes) or is_ssh_key(key_bytes): 

346 raise InvalidKeyError( 

347 "The specified key is an asymmetric key or x509 certificate and" 

348 " should not be used as an HMAC secret." 

349 ) 

350 

351 return key_bytes 

352 

353 @overload 

354 @staticmethod 

355 def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: ... 

356 

357 @overload 

358 @staticmethod 

359 def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: ... 

360 

361 @staticmethod 

362 def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: 

363 jwk = { 

364 "k": base64url_encode(force_bytes(key_obj)).decode(), 

365 "kty": "oct", 

366 } 

367 

368 if as_dict: 

369 return jwk 

370 else: 

371 return json.dumps(jwk) 

372 

373 @staticmethod 

374 def from_jwk(jwk: str | JWKDict) -> bytes: 

375 try: 

376 if isinstance(jwk, str): 

377 obj: JWKDict = json.loads(jwk) 

378 elif isinstance(jwk, dict): 

379 obj = jwk 

380 else: 

381 raise ValueError 

382 except ValueError: 

383 raise InvalidKeyError("Key is not valid JSON") from None 

384 

385 if obj.get("kty") != "oct": 

386 raise InvalidKeyError("Not an HMAC key") 

387 

388 return base64url_decode(obj["k"]) 

389 

390 def check_key_length(self, key: bytes) -> str | None: 

391 min_length = self.hash_alg().digest_size 

392 if len(key) < min_length: 

393 return ( 

394 f"The HMAC key is {len(key)} bytes long, which is below " 

395 f"the minimum recommended length of {min_length} bytes for " 

396 f"{self.hash_alg().name.upper()}. " 

397 f"See RFC 7518 Section 3.2." 

398 ) 

399 return None 

400 

401 def sign(self, msg: bytes, key: bytes) -> bytes: 

402 return hmac.new(key, msg, self.hash_alg).digest() 

403 

404 def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: 

405 return hmac.compare_digest(sig, self.sign(msg, key)) 

406 

407 

408if has_crypto: 

409 

410 class RSAAlgorithm(Algorithm): 

411 """ 

412 Performs signing and verification operations using 

413 RSASSA-PKCS-v1_5 and the specified hash function. 

414 """ 

415 

416 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

417 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

418 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

419 

420 _crypto_key_types = ALLOWED_RSA_KEY_TYPES 

421 _MIN_KEY_SIZE: ClassVar[int] = 2048 

422 

423 def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: 

424 self.hash_alg = hash_alg 

425 

426 def check_key_length(self, key: AllowedRSAKeys) -> str | None: 

427 if key.key_size < self._MIN_KEY_SIZE: 

428 return ( 

429 f"The RSA key is {key.key_size} bits long, which is below " 

430 f"the minimum recommended size of {self._MIN_KEY_SIZE} bits. " 

431 f"See NIST SP 800-131A." 

432 ) 

433 return None 

434 

435 def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: 

436 if isinstance(key, self._crypto_key_types): 

437 return key 

438 

439 if not isinstance(key, (bytes, str)): 

440 raise TypeError("Expecting a PEM-formatted key.") 

441 

442 key_bytes = force_bytes(key) 

443 

444 try: 

445 if key_bytes.startswith(b"ssh-rsa"): 

446 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

447 self.check_crypto_key_type(public_key) 

448 return cast(RSAPublicKey, public_key) 

449 else: 

450 private_key: PrivateKeyTypes = load_pem_private_key( 

451 key_bytes, password=None 

452 ) 

453 self.check_crypto_key_type(private_key) 

454 return cast(RSAPrivateKey, private_key) 

455 except ValueError: 

456 try: 

457 public_key = load_pem_public_key(key_bytes) 

458 self.check_crypto_key_type(public_key) 

459 return cast(RSAPublicKey, public_key) 

460 except (ValueError, UnsupportedAlgorithm): 

461 raise InvalidKeyError( 

462 "Could not parse the provided public key." 

463 ) from None 

464 

465 @overload 

466 @staticmethod 

467 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: ... 

468 

469 @overload 

470 @staticmethod 

471 def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: ... 

472 

473 @staticmethod 

474 def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: 

475 obj: dict[str, Any] | None = None 

476 

477 if hasattr(key_obj, "private_numbers"): 

478 # Private key 

479 numbers = key_obj.private_numbers() 

480 

481 obj = { 

482 "kty": "RSA", 

483 "key_ops": ["sign"], 

484 "n": to_base64url_uint(numbers.public_numbers.n).decode(), 

485 "e": to_base64url_uint(numbers.public_numbers.e).decode(), 

486 "d": to_base64url_uint(numbers.d).decode(), 

487 "p": to_base64url_uint(numbers.p).decode(), 

488 "q": to_base64url_uint(numbers.q).decode(), 

489 "dp": to_base64url_uint(numbers.dmp1).decode(), 

490 "dq": to_base64url_uint(numbers.dmq1).decode(), 

491 "qi": to_base64url_uint(numbers.iqmp).decode(), 

492 } 

493 

494 elif hasattr(key_obj, "verify"): 

495 # Public key 

496 numbers = key_obj.public_numbers() 

497 

498 obj = { 

499 "kty": "RSA", 

500 "key_ops": ["verify"], 

501 "n": to_base64url_uint(numbers.n).decode(), 

502 "e": to_base64url_uint(numbers.e).decode(), 

503 } 

504 else: 

505 raise InvalidKeyError("Not a public or private key") 

506 

507 if as_dict: 

508 return obj 

509 else: 

510 return json.dumps(obj) 

511 

512 @staticmethod 

513 def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: 

514 try: 

515 if isinstance(jwk, str): 

516 obj = json.loads(jwk) 

517 elif isinstance(jwk, dict): 

518 obj = jwk 

519 else: 

520 raise ValueError 

521 except ValueError: 

522 raise InvalidKeyError("Key is not valid JSON") from None 

523 

524 if obj.get("kty") != "RSA": 

525 raise InvalidKeyError("Not an RSA key") from None 

526 

527 if "d" in obj and "e" in obj and "n" in obj: 

528 # Private key 

529 if "oth" in obj: 

530 raise InvalidKeyError( 

531 "Unsupported RSA private key: > 2 primes not supported" 

532 ) 

533 

534 other_props = ["p", "q", "dp", "dq", "qi"] 

535 props_found = [prop in obj for prop in other_props] 

536 any_props_found = any(props_found) 

537 

538 if any_props_found and not all(props_found): 

539 raise InvalidKeyError( 

540 "RSA key must include all parameters if any are present besides d" 

541 ) from None 

542 

543 public_numbers = RSAPublicNumbers( 

544 from_base64url_uint(obj["e"]), 

545 from_base64url_uint(obj["n"]), 

546 ) 

547 

548 if any_props_found: 

549 numbers = RSAPrivateNumbers( 

550 d=from_base64url_uint(obj["d"]), 

551 p=from_base64url_uint(obj["p"]), 

552 q=from_base64url_uint(obj["q"]), 

553 dmp1=from_base64url_uint(obj["dp"]), 

554 dmq1=from_base64url_uint(obj["dq"]), 

555 iqmp=from_base64url_uint(obj["qi"]), 

556 public_numbers=public_numbers, 

557 ) 

558 else: 

559 d = from_base64url_uint(obj["d"]) 

560 p, q = rsa_recover_prime_factors( 

561 public_numbers.n, d, public_numbers.e 

562 ) 

563 

564 numbers = RSAPrivateNumbers( 

565 d=d, 

566 p=p, 

567 q=q, 

568 dmp1=rsa_crt_dmp1(d, p), 

569 dmq1=rsa_crt_dmq1(d, q), 

570 iqmp=rsa_crt_iqmp(p, q), 

571 public_numbers=public_numbers, 

572 ) 

573 

574 return numbers.private_key() 

575 elif "n" in obj and "e" in obj: 

576 # Public key 

577 return RSAPublicNumbers( 

578 from_base64url_uint(obj["e"]), 

579 from_base64url_uint(obj["n"]), 

580 ).public_key() 

581 else: 

582 raise InvalidKeyError("Not a public or private key") 

583 

584 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

585 signature: bytes = key.sign(msg, padding.PKCS1v15(), self.hash_alg()) 

586 return signature 

587 

588 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

589 try: 

590 key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) 

591 return True 

592 except InvalidSignature: 

593 return False 

594 

595 class ECAlgorithm(Algorithm): 

596 """ 

597 Performs signing and verification operations using 

598 ECDSA and the specified hash function 

599 """ 

600 

601 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 

602 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 

603 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 

604 

605 _crypto_key_types = ALLOWED_EC_KEY_TYPES 

606 

607 def __init__( 

608 self, 

609 hash_alg: type[hashes.HashAlgorithm], 

610 expected_curve: type[EllipticCurve] | None = None, 

611 ) -> None: 

612 self.hash_alg = hash_alg 

613 self.expected_curve = expected_curve 

614 

615 def _validate_curve(self, key: AllowedECKeys) -> None: 

616 """Validate that the key's curve matches the expected curve.""" 

617 if self.expected_curve is None: 

618 return 

619 

620 if not isinstance(key.curve, self.expected_curve): 

621 raise InvalidKeyError( 

622 f"The key's curve '{key.curve.name}' does not match the expected " 

623 f"curve '{self.expected_curve.name}' for this algorithm" 

624 ) 

625 

626 def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: 

627 if isinstance(key, self._crypto_key_types): 

628 self._validate_curve(key) 

629 return key 

630 

631 if not isinstance(key, (bytes, str)): 

632 raise TypeError("Expecting a PEM-formatted key.") 

633 

634 key_bytes = force_bytes(key) 

635 

636 # Attempt to load key. We don't know if it's 

637 # a Signing Key or a Verifying Key, so we try 

638 # the Verifying Key first. 

639 try: 

640 if key_bytes.startswith(b"ecdsa-sha2-"): 

641 public_key: PublicKeyTypes = load_ssh_public_key(key_bytes) 

642 else: 

643 public_key = load_pem_public_key(key_bytes) 

644 

645 # Explicit check the key to prevent confusing errors from cryptography 

646 self.check_crypto_key_type(public_key) 

647 ec_public_key = cast(EllipticCurvePublicKey, public_key) 

648 self._validate_curve(ec_public_key) 

649 return ec_public_key 

650 except ValueError: 

651 private_key = load_pem_private_key(key_bytes, password=None) 

652 self.check_crypto_key_type(private_key) 

653 ec_private_key = cast(EllipticCurvePrivateKey, private_key) 

654 self._validate_curve(ec_private_key) 

655 return ec_private_key 

656 

657 def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: 

658 der_sig = key.sign(msg, ECDSA(self.hash_alg())) 

659 

660 return der_to_raw_signature(der_sig, key.curve) 

661 

662 def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: 

663 try: 

664 der_sig = raw_to_der_signature(sig, key.curve) 

665 except ValueError: 

666 return False 

667 

668 try: 

669 public_key = ( 

670 key.public_key() 

671 if isinstance(key, EllipticCurvePrivateKey) 

672 else key 

673 ) 

674 public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) 

675 return True 

676 except InvalidSignature: 

677 return False 

678 

679 @overload 

680 @staticmethod 

681 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: ... 

682 

683 @overload 

684 @staticmethod 

685 def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: ... 

686 

687 @staticmethod 

688 def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: 

689 if isinstance(key_obj, EllipticCurvePrivateKey): 

690 public_numbers = key_obj.public_key().public_numbers() 

691 elif isinstance(key_obj, EllipticCurvePublicKey): 

692 public_numbers = key_obj.public_numbers() 

693 else: 

694 raise InvalidKeyError("Not a public or private key") 

695 

696 if isinstance(key_obj.curve, SECP256R1): 

697 crv = "P-256" 

698 elif isinstance(key_obj.curve, SECP384R1): 

699 crv = "P-384" 

700 elif isinstance(key_obj.curve, SECP521R1): 

701 crv = "P-521" 

702 elif isinstance(key_obj.curve, SECP256K1): 

703 crv = "secp256k1" 

704 else: 

705 raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") 

706 

707 obj: dict[str, Any] = { 

708 "kty": "EC", 

709 "crv": crv, 

710 "x": to_base64url_uint( 

711 public_numbers.x, 

712 bit_length=key_obj.curve.key_size, 

713 ).decode(), 

714 "y": to_base64url_uint( 

715 public_numbers.y, 

716 bit_length=key_obj.curve.key_size, 

717 ).decode(), 

718 } 

719 

720 if isinstance(key_obj, EllipticCurvePrivateKey): 

721 obj["d"] = to_base64url_uint( 

722 key_obj.private_numbers().private_value, 

723 bit_length=key_obj.curve.key_size, 

724 ).decode() 

725 

726 if as_dict: 

727 return obj 

728 else: 

729 return json.dumps(obj) 

730 

731 @staticmethod 

732 def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: 

733 try: 

734 if isinstance(jwk, str): 

735 obj = json.loads(jwk) 

736 elif isinstance(jwk, dict): 

737 obj = jwk 

738 else: 

739 raise ValueError 

740 except ValueError: 

741 raise InvalidKeyError("Key is not valid JSON") from None 

742 

743 if obj.get("kty") != "EC": 

744 raise InvalidKeyError("Not an Elliptic curve key") from None 

745 

746 if "x" not in obj or "y" not in obj: 

747 raise InvalidKeyError("Not an Elliptic curve key") from None 

748 

749 x = base64url_decode(obj.get("x")) 

750 y = base64url_decode(obj.get("y")) 

751 

752 curve = obj.get("crv") 

753 curve_obj: EllipticCurve 

754 

755 if curve == "P-256": 

756 if len(x) == len(y) == 32: 

757 curve_obj = SECP256R1() 

758 else: 

759 raise InvalidKeyError( 

760 "Coords should be 32 bytes for curve P-256" 

761 ) from None 

762 elif curve == "P-384": 

763 if len(x) == len(y) == 48: 

764 curve_obj = SECP384R1() 

765 else: 

766 raise InvalidKeyError( 

767 "Coords should be 48 bytes for curve P-384" 

768 ) from None 

769 elif curve == "P-521": 

770 if len(x) == len(y) == 66: 

771 curve_obj = SECP521R1() 

772 else: 

773 raise InvalidKeyError( 

774 "Coords should be 66 bytes for curve P-521" 

775 ) from None 

776 elif curve == "secp256k1": 

777 if len(x) == len(y) == 32: 

778 curve_obj = SECP256K1() 

779 else: 

780 raise InvalidKeyError( 

781 "Coords should be 32 bytes for curve secp256k1" 

782 ) 

783 else: 

784 raise InvalidKeyError(f"Invalid curve: {curve}") 

785 

786 public_numbers = EllipticCurvePublicNumbers( 

787 x=int.from_bytes(x, byteorder="big"), 

788 y=int.from_bytes(y, byteorder="big"), 

789 curve=curve_obj, 

790 ) 

791 

792 if "d" not in obj: 

793 return public_numbers.public_key() 

794 

795 d = base64url_decode(obj.get("d")) 

796 if len(d) != len(x): 

797 raise InvalidKeyError( 

798 "D should be {} bytes for curve {}", len(x), curve 

799 ) 

800 

801 return EllipticCurvePrivateNumbers( 

802 int.from_bytes(d, byteorder="big"), public_numbers 

803 ).private_key() 

804 

805 class RSAPSSAlgorithm(RSAAlgorithm): 

806 """ 

807 Performs a signature using RSASSA-PSS with MGF1 

808 """ 

809 

810 def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: 

811 signature: bytes = key.sign( 

812 msg, 

813 padding.PSS( 

814 mgf=padding.MGF1(self.hash_alg()), 

815 salt_length=self.hash_alg().digest_size, 

816 ), 

817 self.hash_alg(), 

818 ) 

819 return signature 

820 

821 def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: 

822 try: 

823 key.verify( 

824 sig, 

825 msg, 

826 padding.PSS( 

827 mgf=padding.MGF1(self.hash_alg()), 

828 salt_length=self.hash_alg().digest_size, 

829 ), 

830 self.hash_alg(), 

831 ) 

832 return True 

833 except InvalidSignature: 

834 return False 

835 

836 class OKPAlgorithm(Algorithm): 

837 """ 

838 Performs signing and verification operations using EdDSA 

839 

840 This class requires ``cryptography>=2.6`` to be installed. 

841 """ 

842 

843 _crypto_key_types = ALLOWED_OKP_KEY_TYPES 

844 

845 def __init__(self, **kwargs: Any) -> None: 

846 pass 

847 

848 def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: 

849 if not isinstance(key, (str, bytes)): 

850 self.check_crypto_key_type(key) 

851 return key 

852 

853 key_str = key.decode("utf-8") if isinstance(key, bytes) else key 

854 key_bytes = key.encode("utf-8") if isinstance(key, str) else key 

855 

856 loaded_key: PublicKeyTypes | PrivateKeyTypes 

857 if "-----BEGIN PUBLIC" in key_str: 

858 loaded_key = load_pem_public_key(key_bytes) 

859 elif "-----BEGIN PRIVATE" in key_str: 

860 loaded_key = load_pem_private_key(key_bytes, password=None) 

861 elif key_str[0:4] == "ssh-": 

862 loaded_key = load_ssh_public_key(key_bytes) 

863 else: 

864 raise InvalidKeyError("Not a public or private key") 

865 

866 # Explicit check the key to prevent confusing errors from cryptography 

867 self.check_crypto_key_type(loaded_key) 

868 return cast("AllowedOKPKeys", loaded_key) 

869 

870 def sign( 

871 self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey 

872 ) -> bytes: 

873 """ 

874 Sign a message ``msg`` using the EdDSA private key ``key`` 

875 :param str|bytes msg: Message to sign 

876 :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` 

877 or :class:`.Ed448PrivateKey` isinstance 

878 :return bytes signature: The signature, as bytes 

879 """ 

880 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

881 signature: bytes = key.sign(msg_bytes) 

882 return signature 

883 

884 def verify( 

885 self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes 

886 ) -> bool: 

887 """ 

888 Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` 

889 

890 :param str|bytes sig: EdDSA signature to check ``msg`` against 

891 :param str|bytes msg: Message to sign 

892 :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: 

893 A private or public EdDSA key instance 

894 :return bool verified: True if signature is valid, False if not. 

895 """ 

896 try: 

897 msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg 

898 sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig 

899 

900 public_key = ( 

901 key.public_key() 

902 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) 

903 else key 

904 ) 

905 public_key.verify(sig_bytes, msg_bytes) 

906 return True # If no exception was raised, the signature is valid. 

907 except InvalidSignature: 

908 return False 

909 

910 @overload 

911 @staticmethod 

912 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: ... 

913 

914 @overload 

915 @staticmethod 

916 def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: ... 

917 

918 @staticmethod 

919 def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: 

920 if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): 

921 x = key.public_bytes( 

922 encoding=Encoding.Raw, 

923 format=PublicFormat.Raw, 

924 ) 

925 crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" 

926 

927 obj = { 

928 "x": base64url_encode(force_bytes(x)).decode(), 

929 "kty": "OKP", 

930 "crv": crv, 

931 } 

932 

933 if as_dict: 

934 return obj 

935 else: 

936 return json.dumps(obj) 

937 

938 if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): 

939 d = key.private_bytes( 

940 encoding=Encoding.Raw, 

941 format=PrivateFormat.Raw, 

942 encryption_algorithm=NoEncryption(), 

943 ) 

944 

945 x = key.public_key().public_bytes( 

946 encoding=Encoding.Raw, 

947 format=PublicFormat.Raw, 

948 ) 

949 

950 crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" 

951 obj = { 

952 "x": base64url_encode(force_bytes(x)).decode(), 

953 "d": base64url_encode(force_bytes(d)).decode(), 

954 "kty": "OKP", 

955 "crv": crv, 

956 } 

957 

958 if as_dict: 

959 return obj 

960 else: 

961 return json.dumps(obj) 

962 

963 raise InvalidKeyError("Not a public or private key") 

964 

965 @staticmethod 

966 def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: 

967 try: 

968 if isinstance(jwk, str): 

969 obj = json.loads(jwk) 

970 elif isinstance(jwk, dict): 

971 obj = jwk 

972 else: 

973 raise ValueError 

974 except ValueError: 

975 raise InvalidKeyError("Key is not valid JSON") from None 

976 

977 if obj.get("kty") != "OKP": 

978 raise InvalidKeyError("Not an Octet Key Pair") 

979 

980 curve = obj.get("crv") 

981 if curve != "Ed25519" and curve != "Ed448": 

982 raise InvalidKeyError(f"Invalid curve: {curve}") 

983 

984 if "x" not in obj: 

985 raise InvalidKeyError('OKP should have "x" parameter') 

986 x = base64url_decode(obj.get("x")) 

987 

988 try: 

989 if "d" not in obj: 

990 if curve == "Ed25519": 

991 return Ed25519PublicKey.from_public_bytes(x) 

992 return Ed448PublicKey.from_public_bytes(x) 

993 d = base64url_decode(obj.get("d")) 

994 if curve == "Ed25519": 

995 return Ed25519PrivateKey.from_private_bytes(d) 

996 return Ed448PrivateKey.from_private_bytes(d) 

997 except ValueError as err: 

998 raise InvalidKeyError("Invalid key parameter") from err