Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jwt/api_jws.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

171 statements  

1from __future__ import annotations 

2 

3import binascii 

4import json 

5import warnings 

6from typing import TYPE_CHECKING, Any 

7 

8from .algorithms import ( 

9 Algorithm, 

10 get_default_algorithms, 

11 has_crypto, 

12 requires_cryptography, 

13) 

14from .api_jwk import PyJWK 

15from .exceptions import ( 

16 DecodeError, 

17 InvalidAlgorithmError, 

18 InvalidSignatureError, 

19 InvalidTokenError, 

20) 

21from .utils import base64url_decode, base64url_encode 

22from .warnings import RemovedInPyjwt3Warning 

23 

24if TYPE_CHECKING: 

25 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

26 

27 

28class PyJWS: 

29 header_typ = "JWT" 

30 

31 def __init__( 

32 self, 

33 algorithms: list[str] | None = None, 

34 options: dict[str, Any] | None = None, 

35 ) -> None: 

36 self._algorithms = get_default_algorithms() 

37 self._valid_algs = ( 

38 set(algorithms) if algorithms is not None else set(self._algorithms) 

39 ) 

40 

41 # Remove algorithms that aren't on the whitelist 

42 for key in list(self._algorithms.keys()): 

43 if key not in self._valid_algs: 

44 del self._algorithms[key] 

45 

46 if options is None: 

47 options = {} 

48 self.options = {**self._get_default_options(), **options} 

49 

50 @staticmethod 

51 def _get_default_options() -> dict[str, bool]: 

52 return {"verify_signature": True} 

53 

54 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: 

55 """ 

56 Registers a new Algorithm for use when creating and verifying tokens. 

57 """ 

58 if alg_id in self._algorithms: 

59 raise ValueError("Algorithm already has a handler.") 

60 

61 if not isinstance(alg_obj, Algorithm): 

62 raise TypeError("Object is not of type `Algorithm`") 

63 

64 self._algorithms[alg_id] = alg_obj 

65 self._valid_algs.add(alg_id) 

66 

67 def unregister_algorithm(self, alg_id: str) -> None: 

68 """ 

69 Unregisters an Algorithm for use when creating and verifying tokens 

70 Throws KeyError if algorithm is not registered. 

71 """ 

72 if alg_id not in self._algorithms: 

73 raise KeyError( 

74 "The specified algorithm could not be removed" 

75 " because it is not registered." 

76 ) 

77 

78 del self._algorithms[alg_id] 

79 self._valid_algs.remove(alg_id) 

80 

81 def get_algorithms(self) -> list[str]: 

82 """ 

83 Returns a list of supported values for the 'alg' parameter. 

84 """ 

85 return list(self._valid_algs) 

86 

87 def get_algorithm_by_name(self, alg_name: str) -> Algorithm: 

88 """ 

89 For a given string name, return the matching Algorithm object. 

90 

91 Example usage: 

92 

93 >>> jws_obj.get_algorithm_by_name("RS256") 

94 """ 

95 try: 

96 return self._algorithms[alg_name] 

97 except KeyError as e: 

98 if not has_crypto and alg_name in requires_cryptography: 

99 raise NotImplementedError( 

100 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" 

101 ) from e 

102 raise NotImplementedError("Algorithm not supported") from e 

103 

104 def encode( 

105 self, 

106 payload: bytes, 

107 key: AllowedPrivateKeys | str | bytes, 

108 algorithm: str | None = "HS256", 

109 headers: dict[str, Any] | None = None, 

110 json_encoder: type[json.JSONEncoder] | None = None, 

111 is_payload_detached: bool = False, 

112 sort_headers: bool = True, 

113 ) -> str: 

114 segments = [] 

115 

116 # declare a new var to narrow the type for type checkers 

117 algorithm_: str = algorithm if algorithm is not None else "none" 

118 

119 # Prefer headers values if present to function parameters. 

120 if headers: 

121 headers_alg = headers.get("alg") 

122 if headers_alg: 

123 algorithm_ = headers["alg"] 

124 

125 headers_b64 = headers.get("b64") 

126 if headers_b64 is False: 

127 is_payload_detached = True 

128 

129 # Header 

130 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} 

131 

132 if headers: 

133 self._validate_headers(headers) 

134 header.update(headers) 

135 

136 if not header["typ"]: 

137 del header["typ"] 

138 

139 if is_payload_detached: 

140 header["b64"] = False 

141 elif "b64" in header: 

142 # True is the standard value for b64, so no need for it 

143 del header["b64"] 

144 

145 json_header = json.dumps( 

146 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers 

147 ).encode() 

148 

149 segments.append(base64url_encode(json_header)) 

150 

151 if is_payload_detached: 

152 msg_payload = payload 

153 else: 

154 msg_payload = base64url_encode(payload) 

155 segments.append(msg_payload) 

156 

157 # Segments 

158 signing_input = b".".join(segments) 

159 

160 alg_obj = self.get_algorithm_by_name(algorithm_) 

161 key = alg_obj.prepare_key(key) 

162 signature = alg_obj.sign(signing_input, key) 

163 

164 segments.append(base64url_encode(signature)) 

165 

166 # Don't put the payload content inside the encoded token when detached 

167 if is_payload_detached: 

168 segments[1] = b"" 

169 encoded_string = b".".join(segments) 

170 

171 return encoded_string.decode("utf-8") 

172 

