Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/api_jwt.py: 63%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

165 statements  

1from __future__ import annotations 

2 

3import json 

4import warnings 

5from calendar import timegm 

6from collections.abc import Iterable, Sequence 

7from datetime import datetime, timedelta, timezone 

8from typing import TYPE_CHECKING, Any 

9 

10from . import api_jws 

11from .exceptions import ( 

12 DecodeError, 

13 ExpiredSignatureError, 

14 ImmatureSignatureError, 

15 InvalidAudienceError, 

16 InvalidIssuedAtError, 

17 InvalidIssuerError, 

18 InvalidJTIError, 

19 InvalidSubjectError, 

20 MissingRequiredClaimError, 

21) 

22from .warnings import RemovedInPyjwt3Warning 

23 

24if TYPE_CHECKING: 

25 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys 

26 from .api_jwk import PyJWK 

27 

28 

29class PyJWT: 

30 def __init__(self, options: dict[str, Any] | None = None) -> None: 

31 if options is None: 

32 options = {} 

33 self.options: dict[str, Any] = {**self._get_default_options(), **options} 

34 

35 @staticmethod 

36 def _get_default_options() -> dict[str, bool | list[str]]: 

37 return { 

38 "verify_signature": True, 

39 "verify_exp": True, 

40 "verify_nbf": True, 

41 "verify_iat": True, 

42 "verify_aud": True, 

43 "verify_iss": True, 

44 "verify_sub": True, 

45 "verify_jti": True, 

46 "require": [], 

47 } 

48 

49 def encode( 

50 self, 

51 payload: dict[str, Any], 

52 key: AllowedPrivateKeys | PyJWK | str | bytes, 

53 algorithm: str | None = None, 

54 headers: dict[str, Any] | None = None, 

55 json_encoder: type[json.JSONEncoder] | None = None, 

56 sort_headers: bool = True, 

57 ) -> str: 

58 # Check that we get a dict 

59 if not isinstance(payload, dict): 

60 raise TypeError( 

61 "Expecting a dict object, as JWT only supports " 

62 "JSON objects as payloads." 

63 ) 

64 

65 # Payload 

66 payload = payload.copy() 

67 for time_claim in ["exp", "iat", "nbf"]: 

68 # Convert datetime to a intDate value in known time-format claims 

69 if isinstance(payload.get(time_claim), datetime): 

70 payload[time_claim] = timegm(payload[time_claim].utctimetuple()) 

71 

72 json_payload = self._encode_payload( 

73 payload, 

74 headers=headers, 

75 json_encoder=json_encoder, 

76 ) 

77 

78 return api_jws.encode( 

79 json_payload, 

80 key, 

81 algorithm, 

82 headers, 

83 json_encoder, 

84 sort_headers=sort_headers, 

85 ) 

86 

87 def _encode_payload( 

88 self, 

89 payload: dict[str, Any], 

90 headers: dict[str, Any] | None = None, 

91 json_encoder: type[json.JSONEncoder] | None = None, 

92 ) -> bytes: 

93 """ 

94 Encode a given payload to the bytes to be signed. 

95 

96 This method is intended to be overridden by subclasses that need to 

97 encode the payload in a different way, e.g. compress the payload. 

98 """ 

99 return json.dumps( 

100 payload, 

101 separators=(",", ":"), 

102 cls=json_encoder, 

103 ).encode("utf-8") 

104 

105 def decode_complete( 

106 self, 

107 jwt: str | bytes, 

108 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

109 algorithms: Sequence[str] | None = None, 

110 options: dict[str, Any] | None = None, 

111 # deprecated arg, remove in pyjwt3 

112 verify: bool | None = None, 

113 # could be used as passthrough to api_jws, consider removal in pyjwt3 

114 detached_payload: bytes | None = None, 

115 # passthrough arguments to _validate_claims 

116 # consider putting in options 

117 audience: str | Iterable[str] | None = None, 

118 issuer: str | Sequence[str] | None = None, 

119 subject: str | None = None, 

120 leeway: float | timedelta = 0, 

121 # kwargs 

122 **kwargs: Any, 

123 ) -> dict[str, Any]: 

124 if kwargs: 

125 warnings.warn( 

126 "passing additional kwargs to decode_complete() is deprecated " 

127 "and will be removed in pyjwt version 3. " 

128 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

129 RemovedInPyjwt3Warning, 

130 stacklevel=2, 

131 ) 

132 options = dict(options or {}) # shallow-copy or initialize an empty dict 

133 options.setdefault("verify_signature", True) 

134 

135 # If the user has set the legacy `verify` argument, and it doesn't match 

136 # what the relevant `options` entry for the argument is, inform the user 

137 # that they're likely making a mistake. 

138 if verify is not None and verify != options["verify_signature"]: 

139 warnings.warn( 

140 "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. " 

141 "The equivalent is setting `verify_signature` to False in the `options` dictionary. " 

142 "This invocation has a mismatch between the kwarg and the option entry.", 

143 category=DeprecationWarning, 

144 stacklevel=2, 

145 ) 

