1import datetime
2from typing import Any
3from typing import Callable
4from typing import Optional
5
6import jwt
7from flask import Flask
8from jwt import DecodeError
9from jwt import ExpiredSignatureError
10from jwt import InvalidAudienceError
11from jwt import InvalidIssuerError
12from jwt import InvalidTokenError
13from jwt import MissingRequiredClaimError
14
15from flask_jwt_extended.config import config
16from flask_jwt_extended.default_callbacks import default_additional_claims_callback
17from flask_jwt_extended.default_callbacks import default_blocklist_callback
18from flask_jwt_extended.default_callbacks import default_decode_key_callback
19from flask_jwt_extended.default_callbacks import default_encode_key_callback
20from flask_jwt_extended.default_callbacks import default_expired_token_callback
21from flask_jwt_extended.default_callbacks import default_invalid_token_callback
22from flask_jwt_extended.default_callbacks import default_jwt_headers_callback
23from flask_jwt_extended.default_callbacks import default_needs_fresh_token_callback
24from flask_jwt_extended.default_callbacks import default_revoked_token_callback
25from flask_jwt_extended.default_callbacks import default_token_verification_callback
26from flask_jwt_extended.default_callbacks import (
27 default_token_verification_failed_callback,
28)
29from flask_jwt_extended.default_callbacks import default_unauthorized_callback
30from flask_jwt_extended.default_callbacks import default_user_identity_callback
31from flask_jwt_extended.default_callbacks import default_user_lookup_error_callback
32from flask_jwt_extended.exceptions import CSRFError
33from flask_jwt_extended.exceptions import FreshTokenRequired
34from flask_jwt_extended.exceptions import InvalidHeaderError
35from flask_jwt_extended.exceptions import InvalidQueryParamError
36from flask_jwt_extended.exceptions import JWTDecodeError
37from flask_jwt_extended.exceptions import NoAuthorizationError
38from flask_jwt_extended.exceptions import RevokedTokenError
39from flask_jwt_extended.exceptions import UserClaimsVerificationError
40from flask_jwt_extended.exceptions import UserLookupError
41from flask_jwt_extended.exceptions import WrongTokenError
42from flask_jwt_extended.tokens import _decode_jwt
43from flask_jwt_extended.tokens import _encode_jwt
44from flask_jwt_extended.typing import ExpiresDelta
45from flask_jwt_extended.typing import Fresh
46from flask_jwt_extended.utils import current_user_context_processor
47
48
49class JWTManager(object):
50 """
51 An object used to hold JWT settings and callback functions for the
52 Flask-JWT-Extended extension.
53
54 Instances of :class:`JWTManager` are *not* bound to specific apps, so
55 you can create one in the main body of your code and then bind it
56 to your app in a factory function.
57 """
58
59 def __init__(
60 self, app: Optional[Flask] = None, add_context_processor: bool = False
61 ) -> None:
62 """
63 Create the JWTManager instance. You can either pass a flask application
64 in directly here to register this extension with the flask app, or
65 call init_app after creating this object (in a factory pattern).
66
67 :param app:
68 The Flask Application object
69 :param add_context_processor:
70 Controls if `current_user` is should be added to flasks template
71 context (and thus be available for use in Jinja templates). Defaults
72 to ``False``.
73 """
74 # Register the default error handler callback methods. These can be
75 # overridden with the appropriate loader decorators
76 self._decode_key_callback = default_decode_key_callback
77 self._encode_key_callback = default_encode_key_callback
78 self._expired_token_callback = default_expired_token_callback
79 self._invalid_token_callback = default_invalid_token_callback
80 self._jwt_additional_header_callback = default_jwt_headers_callback
81 self._needs_fresh_token_callback = default_needs_fresh_token_callback
82 self._revoked_token_callback = default_revoked_token_callback
83 self._token_in_blocklist_callback = default_blocklist_callback
84 self._token_verification_callback = default_token_verification_callback
85 self._unauthorized_callback = default_unauthorized_callback
86 self._user_claims_callback = default_additional_claims_callback
87 self._user_identity_callback = default_user_identity_callback
88 self._user_lookup_callback: Optional[Callable] = None
89 self._user_lookup_error_callback = default_user_lookup_error_callback
90 self._token_verification_failed_callback = (
91 default_token_verification_failed_callback
92 )
93
94 # Register this extension with the flask app now (if it is provided)
95 if app is not None:
96 self.init_app(app, add_context_processor)
97
98 def init_app(self, app: Flask, add_context_processor: bool = False) -> None:
99 """
100 Register this extension with the flask app.
101
102 :param app:
103 The Flask Application object
104 :param add_context_processor:
105 Controls if `current_user` is should be added to flasks template
106 context (and thus be available for use in Jinja templates). Defaults
107 to ``False``.
108 """
109 # Save this so we can use it later in the extension
110 if not hasattr(app, "extensions"): # pragma: no cover
111 app.extensions = {}
112 app.extensions["flask-jwt-extended"] = self
113
114 if add_context_processor:
115 app.context_processor(current_user_context_processor)
116
117 # Set all the default configurations for this extension
118 self._set_default_configuration_options(app)
119 self._set_error_handler_callbacks(app)
120
121 def _set_error_handler_callbacks(self, app: Flask) -> None:
122 @app.errorhandler(CSRFError)
123 def handle_csrf_error(e):
124 return self._unauthorized_callback(str(e))
125
126 @app.errorhandler(DecodeError)
127 def handle_decode_error(e):
128 return self._invalid_token_callback(str(e))
129
130 @app.errorhandler(ExpiredSignatureError)
131 def handle_expired_error(e):
132 return self._expired_token_callback(e.jwt_header, e.jwt_data)
133
134 @app.errorhandler(FreshTokenRequired)
135 def handle_fresh_token_required(e):
136 return self._needs_fresh_token_callback(e.jwt_header, e.jwt_data)
137
138 @app.errorhandler(MissingRequiredClaimError)
139 def handle_missing_required_claim_error(e):
140 return self._invalid_token_callback(str(e))
141
142 @app.errorhandler(InvalidAudienceError)
143 def handle_invalid_audience_error(e):
144 return self._invalid_token_callback(str(e))
145
146 @app.errorhandler(InvalidIssuerError)
147 def handle_invalid_issuer_error(e):
148 return self._invalid_token_callback(str(e))
149
150 @app.errorhandler(InvalidHeaderError)
151 def handle_invalid_header_error(e):
152 return self._invalid_token_callback(str(e))
153
154 @app.errorhandler(InvalidTokenError)
155 def handle_invalid_token_error(e):
156 return self._invalid_token_callback(str(e))
157
158 @app.errorhandler(JWTDecodeError)
159 def handle_jwt_decode_error(e):
160 return self._invalid_token_callback(str(e))
161
162 @app.errorhandler(NoAuthorizationError)
163 def handle_auth_error(e):
164 return self._unauthorized_callback(str(e))
165
166 @app.errorhandler(InvalidQueryParamError)
167 def handle_invalid_query_param_error(e):
168 return self._invalid_token_callback(str(e))
169
170 @app.errorhandler(RevokedTokenError)
171 def handle_revoked_token_error(e):
172 return self._revoked_token_callback(e.jwt_header, e.jwt_data)
173
174 @app.errorhandler(UserClaimsVerificationError)
175 def handle_failed_token_verification(e):
176 return self._token_verification_failed_callback(e.jwt_header, e.jwt_data)
177
178 @app.errorhandler(UserLookupError)
179 def handler_user_lookup_error(e):
180 return self._user_lookup_error_callback(e.jwt_header, e.jwt_data)
181
182 @app.errorhandler(WrongTokenError)
183 def handle_wrong_token_error(e):
184 return self._invalid_token_callback(str(e))
185
186 @staticmethod
187 def _set_default_configuration_options(app: Flask) -> None:
188 app.config.setdefault(
189 "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15)
190 )
191 app.config.setdefault("JWT_ACCESS_COOKIE_NAME", "access_token_cookie")
192 app.config.setdefault("JWT_ACCESS_COOKIE_PATH", "/")
193 app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_NAME", "csrf_access_token")
194 app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_PATH", "/")
195 app.config.setdefault("JWT_ACCESS_CSRF_FIELD_NAME", "csrf_token")
196 app.config.setdefault("JWT_ACCESS_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
197 app.config.setdefault("JWT_ALGORITHM", "HS256")
198 app.config.setdefault("JWT_COOKIE_CSRF_PROTECT", True)
199 app.config.setdefault("JWT_COOKIE_DOMAIN", None)
200 app.config.setdefault("JWT_COOKIE_SAMESITE", None)
201 app.config.setdefault("JWT_COOKIE_SECURE", False)
202 app.config.setdefault("JWT_CSRF_CHECK_FORM", False)
203 app.config.setdefault("JWT_CSRF_IN_COOKIES", True)
204 app.config.setdefault("JWT_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
205 app.config.setdefault("JWT_DECODE_ALGORITHMS", None)
206 app.config.setdefault("JWT_DECODE_AUDIENCE", None)
207 app.config.setdefault("JWT_DECODE_ISSUER", None)
208 app.config.setdefault("JWT_DECODE_LEEWAY", 0)
209 app.config.setdefault("JWT_ENCODE_AUDIENCE", None)
210 app.config.setdefault("JWT_ENCODE_ISSUER", None)
211 app.config.setdefault("JWT_ERROR_MESSAGE_KEY", "msg")
212 app.config.setdefault("JWT_HEADER_NAME", "Authorization")
213 app.config.setdefault("JWT_HEADER_TYPE", "Bearer")
214 app.config.setdefault("JWT_IDENTITY_CLAIM", "sub")
215 app.config.setdefault("JWT_JSON_KEY", "access_token")
216 app.config.setdefault("JWT_PRIVATE_KEY", None)
217 app.config.setdefault("JWT_PUBLIC_KEY", None)
218 app.config.setdefault("JWT_QUERY_STRING_NAME", "jwt")
219 app.config.setdefault("JWT_QUERY_STRING_VALUE_PREFIX", "")
220 app.config.setdefault("JWT_REFRESH_COOKIE_NAME", "refresh_token_cookie")
221 app.config.setdefault("JWT_REFRESH_COOKIE_PATH", "/")
222 app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_NAME", "csrf_refresh_token")
223 app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_PATH", "/")
224 app.config.setdefault("JWT_REFRESH_CSRF_FIELD_NAME", "csrf_token")
225 app.config.setdefault("JWT_REFRESH_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
226 app.config.setdefault("JWT_REFRESH_JSON_KEY", "refresh_token")
227 app.config.setdefault("JWT_REFRESH_TOKEN_EXPIRES", datetime.timedelta(days=30))
228 app.config.setdefault("JWT_SECRET_KEY", None)
229 app.config.setdefault("JWT_SESSION_COOKIE", True)
230 app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",))
231 app.config.setdefault("JWT_VERIFY_SUB", True)
232 app.config.setdefault("JWT_ENCODE_NBF", True)
233
234 def additional_claims_loader(self, callback: Callable) -> Callable:
235 """
236 This decorator sets the callback function used to add additional claims
237 when creating a JWT. The claims returned by this function will be merged
238 with any claims passed in via the ``additional_claims`` argument to
239 :func:`~flask_jwt_extended.create_access_token` or
240 :func:`~flask_jwt_extended.create_refresh_token`.
241
242 The decorated function must take **one** argument.
243
244 The argument is the identity that was used when creating a JWT.
245
246 The decorated function must return a dictionary of claims to add to the JWT.
247 """
248 self._user_claims_callback = callback
249 return callback
250
251 def additional_headers_loader(self, callback: Callable) -> Callable:
252 """
253 This decorator sets the callback function used to add additional headers
254 when creating a JWT. The headers returned by this function will be merged
255 with any headers passed in via the ``additional_headers`` argument to
256 :func:`~flask_jwt_extended.create_access_token` or
257 :func:`~flask_jwt_extended.create_refresh_token`.
258
259 The decorated function must take **one** argument.
260
261 The argument is the identity that was used when creating a JWT.
262
263 The decorated function must return a dictionary of headers to add to the JWT.
264 """
265 self._jwt_additional_header_callback = callback
266 return callback
267
268 def decode_key_loader(self, callback: Callable) -> Callable:
269 """
270 This decorator sets the callback function for dynamically setting the JWT
271 decode key based on the **UNVERIFIED** contents of the token. Think
272 carefully before using this functionality, in most cases you probably
273 don't need it.
274
275 The decorated function must take **two** arguments.
276
277 The first argument is a dictionary containing the header data of the
278 unverified JWT.
279
280 The second argument is a dictionary containing the payload data of the
281 unverified JWT.
282
283 The decorated function must return a *string* that is used to decode and
284 verify the token.
285 """
286 self._decode_key_callback = callback
287 return callback
288
289 def encode_key_loader(self, callback: Callable) -> Callable:
290 """
291 This decorator sets the callback function for dynamically setting the JWT
292 encode key based on the tokens identity. Think carefully before using this
293 functionality, in most cases you probably don't need it.
294
295 The decorated function must take **one** argument.
296
297 The argument is the identity used to create this JWT.
298
299 The decorated function must return a *string* which is the secrete key used to
300 encode the JWT.
301 """
302 self._encode_key_callback = callback
303 return callback
304
305 def expired_token_loader(self, callback: Callable) -> Callable:
306 """
307 This decorator sets the callback function for returning a custom
308 response when an expired JWT is encountered.
309
310 The decorated function must take **two** arguments.
311
312 The first argument is a dictionary containing the header data of the JWT.
313
314 The second argument is a dictionary containing the payload data of the JWT.
315
316 The decorated function must return a Flask Response.
317 """
318 self._expired_token_callback = callback
319 return callback
320
321 def invalid_token_loader(self, callback: Callable) -> Callable:
322 """
323 This decorator sets the callback function for returning a custom
324 response when an invalid JWT is encountered.
325
326 This decorator sets the callback function that will be used if an
327 invalid JWT attempts to access a protected endpoint.
328
329 The decorated function must take **one** argument.
330
331 The argument is a string which contains the reason why a token is invalid.
332
333 The decorated function must return a Flask Response.
334 """
335 self._invalid_token_callback = callback
336 return callback
337
338 def needs_fresh_token_loader(self, callback: Callable) -> Callable:
339 """
340 This decorator sets the callback function for returning a custom
341 response when a valid and non-fresh token is used on an endpoint
342 that is marked as ``fresh=True``.
343
344 The decorated function must take **two** arguments.
345
346 The first argument is a dictionary containing the header data of the JWT.
347
348 The second argument is a dictionary containing the payload data of the JWT.
349
350 The decorated function must return a Flask Response.
351 """
352 self._needs_fresh_token_callback = callback
353 return callback
354
355 def revoked_token_loader(self, callback: Callable) -> Callable:
356 """
357 This decorator sets the callback function for returning a custom
358 response when a revoked token is encountered.
359
360 The decorated function must take **two** arguments.
361
362 The first argument is a dictionary containing the header data of the JWT.
363
364 The second argument is a dictionary containing the payload data of the JWT.
365
366 The decorated function must return a Flask Response.
367 """
368 self._revoked_token_callback = callback
369 return callback
370
371 def token_in_blocklist_loader(self, callback: Callable) -> Callable:
372 """
373 This decorator sets the callback function used to check if a JWT has
374 been revoked.
375
376 The decorated function must take **two** arguments.
377
378 The first argument is a dictionary containing the header data of the JWT.
379
380 The second argument is a dictionary containing the payload data of the JWT.
381
382 The decorated function must be return ``True`` if the token has been
383 revoked, ``False`` otherwise.
384 """
385 self._token_in_blocklist_callback = callback
386 return callback
387
388 def token_verification_failed_loader(self, callback: Callable) -> Callable:
389 """
390 This decorator sets the callback function used to return a custom
391 response when the claims verification check fails.
392
393 The decorated function must take **two** arguments.
394
395 The first argument is a dictionary containing the header data of the JWT.
396
397 The second argument is a dictionary containing the payload data of the JWT.
398
399 The decorated function must return a Flask Response.
400 """
401 self._token_verification_failed_callback = callback
402 return callback
403
404 def token_verification_loader(self, callback: Callable) -> Callable:
405 """
406 This decorator sets the callback function used for custom verification
407 of a valid JWT.
408
409 The decorated function must take **two** arguments.
410
411 The first argument is a dictionary containing the header data of the JWT.
412
413 The second argument is a dictionary containing the payload data of the JWT.
414
415 The decorated function must return ``True`` if the token is valid, or
416 ``False`` otherwise.
417 """
418 self._token_verification_callback = callback
419 return callback
420
421 def unauthorized_loader(self, callback: Callable) -> Callable:
422 """
423 This decorator sets the callback function used to return a custom
424 response when no JWT is present.
425
426 The decorated function must take **one** argument.
427
428 The argument is a string that explains why the JWT could not be found.
429
430 The decorated function must return a Flask Response.
431 """
432 self._unauthorized_callback = callback
433 return callback
434
435 def user_identity_loader(self, callback: Callable) -> Callable:
436 """
437 This decorator sets the callback function used to convert an identity to
438 a string when creating JWTs. This is useful for using objects (such as
439 SQLAlchemy instances) as the identity when creating your tokens.
440
441 The decorated function must take **one** argument.
442
443 The argument is the identity that was used when creating a JWT.
444
445 The decorated function must return a string.
446 """
447 self._user_identity_callback = callback
448 return callback
449
450 def user_lookup_loader(self, callback: Callable) -> Callable:
451 """
452 This decorator sets the callback function used to convert a JWT into
453 a python object that can be used in a protected endpoint. This is useful
454 for automatically loading a SQLAlchemy instance based on the contents
455 of the JWT.
456
457 The object returned from this function can be accessed via
458 :attr:`~flask_jwt_extended.current_user` or
459 :meth:`~flask_jwt_extended.get_current_user`
460
461 The decorated function must take **two** arguments.
462
463 The first argument is a dictionary containing the header data of the JWT.
464
465 The second argument is a dictionary containing the payload data of the JWT.
466
467 The decorated function can return any python object, which can then be
468 accessed in a protected endpoint. If an object cannot be loaded, for
469 example if a user has been deleted from your database, ``None`` must be
470 returned to indicate that an error occurred loading the user.
471 """
472 self._user_lookup_callback = callback
473 return callback
474
475 def user_lookup_error_loader(self, callback: Callable) -> Callable:
476 """
477 This decorator sets the callback function used to return a custom
478 response when loading a user via
479 :meth:`~flask_jwt_extended.JWTManager.user_lookup_loader` fails.
480
481 The decorated function must take **two** arguments.
482
483 The first argument is a dictionary containing the header data of the JWT.
484
485 The second argument is a dictionary containing the payload data of the JWT.
486
487 The decorated function must return a Flask Response.
488 """
489 self._user_lookup_error_callback = callback
490 return callback
491
492 def _encode_jwt_from_config(
493 self,
494 identity: Any,
495 token_type: str,
496 claims=None,
497 fresh: Fresh = False,
498 expires_delta: Optional[ExpiresDelta] = None,
499 headers=None,
500 ) -> str:
501 header_overrides = self._jwt_additional_header_callback(identity)
502 if headers is not None:
503 header_overrides.update(headers)
504
505 claim_overrides = self._user_claims_callback(identity)
506 if claims is not None:
507 claim_overrides.update(claims)
508
509 if expires_delta is None:
510 if token_type == "access":
511 expires_delta = config.access_expires
512 else:
513 expires_delta = config.refresh_expires
514
515 return _encode_jwt(
516 algorithm=config.algorithm,
517 audience=config.encode_audience,
518 claim_overrides=claim_overrides,
519 csrf=config.cookie_csrf_protect,
520 expires_delta=expires_delta,
521 fresh=fresh,
522 header_overrides=header_overrides,
523 identity=self._user_identity_callback(identity),
524 identity_claim_key=config.identity_claim_key,
525 issuer=config.encode_issuer,
526 json_encoder=config.json_encoder,
527 secret=self._encode_key_callback(identity),
528 token_type=token_type,
529 nbf=config.encode_nbf,
530 )
531
532 def _decode_jwt_from_config(
533 self, encoded_token: str, csrf_value=None, allow_expired: bool = False
534 ) -> dict:
535 unverified_claims = jwt.decode(
536 encoded_token,
537 algorithms=config.decode_algorithms,
538 options={"verify_signature": False},
539 )
540 unverified_headers = jwt.get_unverified_header(encoded_token)
541 secret = self._decode_key_callback(unverified_headers, unverified_claims)
542
543 kwargs = {
544 "algorithms": config.decode_algorithms,
545 "audience": config.decode_audience,
546 "csrf_value": csrf_value,
547 "encoded_token": encoded_token,
548 "identity_claim_key": config.identity_claim_key,
549 "issuer": config.decode_issuer,
550 "leeway": config.leeway,
551 "secret": secret,
552 "verify_aud": config.decode_audience is not None,
553 "verify_sub": config.verify_sub,
554 }
555
556 try:
557 return _decode_jwt(**kwargs, allow_expired=allow_expired)
558 except ExpiredSignatureError as e:
559 # TODO: If we ever do another breaking change, don't raise this pyjwt
560 # error directly, instead raise a custom error of ours from this
561 # error.
562 e.jwt_header = unverified_headers # type: ignore
563 e.jwt_data = _decode_jwt(**kwargs, allow_expired=True) # type: ignore
564 raise