Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/api_jws.py: 65%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

179 statements  

1from __future__ import annotations 

2 

3import binascii 

4import json 

5import warnings 

6from collections.abc import Sequence 

7from typing import TYPE_CHECKING, Any 

8 

9from .algorithms import ( 

10 Algorithm, 

11 get_default_algorithms, 

12 has_crypto, 

13 requires_cryptography, 

14) 

15from .api_jwk import PyJWK 

16from .exceptions import ( 

17 DecodeError, 

18 InvalidAlgorithmError, 

19 InvalidSignatureError, 

20 InvalidTokenError, 

21) 

22from .utils import base64url_decode, base64url_encode 

23from .warnings import RemovedInPyjwt3Warning 

24 

25if TYPE_CHECKING: 

26 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

27 from .types import SigOptions 

28 

29 

30class PyJWS: 

31 header_typ = "JWT" 

32 

33 def __init__( 

34 self, 

35 algorithms: Sequence[str] | None = None, 

36 options: SigOptions | None = None, 

37 ) -> None: 

38 self._algorithms = get_default_algorithms() 

39 self._valid_algs = ( 

40 set(algorithms) if algorithms is not None else set(self._algorithms) 

41 ) 

42 

43 # Remove algorithms that aren't on the whitelist 

44 for key in list(self._algorithms.keys()): 

45 if key not in self._valid_algs: 

46 del self._algorithms[key] 

47 

48 self.options: SigOptions = self._get_default_options() 

49 if options is not None: 

50 self.options = {**self.options, **options} 

51 

52 @staticmethod 

53 def _get_default_options() -> SigOptions: 

54 return {"verify_signature": True} 

55 

56 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: 

57 """ 

58 Registers a new Algorithm for use when creating and verifying tokens. 

59 

60 :param str alg_id: the ID of the Algorithm 

61 :param alg_obj: the Algorithm object 

62 :type alg_obj: Algorithm 

63 """ 

64 if alg_id in self._algorithms: 

65 raise ValueError("Algorithm already has a handler.") 

66 

67 if not isinstance(alg_obj, Algorithm): 

68 raise TypeError("Object is not of type `Algorithm`") 

69 

70 self._algorithms[alg_id] = alg_obj 

71 self._valid_algs.add(alg_id) 

72 

73 def unregister_algorithm(self, alg_id: str) -> None: 

74 """ 

75 Unregisters an Algorithm for use when creating and verifying tokens 

76 :param str alg_id: the ID of the Algorithm 

77 :raises KeyError: if algorithm is not registered. 

78 """ 

79 if alg_id not in self._algorithms: 

80 raise KeyError( 

81 "The specified algorithm could not be removed" 

82 " because it is not registered." 

83 ) 

84 

85 del self._algorithms[alg_id] 

86 self._valid_algs.remove(alg_id) 

87 

88 def get_algorithms(self) -> list[str]: 

89 """ 

90 Returns a list of supported values for the `alg` parameter. 

91 

92 :rtype: list[str] 

93 """ 

94 return list(self._valid_algs) 

95 

96 def get_algorithm_by_name(self, alg_name: str) -> Algorithm: 

97 """ 

98 For a given string name, return the matching Algorithm object. 

99 

100 Example usage: 

101 >>> jws_obj = PyJWS() 

102 >>> jws_obj.get_algorithm_by_name("RS256") 

103 

104 :param alg_name: The name of the algorithm to retrieve 

105 :type alg_name: str 

106 :rtype: Algorithm 

107 """ 

108 try: 

109 return self._algorithms[alg_name] 

110 except KeyError as e: 

111 if not has_crypto and alg_name in requires_cryptography: 

112 raise NotImplementedError( 

113 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" 

114 ) from e 

115 raise NotImplementedError("Algorithm not supported") from e 

116 

117 def encode( 

118 self, 

119 payload: bytes, 

120 key: AllowedPrivateKeys | PyJWK | str | bytes, 

121 algorithm: str | None = "HS256", 

122 headers: dict[str, Any] | None = None, 

123 json_encoder: type[json.JSONEncoder] | None = None, 

124 is_payload_detached: bool = False, 

125 sort_headers: bool = True, 

126 ) -> str: 

