Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/api_jws.py: 61%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

178 statements  

1from __future__ import annotations 

2 

3import binascii 

4import json 

5import warnings 

6from collections.abc import Sequence 

7from typing import TYPE_CHECKING, Any 

8 

9from .algorithms import ( 

10 Algorithm, 

11 get_default_algorithms, 

12 has_crypto, 

13 requires_cryptography, 

14) 

15from .api_jwk import PyJWK 

16from .exceptions import ( 

17 DecodeError, 

18 InvalidAlgorithmError, 

19 InvalidSignatureError, 

20 InvalidTokenError, 

21) 

22from .utils import base64url_decode, base64url_encode 

23from .warnings import RemovedInPyjwt3Warning 

24 

25if TYPE_CHECKING: 

26 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

27 

28 

29class PyJWS: 

30 header_typ = "JWT" 

31 

32 def __init__( 

33 self, 

34 algorithms: Sequence[str] | None = None, 

35 options: dict[str, Any] | None = None, 

36 ) -> None: 

37 self._algorithms = get_default_algorithms() 

38 self._valid_algs = ( 

39 set(algorithms) if algorithms is not None else set(self._algorithms) 

40 ) 

41 

42 # Remove algorithms that aren't on the whitelist 

43 for key in list(self._algorithms.keys()): 

44 if key not in self._valid_algs: 

45 del self._algorithms[key] 

46 

47 if options is None: 

48 options = {} 

49 self.options = {**self._get_default_options(), **options} 

50 

51 @staticmethod 

52 def _get_default_options() -> dict[str, bool]: 

53 return {"verify_signature": True} 

54 

55 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: 

56 """ 

57 Registers a new Algorithm for use when creating and verifying tokens. 

58 """ 

59 if alg_id in self._algorithms: 

60 raise ValueError("Algorithm already has a handler.") 

61 

62 if not isinstance(alg_obj, Algorithm): 

63 raise TypeError("Object is not of type `Algorithm`") 

64 

65 self._algorithms[alg_id] = alg_obj 

66 self._valid_algs.add(alg_id) 

67 

68 def unregister_algorithm(self, alg_id: str) -> None: 

69 """ 

70 Unregisters an Algorithm for use when creating and verifying tokens 

71 Throws KeyError if algorithm is not registered. 

72 """ 

73 if alg_id not in self._algorithms: 

74 raise KeyError( 

75 "The specified algorithm could not be removed" 

76 " because it is not registered." 

77 ) 

78 

79 del self._algorithms[alg_id] 

80 self._valid_algs.remove(alg_id) 

81 

82 def get_algorithms(self) -> list[str]: 

83 """ 

84 Returns a list of supported values for the 'alg' parameter. 

85 """ 

86 return list(self._valid_algs) 

87 

88 def get_algorithm_by_name(self, alg_name: str) -> Algorithm: 

89 """ 

90 For a given string name, return the matching Algorithm object. 

91 

92 Example usage: 

93 

94 >>> jws_obj.get_algorithm_by_name("RS256") 

95 """ 

96 try: 

97 return self._algorithms[alg_name] 

98 except KeyError as e: 

99 if not has_crypto and alg_name in requires_cryptography: 

100 raise NotImplementedError( 

101 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" 

102 ) from e 

103 raise NotImplementedError("Algorithm not supported") from e 

104 

105 def encode( 

106 self, 

107 payload: bytes, 

108 key: AllowedPrivateKeys | PyJWK | str | bytes, 

109 algorithm: str | None = None, 

110 headers: dict[str, Any] | None = None, 

111 json_encoder: type[json.JSONEncoder] | None = None, 

112 is_payload_detached: bool = False, 

113 sort_headers: bool = True, 

114 ) -> str: 

115 segments = [] 

116 

117 # declare a new var to narrow the type for type checkers 

118 if algorithm is None: 

119 if isinstance(key, PyJWK): 

120 algorithm_ = key.algorithm_name 

121 else: 

122 algorithm_ = "HS256" 

123 else: 

124 algorithm_ = algorithm 

125 

126 # Prefer headers values if present to function parameters. 

127 if headers: 

128 headers_alg = headers.get("alg") 

129 if headers_alg: 

130 algorithm_ = headers["alg"] 

131 

132 headers_b64 = headers.get("b64") 

133 if headers_b64 is False: 

134 is_payload_detached = True 

135 

136 # Header 

137 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} 

138 

139 if headers: 

140 self._validate_headers(headers) 

141 header.update(headers) 

142 

143 if not header["typ"]: 

144 del header["typ"] 

145 

146 if is_payload_detached: 

147 header["b64"] = False 

148 elif "b64" in header: 

149 # True is the standard value for b64, so no need for it 

150 del header["b64"] 

151 

152 json_header = json.dumps( 

153 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers 

154 ).encode() 

155 

156 segments.append(base64url_encode(json_header)) 

157 

158 if is_payload_detached: 

159 msg_payload = payload 

160 else: 

161 msg_payload = base64url_encode(payload) 

162 segments.append(msg_payload) 

163 

164 # Segments 

165 signing_input = b".".join(segments) 

166 

167 alg_obj = self.get_algorithm_by_name(algorithm_) 

168 if isinstance(key, PyJWK): 

169 key = key.key 

170 key = alg_obj.prepare_key(key) 

171 signature = alg_obj.sign(signing_input, key) 

172 

173 segments.append(base64url_encode(signature)) 

174 

175 # Don't put the payload content inside the encoded token when detached 

176 if is_payload_detached: 

177 segments[1] = b"" 

178 encoded_string = b".".join(segments) 

179 

180 return encoded_string.decode("utf-8") 

181 

