Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/api_jws.py: 56%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

224 statements  

1from __future__ import annotations 

2 

3import binascii 

4import json 

5import warnings 

6from collections.abc import Sequence 

7from typing import TYPE_CHECKING, Any 

8 

9from .algorithms import ( 

10 Algorithm, 

11 get_default_algorithms, 

12 has_crypto, 

13 requires_cryptography, 

14) 

15from .api_jwk import PyJWK 

16from .exceptions import ( 

17 DecodeError, 

18 InvalidAlgorithmError, 

19 InvalidKeyError, 

20 InvalidSignatureError, 

21 InvalidTokenError, 

22) 

23from .utils import base64url_decode, base64url_encode 

24from .warnings import InsecureKeyLengthWarning, RemovedInPyjwt3Warning 

25 

26if TYPE_CHECKING: 

27 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

28 from .types import SigOptions 

29 

30_ALGORITHM_UNSET = object() 

31 

32 

33class PyJWS: 

34 header_typ = "JWT" 

35 

36 def __init__( 

37 self, 

38 algorithms: Sequence[str] | None = None, 

39 options: SigOptions | None = None, 

40 ) -> None: 

41 self._algorithms = get_default_algorithms() 

42 self._valid_algs = ( 

43 set(algorithms) if algorithms is not None else set(self._algorithms) 

44 ) 

45 

46 # Remove algorithms that aren't on the whitelist 

47 for key in list(self._algorithms.keys()): 

48 if key not in self._valid_algs: 

49 del self._algorithms[key] 

50 

51 self.options: SigOptions = self._get_default_options() 

52 if options is not None: 

53 self.options = {**self.options, **options} 

54 

55 @staticmethod 

56 def _get_default_options() -> SigOptions: 

57 return {"verify_signature": True, "enforce_minimum_key_length": False} 

58 

59 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: 

60 """ 

61 Registers a new Algorithm for use when creating and verifying tokens. 

62 

63 :param str alg_id: the ID of the Algorithm 

64 :param alg_obj: the Algorithm object 

65 :type alg_obj: Algorithm 

66 """ 

67 if alg_id in self._algorithms: 

68 raise ValueError("Algorithm already has a handler.") 

69 

70 if not isinstance(alg_obj, Algorithm): 

71 raise TypeError("Object is not of type `Algorithm`") 

72 

73 self._algorithms[alg_id] = alg_obj 

74 self._valid_algs.add(alg_id) 

75 

76 def unregister_algorithm(self, alg_id: str) -> None: 

77 """ 

78 Unregisters an Algorithm for use when creating and verifying tokens 

79 :param str alg_id: the ID of the Algorithm 

80 :raises KeyError: if algorithm is not registered. 

81 """ 

82 if alg_id not in self._algorithms: 

83 raise KeyError( 

84 "The specified algorithm could not be removed" 

85 " because it is not registered." 

86 ) 

87 

88 del self._algorithms[alg_id] 

89 self._valid_algs.remove(alg_id) 

90 

91 def get_algorithms(self) -> list[str]: 

92 """ 

93 Returns a list of supported values for the `alg` parameter. 

94 

95 :rtype: list[str] 

96 """ 

97 return list(self._valid_algs) 

98 

99 def get_algorithm_by_name(self, alg_name: str) -> Algorithm: 

100 """ 

101 For a given string name, return the matching Algorithm object. 

102 

103 Example usage: 

104 >>> jws_obj = PyJWS() 

105 >>> jws_obj.get_algorithm_by_name("RS256") 

106 

107 :param alg_name: The name of the algorithm to retrieve 

108 :type alg_name: str 

109 :rtype: Algorithm 

110 """ 

111 try: 

112 return self._algorithms[alg_name] 

113 except KeyError as e: 

114 if not has_crypto and alg_name in requires_cryptography: 

115 raise NotImplementedError( 

116 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" 

117 ) from e 

118 raise NotImplementedError("Algorithm not supported") from e 

119 