127 segments: list[bytes] = [] 

128 

129 # declare a new var to narrow the type for type checkers 

130 if algorithm is None: 

131 if isinstance(key, PyJWK): 

132 algorithm_ = key.algorithm_name 

133 else: 

134 algorithm_ = "none" 

135 else: 

136 algorithm_ = algorithm 

137 

138 # Prefer headers values if present to function parameters. 

139 if headers: 

140 headers_alg = headers.get("alg") 

141 if headers_alg: 

142 algorithm_ = headers["alg"] 

143 

144 headers_b64 = headers.get("b64") 

145 if headers_b64 is False: 

146 is_payload_detached = True 

147 

148 # Header 

149 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} 

150 

151 if headers: 

152 self._validate_headers(headers) 

153 header.update(headers) 

154 

155 if not header["typ"]: 

156 del header["typ"] 

157 

158 if is_payload_detached: 

159 header["b64"] = False 

160 elif "b64" in header: 

161 # True is the standard value for b64, so no need for it 

162 del header["b64"] 

163 

164 json_header = json.dumps( 

165 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers 

166 ).encode() 

167 

168 segments.append(base64url_encode(json_header)) 

169 

170 if is_payload_detached: 

171 msg_payload = payload 

172 else: 

173 msg_payload = base64url_encode(payload) 

174 segments.append(msg_payload) 

175 

176 # Segments 

177 signing_input = b".".join(segments) 

178 

179 alg_obj = self.get_algorithm_by_name(algorithm_) 

180 if isinstance(key, PyJWK): 

181 key = key.key 

182 key = alg_obj.prepare_key(key) 

183 signature = alg_obj.sign(signing_input, key) 

184 

185 segments.append(base64url_encode(signature)) 

186 

187 # Don't put the payload content inside the encoded token when detached 

188 if is_payload_detached: 

189 segments[1] = b"" 

190 encoded_string = b".".join(segments) 

191 

192 return encoded_string.decode("utf-8") 

193 

194 def decode_complete( 

195 self, 

196 jwt: str | bytes, 

197 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

198 algorithms: Sequence[str] | None = None, 

199 options: SigOptions | None = None, 

200 detached_payload: bytes | None = None, 

201 **kwargs: dict[str, Any], 

202 ) -> dict[str, Any]: 

203 if kwargs: 

204 warnings.warn( 

205 "passing additional kwargs to decode_complete() is deprecated " 

206 "and will be removed in pyjwt version 3. " 

207 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

208 RemovedInPyjwt3Warning, 

209 stacklevel=2, 

210 ) 

211 merged_options: SigOptions 

212 if options is None: 

213 merged_options = self.options 

214 else: 

215 merged_options = {**self.options, **options} 

216 

217 verify_signature = merged_options["verify_signature"] 

218 

219 if verify_signature and not algorithms and not isinstance(key, PyJWK): 

220 raise DecodeError( 

221 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

222 ) 

223 

224 payload, signing_input, header, signature = self._load(jwt) 

225 

226 if header.get("b64", True) is False: 

227 if detached_payload is None: 

228 raise DecodeError( 

229 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' 

230 ) 

231 payload = detached_payload 

232 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) 

233 

234 if verify_signature: 

235 self._verify_signature(signing_input, header, signature, key, algorithms) 

236 

237 return { 

238 "payload": payload, 

239 "header": header, 

240 "signature": signature, 

241 } 

242 

243 def decode( 

244 self, 

245 jwt: str | bytes, 

246 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

247 algorithms: Sequence[str] | None = None, 

248 options: SigOptions | None = None, 

249 detached_payload: bytes | None = None, 

250 **kwargs: dict[str, Any], 

251 ) -> Any: 

252 if kwargs: 

