Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/api_jws.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

209 statements  

1from __future__ import annotations 

2 

3import binascii 

4import json 

5import warnings 

6from collections.abc import Sequence 

7from typing import TYPE_CHECKING, Any 

8 

9from .algorithms import ( 

10 Algorithm, 

11 get_default_algorithms, 

12 has_crypto, 

13 requires_cryptography, 

14) 

15from .api_jwk import PyJWK 

16from .exceptions import ( 

17 DecodeError, 

18 InvalidAlgorithmError, 

19 InvalidKeyError, 

20 InvalidSignatureError, 

21 InvalidTokenError, 

22) 

23from .utils import base64url_decode, base64url_encode 

24from .warnings import InsecureKeyLengthWarning, RemovedInPyjwt3Warning 

25 

26if TYPE_CHECKING: 

27 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

28 from .types import SigOptions 

29 

30_ALGORITHM_UNSET = object() 

31 

32 

33class PyJWS: 

34 header_typ = "JWT" 

35 

36 def __init__( 

37 self, 

38 algorithms: Sequence[str] | None = None, 

39 options: SigOptions | None = None, 

40 ) -> None: 

41 self._algorithms = get_default_algorithms() 

42 self._valid_algs = ( 

43 set(algorithms) if algorithms is not None else set(self._algorithms) 

44 ) 

45 

46 # Remove algorithms that aren't on the whitelist 

47 for key in list(self._algorithms.keys()): 

48 if key not in self._valid_algs: 

49 del self._algorithms[key] 

50 

51 self.options: SigOptions = self._get_default_options() 

52 if options is not None: 

53 self.options = {**self.options, **options} 

54 

55 @staticmethod 

56 def _get_default_options() -> SigOptions: 

57 return {"verify_signature": True, "enforce_minimum_key_length": False} 

58 

59 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: 

60 """ 

61 Registers a new Algorithm for use when creating and verifying tokens. 

62 

63 :param str alg_id: the ID of the Algorithm 

64 :param alg_obj: the Algorithm object 

65 :type alg_obj: Algorithm 

66 """ 

67 if alg_id in self._algorithms: 

68 raise ValueError("Algorithm already has a handler.") 

69 

70 if not isinstance(alg_obj, Algorithm): 

71 raise TypeError("Object is not of type `Algorithm`") 

72 

73 self._algorithms[alg_id] = alg_obj 

74 self._valid_algs.add(alg_id) 

75 

76 def unregister_algorithm(self, alg_id: str) -> None: 

77 """ 

78 Unregisters an Algorithm for use when creating and verifying tokens 

79 :param str alg_id: the ID of the Algorithm 

80 :raises KeyError: if algorithm is not registered. 

81 """ 

82 if alg_id not in self._algorithms: 

83 raise KeyError( 

84 "The specified algorithm could not be removed" 

85 " because it is not registered." 

86 ) 

87 

88 del self._algorithms[alg_id] 

89 self._valid_algs.remove(alg_id) 

90 

91 def get_algorithms(self) -> list[str]: 

92 """ 

93 Returns a list of supported values for the `alg` parameter. 

94 

95 :rtype: list[str] 

96 """ 

97 return list(self._valid_algs) 

98 

99 def get_algorithm_by_name(self, alg_name: str) -> Algorithm: 

100 """ 

101 For a given string name, return the matching Algorithm object. 

102 

103 Example usage: 

104 >>> jws_obj = PyJWS() 

105 >>> jws_obj.get_algorithm_by_name("RS256") 

106 

107 :param alg_name: The name of the algorithm to retrieve 

108 :type alg_name: str 

109 :rtype: Algorithm 

110 """ 

111 try: 

112 return self._algorithms[alg_name] 

113 except KeyError as e: 

114 if not has_crypto and alg_name in requires_cryptography: 

115 raise NotImplementedError( 

116 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" 

117 ) from e 

118 raise NotImplementedError("Algorithm not supported") from e 

119 

120 def encode( 

121 self, 

122 payload: bytes, 

123 key: AllowedPrivateKeys | PyJWK | str | bytes, 

124 algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment] 

125 headers: dict[str, Any] | None = None, 

126 json_encoder: type[json.JSONEncoder] | None = None, 

127 is_payload_detached: bool = False, 

128 sort_headers: bool = True, 

129 ) -> str: 

130 segments: list[bytes] = [] 

131 

132 # declare a new var to narrow the type for type checkers 

133 if algorithm is _ALGORITHM_UNSET: 

134 if isinstance(key, PyJWK): 

135 algorithm_ = key.algorithm_name 

136 else: 