120 def encode( 

121 self, 

122 payload: bytes, 

123 key: AllowedPrivateKeys | PyJWK | str | bytes, 

124 algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment] 

125 headers: dict[str, Any] | None = None, 

126 json_encoder: type[json.JSONEncoder] | None = None, 

127 is_payload_detached: bool = False, 

128 sort_headers: bool = True, 

129 ) -> str: 

130 segments: list[bytes] = [] 

131 

132 # declare a new var to narrow the type for type checkers 

133 if algorithm is _ALGORITHM_UNSET: 

134 if isinstance(key, PyJWK): 

135 algorithm_ = key.algorithm_name 

136 else: 

137 algorithm_ = "HS256" 

138 elif algorithm is None: 

139 if isinstance(key, PyJWK): 

140 algorithm_ = key.algorithm_name 

141 else: 

142 algorithm_ = "none" 

143 else: 

144 algorithm_ = algorithm 

145 

146 # Prefer headers values if present to function parameters. 

147 if headers: 

148 headers_alg = headers.get("alg") 

149 if headers_alg: 

150 algorithm_ = headers["alg"] 

151 

152 headers_b64 = headers.get("b64") 

153 if headers_b64 is False: 

154 is_payload_detached = True 

155 

156 # Header 

157 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} 

158 

159 if headers: 

160 self._validate_headers(headers, encoding=True) 

161 header.update(headers) 

162 

163 if not header["typ"]: 

164 del header["typ"] 

165 

166 if is_payload_detached: 

167 header["b64"] = False 

168 # RFC 7797 §3: producers MUST list "b64" in "crit" whenever 

169 # "b64" appears in the protected header, so b64-unaware 

170 # verifiers don't silently treat an unencoded payload as 

171 # base64-encoded. 

172 existing_crit = header.get("crit", []) 

173 if not isinstance(existing_crit, list): 

174 raise InvalidTokenError("Invalid 'crit' header: must be a list") 

175 if "b64" not in existing_crit: 

176 header["crit"] = [*existing_crit, "b64"] 

177 elif "b64" in header: 

178 # True is the standard value for b64, so no need for it 

179 del header["b64"] 

180 

181 json_header = json.dumps( 

182 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers 

183 ).encode() 

184 

185 segments.append(base64url_encode(json_header)) 

186 

187 if is_payload_detached: 

188 msg_payload = payload 

189 else: 

190 msg_payload = base64url_encode(payload) 

191 segments.append(msg_payload) 

192 

193 # Segments 

194 signing_input = b".".join(segments) 

195 

196 alg_obj = self.get_algorithm_by_name(algorithm_) 

197 if isinstance(key, PyJWK): 

198 key = key.key 

199 key = alg_obj.prepare_key(key) 

200 

201 key_length_msg = alg_obj.check_key_length(key) 

202 if key_length_msg: 

203 if self.options.get("enforce_minimum_key_length", False): 

204 raise InvalidKeyError(key_length_msg) 

205 else: 

206 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=2) 

207 

208 signature = alg_obj.sign(signing_input, key) 

209 

210 segments.append(base64url_encode(signature)) 

211 

212 # Don't put the payload content inside the encoded token when detached 

213 if is_payload_detached: 

214 segments[1] = b"" 

215 encoded_string = b".".join(segments) 

216 

217 return encoded_string.decode("utf-8") 

218 

219 def decode_complete( 

220 self, 

221 jwt: str | bytes, 

222 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

223 algorithms: Sequence[str] | None = None, 

224 options: SigOptions | None = None, 

225 detached_payload: bytes | None = None, 

226 **kwargs: dict[str, Any], 

227 ) -> dict[str, Any]: 

228 if kwargs: 

229 warnings.warn( 

230 "passing additional kwargs to decode_complete() is deprecated " 

231 "and will be removed in pyjwt version 3. " 

232 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

233 RemovedInPyjwt3Warning, 

234 stacklevel=2, 

235 ) 

236 merged_options: SigOptions 

237 if options is None: 

238 merged_options = self.options 

239 else: 

240 merged_options = {**self.options, **options} 