182 def decode_complete( 

183 self, 

184 jwt: str | bytes, 

185 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

186 algorithms: Sequence[str] | None = None, 

187 options: dict[str, Any] | None = None, 

188 detached_payload: bytes | None = None, 

189 **kwargs, 

190 ) -> dict[str, Any]: 

191 if kwargs: 

192 warnings.warn( 

193 "passing additional kwargs to decode_complete() is deprecated " 

194 "and will be removed in pyjwt version 3. " 

195 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

196 RemovedInPyjwt3Warning, 

197 stacklevel=2, 

198 ) 

199 if options is None: 

200 options = {} 

201 merged_options = {**self.options, **options} 

202 verify_signature = merged_options["verify_signature"] 

203 

204 if verify_signature and not algorithms and not isinstance(key, PyJWK): 

205 raise DecodeError( 

206 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

207 ) 

208 

209 payload, signing_input, header, signature = self._load(jwt) 

210 

211 if header.get("b64", True) is False: 

212 if detached_payload is None: 

213 raise DecodeError( 

214 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' 

215 ) 

216 payload = detached_payload 

217 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) 

218 

219 if verify_signature: 

220 self._verify_signature(signing_input, header, signature, key, algorithms) 

221 

222 return { 

223 "payload": payload, 

224 "header": header, 

225 "signature": signature, 

226 } 

227 

228 def decode( 

229 self, 

230 jwt: str | bytes, 

231 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

232 algorithms: Sequence[str] | None = None, 

233 options: dict[str, Any] | None = None, 

234 detached_payload: bytes | None = None, 

235 **kwargs, 

236 ) -> Any: 

237 if kwargs: 

238 warnings.warn( 

239 "passing additional kwargs to decode() is deprecated " 

240 "and will be removed in pyjwt version 3. " 

241 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

242 RemovedInPyjwt3Warning, 

243 stacklevel=2, 

244 ) 

245 decoded = self.decode_complete( 

246 jwt, key, algorithms, options, detached_payload=detached_payload 

247 ) 

248 return decoded["payload"] 

249 

250 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: 

251 """Returns back the JWT header parameters as a dict() 

252 

253 Note: The signature is not verified so the header parameters 

254 should not be fully trusted until signature verification is complete 

255 """ 

256 headers = self._load(jwt)[2] 

257 self._validate_headers(headers) 

258 

259 return headers 

260 

261 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: 

262 if isinstance(jwt, str): 

263 jwt = jwt.encode("utf-8") 

264 

265 if not isinstance(jwt, bytes): 

266 raise DecodeError(f"Invalid token type. Token must be a {bytes}") 

267 

268 try: 

269 signing_input, crypto_segment = jwt.rsplit(b".", 1) 

270 header_segment, payload_segment = signing_input.split(b".", 1) 

271 except ValueError as err: 

272 raise DecodeError("Not enough segments") from err 

273 

274 try: 

275 header_data = base64url_decode(header_segment) 

276 except (TypeError, binascii.Error) as err: 

277 raise DecodeError("Invalid header padding") from err 

278 

279 try: 

280 header = json.loads(header_data) 

281 except ValueError as e: 

282 raise DecodeError(f"Invalid header string: {e}") from e 

283 

284 if not isinstance(header, dict): 

285 raise DecodeError("Invalid header string: must be a json object") 

286 

287 try: 

288 payload = base64url_decode(payload_segment) 

289 except (TypeError, binascii.Error) as err: 

290 raise DecodeError("Invalid payload padding") from err 

291 

292 try: 

293 signature = base64url_decode(crypto_segment) 

294 except (TypeError, binascii.Error) as err: 

295 raise DecodeError("Invalid crypto padding") from err 

296 

297 return (payload, signing_input, header, signature) 

298 

299 def _verify_signature( 

300 self, 

301 signing_input: bytes, 

302 header: dict[str, Any], 

303 signature: bytes, 

304 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

305 algorithms: Sequence[str] | None = None, 

306 ) -> None: 

307 if algorithms is None and isinstance(key, PyJWK): 

308 algorithms = [key.algorithm_name] 

309 try: 

310 alg = header["alg"] 

311 except KeyError: 

312 raise InvalidAlgorithmError("Algorithm not specified") from None 

313 

314 if not alg or (algorithms is not None and alg not in algorithms): 

315 raise InvalidAlgorithmError("The specified alg value is not allowed") 

316 

317 if isinstance(key, PyJWK): 

318 alg_obj = key.Algorithm 

319 prepared_key = key.key 

320 else: 

321 try: 

322 alg_obj = self.get_algorithm_by_name(alg) 

323 except NotImplementedError as e: 

324 raise InvalidAlgorithmError("Algorithm not supported") from e 

325 prepared_key = alg_obj.prepare_key(key) 

326 

327 if not alg_obj.verify(signing_input, prepared_key, signature): 

328 raise InvalidSignatureError("Signature verification failed") 

329 

330 def _validate_headers(self, headers: dict[str, Any]) -> None: 

331 if "kid" in headers: 

332 self._validate_kid(headers["kid"]) 

333 

334 def _validate_kid(self, kid: Any) -> None: 

335 if not isinstance(kid, str): 

336 raise InvalidTokenError("Key ID header parameter must be a string") 

337 

338 

339_jws_global_obj = PyJWS() 

340encode = _jws_global_obj.encode 

341decode_complete = _jws_global_obj.decode_complete 

342decode = _jws_global_obj.decode 

343register_algorithm = _jws_global_obj.register_algorithm 

344unregister_algorithm = _jws_global_obj.unregister_algorithm 

345get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name 

346get_unverified_header = _jws_global_obj.get_unverified_header