Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jwt/api_jwt.py: 64%

144 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-09 07:17 +0000

1from __future__ import annotations 

2 

3import json 

4import warnings 

5from calendar import timegm 

6from collections.abc import Iterable 

7from datetime import datetime, timedelta, timezone 

8from typing import TYPE_CHECKING, Any 

9 

10from . import api_jws 

11from .exceptions import ( 

12 DecodeError, 

13 ExpiredSignatureError, 

14 ImmatureSignatureError, 

15 InvalidAudienceError, 

16 InvalidIssuedAtError, 

17 InvalidIssuerError, 

18 MissingRequiredClaimError, 

19) 

20from .warnings import RemovedInPyjwt3Warning 

21 

22if TYPE_CHECKING: 

23 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

24 

25 

26class PyJWT: 

27 def __init__(self, options: dict[str, Any] | None = None) -> None: 

28 if options is None: 

29 options = {} 

30 self.options: dict[str, Any] = {**self._get_default_options(), **options} 

31 

32 @staticmethod 

33 def _get_default_options() -> dict[str, bool | list[str]]: 

34 return { 

35 "verify_signature": True, 

36 "verify_exp": True, 

37 "verify_nbf": True, 

38 "verify_iat": True, 

39 "verify_aud": True, 

40 "verify_iss": True, 

41 "require": [], 

42 } 

43 

44 def encode( 

45 self, 

46 payload: dict[str, Any], 

47 key: AllowedPrivateKeys | str | bytes, 

48 algorithm: str | None = "HS256", 

49 headers: dict[str, Any] | None = None, 

50 json_encoder: type[json.JSONEncoder] | None = None, 

51 sort_headers: bool = True, 

52 ) -> str: 

53 # Check that we get a dict 

54 if not isinstance(payload, dict): 

55 raise TypeError( 

56 "Expecting a dict object, as JWT only supports " 

57 "JSON objects as payloads." 

58 ) 

59 

60 # Payload 

61 payload = payload.copy() 

62 for time_claim in ["exp", "iat", "nbf"]: 

63 # Convert datetime to a intDate value in known time-format claims 

64 if isinstance(payload.get(time_claim), datetime): 

65 payload[time_claim] = timegm(payload[time_claim].utctimetuple()) 

66 

67 json_payload = self._encode_payload( 

68 payload, 

69 headers=headers, 

70 json_encoder=json_encoder, 

71 ) 

72 

73 return api_jws.encode( 

74 json_payload, 

75 key, 

76 algorithm, 

77 headers, 

78 json_encoder, 

79 sort_headers=sort_headers, 

80 ) 

81 

82 def _encode_payload( 

83 self, 

84 payload: dict[str, Any], 

85 headers: dict[str, Any] | None = None, 

86 json_encoder: type[json.JSONEncoder] | None = None, 

87 ) -> bytes: 

88 """ 

89 Encode a given payload to the bytes to be signed. 

90 

91 This method is intended to be overridden by subclasses that need to 

92 encode the payload in a different way, e.g. compress the payload. 

93 """ 

94 return json.dumps( 

95 payload, 

96 separators=(",", ":"), 

97 cls=json_encoder, 

98 ).encode("utf-8") 

99 

100 def decode_complete( 

101 self, 

102 jwt: str | bytes, 

103 key: AllowedPublicKeys | str | bytes = "", 

104 algorithms: list[str] | None = None, 

105 options: dict[str, Any] | None = None, 

106 # deprecated arg, remove in pyjwt3 

107 verify: bool | None = None, 

108 # could be used as passthrough to api_jws, consider removal in pyjwt3 

109 detached_payload: bytes | None = None, 

110 # passthrough arguments to _validate_claims 

111 # consider putting in options 

112 audience: str | Iterable[str] | None = None, 

113 issuer: str | None = None, 

114 leeway: float | timedelta = 0, 

115 # kwargs 

116 **kwargs: Any, 

117 ) -> dict[str, Any]: 

118 if kwargs: 

119 warnings.warn( 

120 "passing additional kwargs to decode_complete() is deprecated " 

121 "and will be removed in pyjwt version 3. " 

122 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

123 RemovedInPyjwt3Warning, 

124 ) 

125 options = dict(options or {}) # shallow-copy or initialize an empty dict 

