Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jwt/api_jws.py: 26%

165 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:05 +0000

1from __future__ import annotations 

2 

3import binascii 

4import json 

5import warnings 

6from typing import TYPE_CHECKING, Any 

7 

8from .algorithms import ( 

9 Algorithm, 

10 get_default_algorithms, 

11 has_crypto, 

12 requires_cryptography, 

13) 

14from .exceptions import ( 

15 DecodeError, 

16 InvalidAlgorithmError, 

17 InvalidSignatureError, 

18 InvalidTokenError, 

19) 

20from .utils import base64url_decode, base64url_encode 

21from .warnings import RemovedInPyjwt3Warning 

22 

23if TYPE_CHECKING: 

24 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

25 

26 

27class PyJWS: 

28 header_typ = "JWT" 

29 

30 def __init__( 

31 self, 

32 algorithms: list[str] | None = None, 

33 options: dict[str, Any] | None = None, 

34 ) -> None: 

35 self._algorithms = get_default_algorithms() 

36 self._valid_algs = ( 

37 set(algorithms) if algorithms is not None else set(self._algorithms) 

38 ) 

39 

40 # Remove algorithms that aren't on the whitelist 

41 for key in list(self._algorithms.keys()): 

42 if key not in self._valid_algs: 

43 del self._algorithms[key] 

44 

45 if options is None: 

46 options = {} 

47 self.options = {**self._get_default_options(), **options} 

48 

49 @staticmethod 

50 def _get_default_options() -> dict[str, bool]: 

51 return {"verify_signature": True} 

52 

53 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: 

54 """ 

55 Registers a new Algorithm for use when creating and verifying tokens. 

56 """ 

57 if alg_id in self._algorithms: 

58 raise ValueError("Algorithm already has a handler.") 

59 

60 if not isinstance(alg_obj, Algorithm): 

61 raise TypeError("Object is not of type `Algorithm`") 

62 

63 self._algorithms[alg_id] = alg_obj 

64 self._valid_algs.add(alg_id) 

65 

66 def unregister_algorithm(self, alg_id: str) -> None: 

67 """ 

68 Unregisters an Algorithm for use when creating and verifying tokens 

69 Throws KeyError if algorithm is not registered. 

70 """ 

71 if alg_id not in self._algorithms: 

72 raise KeyError( 

73 "The specified algorithm could not be removed" 

74 " because it is not registered." 

75 ) 

76 

77 del self._algorithms[alg_id] 

78 self._valid_algs.remove(alg_id) 

79 

80 def get_algorithms(self) -> list[str]: 

81 """ 

82 Returns a list of supported values for the 'alg' parameter. 

83 """ 

84 return list(self._valid_algs) 

85 

86 def get_algorithm_by_name(self, alg_name: str) -> Algorithm: 

87 """ 

88 For a given string name, return the matching Algorithm object. 

89 

90 Example usage: 

91 

92 >>> jws_obj.get_algorithm_by_name("RS256") 

93 """ 

94 try: 

95 return self._algorithms[alg_name] 

96 except KeyError as e: 

97 if not has_crypto and alg_name in requires_cryptography: 

98 raise NotImplementedError( 

99 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" 

100 ) from e 

101 raise NotImplementedError("Algorithm not supported") from e 

102 

103 def encode( 

104 self, 

105 payload: bytes, 

106 key: AllowedPrivateKeys | str | bytes, 

107 algorithm: str | None = "HS256", 

108 headers: dict[str, Any] | None = None, 

109 json_encoder: type[json.JSONEncoder] | None = None, 

110 is_payload_detached: bool = False, 

111 sort_headers: bool = True, 

112 ) -> str: 

113 segments = [] 

114 

115 # declare a new var to narrow the type for type checkers 

116 algorithm_: str = algorithm if algorithm is not None else "none" 

117 

118 # Prefer headers values if present to function parameters. 

119 if headers: 

120 headers_alg = headers.get("alg") 

121 if headers_alg: 

122 algorithm_ = headers["alg"] 

123 

124 headers_b64 = headers.get("b64") 

125 if headers_b64 is False: 

126 is_payload_detached = True 

127 

128 # Header 

129 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} 

130 

131 if headers: 

132 self._validate_headers(headers) 

133 header.update(headers) 

134 

135 if not header["typ"]: 

136 del header["typ"] 

137 

138 if is_payload_detached: 

139 header["b64"] = False 

140 elif "b64" in header: 

141 # True is the standard value for b64, so no need for it 

142 del header["b64"] 

143 

144 json_header = json.dumps( 

145 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers 

146 ).encode() 

147 

148 segments.append(base64url_encode(json_header)) 

149 

150 if is_payload_detached: 

151 msg_payload = payload 

152 else: 

153 msg_payload = base64url_encode(payload) 

154 segments.append(msg_payload) 

155 

156 # Segments 

157 signing_input = b".".join(segments) 

158 

159 alg_obj = self.get_algorithm_by_name(algorithm_) 

160 key = alg_obj.prepare_key(key) 

161 signature = alg_obj.sign(signing_input, key) 

162 

163 segments.append(base64url_encode(signature)) 

164 

165 # Don't put the payload content inside the encoded token when detached 

166 if is_payload_detached: 

167 segments[1] = b"" 

168 encoded_string = b".".join(segments) 

169 

170 return encoded_string.decode("utf-8") 

171 