253 warnings.warn( 

254 "passing additional kwargs to decode() is deprecated " 

255 "and will be removed in pyjwt version 3. " 

256 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

257 RemovedInPyjwt3Warning, 

258 stacklevel=2, 

259 ) 

260 decoded = self.decode_complete( 

261 jwt, key, algorithms, options, detached_payload=detached_payload 

262 ) 

263 return decoded["payload"] 

264 

265 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: 

266 """Returns back the JWT header parameters as a `dict` 

267 

268 Note: The signature is not verified so the header parameters 

269 should not be fully trusted until signature verification is complete 

270 """ 

271 headers = self._load(jwt)[2] 

272 self._validate_headers(headers) 

273 

274 return headers 

275 

276 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: 

277 if isinstance(jwt, str): 

278 jwt = jwt.encode("utf-8") 

279 

280 if not isinstance(jwt, bytes): 

281 raise DecodeError(f"Invalid token type. Token must be a {bytes}") 

282 

283 try: 

284 signing_input, crypto_segment = jwt.rsplit(b".", 1) 

285 header_segment, payload_segment = signing_input.split(b".", 1) 

286 except ValueError as err: 

287 raise DecodeError("Not enough segments") from err 

288 

289 try: 

290 header_data = base64url_decode(header_segment) 

291 except (TypeError, binascii.Error) as err: 

292 raise DecodeError("Invalid header padding") from err 

293 

294 try: 

295 header: dict[str, Any] = json.loads(header_data) 

296 except ValueError as e: 

297 raise DecodeError(f"Invalid header string: {e}") from e 

298 

299 if not isinstance(header, dict): 

300 raise DecodeError("Invalid header string: must be a json object") 

301 

302 try: 

303 payload = base64url_decode(payload_segment) 

304 except (TypeError, binascii.Error) as err: 

305 raise DecodeError("Invalid payload padding") from err 

306 

307 try: 

308 signature = base64url_decode(crypto_segment) 

309 except (TypeError, binascii.Error) as err: 

310 raise DecodeError("Invalid crypto padding") from err 

311 

312 return (payload, signing_input, header, signature) 

313 

314 def _verify_signature( 

315 self, 

316 signing_input: bytes, 

317 header: dict[str, Any], 

318 signature: bytes, 

319 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

320 algorithms: Sequence[str] | None = None, 

321 ) -> None: 

322 if algorithms is None and isinstance(key, PyJWK): 

323 algorithms = [key.algorithm_name] 

324 try: 

325 alg = header["alg"] 

326 except KeyError: 

327 raise InvalidAlgorithmError("Algorithm not specified") from None 

328 

329 if not alg or (algorithms is not None and alg not in algorithms): 

330 raise InvalidAlgorithmError("The specified alg value is not allowed") 

331 

332 if isinstance(key, PyJWK): 

333 alg_obj = key.Algorithm 

334 prepared_key = key.key 

335 else: 

336 try: 

337 alg_obj = self.get_algorithm_by_name(alg) 

338 except NotImplementedError as e: 

339 raise InvalidAlgorithmError("Algorithm not supported") from e 

340 prepared_key = alg_obj.prepare_key(key) 

341 

342 if not alg_obj.verify(signing_input, prepared_key, signature): 

343 raise InvalidSignatureError("Signature verification failed") 

344 

345 def _validate_headers(self, headers: dict[str, Any]) -> None: 

346 if "kid" in headers: 

347 self._validate_kid(headers["kid"]) 

348 

349 def _validate_kid(self, kid: Any) -> None: 

350 if not isinstance(kid, str): 

351 raise InvalidTokenError("Key ID header parameter must be a string") 

352 

353 

354_jws_global_obj = PyJWS() 

355encode = _jws_global_obj.encode 

356decode_complete = _jws_global_obj.decode_complete 

357decode = _jws_global_obj.decode 

358register_algorithm = _jws_global_obj.register_algorithm 

359unregister_algorithm = _jws_global_obj.unregister_algorithm 

360get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name 

361get_unverified_header = _jws_global_obj.get_unverified_header