241 

242 verify_signature = merged_options["verify_signature"] 

243 

244 if verify_signature and not algorithms and not isinstance(key, PyJWK): 

245 raise DecodeError( 

246 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

247 ) 

248 

249 payload, signing_input, header, signature = self._load(jwt) 

250 

251 self._validate_headers(header) 

252 

253 if header.get("b64", True) is False: 

254 # RFC 7797 §3: when "b64" is present in the protected header, 

255 # it MUST also appear in "crit". A token that sets b64=false 

256 # without declaring it critical is malformed. 

257 crit = header.get("crit") or [] 

258 if not isinstance(crit, list) or "b64" not in crit: 

259 raise InvalidTokenError( 

260 "The 'b64' header parameter requires 'b64' to be listed in 'crit'." 

261 ) 

262 if detached_payload is None: 

263 raise DecodeError( 

264 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' 

265 ) 

266 payload = detached_payload 

267 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) 

268 

269 if verify_signature: 

270 self._verify_signature( 

271 signing_input, 

272 header, 

273 signature, 

274 key, 

275 algorithms, 

276 options=merged_options, 

277 ) 

278 

279 return { 

280 "payload": payload, 

281 "header": header, 

282 "signature": signature, 

283 } 

284 

285 def decode( 

286 self, 

287 jwt: str | bytes, 

288 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

289 algorithms: Sequence[str] | None = None, 

290 options: SigOptions | None = None, 

291 detached_payload: bytes | None = None, 

292 **kwargs: dict[str, Any], 

293 ) -> Any: 

294 if kwargs: 

295 warnings.warn( 

296 "passing additional kwargs to decode() is deprecated " 

297 "and will be removed in pyjwt version 3. " 

298 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

299 RemovedInPyjwt3Warning, 

300 stacklevel=2, 

301 ) 

302 decoded = self.decode_complete( 

303 jwt, key, algorithms, options, detached_payload=detached_payload 

304 ) 

305 return decoded["payload"] 

306 

307 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: 

308 """Returns back the JWT header parameters as a `dict` 

309 

310 Note: The signature is not verified so the header parameters 

311 should not be fully trusted until signature verification is complete 

312 """ 

313 headers = self._load(jwt)[2] 

314 self._validate_headers(headers) 

315 

316 return headers 

317 

318 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: 

319 if isinstance(jwt, str): 

320 jwt = jwt.encode("utf-8") 

321 

322 if not isinstance(jwt, bytes): 

323 raise DecodeError(f"Invalid token type. Token must be a {bytes}") 

324 

325 try: 

326 signing_input, crypto_segment = jwt.rsplit(b".", 1) 

327 header_segment, payload_segment = signing_input.split(b".", 1) 

328 except ValueError as err: 

329 raise DecodeError("Not enough segments") from err 

330 

331 try: 

332 header_data = base64url_decode(header_segment) 

333 except (TypeError, binascii.Error) as err: 

334 raise DecodeError("Invalid header padding") from err 

335 

336 try: 

337 header: dict[str, Any] = json.loads(header_data) 

338 except ValueError as e: 

339 raise DecodeError(f"Invalid header string: {e}") from e 

340 

341 if not isinstance(header, dict): 

342 raise DecodeError("Invalid header string: must be a json object") 

343 

344 if header.get("b64", True) is False: 

345 # Detached payload form (RFC 7515 Appendix F): the compact-form 

346 # payload segment must be empty; the caller supplies the actual 

347 # payload via the `detached_payload` argument in decode_complete. 

348 # Skipping the base64 decode here removes an unauthenticated work 

349 # amplifier — otherwise an attacker can inflate the unused 

350 # segment to force CPU + memory cost before the signature is 

351 # even checked. 

352 if payload_segment: 

353 raise DecodeError("Payload segment must be empty when 'b64' is false.") 

354 payload = b"" 

355 else: 

356 try: 

357 payload = base64url_decode(payload_segment) 

358 except (TypeError, binascii.Error) as err: 

359 raise DecodeError("Invalid payload padding") from err 