137 algorithm_ = "HS256" 

138 elif algorithm is None: 

139 if isinstance(key, PyJWK): 

140 algorithm_ = key.algorithm_name 

141 else: 

142 algorithm_ = "none" 

143 else: 

144 algorithm_ = algorithm 

145 

146 # Prefer headers values if present to function parameters. 

147 if headers: 

148 headers_alg = headers.get("alg") 

149 if headers_alg: 

150 algorithm_ = headers["alg"] 

151 

152 headers_b64 = headers.get("b64") 

153 if headers_b64 is False: 

154 is_payload_detached = True 

155 

156 # Header 

157 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} 

158 

159 if headers: 

160 self._validate_headers(headers, encoding=True) 

161 header.update(headers) 

162 

163 if not header["typ"]: 

164 del header["typ"] 

165 

166 if is_payload_detached: 

167 header["b64"] = False 

168 elif "b64" in header: 

169 # True is the standard value for b64, so no need for it 

170 del header["b64"] 

171 

172 json_header = json.dumps( 

173 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers 

174 ).encode() 

175 

176 segments.append(base64url_encode(json_header)) 

177 

178 if is_payload_detached: 

179 msg_payload = payload 

180 else: 

181 msg_payload = base64url_encode(payload) 

182 segments.append(msg_payload) 

183 

184 # Segments 

185 signing_input = b".".join(segments) 

186 

187 alg_obj = self.get_algorithm_by_name(algorithm_) 

188 if isinstance(key, PyJWK): 

189 key = key.key 

190 key = alg_obj.prepare_key(key) 

191 

192 key_length_msg = alg_obj.check_key_length(key) 

193 if key_length_msg: 

194 if self.options.get("enforce_minimum_key_length", False): 

195 raise InvalidKeyError(key_length_msg) 

196 else: 

197 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=2) 

198 

199 signature = alg_obj.sign(signing_input, key) 

200 

201 segments.append(base64url_encode(signature)) 

202 

203 # Don't put the payload content inside the encoded token when detached 

204 if is_payload_detached: 

205 segments[1] = b"" 

206 encoded_string = b".".join(segments) 

207 

208 return encoded_string.decode("utf-8") 

209 

210 def decode_complete( 

211 self, 

212 jwt: str | bytes, 

213 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

214 algorithms: Sequence[str] | None = None, 

215 options: SigOptions | None = None, 

216 detached_payload: bytes | None = None, 

217 **kwargs: dict[str, Any], 

218 ) -> dict[str, Any]: 

219 if kwargs: 

220 warnings.warn( 

221 "passing additional kwargs to decode_complete() is deprecated " 

222 "and will be removed in pyjwt version 3. " 

223 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

224 RemovedInPyjwt3Warning, 

225 stacklevel=2, 

226 ) 

227 merged_options: SigOptions 

228 if options is None: 

229 merged_options = self.options 

230 else: 

231 merged_options = {**self.options, **options} 

232 

233 verify_signature = merged_options["verify_signature"] 

234 

235 if verify_signature and not algorithms and not isinstance(key, PyJWK): 

236 raise DecodeError( 

237 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

238 ) 

239 

240 payload, signing_input, header, signature = self._load(jwt) 

241 

242 self._validate_headers(header) 

243 

244 if header.get("b64", True) is False: 

245 if detached_payload is None: 

246 raise DecodeError( 

247 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' 

248 ) 

249 payload = detached_payload 

250 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) 

251 

252 if verify_signature: 

253 self._verify_signature(signing_input, header, signature, key, algorithms) 

254 

255 return { 

256 "payload": payload, 

257 "header": header, 

258 "signature": signature, 

259 } 

260 

261 def decode( 

262 self, 

263 jwt: str | bytes, 

264 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

265 algorithms: Sequence[str] | None = None, 

266 options: SigOptions | None = None, 

267 detached_payload: bytes | None = None, 

268 **kwargs: dict[str, Any], 

269 ) -> Any: 

270 if kwargs: 

271 warnings.warn( 

272 "passing additional kwargs to decode() is deprecated " 

273 "and will be removed in pyjwt version 3. " 

274 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

275 RemovedInPyjwt3Warning, 

276 stacklevel=2, 

277 ) 

278 decoded = self.decode_complete( 

279 jwt, key, algorithms, options, detached_payload=detached_payload 

280 ) 

281 return decoded["payload"] 

282 

283 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: 

284 """Returns back the JWT header parameters as a `dict` 

285 

286 Note: The signature is not verified so the header parameters 

287 should not be fully trusted until signature verification is complete 

288 """ 

289 headers = self._load(jwt)[2] 

