Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/api_jws.py: 67%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

189 statements  

1from __future__ import annotations 

2 

3import binascii 

4import json 

5import warnings 

6from collections.abc import Sequence 

7from typing import TYPE_CHECKING, Any 

8 

9from .algorithms import ( 

10 Algorithm, 

11 get_default_algorithms, 

12 has_crypto, 

13 requires_cryptography, 

14) 

15from .api_jwk import PyJWK 

16from .exceptions import ( 

17 DecodeError, 

18 InvalidAlgorithmError, 

19 InvalidKeyError, 

20 InvalidSignatureError, 

21 InvalidTokenError, 

22) 

23from .utils import base64url_decode, base64url_encode 

24from .warnings import InsecureKeyLengthWarning, RemovedInPyjwt3Warning 

25 

26if TYPE_CHECKING: 

27 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

28 from .types import SigOptions 

29 

30 

31class PyJWS: 

32 header_typ = "JWT" 

33 

34 def __init__( 

35 self, 

36 algorithms: Sequence[str] | None = None, 

37 options: SigOptions | None = None, 

38 ) -> None: 

39 self._algorithms = get_default_algorithms() 

40 self._valid_algs = ( 

41 set(algorithms) if algorithms is not None else set(self._algorithms) 

42 ) 

43 

44 # Remove algorithms that aren't on the whitelist 

45 for key in list(self._algorithms.keys()): 

46 if key not in self._valid_algs: 

47 del self._algorithms[key] 

48 

49 self.options: SigOptions = self._get_default_options() 

50 if options is not None: 

51 self.options = {**self.options, **options} 

52 

53 @staticmethod 

54 def _get_default_options() -> SigOptions: 

55 return {"verify_signature": True, "enforce_minimum_key_length": False} 

56 

57 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: 

58 """ 

59 Registers a new Algorithm for use when creating and verifying tokens. 

60 

61 :param str alg_id: the ID of the Algorithm 

62 :param alg_obj: the Algorithm object 

63 :type alg_obj: Algorithm 

64 """ 

65 if alg_id in self._algorithms: 

66 raise ValueError("Algorithm already has a handler.") 

67 

68 if not isinstance(alg_obj, Algorithm): 

69 raise TypeError("Object is not of type `Algorithm`") 

70 

71 self._algorithms[alg_id] = alg_obj 

72 self._valid_algs.add(alg_id) 

73 

74 def unregister_algorithm(self, alg_id: str) -> None: 

75 """ 

76 Unregisters an Algorithm for use when creating and verifying tokens 

77 :param str alg_id: the ID of the Algorithm 

78 :raises KeyError: if algorithm is not registered. 

79 """ 

80 if alg_id not in self._algorithms: 

81 raise KeyError( 

82 "The specified algorithm could not be removed" 

83 " because it is not registered." 

84 ) 

85 

86 del self._algorithms[alg_id] 

87 self._valid_algs.remove(alg_id) 

88 

89 def get_algorithms(self) -> list[str]: 

90 """ 

91 Returns a list of supported values for the `alg` parameter. 

92 

93 :rtype: list[str] 

94 """ 

95 return list(self._valid_algs) 

96 

97 def get_algorithm_by_name(self, alg_name: str) -> Algorithm: 

98 """ 

99 For a given string name, return the matching Algorithm object. 

100 

101 Example usage: 

102 >>> jws_obj = PyJWS() 

103 >>> jws_obj.get_algorithm_by_name("RS256") 

104 

105 :param alg_name: The name of the algorithm to retrieve 

106 :type alg_name: str 

107 :rtype: Algorithm 

108 """ 

109 try: 

110 return self._algorithms[alg_name] 

111 except KeyError as e: 

112 if not has_crypto and alg_name in requires_cryptography: 

113 raise NotImplementedError( 

114 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" 

115 ) from e 

116 raise NotImplementedError("Algorithm not supported") from e 

117 

118 def encode( 

119 self, 

120 payload: bytes, 

121 key: AllowedPrivateKeys | PyJWK | str | bytes, 

122 algorithm: str | None = "HS256", 

123 headers: dict[str, Any] | None = None, 

124 json_encoder: type[json.JSONEncoder] | None = None, 

125 is_payload_detached: bool = False, 

126 sort_headers: bool = True, 

127 ) -> str: 

