Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/auth/token.py: 46%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

83 statements  

1from abc import ABC, abstractmethod 

2from datetime import datetime, timezone 

3 

4from redis.auth.err import InvalidTokenSchemaErr 

5 

6 

7class TokenInterface(ABC): 

8 @abstractmethod 

9 def is_expired(self) -> bool: 

10 pass 

11 

12 @abstractmethod 

13 def ttl(self) -> float: 

14 pass 

15 

16 @abstractmethod 

17 def try_get(self, key: str) -> str: 

18 pass 

19 

20 @abstractmethod 

21 def get_value(self) -> str: 

22 pass 

23 

24 @abstractmethod 

25 def get_expires_at_ms(self) -> float: 

26 pass 

27 

28 @abstractmethod 

29 def get_received_at_ms(self) -> float: 

30 pass 

31 

32 

33class TokenResponse: 

34 def __init__(self, token: TokenInterface): 

35 self._token = token 

36 

37 def get_token(self) -> TokenInterface: 

38 return self._token 

39 

40 def get_ttl_ms(self) -> float: 

41 return self._token.get_expires_at_ms() - self._token.get_received_at_ms() 

42 

43 

44class SimpleToken(TokenInterface): 

45 def __init__( 

46 self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict 

47 ) -> None: 

48 self.value = value 

49 self.expires_at = expires_at_ms 

50 self.received_at = received_at_ms 

51 self.claims = claims 

52 

53 def ttl(self) -> float: 

54 if self.expires_at == -1: 

55 return -1 

56 

57 return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000) 

58 

59 def is_expired(self) -> bool: 

60 if self.expires_at == -1: 

61 return False 

62 

63 return self.ttl() <= 0 

64 

65 def try_get(self, key: str) -> str: 

66 return self.claims.get(key) 

67 

68 def get_value(self) -> str: 

69 return self.value 

70 

71 def get_expires_at_ms(self) -> float: 

72 return self.expires_at 

73 

74 def get_received_at_ms(self) -> float: 

75 return self.received_at 

76 

77 

78class JWToken(TokenInterface): 

79 REQUIRED_FIELDS = {"exp"} 

80 

81 def __init__(self, token: str): 

82 try: 

83 import jwt 

84 except ImportError as ie: 

85 raise ImportError( 

86 f"The PyJWT library is required for {self.__class__.__name__}.", 

87 ) from ie 

88 self._value = token 

89 self._decoded = jwt.decode( 

90 self._value, 

91 options={"verify_signature": False}, 

92 algorithms=[jwt.get_unverified_header(self._value).get("alg")], 

93 ) 

94 self._validate_token() 

95 

96 def is_expired(self) -> bool: 

97 exp = self._decoded["exp"] 

98 if exp == -1: 

99 return False 

100 

101 return ( 

102 self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000 

103 ) 

104 

105 def ttl(self) -> float: 

106 exp = self._decoded["exp"] 

107 if exp == -1: 

108 return -1 

109 

110 return ( 

111 self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000 

112 ) 

113 

114 def try_get(self, key: str) -> str: 

115 return self._decoded.get(key) 

116 

117 def get_value(self) -> str: 

118 return self._value 

119 

120 def get_expires_at_ms(self) -> float: 

121 return float(self._decoded["exp"] * 1000) 

122 

123 def get_received_at_ms(self) -> float: 

124 return datetime.now(timezone.utc).timestamp() * 1000 

125 

126 def _validate_token(self): 

127 actual_fields = {x for x in self._decoded.keys()} 

128 

129 if len(self.REQUIRED_FIELDS - actual_fields) != 0: 

130 raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)