1import datetime
2from typing import Any
3from typing import Callable
4from typing import Optional
5
6import jwt
7from flask import Flask
8from jwt import DecodeError
9from jwt import ExpiredSignatureError
10from jwt import InvalidAudienceError
11from jwt import InvalidIssuerError
12from jwt import InvalidTokenError
13from jwt import MissingRequiredClaimError
14
15from flask_jwt_extended.config import config
16from flask_jwt_extended.default_callbacks import default_additional_claims_callback
17from flask_jwt_extended.default_callbacks import default_blocklist_callback
18from flask_jwt_extended.default_callbacks import default_decode_key_callback
19from flask_jwt_extended.default_callbacks import default_encode_key_callback
20from flask_jwt_extended.default_callbacks import default_expired_token_callback
21from flask_jwt_extended.default_callbacks import default_invalid_token_callback
22from flask_jwt_extended.default_callbacks import default_jwt_headers_callback
23from flask_jwt_extended.default_callbacks import default_needs_fresh_token_callback
24from flask_jwt_extended.default_callbacks import default_revoked_token_callback
25from flask_jwt_extended.default_callbacks import default_token_verification_callback
26from flask_jwt_extended.default_callbacks import (
27 default_token_verification_failed_callback,
28)
29from flask_jwt_extended.default_callbacks import default_unauthorized_callback
30from flask_jwt_extended.default_callbacks import default_user_identity_callback
31from flask_jwt_extended.default_callbacks import default_user_lookup_error_callback
32from flask_jwt_extended.exceptions import CSRFError
33from flask_jwt_extended.exceptions import FreshTokenRequired
34from flask_jwt_extended.exceptions import InvalidHeaderError
35from flask_jwt_extended.exceptions import InvalidQueryParamError
36from flask_jwt_extended.exceptions import JWTDecodeError
37from flask_jwt_extended.exceptions import NoAuthorizationError
38from flask_jwt_extended.exceptions import RevokedTokenError
39from flask_jwt_extended.exceptions import UserClaimsVerificationError
40from flask_jwt_extended.exceptions import UserLookupError
41from flask_jwt_extended.exceptions import WrongTokenError
42from flask_jwt_extended.tokens import _decode_jwt
43from flask_jwt_extended.tokens import _encode_jwt
44from flask_jwt_extended.typing import ExpiresDelta
45from flask_jwt_extended.typing import Fresh
46from flask_jwt_extended.utils import current_user_context_processor
47
48
49class JWTManager(object):
50 """
51 An object used to hold JWT settings and callback functions for the
52 Flask-JWT-Extended extension.
53
54 Instances of :class:`JWTManager` are *not* bound to specific apps, so
55 you can create one in the main body of your code and then bind it
56 to your app in a factory function.
57 """
58
59 def __init__(
60 self, app: Optional[Flask] = None, add_context_processor: bool = False
61 ) -> None:
62 """
63 Create the JWTManager instance. You can either pass a flask application
64 in directly here to register this extension with the flask app, or
65 call init_app after creating this object (in a factory pattern).
66
67 :param app:
68 The Flask Application object
69 :param add_context_processor:
70 Controls if `current_user` is should be added to flasks template
71 context (and thus be available for use in Jinja templates). Defaults
72 to ``False``.
73 """
74 # Register the default error handler callback methods. These can be
75 # overridden with the appropriate loader decorators
76 self._decode_key_callback = default_decode_key_callback
77 self._encode_key_callback = default_encode_key_callback
78 self._expired_token_callback = default_expired_token_callback
79 self._invalid_token_callback = default_invalid_token_callback
80 self._jwt_additional_header_callback = default_jwt_headers_callback
81 self._needs_fresh_token_callback = default_needs_fresh_token_callback
82 self._revoked_token_callback = default_revoked_token_callback
83 self._token_in_blocklist_callback = default_blocklist_callback
84 self._token_verification_callback = default_token_verification_callback
85 self._unauthorized_callback = default_unauthorized_callback
86 self._user_claims_callback = default_additional_claims_callback
87 self._user_identity_callback = default_user_identity_callback
88 self._user_lookup_callback: Optional[Callable] = None
89 self._user_lookup_error_callback = default_user_lookup_error_callback
90 self._token_verification_failed_callback = (
91 default_token_verification_failed_callback
92 )
93
94 # Register this extension with the flask app now (if it is provided)
95 if app is not None:
96 self.init_app(app, add_context_processor)
97
98 def init_app(self, app: Flask, add_context_processor: bool = False) -> None:
99 """
100 Register this extension with the flask app.
101
102 :param app:
103 The Flask Application object
104 :param add_context_processor:
105 Controls if `current_user` is should be added to flasks template
106 context (and thus be available for use in Jinja templates). Defaults
107 to ``False``.
108 """
109 # Save this so we can use it later in the extension
110 if not hasattr(app, "extensions"): # pragma: no cover
111 app.extensions = {}
112 app.extensions["flask-jwt-extended"] = self
113
114 if add_context_processor:
115 app.context_processor(current_user_context_processor)
116
117 # Set all the default configurations for this extension
118 self._set_default_configuration_options(app)
119 self._set_error_handler_callbacks(app)
120
121 def _set_error_handler_callbacks(self, app: Flask) -> None:
122 @app.errorhandler(CSRFError)
123 def handle_csrf_error(e):
124 return self._unauthorized_callback(str(e))
125
126 @app.errorhandler(DecodeError)
127 def handle_decode_error(e):
128 return self._invalid_token_callback(str(e))
129
130 @app.errorhandler(ExpiredSignatureError)
131 def handle_expired_error(e):
132 return self._expired_token_callback(e.jwt_header, e.jwt_data)
133
134 @app.errorhandler(FreshTokenRequired)
135 def handle_fresh_token_required(e):
136 return self._needs_fresh_token_callback(e.jwt_header, e.jwt_data)
137
138 @app.errorhandler(MissingRequiredClaimError)
139 def handle_missing_required_claim_error(e):
140 return self._invalid_token_callback(str(e))
141
142 @app.errorhandler(InvalidAudienceError)
143 def handle_invalid_audience_error(e):
144 return self._invalid_token_callback(str(e))
145
146 @app.errorhandler(InvalidIssuerError)
147 def handle_invalid_issuer_error(e):
148 return self._invalid_token_callback(str(e))
149
150 @app.errorhandler(InvalidHeaderError)
151 def handle_invalid_header_error(e):
152 return self._invalid_token_callback(str(e))
153
154 @app.errorhandler(InvalidTokenError)
155 def handle_invalid_token_error(e):
156 return self._invalid_token_callback(str(e))
157
158 @app.errorhandler(JWTDecodeError)
159 def handle_jwt_decode_error(e):
160 return self._invalid_token_callback(str(e))
161
162 @app.errorhandler(NoAuthorizationError)
163 def handle_auth_error(e):
164 return self._unauthorized_callback(str(e))
165
166 @app.errorhandler(InvalidQueryParamError)
167 def handle_invalid_query_param_error(e):
168 return self._invalid_token_callback(str(e))
169
170 @app.errorhandler(RevokedTokenError)
171 def handle_revoked_token_error(e):
172 return self._revoked_token_callback(e.jwt_header, e.jwt_data)
173
174 @app.errorhandler(UserClaimsVerificationError)
175 def handle_failed_token_verification(e):
176 return self._token_verification_failed_callback(e.jwt_header, e.jwt_data)
177
178 @app.errorhandler(UserLookupError)
179 def handler_user_lookup_error(e):
180 return self._user_lookup_error_callback(e.jwt_header, e.jwt_data)
181
182 @app.errorhandler(WrongTokenError)
183 def handle_wrong_token_error(e):
184 return self._invalid_token_callback(str(e))
185
186 @staticmethod
187 def _set_default_configuration_options(app: Flask) -> None:
188 app.config.setdefault(
189 "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15)
190 )
191 app.config.setdefault("JWT_ACCESS_COOKIE_NAME", "access_token_cookie")
192 app.config.setdefault("JWT_ACCESS_COOKIE_PATH", "/")
193 app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_NAME", "csrf_access_token")
194 app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_PATH", "/")
195 app.config.setdefault("JWT_ACCESS_CSRF_FIELD_NAME", "csrf_token")
196 app.config.setdefault("JWT_ACCESS_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
197 app.config.setdefault("JWT_ALGORITHM", "HS256")
198 app.config.setdefault("JWT_COOKIE_CSRF_PROTECT", True)
199 app.config.setdefault("JWT_COOKIE_DOMAIN", None)
200 app.config.setdefault("JWT_COOKIE_SAMESITE", None)
201 app.config.setdefault("JWT_COOKIE_SECURE", False)
202 app.config.setdefault("JWT_CSRF_CHECK_FORM", False)
203 app.config.setdefault("JWT_CSRF_IN_COOKIES", True)
204 app.config.setdefault("JWT_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
205 app.config.setdefault("JWT_DECODE_ALGORITHMS", None)
206 app.config.setdefault("JWT_DECODE_AUDIENCE", None)
207 app.config.setdefault("JWT_DECODE_ISSUER", None)
208 app.config.setdefault("JWT_DECODE_LEEWAY", 0)
209 app.config.setdefault("JWT_ENCODE_AUDIENCE", None)
210 app.config.setdefault("JWT_ENCODE_ISSUER", None)
211 app.config.setdefault("JWT_ERROR_MESSAGE_KEY", "msg")
212 app.config.setdefault("JWT_HEADER_NAME", "Authorization")
213 app.config.setdefault("JWT_HEADER_TYPE", "Bearer")
214 app.config.setdefault("JWT_IDENTITY_CLAIM", "sub")
215 app.config.setdefault("JWT_JSON_KEY", "access_token")
216 app.config.setdefault("JWT_PRIVATE_KEY", None)
217 app.config.setdefault("JWT_PUBLIC_KEY", None)
218 app.config.setdefault("JWT_QUERY_STRING_NAME", "jwt")
219 app.config.setdefault("JWT_QUERY_STRING_VALUE_PREFIX", "")
220 app.config.setdefault("JWT_REFRESH_COOKIE_NAME", "refresh_token_cookie")
221 app.config.setdefault("JWT_REFRESH_COOKIE_PATH", "/")
222 app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_NAME", "csrf_refresh_token")
223 app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_PATH", "/")
224 app.config.setdefault("JWT_REFRESH_CSRF_FIELD_NAME", "csrf_token")
225 app.config.setdefault("JWT_REFRESH_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
226 app.config.setdefault("JWT_REFRESH_JSON_KEY", "refresh_token")
227 app.config.setdefault("JWT_REFRESH_TOKEN_EXPIRES", datetime.timedelta(days=30))
228 app.config.setdefault("JWT_SECRET_KEY", None)
229 app.config.setdefault("JWT_SESSION_COOKIE", True)
230 app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",))
231 app.config.setdefault("JWT_ENCODE_NBF", True)
232
233 def additional_claims_loader(self, callback: Callable) -> Callable:
234 """
235 This decorator sets the callback function used to add additional claims
236 when creating a JWT. The claims returned by this function will be merged
237 with any claims passed in via the ``additional_claims`` argument to
238 :func:`~flask_jwt_extended.create_access_token` or
239 :func:`~flask_jwt_extended.create_refresh_token`.
240
241 The decorated function must take **one** argument.
242
243 The argument is the identity that was used when creating a JWT.
244
245 The decorated function must return a dictionary of claims to add to the JWT.
246 """
247 self._user_claims_callback = callback
248 return callback
249
250 def additional_headers_loader(self, callback: Callable) -> Callable:
251 """
252 This decorator sets the callback function used to add additional headers
253 when creating a JWT. The headers returned by this function will be merged
254 with any headers passed in via the ``additional_headers`` argument to
255 :func:`~flask_jwt_extended.create_access_token` or
256 :func:`~flask_jwt_extended.create_refresh_token`.
257
258 The decorated function must take **one** argument.
259
260 The argument is the identity that was used when creating a JWT.
261
262 The decorated function must return a dictionary of headers to add to the JWT.
263 """
264 self._jwt_additional_header_callback = callback
265 return callback
266
267 def decode_key_loader(self, callback: Callable) -> Callable:
268 """
269 This decorator sets the callback function for dynamically setting the JWT
270 decode key based on the **UNVERIFIED** contents of the token. Think
271 carefully before using this functionality, in most cases you probably
272 don't need it.
273
274 The decorated function must take **two** arguments.
275
276 The first argument is a dictionary containing the header data of the
277 unverified JWT.
278
279 The second argument is a dictionary containing the payload data of the
280 unverified JWT.
281
282 The decorated function must return a *string* that is used to decode and
283 verify the token.
284 """
285 self._decode_key_callback = callback
286 return callback
287
288 def encode_key_loader(self, callback: Callable) -> Callable:
289 """
290 This decorator sets the callback function for dynamically setting the JWT
291 encode key based on the tokens identity. Think carefully before using this
292 functionality, in most cases you probably don't need it.
293
294 The decorated function must take **one** argument.
295
296 The argument is the identity used to create this JWT.
297
298 The decorated function must return a *string* which is the secrete key used to
299 encode the JWT.
300 """
301 self._encode_key_callback = callback
302 return callback
303
304 def expired_token_loader(self, callback: Callable) -> Callable:
305 """
306 This decorator sets the callback function for returning a custom
307 response when an expired JWT is encountered.
308
309 The decorated function must take **two** arguments.
310
311 The first argument is a dictionary containing the header data of the JWT.
312
313 The second argument is a dictionary containing the payload data of the JWT.
314
315 The decorated function must return a Flask Response.
316 """
317 self._expired_token_callback = callback
318 return callback
319
320 def invalid_token_loader(self, callback: Callable) -> Callable:
321 """
322 This decorator sets the callback function for returning a custom
323 response when an invalid JWT is encountered.
324
325 This decorator sets the callback function that will be used if an
326 invalid JWT attempts to access a protected endpoint.
327
328 The decorated function must take **one** argument.
329
330 The argument is a string which contains the reason why a token is invalid.
331
332 The decorated function must return a Flask Response.
333 """
334 self._invalid_token_callback = callback
335 return callback
336
337 def needs_fresh_token_loader(self, callback: Callable) -> Callable:
338 """
339 This decorator sets the callback function for returning a custom
340 response when a valid and non-fresh token is used on an endpoint
341 that is marked as ``fresh=True``.
342
343 The decorated function must take **two** arguments.
344
345 The first argument is a dictionary containing the header data of the JWT.
346
347 The second argument is a dictionary containing the payload data of the JWT.
348
349 The decorated function must return a Flask Response.
350 """
351 self._needs_fresh_token_callback = callback
352 return callback
353
354 def revoked_token_loader(self, callback: Callable) -> Callable:
355 """
356 This decorator sets the callback function for returning a custom
357 response when a revoked token is encountered.
358
359 The decorated function must take **two** arguments.
360
361 The first argument is a dictionary containing the header data of the JWT.
362
363 The second argument is a dictionary containing the payload data of the JWT.
364
365 The decorated function must return a Flask Response.
366 """
367 self._revoked_token_callback = callback
368 return callback
369
370 def token_in_blocklist_loader(self, callback: Callable) -> Callable:
371 """
372 This decorator sets the callback function used to check if a JWT has
373 been revoked.
374
375 The decorated function must take **two** arguments.
376
377 The first argument is a dictionary containing the header data of the JWT.
378
379 The second argument is a dictionary containing the payload data of the JWT.
380
381 The decorated function must be return ``True`` if the token has been
382 revoked, ``False`` otherwise.
383 """
384 self._token_in_blocklist_callback = callback
385 return callback
386
387 def token_verification_failed_loader(self, callback: Callable) -> Callable:
388 """
389 This decorator sets the callback function used to return a custom
390 response when the claims verification check fails.
391
392 The decorated function must take **two** arguments.
393
394 The first argument is a dictionary containing the header data of the JWT.
395
396 The second argument is a dictionary containing the payload data of the JWT.
397
398 The decorated function must return a Flask Response.
399 """
400 self._token_verification_failed_callback = callback
401 return callback
402
403 def token_verification_loader(self, callback: Callable) -> Callable:
404 """
405 This decorator sets the callback function used for custom verification
406 of a valid JWT.
407
408 The decorated function must take **two** arguments.
409
410 The first argument is a dictionary containing the header data of the JWT.
411
412 The second argument is a dictionary containing the payload data of the JWT.
413
414 The decorated function must return ``True`` if the token is valid, or
415 ``False`` otherwise.
416 """
417 self._token_verification_callback = callback
418 return callback
419
420 def unauthorized_loader(self, callback: Callable) -> Callable:
421 """
422 This decorator sets the callback function used to return a custom
423 response when no JWT is present.
424
425 The decorated function must take **one** argument.
426
427 The argument is a string that explains why the JWT could not be found.
428
429 The decorated function must return a Flask Response.
430 """
431 self._unauthorized_callback = callback
432 return callback
433
434 def user_identity_loader(self, callback: Callable) -> Callable:
435 """
436 This decorator sets the callback function used to convert an identity to
437 a JSON serializable format when creating JWTs. This is useful for
438 using objects (such as SQLAlchemy instances) as the identity when
439 creating your tokens.
440
441 The decorated function must take **one** argument.
442
443 The argument is the identity that was used when creating a JWT.
444
445 The decorated function must return JSON serializable data.
446 """
447 self._user_identity_callback = callback
448 return callback
449
450 def user_lookup_loader(self, callback: Callable) -> Callable:
451 """
452 This decorator sets the callback function used to convert a JWT into
453 a python object that can be used in a protected endpoint. This is useful
454 for automatically loading a SQLAlchemy instance based on the contents
455 of the JWT.
456
457 The object returned from this function can be accessed via
458 :attr:`~flask_jwt_extended.current_user` or
459 :meth:`~flask_jwt_extended.get_current_user`
460
461 The decorated function must take **two** arguments.
462
463 The first argument is a dictionary containing the header data of the JWT.
464
465 The second argument is a dictionary containing the payload data of the JWT.
466
467 The decorated function can return any python object, which can then be
468 accessed in a protected endpoint. If an object cannot be loaded, for
469 example if a user has been deleted from your database, ``None`` must be
470 returned to indicate that an error occurred loading the user.
471 """
472 self._user_lookup_callback = callback
473 return callback
474
475 def user_lookup_error_loader(self, callback: Callable) -> Callable:
476 """
477 This decorator sets the callback function used to return a custom
478 response when loading a user via
479 :meth:`~flask_jwt_extended.JWTManager.user_lookup_loader` fails.
480
481 The decorated function must take **two** arguments.
482
483 The first argument is a dictionary containing the header data of the JWT.
484
485 The second argument is a dictionary containing the payload data of the JWT.
486
487 The decorated function must return a Flask Response.
488 """
489 self._user_lookup_error_callback = callback
490 return callback
491
492 def _encode_jwt_from_config(
493 self,
494 identity: Any,
495 token_type: str,
496 claims=None,
497 fresh: Fresh = False,
498 expires_delta: Optional[ExpiresDelta] = None,
499 headers=None,
500 ) -> str:
501 header_overrides = self._jwt_additional_header_callback(identity)
502 if headers is not None:
503 header_overrides.update(headers)
504
505 claim_overrides = self._user_claims_callback(identity)
506 if claims is not None:
507 claim_overrides.update(claims)
508
509 if expires_delta is None:
510 if token_type == "access":
511 expires_delta = config.access_expires
512 else:
513 expires_delta = config.refresh_expires
514
515 return _encode_jwt(
516 algorithm=config.algorithm,
517 audience=config.encode_audience,
518 claim_overrides=claim_overrides,
519 csrf=config.cookie_csrf_protect,
520 expires_delta=expires_delta,
521 fresh=fresh,
522 header_overrides=header_overrides,
523 identity=self._user_identity_callback(identity),
524 identity_claim_key=config.identity_claim_key,
525 issuer=config.encode_issuer,
526 json_encoder=config.json_encoder,
527 secret=self._encode_key_callback(identity),
528 token_type=token_type,
529 nbf=config.encode_nbf,
530 )
531
532 def _decode_jwt_from_config(
533 self, encoded_token: str, csrf_value=None, allow_expired: bool = False
534 ) -> dict:
535 unverified_claims = jwt.decode(
536 encoded_token,
537 algorithms=config.decode_algorithms,
538 options={"verify_signature": False},
539 )
540 unverified_headers = jwt.get_unverified_header(encoded_token)
541 secret = self._decode_key_callback(unverified_headers, unverified_claims)
542
543 kwargs = {
544 "algorithms": config.decode_algorithms,
545 "audience": config.decode_audience,
546 "csrf_value": csrf_value,
547 "encoded_token": encoded_token,
548 "identity_claim_key": config.identity_claim_key,
549 "issuer": config.decode_issuer,
550 "leeway": config.leeway,
551 "secret": secret,
552 "verify_aud": config.decode_audience is not None,
553 }
554
555 try:
556 return _decode_jwt(**kwargs, allow_expired=allow_expired)
557 except ExpiredSignatureError as e:
558 # TODO: If we ever do another breaking change, don't raise this pyjwt
559 # error directly, instead raise a custom error of ours from this
560 # error.
561 e.jwt_header = unverified_headers # type: ignore
562 e.jwt_data = _decode_jwt(**kwargs, allow_expired=True) # type: ignore
563 raise