126 options.setdefault("verify_signature", True) 

127 

128 # If the user has set the legacy `verify` argument, and it doesn't match 

129 # what the relevant `options` entry for the argument is, inform the user 

130 # that they're likely making a mistake. 

131 if verify is not None and verify != options["verify_signature"]: 

132 warnings.warn( 

133 "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. " 

134 "The equivalent is setting `verify_signature` to False in the `options` dictionary. " 

135 "This invocation has a mismatch between the kwarg and the option entry.", 

136 category=DeprecationWarning, 

137 ) 

138 

139 if not options["verify_signature"]: 

140 options.setdefault("verify_exp", False) 

141 options.setdefault("verify_nbf", False) 

142 options.setdefault("verify_iat", False) 

143 options.setdefault("verify_aud", False) 

144 options.setdefault("verify_iss", False) 

145 

146 if options["verify_signature"] and not algorithms: 

147 raise DecodeError( 

148 'It is required that you pass in a value for the "algorithms" argument when calling decode().' 

149 ) 

150 

151 decoded = api_jws.decode_complete( 

152 jwt, 

153 key=key, 

154 algorithms=algorithms, 

155 options=options, 

156 detached_payload=detached_payload, 

157 ) 

158 

159 payload = self._decode_payload(decoded) 

160 

161 merged_options = {**self.options, **options} 

162 self._validate_claims( 

163 payload, merged_options, audience=audience, issuer=issuer, leeway=leeway 

164 ) 

165 

166 decoded["payload"] = payload 

167 return decoded 

168 

169 def _decode_payload(self, decoded: dict[str, Any]) -> Any: 

170 """ 

171 Decode the payload from a JWS dictionary (payload, signature, header). 

172 

173 This method is intended to be overridden by subclasses that need to 

174 decode the payload in a different way, e.g. decompress compressed 

175 payloads. 

176 """ 

177 try: 

178 payload = json.loads(decoded["payload"]) 

179 except ValueError as e: 

180 raise DecodeError(f"Invalid payload string: {e}") 

181 if not isinstance(payload, dict): 

182 raise DecodeError("Invalid payload string: must be a json object") 

183 return payload 

184 

185 def decode( 

186 self, 

187 jwt: str | bytes, 

188 key: AllowedPublicKeys | str | bytes = "", 

189 algorithms: list[str] | None = None, 

190 options: dict[str, Any] | None = None, 

191 # deprecated arg, remove in pyjwt3 

192 verify: bool | None = None, 

193 # could be used as passthrough to api_jws, consider removal in pyjwt3 

194 detached_payload: bytes | None = None, 

195 # passthrough arguments to _validate_claims 

196 # consider putting in options 

197 audience: str | Iterable[str] | None = None, 

198 issuer: str | None = None, 

199 leeway: float | timedelta = 0, 

200 # kwargs 

201 **kwargs: Any, 

202 ) -> Any: 

203 if kwargs: 

204 warnings.warn( 

205 "passing additional kwargs to decode() is deprecated " 

206 "and will be removed in pyjwt version 3. " 

207 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

208 RemovedInPyjwt3Warning, 

209 ) 

210 decoded = self.decode_complete( 

211 jwt, 

212 key, 

213 algorithms, 

214 options, 

215 verify=verify, 

216 detached_payload=detached_payload, 

217 audience=audience, 

218 issuer=issuer, 

219 leeway=leeway, 

220 ) 

221 return decoded["payload"] 

222 

223 def _validate_claims( 

224 self, 

225 payload: dict[str, Any], 

226 options: dict[str, Any], 

227 audience=None, 

228 issuer=None, 

229 leeway: float | timedelta = 0, 

230 ) -> None: 

231 if isinstance(leeway, timedelta): 

232 leeway = leeway.total_seconds() 

233 

234 if audience is not None and not isinstance(audience, (str, Iterable)): 

235 raise TypeError("audience must be a string, iterable or None") 

236 

237 self._validate_required_claims(payload, options) 

238 

239 now = datetime.now(tz=timezone.utc).timestamp() 

240 

241 if "iat" in payload and options["verify_iat"]: 

242 self._validate_iat(payload, now, leeway) 

243 

244 if "nbf" in payload and options["verify_nbf"]: 

245 self._validate_nbf(payload, now, leeway) 