173 def decode_complete( 

174 self, 

175 jwt: str | bytes, 

176 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

177 algorithms: list[str] | None = None, 

178 options: dict[str, Any] | None = None, 

179 detached_payload: bytes | None = None, 

180 **kwargs, 

181 ) -> dict[str, Any]: 

182 if kwargs: 

183 warnings.warn( 

184 "passing additional kwargs to decode_complete() is deprecated " 

185 "and will be removed in pyjwt version 3. " 

186 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

187 RemovedInPyjwt3Warning, 

188 ) 

189 if options is None: 

190 options = {} 

191 merged_options = {**self.options, **options} 

192 verify_signature = merged_options["verify_signature"] 

193 

194 if verify_signature and not algorithms and not isinstance(key, PyJWK): 

195 raise DecodeError( 

196 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

197 ) 

198 

199 payload, signing_input, header, signature = self._load(jwt) 

200 

201 if header.get("b64", True) is False: 

202 if detached_payload is None: 

203 raise DecodeError( 

204 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' 

205 ) 

206 payload = detached_payload 

207 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) 

208 

209 if verify_signature: 

210 self._verify_signature(signing_input, header, signature, key, algorithms) 

211 

212 return { 

213 "payload": payload, 

214 "header": header, 

215 "signature": signature, 

216 } 

217 

218 def decode( 

219 self, 

220 jwt: str | bytes, 

221 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

222 algorithms: list[str] | None = None, 

223 options: dict[str, Any] | None = None, 

224 detached_payload: bytes | None = None, 

225 **kwargs, 

226 ) -> Any: 

227 if kwargs: 

228 warnings.warn( 

229 "passing additional kwargs to decode() is deprecated " 

230 "and will be removed in pyjwt version 3. " 

231 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

232 RemovedInPyjwt3Warning, 

233 ) 

234 decoded = self.decode_complete( 

235 jwt, key, algorithms, options, detached_payload=detached_payload 

236 ) 

237 return decoded["payload"] 

238 

239 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: 

240 """Returns back the JWT header parameters as a dict() 

241 

242 Note: The signature is not verified so the header parameters 

243 should not be fully trusted until signature verification is complete 

244 """ 

245 headers = self._load(jwt)[2] 

246 self._validate_headers(headers) 

247 

248 return headers 

249 

250 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: 

251 if isinstance(jwt, str): 

252 jwt = jwt.encode("utf-8") 

253 

254 if not isinstance(jwt, bytes): 

255 raise DecodeError(f"Invalid token type. Token must be a {bytes}") 

256 

257 try: 

258 signing_input, crypto_segment = jwt.rsplit(b".", 1) 

259 header_segment, payload_segment = signing_input.split(b".", 1) 

260 except ValueError as err: 

261 raise DecodeError("Not enough segments") from err 

262 

263 try: 

264 header_data = base64url_decode(header_segment) 

265 except (TypeError, binascii.Error) as err: 

266 raise DecodeError("Invalid header padding") from err 

267 

268 try: 

269 header = json.loads(header_data) 

270 except ValueError as e: 

271 raise DecodeError(f"Invalid header string: {e}") from e 

272 

273 if not isinstance(header, dict): 

274 raise DecodeError("Invalid header string: must be a json object") 

275 

276 try: 

277 payload = base64url_decode(payload_segment) 

278 except (TypeError, binascii.Error) as err: 

279 raise DecodeError("Invalid payload padding") from err 

280 

281 try: 

282 signature = base64url_decode(crypto_segment) 

283 except (TypeError, binascii.Error) as err: 

284 raise DecodeError("Invalid crypto padding") from err 

285 

286 return (payload, signing_input, header, signature) 

287 

288 def _verify_signature( 

289 self, 

290 signing_input: bytes, 

291 header: dict[str, Any], 

292 signature: bytes, 

293 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

294 algorithms: list[str] | None = None, 

295 ) -> None: 

296 if algorithms is None and isinstance(key, PyJWK): 

297 algorithms = [key.algorithm_name] 

298 try: 

299 alg = header["alg"] 

300 except KeyError: 

301 raise InvalidAlgorithmError("Algorithm not specified") 

302 

303 if not alg or (algorithms is not None and alg not in algorithms): 

304 raise InvalidAlgorithmError("The specified alg value is not allowed") 

305 

306 if isinstance(key, PyJWK): 

307 alg_obj = key.Algorithm 

308 prepared_key = key.key 

309 else: 

310 try: 

311 alg_obj = self.get_algorithm_by_name(alg) 

312 except NotImplementedError as e: 

313 raise InvalidAlgorithmError("Algorithm not supported") from e 

314 prepared_key = alg_obj.prepare_key(key) 

315 

316 if not alg_obj.verify(signing_input, prepared_key, signature): 

317 raise InvalidSignatureError("Signature verification failed") 

318 

319 def _validate_headers(self, headers: dict[str, Any]) -> None: 

320 if "kid" in headers: 

321 self._validate_kid(headers["kid"]) 

322 

323 def _validate_kid(self, kid: Any) -> None: 

324 if not isinstance(kid, str): 

325 raise InvalidTokenError("Key ID header parameter must be a string") 

326 

327 

328_jws_global_obj = PyJWS() 

329encode = _jws_global_obj.encode 

330decode_complete = _jws_global_obj.decode_complete 

331decode = _jws_global_obj.decode 

332register_algorithm = _jws_global_obj.register_algorithm 

333unregister_algorithm = _jws_global_obj.unregister_algorithm 

334get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name 

335get_unverified_header = _jws_global_obj.get_unverified_header