290 self._validate_headers(headers) 

291 

292 return headers 

293 

294 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: 

295 if isinstance(jwt, str): 

296 jwt = jwt.encode("utf-8") 

297 

298 if not isinstance(jwt, bytes): 

299 raise DecodeError(f"Invalid token type. Token must be a {bytes}") 

300 

301 try: 

302 signing_input, crypto_segment = jwt.rsplit(b".", 1) 

303 header_segment, payload_segment = signing_input.split(b".", 1) 

304 except ValueError as err: 

305 raise DecodeError("Not enough segments") from err 

306 

307 try: 

308 header_data = base64url_decode(header_segment) 

309 except (TypeError, binascii.Error) as err: 

310 raise DecodeError("Invalid header padding") from err 

311 

312 try: 

313 header: dict[str, Any] = json.loads(header_data) 

314 except ValueError as e: 

315 raise DecodeError(f"Invalid header string: {e}") from e 

316 

317 if not isinstance(header, dict): 

318 raise DecodeError("Invalid header string: must be a json object") 

319 

320 try: 

321 payload = base64url_decode(payload_segment) 

322 except (TypeError, binascii.Error) as err: 

323 raise DecodeError("Invalid payload padding") from err 

324 

325 try: 

326 signature = base64url_decode(crypto_segment) 

327 except (TypeError, binascii.Error) as err: 

328 raise DecodeError("Invalid crypto padding") from err 

329 

330 return (payload, signing_input, header, signature) 

331 

332 def _verify_signature( 

333 self, 

334 signing_input: bytes, 

335 header: dict[str, Any], 

336 signature: bytes, 

337 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

338 algorithms: Sequence[str] | None = None, 

339 ) -> None: 

340 if algorithms is None and isinstance(key, PyJWK): 

341 algorithms = [key.algorithm_name] 

342 try: 

343 alg = header["alg"] 

344 except KeyError: 

345 raise InvalidAlgorithmError("Algorithm not specified") from None 

346 

347 if not alg or (algorithms is not None and alg not in algorithms): 

348 raise InvalidAlgorithmError("The specified alg value is not allowed") 

349 

350 if isinstance(key, PyJWK): 

351 alg_obj = key.Algorithm 

352 prepared_key = key.key 

353 else: 

354 try: 

355 alg_obj = self.get_algorithm_by_name(alg) 

356 except NotImplementedError as e: 

357 raise InvalidAlgorithmError("Algorithm not supported") from e 

358 prepared_key = alg_obj.prepare_key(key) 

359 

360 key_length_msg = alg_obj.check_key_length(prepared_key) 

361 if key_length_msg: 

362 if self.options.get("enforce_minimum_key_length", False): 

363 raise InvalidKeyError(key_length_msg) 

364 else: 

365 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=4) 

366 

367 if not alg_obj.verify(signing_input, prepared_key, signature): 

368 raise InvalidSignatureError("Signature verification failed") 

369 

370 # Extensions that PyJWT actually understands and supports 

371 _supported_crit: set[str] = {"b64"} 

372 

373 def _validate_headers( 

374 self, headers: dict[str, Any], *, encoding: bool = False 

375 ) -> None: 

376 if "kid" in headers: 

377 self._validate_kid(headers["kid"]) 

378 if not encoding and "crit" in headers: 

379 self._validate_crit(headers) 

380 

381 def _validate_kid(self, kid: Any) -> None: 

382 if not isinstance(kid, str): 

383 raise InvalidTokenError("Key ID header parameter must be a string") 

384 

385 def _validate_crit(self, headers: dict[str, Any]) -> None: 

386 crit = headers["crit"] 

387 if not isinstance(crit, list) or len(crit) == 0: 

388 raise InvalidTokenError("Invalid 'crit' header: must be a non-empty list") 

389 for ext in crit: 

390 if not isinstance(ext, str): 

391 raise InvalidTokenError("Invalid 'crit' header: values must be strings") 

392 if ext not in self._supported_crit: 

393 raise InvalidTokenError(f"Unsupported critical extension: {ext}") 

394 if ext not in headers: 

395 raise InvalidTokenError( 

396 f"Critical extension '{ext}' is missing from headers" 

397 ) 

398 

399 

400_jws_global_obj = PyJWS() 

401encode = _jws_global_obj.encode 

402decode_complete = _jws_global_obj.decode_complete 

403decode = _jws_global_obj.decode 

404register_algorithm = _jws_global_obj.register_algorithm 

405unregister_algorithm = _jws_global_obj.unregister_algorithm 

406get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name 

407get_unverified_header = _jws_global_obj.get_unverified_header