246 

247 if "exp" in payload and options["verify_exp"]: 

248 self._validate_exp(payload, now, leeway) 

249 

250 if options["verify_iss"]: 

251 self._validate_iss(payload, issuer) 

252 

253 if options["verify_aud"]: 

254 self._validate_aud( 

255 payload, audience, strict=options.get("strict_aud", False) 

256 ) 

257 

258 def _validate_required_claims( 

259 self, 

260 payload: dict[str, Any], 

261 options: dict[str, Any], 

262 ) -> None: 

263 for claim in options["require"]: 

264 if payload.get(claim) is None: 

265 raise MissingRequiredClaimError(claim) 

266 

267 def _validate_iat( 

268 self, 

269 payload: dict[str, Any], 

270 now: float, 

271 leeway: float, 

272 ) -> None: 

273 try: 

274 iat = int(payload["iat"]) 

275 except ValueError: 

276 raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") 

277 if iat > (now + leeway): 

278 raise ImmatureSignatureError("The token is not yet valid (iat)") 

279 

280 def _validate_nbf( 

281 self, 

282 payload: dict[str, Any], 

283 now: float, 

284 leeway: float, 

285 ) -> None: 

286 try: 

287 nbf = int(payload["nbf"]) 

288 except ValueError: 

289 raise DecodeError("Not Before claim (nbf) must be an integer.") 

290 

291 if nbf > (now + leeway): 

292 raise ImmatureSignatureError("The token is not yet valid (nbf)") 

293 

294 def _validate_exp( 

295 self, 

296 payload: dict[str, Any], 

297 now: float, 

298 leeway: float, 

299 ) -> None: 

300 try: 

301 exp = int(payload["exp"]) 

302 except ValueError: 

303 raise DecodeError("Expiration Time claim (exp) must be an" " integer.") 

304 

305 if exp <= (now - leeway): 

306 raise ExpiredSignatureError("Signature has expired") 

307 

308 def _validate_aud( 

309 self, 

310 payload: dict[str, Any], 

311 audience: str | Iterable[str] | None, 

312 *, 

313 strict: bool = False, 

314 ) -> None: 

315 if audience is None: 

316 if "aud" not in payload or not payload["aud"]: 

317 return 

318 # Application did not specify an audience, but 

319 # the token has the 'aud' claim 

320 raise InvalidAudienceError("Invalid audience") 

321 

322 if "aud" not in payload or not payload["aud"]: 

323 # Application specified an audience, but it could not be 

324 # verified since the token does not contain a claim. 

325 raise MissingRequiredClaimError("aud") 

326 

327 audience_claims = payload["aud"] 

328 

329 # In strict mode, we forbid list matching: the supplied audience 

330 # must be a string, and it must exactly match the audience claim. 

331 if strict: 

332 # Only a single audience is allowed in strict mode. 

333 if not isinstance(audience, str): 

334 raise InvalidAudienceError("Invalid audience (strict)") 

335 

336 # Only a single audience claim is allowed in strict mode. 

337 if not isinstance(audience_claims, str): 

338 raise InvalidAudienceError("Invalid claim format in token (strict)") 

339 

340 if audience != audience_claims: 

341 raise InvalidAudienceError("Audience doesn't match (strict)") 

342 

343 return 

344 

345 if isinstance(audience_claims, str): 

346 audience_claims = [audience_claims] 

347 if not isinstance(audience_claims, list): 

348 raise InvalidAudienceError("Invalid claim format in token") 

349 if any(not isinstance(c, str) for c in audience_claims): 

350 raise InvalidAudienceError("Invalid claim format in token") 

351 

352 if isinstance(audience, str): 

353 audience = [audience] 

354 

355 if all(aud not in audience_claims for aud in audience): 

356 raise InvalidAudienceError("Audience doesn't match") 

357 

358 def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None: 

359 if issuer is None: 

360 return 

361 

362 if "iss" not in payload: 

363 raise MissingRequiredClaimError("iss") 

364 

365 if payload["iss"] != issuer: 

366 raise InvalidIssuerError("Invalid issuer") 

367 

368 

369_jwt_global_obj = PyJWT() 

370encode = _jwt_global_obj.encode 

371decode_complete = _jwt_global_obj.decode_complete 

372decode = _jwt_global_obj.decode