360 

361 try: 

362 signature = base64url_decode(crypto_segment) 

363 except (TypeError, binascii.Error) as err: 

364 raise DecodeError("Invalid crypto padding") from err 

365 

366 return (payload, signing_input, header, signature) 

367 

368 def _verify_signature( 

369 self, 

370 signing_input: bytes, 

371 header: dict[str, Any], 

372 signature: bytes, 

373 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

374 algorithms: Sequence[str] | None = None, 

375 options: SigOptions | None = None, 

376 ) -> None: 

377 effective_options = options if options is not None else self.options 

378 

379 if algorithms is None and isinstance(key, PyJWK): 

380 algorithms = [key.algorithm_name] 

381 try: 

382 alg = header["alg"] 

383 except KeyError: 

384 raise InvalidAlgorithmError("Algorithm not specified") from None 

385 

386 if not alg or (algorithms is not None and alg not in algorithms): 

387 raise InvalidAlgorithmError("The specified alg value is not allowed") 

388 

389 if isinstance(key, PyJWK): 

390 # The PyJWK has a fixed algorithm bound at construction time. 

391 # Verification must use that algorithm, not whatever the token 

392 # header advertises, otherwise the caller's allow-list check 

393 # above degenerates into a string compare with no behavioural 

394 # effect on which algorithm actually verifies the signature. 

395 if alg != key.algorithm_name: 

396 raise InvalidAlgorithmError( 

397 f"Token algorithm {alg!r} does not match the key's " 

398 f"algorithm {key.algorithm_name!r}" 

399 ) 

400 alg_obj = key.Algorithm 

401 prepared_key = key.key 

402 else: 

403 try: 

404 alg_obj = self.get_algorithm_by_name(alg) 

405 except NotImplementedError as e: 

406 raise InvalidAlgorithmError("Algorithm not supported") from e 

407 prepared_key = alg_obj.prepare_key(key) 

408 

409 key_length_msg = alg_obj.check_key_length(prepared_key) 

410 if key_length_msg: 

411 if effective_options.get("enforce_minimum_key_length", False): 

412 raise InvalidKeyError(key_length_msg) 

413 else: 

414 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=4) 

415 

416 if not alg_obj.verify(signing_input, prepared_key, signature): 

417 raise InvalidSignatureError("Signature verification failed") 

418 

419 # Extensions that PyJWT actually understands and supports 

420 _supported_crit: set[str] = {"b64"} 

421 

422 def _validate_headers( 

423 self, headers: dict[str, Any], *, encoding: bool = False 

424 ) -> None: 

425 if "kid" in headers: 

426 self._validate_kid(headers["kid"]) 

427 if not encoding and "crit" in headers: 

428 self._validate_crit(headers) 

429 

430 def _validate_kid(self, kid: Any) -> None: 

431 if not isinstance(kid, str): 

432 raise InvalidTokenError("Key ID header parameter must be a string") 

433 

434 def _validate_crit(self, headers: dict[str, Any]) -> None: 

435 crit = headers["crit"] 

436 if not isinstance(crit, list) or len(crit) == 0: 

437 raise InvalidTokenError("Invalid 'crit' header: must be a non-empty list") 

438 for ext in crit: 

439 if not isinstance(ext, str): 

440 raise InvalidTokenError("Invalid 'crit' header: values must be strings") 

441 if ext not in self._supported_crit: 

442 raise InvalidTokenError(f"Unsupported critical extension: {ext}") 

443 if ext not in headers: 

444 raise InvalidTokenError( 

445 f"Critical extension '{ext}' is missing from headers" 

446 ) 

447 

448 

449_jws_global_obj = PyJWS() 

450encode = _jws_global_obj.encode 

451decode_complete = _jws_global_obj.decode_complete 

452decode = _jws_global_obj.decode 

453register_algorithm = _jws_global_obj.register_algorithm 

454unregister_algorithm = _jws_global_obj.unregister_algorithm 

455get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name 

456get_unverified_header = _jws_global_obj.get_unverified_header