1import json
2from typing import Any
3from typing import Type
4from typing import TYPE_CHECKING
5
6from flask import current_app
7from flask import Flask
8
9from flask_jwt_extended.exceptions import RevokedTokenError
10from flask_jwt_extended.exceptions import UserClaimsVerificationError
11from flask_jwt_extended.exceptions import WrongTokenError
12
13try:
14 from flask.json.provider import DefaultJSONProvider
15
16 HAS_JSON_PROVIDER = True
17except ModuleNotFoundError: # pragma: no cover
18 # The flask.json.provider module was added in Flask 2.2.
19 # Further details are handled in get_json_encoder.
20 HAS_JSON_PROVIDER = False
21
22
23if TYPE_CHECKING: # pragma: no cover
24 from flask_jwt_extended import JWTManager
25
26
27def get_jwt_manager() -> "JWTManager":
28 try:
29 return current_app.extensions["flask-jwt-extended"]
30 except KeyError: # pragma: no cover
31 raise RuntimeError(
32 "You must initialize a JWTManager with this flask "
33 "application before using this method"
34 ) from None
35
36
37def has_user_lookup() -> bool:
38 jwt_manager = get_jwt_manager()
39 return jwt_manager._user_lookup_callback is not None
40
41
42def user_lookup(*args, **kwargs) -> Any:
43 jwt_manager = get_jwt_manager()
44 return jwt_manager._user_lookup_callback and jwt_manager._user_lookup_callback(
45 *args, **kwargs
46 )
47
48
49def verify_token_type(decoded_token: dict, refresh: bool) -> None:
50 if not refresh and decoded_token["type"] == "refresh":
51 raise WrongTokenError("Only non-refresh tokens are allowed")
52 elif refresh and decoded_token["type"] != "refresh":
53 raise WrongTokenError("Only refresh tokens are allowed")
54
55
56def verify_token_not_blocklisted(jwt_header: dict, jwt_data: dict) -> None:
57 jwt_manager = get_jwt_manager()
58 if jwt_manager._token_in_blocklist_callback(jwt_header, jwt_data):
59 raise RevokedTokenError(jwt_header, jwt_data)
60
61
62def custom_verification_for_token(jwt_header: dict, jwt_data: dict) -> None:
63 jwt_manager = get_jwt_manager()
64 if not jwt_manager._token_verification_callback(jwt_header, jwt_data):
65 error_msg = "User claims verification failed"
66 raise UserClaimsVerificationError(error_msg, jwt_header, jwt_data)
67
68
69class JSONEncoder(json.JSONEncoder):
70 """A JSON encoder which uses the app.json_provider_class for the default"""
71
72 def default(self, o: Any) -> Any:
73 # If the registered JSON provider does not implement a default classmethod
74 # use the method defined by the DefaultJSONProvider
75 default = getattr(
76 current_app.json_provider_class, "default", DefaultJSONProvider.default
77 )
78 return default(o)
79
80
81def get_json_encoder(app: Flask) -> Type[json.JSONEncoder]:
82 """Get the JSON Encoder for the provided flask app
83
84 Starting with flask version 2.2 the flask application provides a
85 interface to register a custom JSON Encoder/Decoder under the json_provider_class.
86 As this interface is not compatible with the standard JSONEncoder, the `default`
87 method of the class is wrapped.
88
89 Lookup Order:
90 - app.json_encoder - For Flask < 2.2
91 - app.json_provider_class.default
92 - flask.json.provider.DefaultJSONProvider.default
93
94 """
95 if not HAS_JSON_PROVIDER: # pragma: no cover
96 return app.json_encoder # type: ignore
97
98 return JSONEncoder