172 def decode_complete( 

173 self, 

174 jwt: str | bytes, 

175 key: AllowedPublicKeys | str | bytes = "", 

176 algorithms: list[str] | None = None, 

177 options: dict[str, Any] | None = None, 

178 detached_payload: bytes | None = None, 

179 **kwargs, 

180 ) -> dict[str, Any]: 

181 if kwargs: 

182 warnings.warn( 

183 "passing additional kwargs to decode_complete() is deprecated " 

184 "and will be removed in pyjwt version 3. " 

185 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

186 RemovedInPyjwt3Warning, 

187 ) 

188 if options is None: 

189 options = {} 

190 merged_options = {**self.options, **options} 

191 verify_signature = merged_options["verify_signature"] 

192 

193 if verify_signature and not algorithms: 

194 raise DecodeError( 

195 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

196 ) 

197 

198 payload, signing_input, header, signature = self._load(jwt) 

199 

200 if header.get("b64", True) is False: 

201 if detached_payload is None: 

202 raise DecodeError( 

203 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' 

204 ) 

205 payload = detached_payload 

206 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) 

207 

208 if verify_signature: 

209 self._verify_signature(signing_input, header, signature, key, algorithms) 

210 

211 return { 

212 "payload": payload, 

213 "header": header, 

214 "signature": signature, 

215 } 

216 

217 def decode( 

218 self, 

219 jwt: str | bytes, 

220 key: AllowedPublicKeys | str | bytes = "", 

221 algorithms: list[str] | None = None, 

222 options: dict[str, Any] | None = None, 

223 detached_payload: bytes | None = None, 

224 **kwargs, 

225 ) -> Any: 

226 if kwargs: 

227 warnings.warn( 

228 "passing additional kwargs to decode() is deprecated " 

229 "and will be removed in pyjwt version 3. " 

230 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

231 RemovedInPyjwt3Warning, 

232 ) 

233 decoded = self.decode_complete( 

234 jwt, key, algorithms, options, detached_payload=detached_payload 

235 ) 

236 return decoded["payload"] 

237 

238 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: 

239 """Returns back the JWT header parameters as a dict() 

240 

241 Note: The signature is not verified so the header parameters 

242 should not be fully trusted until signature verification is complete 

243 """ 

244 headers = self._load(jwt)[2] 

245 self._validate_headers(headers) 

246 

247 return headers 

248 

249 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: 

250 if isinstance(jwt, str): 

251 jwt = jwt.encode("utf-8") 

252 

253 if not isinstance(jwt, bytes): 

254 raise DecodeError(f"Invalid token type. Token must be a {bytes}") 

255 

256 try: 

257 signing_input, crypto_segment = jwt.rsplit(b".", 1) 

258 header_segment, payload_segment = signing_input.split(b".", 1) 

259 except ValueError as err: 

260 raise DecodeError("Not enough segments") from err 

261 

262 try: 

263 header_data = base64url_decode(header_segment) 

264 except (TypeError, binascii.Error) as err: 

265 raise DecodeError("Invalid header padding") from err 

266 

267 try: 

268 header = json.loads(header_data) 

269 except ValueError as e: 

270 raise DecodeError(f"Invalid header string: {e}") from e 

271 

272 if not isinstance(header, dict): 

273 raise DecodeError("Invalid header string: must be a json object") 

274 

275 try: 

276 payload = base64url_decode(payload_segment) 

277 except (TypeError, binascii.Error) as err: 

278 raise DecodeError("Invalid payload padding") from err 

279 

280 try: 

281 signature = base64url_decode(crypto_segment) 

282 except (TypeError, binascii.Error) as err: 

283 raise DecodeError("Invalid crypto padding") from err 

284 

285 return (payload, signing_input, header, signature) 

286 

287 def _verify_signature( 

288 self, 

289 signing_input: bytes, 

290 header: dict[str, Any], 

291 signature: bytes, 

292 key: AllowedPublicKeys | str | bytes = "", 

293 algorithms: list[str] | None = None, 

294 ) -> None: 

295 try: 

296 alg = header["alg"] 

297 except KeyError: 

298 raise InvalidAlgorithmError("Algorithm not specified") 

299 

300 if not alg or (algorithms is not None and alg not in algorithms): 

301 raise InvalidAlgorithmError("The specified alg value is not allowed") 

302 

303 try: 

304 alg_obj = self.get_algorithm_by_name(alg) 

305 except NotImplementedError as e: 

306 raise InvalidAlgorithmError("Algorithm not supported") from e 

307 prepared_key = alg_obj.prepare_key(key) 

308 

309 if not alg_obj.verify(signing_input, prepared_key, signature): 

310 raise InvalidSignatureError("Signature verification failed") 

311 

312 def _validate_headers(self, headers: dict[str, Any]) -> None: 

313 if "kid" in headers: 

314 self._validate_kid(headers["kid"]) 

315 

316 def _validate_kid(self, kid: Any) -> None: 

317 if not isinstance(kid, str): 

318 raise InvalidTokenError("Key ID header parameter must be a string") 

319 

320 

321_jws_global_obj = PyJWS() 

322encode = _jws_global_obj.encode 

323decode_complete = _jws_global_obj.decode_complete 

324decode = _jws_global_obj.decode 

325register_algorithm = _jws_global_obj.register_algorithm 

326unregister_algorithm = _jws_global_obj.unregister_algorithm 

327get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name 

328get_unverified_header = _jws_global_obj.get_unverified_header