1import uuid
2from datetime import datetime
3from datetime import timedelta
4from datetime import timezone
5from hmac import compare_digest
6from json import JSONEncoder
7from typing import Any
8from typing import Iterable
9from typing import List
10from typing import Type
11from typing import Union
12
13import jwt
14from jwt.types import Options
15
16from flask_jwt_extended.exceptions import CSRFError
17from flask_jwt_extended.exceptions import JWTDecodeError
18from flask_jwt_extended.typing import ExpiresDelta
19from flask_jwt_extended.typing import Fresh
20
21
22def _encode_jwt(
23 algorithm: str,
24 audience: Union[str, Iterable[str]],
25 claim_overrides: dict,
26 csrf: bool,
27 expires_delta: ExpiresDelta,
28 fresh: Fresh,
29 header_overrides: dict,
30 identity: Any,
31 identity_claim_key: str,
32 issuer: str,
33 json_encoder: Type[JSONEncoder],
34 secret: str,
35 token_type: str,
36 nbf: bool,
37) -> str:
38 now = datetime.now(timezone.utc)
39
40 if isinstance(fresh, timedelta):
41 fresh = datetime.timestamp(now + fresh)
42
43 token_data = {
44 "fresh": fresh,
45 "iat": now,
46 "jti": str(uuid.uuid4()),
47 "type": token_type,
48 identity_claim_key: identity,
49 }
50
51 if nbf:
52 token_data["nbf"] = now
53
54 if csrf:
55 token_data["csrf"] = str(uuid.uuid4())
56
57 if audience:
58 token_data["aud"] = audience
59
60 if issuer:
61 token_data["iss"] = issuer
62
63 if expires_delta:
64 token_data["exp"] = now + expires_delta
65
66 if claim_overrides:
67 token_data.update(claim_overrides)
68
69 return jwt.encode(
70 token_data,
71 secret,
72 algorithm,
73 json_encoder=json_encoder, # type: ignore
74 headers=header_overrides,
75 )
76
77
78def _decode_jwt(
79 algorithms: List,
80 allow_expired: bool,
81 audience: Union[str, Iterable[str]],
82 csrf_value: str,
83 encoded_token: str,
84 identity_claim_key: str,
85 issuer: str,
86 leeway: int,
87 secret: str,
88 verify_aud: bool,
89 verify_sub: bool,
90) -> dict:
91 options: Options
92 options = {"verify_aud": verify_aud, "verify_sub": verify_sub}
93 if allow_expired:
94 options["verify_exp"] = False
95
96 # This call verifies the ext, iat, and nbf claims
97 # This optionally verifies the exp and aud claims if enabled
98 decoded_token = jwt.decode(
99 encoded_token,
100 secret,
101 algorithms=algorithms,
102 audience=audience,
103 issuer=issuer,
104 leeway=leeway,
105 options=options,
106 )
107
108 # Make sure that any custom claims we expect in the token are present
109 if identity_claim_key not in decoded_token:
110 raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
111
112 if "type" not in decoded_token:
113 decoded_token["type"] = "access"
114
115 if "fresh" not in decoded_token:
116 decoded_token["fresh"] = False
117
118 if "jti" not in decoded_token:
119 decoded_token["jti"] = None
120
121 if csrf_value:
122 if "csrf" not in decoded_token:
123 raise JWTDecodeError("Missing claim: csrf")
124 if not compare_digest(decoded_token["csrf"], csrf_value):
125 raise CSRFError("CSRF double submit tokens do not match")
126
127 return decoded_token