146 

147 if not options["verify_signature"]: 

148 options.setdefault("verify_exp", False) 

149 options.setdefault("verify_nbf", False) 

150 options.setdefault("verify_iat", False) 

151 options.setdefault("verify_aud", False) 

152 options.setdefault("verify_iss", False) 

153 options.setdefault("verify_sub", False) 

154 options.setdefault("verify_jti", False) 

155 

156 decoded = api_jws.decode_complete( 

157 jwt, 

158 key=key, 

159 algorithms=algorithms, 

160 options=options, 

161 detached_payload=detached_payload, 

162 ) 

163 

164 payload = self._decode_payload(decoded) 

165 

166 merged_options = {**self.options, **options} 

167 self._validate_claims( 

168 payload, 

169 merged_options, 

170 audience=audience, 

171 issuer=issuer, 

172 leeway=leeway, 

173 subject=subject, 

174 ) 

175 

176 decoded["payload"] = payload 

177 return decoded 

178 

179 def _decode_payload(self, decoded: dict[str, Any]) -> Any: 

180 """ 

181 Decode the payload from a JWS dictionary (payload, signature, header). 

182 

183 This method is intended to be overridden by subclasses that need to 

184 decode the payload in a different way, e.g. decompress compressed 

185 payloads. 

186 """ 

187 try: 

188 payload = json.loads(decoded["payload"]) 

189 except ValueError as e: 

190 raise DecodeError(f"Invalid payload string: {e}") from e 

191 if not isinstance(payload, dict): 

192 raise DecodeError("Invalid payload string: must be a json object") 

193 return payload 

194 

195 def decode( 

196 self, 

197 jwt: str | bytes, 

198 key: AllowedPublicKeys | PyJWK | str | bytes = "", 

199 algorithms: Sequence[str] | None = None, 

200 options: dict[str, Any] | None = None, 

201 # deprecated arg, remove in pyjwt3 

202 verify: bool | None = None, 

203 # could be used as passthrough to api_jws, consider removal in pyjwt3 

204 detached_payload: bytes | None = None, 

205 # passthrough arguments to _validate_claims 

206 # consider putting in options 

207 audience: str | Iterable[str] | None = None, 

208 subject: str | None = None, 

209 issuer: str | Sequence[str] | None = None, 

210 leeway: float | timedelta = 0, 

211 # kwargs 

212 **kwargs: Any, 

213 ) -> Any: 

214 if kwargs: 

215 warnings.warn( 

216 "passing additional kwargs to decode() is deprecated " 

217 "and will be removed in pyjwt version 3. " 

218 f"Unsupported kwargs: {tuple(kwargs.keys())}", 

219 RemovedInPyjwt3Warning, 

220 stacklevel=2, 

221 ) 

222 decoded = self.decode_complete( 

223 jwt, 

224 key, 

225 algorithms, 

226 options, 

227 verify=verify, 

228 detached_payload=detached_payload, 

229 audience=audience, 

230 subject=subject, 

231 issuer=issuer, 

232 leeway=leeway, 

233 ) 

234 return decoded["payload"] 

235 

236 def _validate_claims( 

237 self, 

238 payload: dict[str, Any], 

239 options: dict[str, Any], 

240 audience=None, 

241 issuer=None, 

242 subject: str | None = None, 

243 leeway: float | timedelta = 0, 

244 ) -> None: 

245 if isinstance(leeway, timedelta): 

246 leeway = leeway.total_seconds() 

247 

248 if audience is not None and not isinstance(audience, (str, Iterable)): 

249 raise TypeError("audience must be a string, iterable or None") 

250 

251 self._validate_required_claims(payload, options) 

252 

253 now = datetime.now(tz=timezone.utc).timestamp() 

254 

255 if "iat" in payload and options["verify_iat"]: 

256 self._validate_iat(payload, now, leeway) 

257 

258 if "nbf" in payload and options["verify_nbf"]: 

259 self._validate_nbf(payload, now, leeway) 

260 

261 if "exp" in payload and options["verify_exp"]: 

262 self._validate_exp(payload, now, leeway) 

263 

264 if options["verify_iss"]: 

265 self._validate_iss(payload, issuer) 

266 

267 if options["verify_aud"]: 

268 self._validate_aud( 

269 payload, audience, strict=options.get("strict_aud", False) 

270 ) 

271 

272 if options["verify_sub"]: 

273 self._validate_sub(payload, subject) 

274 

275 if options["verify_jti"]: 

276 self._validate_jti(payload) 

277 

278 def _validate_required_claims( 

279 self, 

280 payload: dict[str, Any], 

281 options: dict[str, Any], 

282 ) -> None: 

283 for claim in options["require"]: 

284 if payload.get(claim) is None: 

285 raise MissingRequiredClaimError(claim) 

286 

287 def _validate_sub(self, payload: dict[str, Any], subject=None) -> None: 