128 segments: list[bytes] = [] 

129 

130 # declare a new var to narrow the type for type checkers 

131 if algorithm is None: 

132 if isinstance(key, PyJWK): 

133 algorithm_ = key.algorithm_name 

134 else: 

135 algorithm_ = "none" 

136 else: 

137 algorithm_ = algorithm 

138 

139 # Prefer headers values if present to function parameters. 

140 if headers: 

141 headers_alg = headers.get("alg") 

142 if headers_alg: 

143 algorithm_ = headers["alg"] 

144 

145 headers_b64 = headers.get("b64") 

146 if headers_b64 is False: 

147 is_payload_detached = True 

148 

149 # Header 

150 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} 

151 

152 if headers: 

153 self._validate_headers(headers) 

154 header.update(headers) 

155 

156 if not header["typ"]: 

157 del header["typ"] 

158 

159 if is_payload_detached: 

160 header["b64"] = False 

161 elif "b64" in header: 

162 # True is the standard value for b64, so no need for it 

163 del header["b64"] 

164 

165 json_header = json.dumps( 

166 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers 

167 ).encode() 

168 

169 segments.append(base64url_encode(json_header)) 

170 

171 if is_payload_detached: 

172 msg_payload = payload 

173 else: 

174 msg_payload = base64url_encode(payload) 

175 segments.append(msg_payload) 

176 

177 # Segments 

178 signing_input = b".".join(segments) 

179 

180 alg_obj = self.get_algorithm_by_name(algorithm_) 

181 if isinstance(key, PyJWK): 

182 key = key.key 

183 key = alg_obj.prepare_key(key) 

184 

185 key_length_msg = alg_obj.check_key_length(key) 

186 if key_length_msg: 

187 if self.options.get("enforce_minimum_key_length", False): 

188 raise InvalidKeyError(key_length_msg) 

189 else: 

190 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=2) 

191 

192 signature = alg_obj.sign(signing_input, key) 

193 

194 segments.append(base64url_encode(signature)) 

195 

196 # Don't put the payload content inside the encoded token when detached 

197 if is_payload_detached: 

198 segments[1] = b"" 

199 encoded_string = b".".join(segments) 

200 

201 return encoded_string.decode("utf-8") 

202 

203 def decode_complete( 

204 self, 

205 jwt: str | bytes, 

206 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

207 algorithms: Sequence[str] | None = None, 

208 options: SigOptions | None = None, 

209 detached_payload: bytes | None = None, 

210 **kwargs: dict[str, Any], 

211 ) -> dict[str, Any]: 

212 if kwargs: 

213 warnings.warn( 

214 "passing additional kwargs to decode_complete() is deprecated " 

215 "and will be removed in pyjwt version 3. " 

216 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

217 RemovedInPyjwt3Warning, 

218 stacklevel=2, 

219 ) 

220 merged_options: SigOptions 

221 if options is None: 

222 merged_options = self.options 

223 else: 

224 merged_options = {**self.options, **options} 

225 

226 verify_signature = merged_options["verify_signature"] 

227 

228 if verify_signature and not algorithms and not isinstance(key, PyJWK): 

229 raise DecodeError( 

230 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

231 ) 

232 

233 payload, signing_input, header, signature = self._load(jwt) 

234 

235 if header.get("b64", True) is False: 

236 if detached_payload is None: 

237 raise DecodeError( 

238 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' 

239 ) 

240 payload = detached_payload 

241 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) 

242 

243 if verify_signature: 

244 self._verify_signature(signing_input, header, signature, key, algorithms) 

245 

246 return { 

247 "payload": payload, 

248 "header": header, 

249 "signature": signature, 

250 } 

251 

252 def decode( 

253 self, 

254 jwt: str | bytes, 

255 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

256 algorithms: Sequence[str] | None = None, 

257 options: SigOptions | None = None, 

258 detached_payload: bytes | None = None, 

259 **kwargs: dict[str, Any], 

260 ) -> Any: 

261 if kwargs: 

262 warnings.warn( 

263 "passing additional kwargs to decode() is deprecated " 

264 "and will be removed in pyjwt version 3. " 

265 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

266 RemovedInPyjwt3Warning, 

267 stacklevel=2, 

268 ) 

