1import datetime 
    2from typing import Any 
    3from typing import Callable 
    4from typing import Optional 
    5 
    6import jwt 
    7from flask import Flask 
    8from jwt import DecodeError 
    9from jwt import ExpiredSignatureError 
    10from jwt import InvalidAudienceError 
    11from jwt import InvalidIssuerError 
    12from jwt import InvalidTokenError 
    13from jwt import MissingRequiredClaimError 
    14 
    15from flask_jwt_extended.config import config 
    16from flask_jwt_extended.default_callbacks import default_additional_claims_callback 
    17from flask_jwt_extended.default_callbacks import default_blocklist_callback 
    18from flask_jwt_extended.default_callbacks import default_decode_key_callback 
    19from flask_jwt_extended.default_callbacks import default_encode_key_callback 
    20from flask_jwt_extended.default_callbacks import default_expired_token_callback 
    21from flask_jwt_extended.default_callbacks import default_invalid_token_callback 
    22from flask_jwt_extended.default_callbacks import default_jwt_headers_callback 
    23from flask_jwt_extended.default_callbacks import default_needs_fresh_token_callback 
    24from flask_jwt_extended.default_callbacks import default_revoked_token_callback 
    25from flask_jwt_extended.default_callbacks import default_token_verification_callback 
    26from flask_jwt_extended.default_callbacks import ( 
    27    default_token_verification_failed_callback, 
    28) 
    29from flask_jwt_extended.default_callbacks import default_unauthorized_callback 
    30from flask_jwt_extended.default_callbacks import default_user_identity_callback 
    31from flask_jwt_extended.default_callbacks import default_user_lookup_error_callback 
    32from flask_jwt_extended.exceptions import CSRFError 
    33from flask_jwt_extended.exceptions import FreshTokenRequired 
    34from flask_jwt_extended.exceptions import InvalidHeaderError 
    35from flask_jwt_extended.exceptions import InvalidQueryParamError 
    36from flask_jwt_extended.exceptions import JWTDecodeError 
    37from flask_jwt_extended.exceptions import NoAuthorizationError 
    38from flask_jwt_extended.exceptions import RevokedTokenError 
    39from flask_jwt_extended.exceptions import UserClaimsVerificationError 
    40from flask_jwt_extended.exceptions import UserLookupError 
    41from flask_jwt_extended.exceptions import WrongTokenError 
    42from flask_jwt_extended.tokens import _decode_jwt 
    43from flask_jwt_extended.tokens import _encode_jwt 
    44from flask_jwt_extended.typing import ExpiresDelta 
    45from flask_jwt_extended.typing import Fresh 
    46from flask_jwt_extended.utils import current_user_context_processor 
    47 
    48 
    49class JWTManager(object): 
    50    """ 
    51    An object used to hold JWT settings and callback functions for the 
    52    Flask-JWT-Extended extension. 
    53 
    54    Instances of :class:`JWTManager` are *not* bound to specific apps, so 
    55    you can create one in the main body of your code and then bind it 
    56    to your app in a factory function. 
    57    """ 
    58 
    59    def __init__( 
    60        self, app: Optional[Flask] = None, add_context_processor: bool = False 
    61    ) -> None: 
    62        """ 
    63        Create the JWTManager instance. You can either pass a flask application 
    64        in directly here to register this extension with the flask app, or 
    65        call init_app after creating this object (in a factory pattern). 
    66 
    67        :param app: 
    68            The Flask Application object 
    69        :param add_context_processor: 
    70            Controls if `current_user` is should be added to flasks template 
    71            context (and thus be available for use in Jinja templates). Defaults 
    72            to ``False``. 
    73        """ 
    74        # Register the default error handler callback methods. These can be 
    75        # overridden with the appropriate loader decorators 
    76        self._decode_key_callback = default_decode_key_callback 
    77        self._encode_key_callback = default_encode_key_callback 
    78        self._expired_token_callback = default_expired_token_callback 
    79        self._invalid_token_callback = default_invalid_token_callback 
    80        self._jwt_additional_header_callback = default_jwt_headers_callback 
    81        self._needs_fresh_token_callback = default_needs_fresh_token_callback 
    82        self._revoked_token_callback = default_revoked_token_callback 
    83        self._token_in_blocklist_callback = default_blocklist_callback 
    84        self._token_verification_callback = default_token_verification_callback 
    85        self._unauthorized_callback = default_unauthorized_callback 
    86        self._user_claims_callback = default_additional_claims_callback 
    87        self._user_identity_callback = default_user_identity_callback 
    88        self._user_lookup_callback: Optional[Callable] = None 
    89        self._user_lookup_error_callback = default_user_lookup_error_callback 
    90        self._token_verification_failed_callback = ( 
    91            default_token_verification_failed_callback 
    92        ) 
    93 
    94        # Register this extension with the flask app now (if it is provided) 
    95        if app is not None: 
    96            self.init_app(app, add_context_processor) 
    97 
    98    def init_app(self, app: Flask, add_context_processor: bool = False) -> None: 
    99        """ 
    100        Register this extension with the flask app. 
    101 
    102        :param app: 
    103            The Flask Application object 
    104        :param add_context_processor: 
    105            Controls if `current_user` is should be added to flasks template 
    106            context (and thus be available for use in Jinja templates). Defaults 
    107            to ``False``. 
    108        """ 
    109        # Save this so we can use it later in the extension 
    110        if not hasattr(app, "extensions"):  # pragma: no cover 
    111            app.extensions = {} 
    112        app.extensions["flask-jwt-extended"] = self 
    113 
    114        if add_context_processor: 
    115            app.context_processor(current_user_context_processor) 
    116 
    117        # Set all the default configurations for this extension 
    118        self._set_default_configuration_options(app) 
    119        self._set_error_handler_callbacks(app) 
    120 
    121    def _set_error_handler_callbacks(self, app: Flask) -> None: 
    122        @app.errorhandler(CSRFError) 
    123        def handle_csrf_error(e): 
    124            return self._unauthorized_callback(str(e)) 
    125 
    126        @app.errorhandler(DecodeError) 
    127        def handle_decode_error(e): 
    128            return self._invalid_token_callback(str(e)) 
    129 
    130        @app.errorhandler(ExpiredSignatureError) 
    131        def handle_expired_error(e): 
    132            return self._expired_token_callback(e.jwt_header, e.jwt_data) 
    133 
    134        @app.errorhandler(FreshTokenRequired) 
    135        def handle_fresh_token_required(e): 
    136            return self._needs_fresh_token_callback(e.jwt_header, e.jwt_data) 
    137 
    138        @app.errorhandler(MissingRequiredClaimError) 
    139        def handle_missing_required_claim_error(e): 
    140            return self._invalid_token_callback(str(e)) 
    141 
    142        @app.errorhandler(InvalidAudienceError) 
    143        def handle_invalid_audience_error(e): 
    144            return self._invalid_token_callback(str(e)) 
    145 
    146        @app.errorhandler(InvalidIssuerError) 
    147        def handle_invalid_issuer_error(e): 
    148            return self._invalid_token_callback(str(e)) 
    149 
    150        @app.errorhandler(InvalidHeaderError) 
    151        def handle_invalid_header_error(e): 
    152            return self._invalid_token_callback(str(e)) 
    153 
    154        @app.errorhandler(InvalidTokenError) 
    155        def handle_invalid_token_error(e): 
    156            return self._invalid_token_callback(str(e)) 
    157 
    158        @app.errorhandler(JWTDecodeError) 
    159        def handle_jwt_decode_error(e): 
    160            return self._invalid_token_callback(str(e)) 
    161 
    162        @app.errorhandler(NoAuthorizationError) 
    163        def handle_auth_error(e): 
    164            return self._unauthorized_callback(str(e)) 
    165 
    166        @app.errorhandler(InvalidQueryParamError) 
    167        def handle_invalid_query_param_error(e): 
    168            return self._invalid_token_callback(str(e)) 
    169 
    170        @app.errorhandler(RevokedTokenError) 
    171        def handle_revoked_token_error(e): 
    172            return self._revoked_token_callback(e.jwt_header, e.jwt_data) 
    173 
    174        @app.errorhandler(UserClaimsVerificationError) 
    175        def handle_failed_token_verification(e): 
    176            return self._token_verification_failed_callback(e.jwt_header, e.jwt_data) 
    177 
    178        @app.errorhandler(UserLookupError) 
    179        def handler_user_lookup_error(e): 
    180            return self._user_lookup_error_callback(e.jwt_header, e.jwt_data) 
    181 
    182        @app.errorhandler(WrongTokenError) 
    183        def handle_wrong_token_error(e): 
    184            return self._invalid_token_callback(str(e)) 
    185 
    186    @staticmethod 
    187    def _set_default_configuration_options(app: Flask) -> None: 
    188        app.config.setdefault( 
    189            "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15) 
    190        ) 
    191        app.config.setdefault("JWT_ACCESS_COOKIE_NAME", "access_token_cookie") 
    192        app.config.setdefault("JWT_ACCESS_COOKIE_PATH", "/") 
    193        app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_NAME", "csrf_access_token") 
    194        app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_PATH", "/") 
    195        app.config.setdefault("JWT_ACCESS_CSRF_FIELD_NAME", "csrf_token") 
    196        app.config.setdefault("JWT_ACCESS_CSRF_HEADER_NAME", "X-CSRF-TOKEN") 
    197        app.config.setdefault("JWT_ALGORITHM", "HS256") 
    198        app.config.setdefault("JWT_COOKIE_CSRF_PROTECT", True) 
    199        app.config.setdefault("JWT_COOKIE_DOMAIN", None) 
    200        app.config.setdefault("JWT_COOKIE_SAMESITE", None) 
    201        app.config.setdefault("JWT_COOKIE_SECURE", False) 
    202        app.config.setdefault("JWT_CSRF_CHECK_FORM", False) 
    203        app.config.setdefault("JWT_CSRF_IN_COOKIES", True) 
    204        app.config.setdefault("JWT_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"]) 
    205        app.config.setdefault("JWT_DECODE_ALGORITHMS", None) 
    206        app.config.setdefault("JWT_DECODE_AUDIENCE", None) 
    207        app.config.setdefault("JWT_DECODE_ISSUER", None) 
    208        app.config.setdefault("JWT_DECODE_LEEWAY", 0) 
    209        app.config.setdefault("JWT_ENCODE_AUDIENCE", None) 
    210        app.config.setdefault("JWT_ENCODE_ISSUER", None) 
    211        app.config.setdefault("JWT_ERROR_MESSAGE_KEY", "msg") 
    212        app.config.setdefault("JWT_HEADER_NAME", "Authorization") 
    213        app.config.setdefault("JWT_HEADER_TYPE", "Bearer") 
    214        app.config.setdefault("JWT_IDENTITY_CLAIM", "sub") 
    215        app.config.setdefault("JWT_JSON_KEY", "access_token") 
    216        app.config.setdefault("JWT_PRIVATE_KEY", None) 
    217        app.config.setdefault("JWT_PUBLIC_KEY", None) 
    218        app.config.setdefault("JWT_QUERY_STRING_NAME", "jwt") 
    219        app.config.setdefault("JWT_QUERY_STRING_VALUE_PREFIX", "") 
    220        app.config.setdefault("JWT_REFRESH_COOKIE_NAME", "refresh_token_cookie") 
    221        app.config.setdefault("JWT_REFRESH_COOKIE_PATH", "/") 
    222        app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_NAME", "csrf_refresh_token") 
    223        app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_PATH", "/") 
    224        app.config.setdefault("JWT_REFRESH_CSRF_FIELD_NAME", "csrf_token") 
    225        app.config.setdefault("JWT_REFRESH_CSRF_HEADER_NAME", "X-CSRF-TOKEN") 
    226        app.config.setdefault("JWT_REFRESH_JSON_KEY", "refresh_token") 
    227        app.config.setdefault("JWT_REFRESH_TOKEN_EXPIRES", datetime.timedelta(days=30)) 
    228        app.config.setdefault("JWT_SECRET_KEY", None) 
    229        app.config.setdefault("JWT_SESSION_COOKIE", True) 
    230        app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",)) 
    231        app.config.setdefault("JWT_ENCODE_NBF", True) 
    232 
    233    def additional_claims_loader(self, callback: Callable) -> Callable: 
    234        """ 
    235        This decorator sets the callback function used to add additional claims 
    236        when creating a JWT. The claims returned by this function will be merged 
    237        with any claims passed in via the ``additional_claims`` argument to 
    238        :func:`~flask_jwt_extended.create_access_token` or 
    239        :func:`~flask_jwt_extended.create_refresh_token`. 
    240 
    241        The decorated function must take **one** argument. 
    242 
    243        The argument is the identity that was used when creating a JWT. 
    244 
    245        The decorated function must return a dictionary of claims to add to the JWT. 
    246        """ 
    247        self._user_claims_callback = callback 
    248        return callback 
    249 
    250    def additional_headers_loader(self, callback: Callable) -> Callable: 
    251        """ 
    252        This decorator sets the callback function used to add additional headers 
    253        when creating a JWT. The headers returned by this function will be merged 
    254        with any headers passed in via the ``additional_headers`` argument to 
    255        :func:`~flask_jwt_extended.create_access_token` or 
    256        :func:`~flask_jwt_extended.create_refresh_token`. 
    257 
    258        The decorated function must take **one** argument. 
    259 
    260        The argument is the identity that was used when creating a JWT. 
    261 
    262        The decorated function must return a dictionary of headers to add to the JWT. 
    263        """ 
    264        self._jwt_additional_header_callback = callback 
    265        return callback 
    266 
    267    def decode_key_loader(self, callback: Callable) -> Callable: 
    268        """ 
    269        This decorator sets the callback function for dynamically setting the JWT 
    270        decode key based on the **UNVERIFIED** contents of the token. Think 
    271        carefully before using this functionality, in most cases you probably 
    272        don't need it. 
    273 
    274        The decorated function must take **two** arguments. 
    275 
    276        The first argument is a dictionary containing the header data of the 
    277        unverified JWT. 
    278 
    279        The second argument is a dictionary containing the payload data of the 
    280        unverified JWT. 
    281 
    282        The decorated function must return a *string* that is used to decode and 
    283        verify the token. 
    284        """ 
    285        self._decode_key_callback = callback 
    286        return callback 
    287 
    288    def encode_key_loader(self, callback: Callable) -> Callable: 
    289        """ 
    290        This decorator sets the callback function for dynamically setting the JWT 
    291        encode key based on the tokens identity. Think carefully before using this 
    292        functionality, in most cases you probably don't need it. 
    293 
    294        The decorated function must take **one** argument. 
    295 
    296        The argument is the identity used to create this JWT. 
    297 
    298        The decorated function must return a *string* which is the secrete key used to 
    299        encode the JWT. 
    300        """ 
    301        self._encode_key_callback = callback 
    302        return callback 
    303 
    304    def expired_token_loader(self, callback: Callable) -> Callable: 
    305        """ 
    306        This decorator sets the callback function for returning a custom 
    307        response when an expired JWT is encountered. 
    308 
    309        The decorated function must take **two** arguments. 
    310 
    311        The first argument is a dictionary containing the header data of the JWT. 
    312 
    313        The second argument is a dictionary containing the payload data of the JWT. 
    314 
    315        The decorated function must return a Flask Response. 
    316        """ 
    317        self._expired_token_callback = callback 
    318        return callback 
    319 
    320    def invalid_token_loader(self, callback: Callable) -> Callable: 
    321        """ 
    322        This decorator sets the callback function for returning a custom 
    323        response when an invalid JWT is encountered. 
    324 
    325        This decorator sets the callback function that will be used if an 
    326        invalid JWT attempts to access a protected endpoint. 
    327 
    328        The decorated function must take **one** argument. 
    329 
    330        The argument is a string which contains the reason why a token is invalid. 
    331 
    332        The decorated function must return a Flask Response. 
    333        """ 
    334        self._invalid_token_callback = callback 
    335        return callback 
    336 
    337    def needs_fresh_token_loader(self, callback: Callable) -> Callable: 
    338        """ 
    339        This decorator sets the callback function for returning a custom 
    340        response when a valid and non-fresh token is used on an endpoint 
    341        that is marked as ``fresh=True``. 
    342 
    343        The decorated function must take **two** arguments. 
    344 
    345        The first argument is a dictionary containing the header data of the JWT. 
    346 
    347        The second argument is a dictionary containing the payload data of the JWT. 
    348 
    349        The decorated function must return a Flask Response. 
    350        """ 
    351        self._needs_fresh_token_callback = callback 
    352        return callback 
    353 
    354    def revoked_token_loader(self, callback: Callable) -> Callable: 
    355        """ 
    356        This decorator sets the callback function for returning a custom 
    357        response when a revoked token is encountered. 
    358 
    359        The decorated function must take **two** arguments. 
    360 
    361        The first argument is a dictionary containing the header data of the JWT. 
    362 
    363        The second argument is a dictionary containing the payload data of the JWT. 
    364 
    365        The decorated function must return a Flask Response. 
    366        """ 
    367        self._revoked_token_callback = callback 
    368        return callback 
    369 
    370    def token_in_blocklist_loader(self, callback: Callable) -> Callable: 
    371        """ 
    372        This decorator sets the callback function used to check if a JWT has 
    373        been revoked. 
    374 
    375        The decorated function must take **two** arguments. 
    376 
    377        The first argument is a dictionary containing the header data of the JWT. 
    378 
    379        The second argument is a dictionary containing the payload data of the JWT. 
    380 
    381        The decorated function must be return ``True`` if the token has been 
    382        revoked, ``False`` otherwise. 
    383        """ 
    384        self._token_in_blocklist_callback = callback 
    385        return callback 
    386 
    387    def token_verification_failed_loader(self, callback: Callable) -> Callable: 
    388        """ 
    389        This decorator sets the callback function used to return a custom 
    390        response when the claims verification check fails. 
    391 
    392        The decorated function must take **two** arguments. 
    393 
    394        The first argument is a dictionary containing the header data of the JWT. 
    395 
    396        The second argument is a dictionary containing the payload data of the JWT. 
    397 
    398        The decorated function must return a Flask Response. 
    399        """ 
    400        self._token_verification_failed_callback = callback 
    401        return callback 
    402 
    403    def token_verification_loader(self, callback: Callable) -> Callable: 
    404        """ 
    405        This decorator sets the callback function used for custom verification 
    406        of a valid JWT. 
    407 
    408        The decorated function must take **two** arguments. 
    409 
    410        The first argument is a dictionary containing the header data of the JWT. 
    411 
    412        The second argument is a dictionary containing the payload data of the JWT. 
    413 
    414        The decorated function must return ``True`` if the token is valid, or 
    415        ``False`` otherwise. 
    416        """ 
    417        self._token_verification_callback = callback 
    418        return callback 
    419 
    420    def unauthorized_loader(self, callback: Callable) -> Callable: 
    421        """ 
    422        This decorator sets the callback function used to return a custom 
    423        response when no JWT is present. 
    424 
    425        The decorated function must take **one** argument. 
    426 
    427        The argument is a string that explains why the JWT could not be found. 
    428 
    429        The decorated function must return a Flask Response. 
    430        """ 
    431        self._unauthorized_callback = callback 
    432        return callback 
    433 
    434    def user_identity_loader(self, callback: Callable) -> Callable: 
    435        """ 
    436        This decorator sets the callback function used to convert an identity to 
    437        a JSON serializable format when creating JWTs. This is useful for 
    438        using objects (such as SQLAlchemy instances) as the identity when 
    439        creating your tokens. 
    440 
    441        The decorated function must take **one** argument. 
    442 
    443        The argument is the identity that was used when creating a JWT. 
    444 
    445        The decorated function must return JSON serializable data. 
    446        """ 
    447        self._user_identity_callback = callback 
    448        return callback 
    449 
    450    def user_lookup_loader(self, callback: Callable) -> Callable: 
    451        """ 
    452        This decorator sets the callback function used to convert a JWT into 
    453        a python object that can be used in a protected endpoint. This is useful 
    454        for automatically loading a SQLAlchemy instance based on the contents 
    455        of the JWT. 
    456 
    457        The object returned from this function can be accessed via 
    458        :attr:`~flask_jwt_extended.current_user` or 
    459        :meth:`~flask_jwt_extended.get_current_user` 
    460 
    461        The decorated function must take **two** arguments. 
    462 
    463        The first argument is a dictionary containing the header data of the JWT. 
    464 
    465        The second argument is a dictionary containing the payload data of the JWT. 
    466 
    467        The decorated function can return any python object, which can then be 
    468        accessed in a protected endpoint. If an object cannot be loaded, for 
    469        example if a user has been deleted from your database, ``None`` must be 
    470        returned to indicate that an error occurred loading the user. 
    471        """ 
    472        self._user_lookup_callback = callback 
    473        return callback 
    474 
    475    def user_lookup_error_loader(self, callback: Callable) -> Callable: 
    476        """ 
    477        This decorator sets the callback function used to return a custom 
    478        response when loading a user via 
    479        :meth:`~flask_jwt_extended.JWTManager.user_lookup_loader` fails. 
    480 
    481        The decorated function must take **two** arguments. 
    482 
    483        The first argument is a dictionary containing the header data of the JWT. 
    484 
    485        The second argument is a dictionary containing the payload data of the JWT. 
    486 
    487        The decorated function must return a Flask Response. 
    488        """ 
    489        self._user_lookup_error_callback = callback 
    490        return callback 
    491 
    492    def _encode_jwt_from_config( 
    493        self, 
    494        identity: Any, 
    495        token_type: str, 
    496        claims=None, 
    497        fresh: Fresh = False, 
    498        expires_delta: Optional[ExpiresDelta] = None, 
    499        headers=None, 
    500    ) -> str: 
    501        header_overrides = self._jwt_additional_header_callback(identity) 
    502        if headers is not None: 
    503            header_overrides.update(headers) 
    504 
    505        claim_overrides = self._user_claims_callback(identity) 
    506        if claims is not None: 
    507            claim_overrides.update(claims) 
    508 
    509        if expires_delta is None: 
    510            if token_type == "access": 
    511                expires_delta = config.access_expires 
    512            else: 
    513                expires_delta = config.refresh_expires 
    514 
    515        return _encode_jwt( 
    516            algorithm=config.algorithm, 
    517            audience=config.encode_audience, 
    518            claim_overrides=claim_overrides, 
    519            csrf=config.cookie_csrf_protect, 
    520            expires_delta=expires_delta, 
    521            fresh=fresh, 
    522            header_overrides=header_overrides, 
    523            identity=self._user_identity_callback(identity), 
    524            identity_claim_key=config.identity_claim_key, 
    525            issuer=config.encode_issuer, 
    526            json_encoder=config.json_encoder, 
    527            secret=self._encode_key_callback(identity), 
    528            token_type=token_type, 
    529            nbf=config.encode_nbf, 
    530        ) 
    531 
    532    def _decode_jwt_from_config( 
    533        self, encoded_token: str, csrf_value=None, allow_expired: bool = False 
    534    ) -> dict: 
    535        unverified_claims = jwt.decode( 
    536            encoded_token, 
    537            algorithms=config.decode_algorithms, 
    538            options={"verify_signature": False}, 
    539        ) 
    540        unverified_headers = jwt.get_unverified_header(encoded_token) 
    541        secret = self._decode_key_callback(unverified_headers, unverified_claims) 
    542 
    543        kwargs = { 
    544            "algorithms": config.decode_algorithms, 
    545            "audience": config.decode_audience, 
    546            "csrf_value": csrf_value, 
    547            "encoded_token": encoded_token, 
    548            "identity_claim_key": config.identity_claim_key, 
    549            "issuer": config.decode_issuer, 
    550            "leeway": config.leeway, 
    551            "secret": secret, 
    552            "verify_aud": config.decode_audience is not None, 
    553        } 
    554 
    555        try: 
    556            return _decode_jwt(**kwargs, allow_expired=allow_expired) 
    557        except ExpiredSignatureError as e: 
    558            # TODO: If we ever do another breaking change, don't raise this pyjwt 
    559            #       error directly, instead raise a custom error of ours from this 
    560            #       error. 
    561            e.jwt_header = unverified_headers  # type: ignore 
    562            e.jwt_data = _decode_jwt(**kwargs, allow_expired=True)  # type: ignore 
    563            raise