1import uuid
2from datetime import datetime
3from datetime import timedelta
4from datetime import timezone
5from hmac import compare_digest
6from json import JSONEncoder
7from typing import Any
8from typing import Iterable
9from typing import List
10from typing import Type
11from typing import Union
12
13import jwt
14
15from flask_jwt_extended.exceptions import CSRFError
16from flask_jwt_extended.exceptions import JWTDecodeError
17from flask_jwt_extended.typing import ExpiresDelta
18from flask_jwt_extended.typing import Fresh
19
20
21def _encode_jwt(
22 algorithm: str,
23 audience: Union[str, Iterable[str]],
24 claim_overrides: dict,
25 csrf: bool,
26 expires_delta: ExpiresDelta,
27 fresh: Fresh,
28 header_overrides: dict,
29 identity: Any,
30 identity_claim_key: str,
31 issuer: str,
32 json_encoder: Type[JSONEncoder],
33 secret: str,
34 token_type: str,
35 nbf: bool,
36) -> str:
37 now = datetime.now(timezone.utc)
38
39 if isinstance(fresh, timedelta):
40 fresh = datetime.timestamp(now + fresh)
41
42 token_data = {
43 "fresh": fresh,
44 "iat": now,
45 "jti": str(uuid.uuid4()),
46 "type": token_type,
47 identity_claim_key: identity,
48 }
49
50 if nbf:
51 token_data["nbf"] = now
52
53 if csrf:
54 token_data["csrf"] = str(uuid.uuid4())
55
56 if audience:
57 token_data["aud"] = audience
58
59 if issuer:
60 token_data["iss"] = issuer
61
62 if expires_delta:
63 token_data["exp"] = now + expires_delta
64
65 if claim_overrides:
66 token_data.update(claim_overrides)
67
68 return jwt.encode(
69 token_data,
70 secret,
71 algorithm,
72 json_encoder=json_encoder, # type: ignore
73 headers=header_overrides,
74 )
75
76
77def _decode_jwt(
78 algorithms: List,
79 allow_expired: bool,
80 audience: Union[str, Iterable[str]],
81 csrf_value: str,
82 encoded_token: str,
83 identity_claim_key: str,
84 issuer: str,
85 leeway: int,
86 secret: str,
87 verify_aud: bool,
88 verify_sub: bool,
89) -> dict:
90 options = {"verify_aud": verify_aud, "verify_sub": verify_sub}
91 if allow_expired:
92 options["verify_exp"] = False
93
94 # This call verifies the ext, iat, and nbf claims
95 # This optionally verifies the exp and aud claims if enabled
96 decoded_token = jwt.decode(
97 encoded_token,
98 secret,
99 algorithms=algorithms,
100 audience=audience,
101 issuer=issuer,
102 leeway=leeway,
103 options=options,
104 )
105
106 # Make sure that any custom claims we expect in the token are present
107 if identity_claim_key not in decoded_token:
108 raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
109
110 if "type" not in decoded_token:
111 decoded_token["type"] = "access"
112
113 if "fresh" not in decoded_token:
114 decoded_token["fresh"] = False
115
116 if "jti" not in decoded_token:
117 decoded_token["jti"] = None
118
119 if csrf_value:
120 if "csrf" not in decoded_token:
121 raise JWTDecodeError("Missing claim: csrf")
122 if not compare_digest(decoded_token["csrf"], csrf_value):
123 raise CSRFError("CSRF double submit tokens do not match")
124
125 return decoded_token