Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/flask_jwt_extended/tokens.py: 74%

53 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-09 07:17 +0000

1import uuid 

2from datetime import datetime 

3from datetime import timedelta 

4from datetime import timezone 

5from hmac import compare_digest 

6from json import JSONEncoder 

7from typing import Any 

8from typing import Iterable 

9from typing import List 

10from typing import Type 

11from typing import Union 

12 

13import jwt 

14 

15from flask_jwt_extended.exceptions import CSRFError 

16from flask_jwt_extended.exceptions import JWTDecodeError 

17from flask_jwt_extended.typing import ExpiresDelta 

18from flask_jwt_extended.typing import Fresh 

19 

20 

21def _encode_jwt( 

22 algorithm: str, 

23 audience: Union[str, Iterable[str]], 

24 claim_overrides: dict, 

25 csrf: bool, 

26 expires_delta: ExpiresDelta, 

27 fresh: Fresh, 

28 header_overrides: dict, 

29 identity: Any, 

30 identity_claim_key: str, 

31 issuer: str, 

32 json_encoder: Type[JSONEncoder], 

33 secret: str, 

34 token_type: str, 

35 nbf: bool, 

36) -> str: 

37 now = datetime.now(timezone.utc) 

38 

39 if isinstance(fresh, timedelta): 

40 fresh = datetime.timestamp(now + fresh) 

41 

42 token_data = { 

43 "fresh": fresh, 

44 "iat": now, 

45 "jti": str(uuid.uuid4()), 

46 "type": token_type, 

47 identity_claim_key: identity, 

48 } 

49 

50 if nbf: 

51 token_data["nbf"] = now 

52 

53 if csrf: 

54 token_data["csrf"] = str(uuid.uuid4()) 

55 

56 if audience: 

57 token_data["aud"] = audience 

58 

59 if issuer: 

60 token_data["iss"] = issuer 

61 

62 if expires_delta: 

63 token_data["exp"] = now + expires_delta 

64 

65 if claim_overrides: 

66 token_data.update(claim_overrides) 

67 

68 return jwt.encode( 

69 token_data, 

70 secret, 

71 algorithm, 

72 json_encoder=json_encoder, # type: ignore 

73 headers=header_overrides, 

74 ) 

75 

76 

77def _decode_jwt( 

78 algorithms: List, 

79 allow_expired: bool, 

80 audience: Union[str, Iterable[str]], 

81 csrf_value: str, 

82 encoded_token: str, 

83 identity_claim_key: str, 

84 issuer: str, 

85 leeway: int, 

86 secret: str, 

87 verify_aud: bool, 

88) -> dict: 

89 options = {"verify_aud": verify_aud} 

90 if allow_expired: 

91 options["verify_exp"] = False 

92 

93 # This call verifies the ext, iat, and nbf claims 

94 # This optionally verifies the exp and aud claims if enabled 

95 decoded_token = jwt.decode( 

96 encoded_token, 

97 secret, 

98 algorithms=algorithms, 

99 audience=audience, 

100 issuer=issuer, 

101 leeway=leeway, 

102 options=options, 

103 ) 

104 

105 # Make sure that any custom claims we expect in the token are present 

106 if identity_claim_key not in decoded_token: 

107 raise JWTDecodeError("Missing claim: {}".format(identity_claim_key)) 

108 

109 if "type" not in decoded_token: 

110 decoded_token["type"] = "access" 

111 

112 if "fresh" not in decoded_token: 

113 decoded_token["fresh"] = False 

114 

115 if "jti" not in decoded_token: 

116 decoded_token["jti"] = None 

117 

118 if csrf_value: 

119 if "csrf" not in decoded_token: 

120 raise JWTDecodeError("Missing claim: csrf") 

121 if not compare_digest(decoded_token["csrf"], csrf_value): 

122 raise CSRFError("CSRF double submit tokens do not match") 

123 

124 return decoded_token