1import uuid 
    2from datetime import datetime 
    3from datetime import timedelta 
    4from datetime import timezone 
    5from hmac import compare_digest 
    6from json import JSONEncoder 
    7from typing import Any 
    8from typing import Iterable 
    9from typing import List 
    10from typing import Type 
    11from typing import Union 
    12 
    13import jwt 
    14 
    15from flask_jwt_extended.exceptions import CSRFError 
    16from flask_jwt_extended.exceptions import JWTDecodeError 
    17from flask_jwt_extended.typing import ExpiresDelta 
    18from flask_jwt_extended.typing import Fresh 
    19 
    20 
    21def _encode_jwt( 
    22    algorithm: str, 
    23    audience: Union[str, Iterable[str]], 
    24    claim_overrides: dict, 
    25    csrf: bool, 
    26    expires_delta: ExpiresDelta, 
    27    fresh: Fresh, 
    28    header_overrides: dict, 
    29    identity: Any, 
    30    identity_claim_key: str, 
    31    issuer: str, 
    32    json_encoder: Type[JSONEncoder], 
    33    secret: str, 
    34    token_type: str, 
    35    nbf: bool, 
    36) -> str: 
    37    now = datetime.now(timezone.utc) 
    38 
    39    if isinstance(fresh, timedelta): 
    40        fresh = datetime.timestamp(now + fresh) 
    41 
    42    token_data = { 
    43        "fresh": fresh, 
    44        "iat": now, 
    45        "jti": str(uuid.uuid4()), 
    46        "type": token_type, 
    47        identity_claim_key: identity, 
    48    } 
    49 
    50    if nbf: 
    51        token_data["nbf"] = now 
    52 
    53    if csrf: 
    54        token_data["csrf"] = str(uuid.uuid4()) 
    55 
    56    if audience: 
    57        token_data["aud"] = audience 
    58 
    59    if issuer: 
    60        token_data["iss"] = issuer 
    61 
    62    if expires_delta: 
    63        token_data["exp"] = now + expires_delta 
    64 
    65    if claim_overrides: 
    66        token_data.update(claim_overrides) 
    67 
    68    return jwt.encode( 
    69        token_data, 
    70        secret, 
    71        algorithm, 
    72        json_encoder=json_encoder,  # type: ignore 
    73        headers=header_overrides, 
    74    ) 
    75 
    76 
    77def _decode_jwt( 
    78    algorithms: List, 
    79    allow_expired: bool, 
    80    audience: Union[str, Iterable[str]], 
    81    csrf_value: str, 
    82    encoded_token: str, 
    83    identity_claim_key: str, 
    84    issuer: str, 
    85    leeway: int, 
    86    secret: str, 
    87    verify_aud: bool, 
    88) -> dict: 
    89    options = {"verify_aud": verify_aud} 
    90    if allow_expired: 
    91        options["verify_exp"] = False 
    92 
    93    # This call verifies the ext, iat, and nbf claims 
    94    # This optionally verifies the exp and aud claims if enabled 
    95    decoded_token = jwt.decode( 
    96        encoded_token, 
    97        secret, 
    98        algorithms=algorithms, 
    99        audience=audience, 
    100        issuer=issuer, 
    101        leeway=leeway, 
    102        options=options, 
    103    ) 
    104 
    105    # Make sure that any custom claims we expect in the token are present 
    106    if identity_claim_key not in decoded_token: 
    107        raise JWTDecodeError("Missing claim: {}".format(identity_claim_key)) 
    108 
    109    if "type" not in decoded_token: 
    110        decoded_token["type"] = "access" 
    111 
    112    if "fresh" not in decoded_token: 
    113        decoded_token["fresh"] = False 
    114 
    115    if "jti" not in decoded_token: 
    116        decoded_token["jti"] = None 
    117 
    118    if csrf_value: 
    119        if "csrf" not in decoded_token: 
    120            raise JWTDecodeError("Missing claim: csrf") 
    121        if not compare_digest(decoded_token["csrf"], csrf_value): 
    122            raise CSRFError("CSRF double submit tokens do not match") 
    123 
    124    return decoded_token