269 decoded = self.decode_complete( 

270 jwt, key, algorithms, options, detached_payload=detached_payload 

271 ) 

272 return decoded["payload"] 

273 

274 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: 

275 """Returns back the JWT header parameters as a `dict` 

276 

277 Note: The signature is not verified so the header parameters 

278 should not be fully trusted until signature verification is complete 

279 """ 

280 headers = self._load(jwt)[2] 

281 self._validate_headers(headers) 

282 

283 return headers 

284 

285 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: 

286 if isinstance(jwt, str): 

287 jwt = jwt.encode("utf-8") 

288 

289 if not isinstance(jwt, bytes): 

290 raise DecodeError(f"Invalid token type. Token must be a {bytes}") 

291 

292 try: 

293 signing_input, crypto_segment = jwt.rsplit(b".", 1) 

294 header_segment, payload_segment = signing_input.split(b".", 1) 

295 except ValueError as err: 

296 raise DecodeError("Not enough segments") from err 

297 

298 try: 

299 header_data = base64url_decode(header_segment) 

300 except (TypeError, binascii.Error) as err: 

301 raise DecodeError("Invalid header padding") from err 

302 

303 try: 

304 header: dict[str, Any] = json.loads(header_data) 

305 except ValueError as e: 

306 raise DecodeError(f"Invalid header string: {e}") from e 

307 

308 if not isinstance(header, dict): 

309 raise DecodeError("Invalid header string: must be a json object") 

310 

311 try: 

312 payload = base64url_decode(payload_segment) 

313 except (TypeError, binascii.Error) as err: 

314 raise DecodeError("Invalid payload padding") from err 

315 

316 try: 

317 signature = base64url_decode(crypto_segment) 

318 except (TypeError, binascii.Error) as err: 

319 raise DecodeError("Invalid crypto padding") from err 

320 

321 return (payload, signing_input, header, signature) 

322 

323 def _verify_signature( 

324 self, 

325 signing_input: bytes, 

326 header: dict[str, Any], 

327 signature: bytes, 

328 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

329 algorithms: Sequence[str] | None = None, 

330 ) -> None: 

331 if algorithms is None and isinstance(key, PyJWK): 

332 algorithms = [key.algorithm_name] 

333 try: 

334 alg = header["alg"] 

335 except KeyError: 

336 raise InvalidAlgorithmError("Algorithm not specified") from None 

337 

338 if not alg or (algorithms is not None and alg not in algorithms): 

339 raise InvalidAlgorithmError("The specified alg value is not allowed") 

340 

341 if isinstance(key, PyJWK): 

342 alg_obj = key.Algorithm 

343 prepared_key = key.key 

344 else: 

345 try: 

346 alg_obj = self.get_algorithm_by_name(alg) 

347 except NotImplementedError as e: 

348 raise InvalidAlgorithmError("Algorithm not supported") from e 

349 prepared_key = alg_obj.prepare_key(key) 

350 

351 key_length_msg = alg_obj.check_key_length(prepared_key) 

352 if key_length_msg: 

353 if self.options.get("enforce_minimum_key_length", False): 

354 raise InvalidKeyError(key_length_msg) 

355 else: 

356 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=4) 

357 

358 if not alg_obj.verify(signing_input, prepared_key, signature): 

359 raise InvalidSignatureError("Signature verification failed") 

360 

361 def _validate_headers(self, headers: dict[str, Any]) -> None: 

362 if "kid" in headers: 

363 self._validate_kid(headers["kid"]) 

364 

365 def _validate_kid(self, kid: Any) -> None: 

366 if not isinstance(kid, str): 

367 raise InvalidTokenError("Key ID header parameter must be a string") 

368 

369 

370_jws_global_obj = PyJWS() 

371encode = _jws_global_obj.encode 

372decode_complete = _jws_global_obj.decode_complete 

373decode = _jws_global_obj.decode 

374register_algorithm = _jws_global_obj.register_algorithm 

375unregister_algorithm = _jws_global_obj.unregister_algorithm 

376get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name 

377get_unverified_header = _jws_global_obj.get_unverified_header