288 """ 

289 Checks whether "sub" if in the payload is valid ot not. 

290 This is an Optional claim 

291 

292 :param payload(dict): The payload which needs to be validated 

293 :param subject(str): The subject of the token 

294 """ 

295 

296 if "sub" not in payload: 

297 return 

298 

299 if not isinstance(payload["sub"], str): 

300 raise InvalidSubjectError("Subject must be a string") 

301 

302 if subject is not None: 

303 if payload.get("sub") != subject: 

304 raise InvalidSubjectError("Invalid subject") 

305 

306 def _validate_jti(self, payload: dict[str, Any]) -> None: 

307 """ 

308 Checks whether "jti" if in the payload is valid ot not 

309 This is an Optional claim 

310 

311 :param payload(dict): The payload which needs to be validated 

312 """ 

313 

314 if "jti" not in payload: 

315 return 

316 

317 if not isinstance(payload.get("jti"), str): 

318 raise InvalidJTIError("JWT ID must be a string") 

319 

320 def _validate_iat( 

321 self, 

322 payload: dict[str, Any], 

323 now: float, 

324 leeway: float, 

325 ) -> None: 

326 try: 

327 iat = int(payload["iat"]) 

328 except ValueError: 

329 raise InvalidIssuedAtError( 

330 "Issued At claim (iat) must be an integer." 

331 ) from None 

332 if iat > (now + leeway): 

333 raise ImmatureSignatureError("The token is not yet valid (iat)") 

334 

335 def _validate_nbf( 

336 self, 

337 payload: dict[str, Any], 

338 now: float, 

339 leeway: float, 

340 ) -> None: 

341 try: 

342 nbf = int(payload["nbf"]) 

343 except ValueError: 

344 raise DecodeError("Not Before claim (nbf) must be an integer.") from None 

345 

346 if nbf > (now + leeway): 

347 raise ImmatureSignatureError("The token is not yet valid (nbf)") 

348 

349 def _validate_exp( 

350 self, 

351 payload: dict[str, Any], 

352 now: float, 

353 leeway: float, 

354 ) -> None: 

355 try: 

356 exp = int(payload["exp"]) 

357 except ValueError: 

358 raise DecodeError( 

359 "Expiration Time claim (exp) must be an integer." 

360 ) from None 

361 

362 if exp <= (now - leeway): 

363 raise ExpiredSignatureError("Signature has expired") 

364 

365 def _validate_aud( 

366 self, 

367 payload: dict[str, Any], 

368 audience: str | Iterable[str] | None, 

369 *, 

370 strict: bool = False, 

371 ) -> None: 

372 if audience is None: 

373 if "aud" not in payload or not payload["aud"]: 

374 return 

375 # Application did not specify an audience, but 

376 # the token has the 'aud' claim 

377 raise InvalidAudienceError("Invalid audience") 

378 

379 if "aud" not in payload or not payload["aud"]: 

380 # Application specified an audience, but it could not be 

381 # verified since the token does not contain a claim. 

382 raise MissingRequiredClaimError("aud") 

383 

384 audience_claims = payload["aud"] 

385 

386 # In strict mode, we forbid list matching: the supplied audience 

387 # must be a string, and it must exactly match the audience claim. 

388 if strict: 

389 # Only a single audience is allowed in strict mode. 

390 if not isinstance(audience, str): 

391 raise InvalidAudienceError("Invalid audience (strict)") 

392 

393 # Only a single audience claim is allowed in strict mode. 

394 if not isinstance(audience_claims, str): 

395 raise InvalidAudienceError("Invalid claim format in token (strict)") 

396 

397 if audience != audience_claims: 

398 raise InvalidAudienceError("Audience doesn't match (strict)") 

399 

400 return 

401 

402 if isinstance(audience_claims, str): 

403 audience_claims = [audience_claims] 

404 if not isinstance(audience_claims, list): 

405 raise InvalidAudienceError("Invalid claim format in token") 

406 if any(not isinstance(c, str) for c in audience_claims): 

407 raise InvalidAudienceError("Invalid claim format in token") 

408 

409 if isinstance(audience, str): 

410 audience = [audience] 

411 

412 if all(aud not in audience_claims for aud in audience): 

413 raise InvalidAudienceError("Audience doesn't match") 

414 

415 def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None: 

416 if issuer is None: 

417 return 

418 

419 if "iss" not in payload: 

420 raise MissingRequiredClaimError("iss") 

421 

422 if isinstance(issuer, str): 

423 if payload["iss"] != issuer: 

424 raise InvalidIssuerError("Invalid issuer") 

425 else: 

426 if payload["iss"] not in issuer: 

427 raise InvalidIssuerError("Invalid issuer") 

428 

429 

430_jwt_global_obj = PyJWT() 

431encode = _jwt_global_obj.encode 

432decode_complete = _jwt_global_obj.decode_complete 

433decode = _jwt_global_obj.decode