1from abc import ABC, abstractmethod
2from datetime import datetime, timezone
3
4from redis.auth.err import InvalidTokenSchemaErr
5
6
7class TokenInterface(ABC):
8 @abstractmethod
9 def is_expired(self) -> bool:
10 pass
11
12 @abstractmethod
13 def ttl(self) -> float:
14 pass
15
16 @abstractmethod
17 def try_get(self, key: str) -> str:
18 pass
19
20 @abstractmethod
21 def get_value(self) -> str:
22 pass
23
24 @abstractmethod
25 def get_expires_at_ms(self) -> float:
26 pass
27
28 @abstractmethod
29 def get_received_at_ms(self) -> float:
30 pass
31
32
33class TokenResponse:
34 def __init__(self, token: TokenInterface):
35 self._token = token
36
37 def get_token(self) -> TokenInterface:
38 return self._token
39
40 def get_ttl_ms(self) -> float:
41 return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
42
43
44class SimpleToken(TokenInterface):
45 def __init__(
46 self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
47 ) -> None:
48 self.value = value
49 self.expires_at = expires_at_ms
50 self.received_at = received_at_ms
51 self.claims = claims
52
53 def ttl(self) -> float:
54 if self.expires_at == -1:
55 return -1
56
57 return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
58
59 def is_expired(self) -> bool:
60 if self.expires_at == -1:
61 return False
62
63 return self.ttl() <= 0
64
65 def try_get(self, key: str) -> str:
66 return self.claims.get(key)
67
68 def get_value(self) -> str:
69 return self.value
70
71 def get_expires_at_ms(self) -> float:
72 return self.expires_at
73
74 def get_received_at_ms(self) -> float:
75 return self.received_at
76
77
78class JWToken(TokenInterface):
79 REQUIRED_FIELDS = {"exp"}
80
81 def __init__(self, token: str):
82 try:
83 import jwt
84 except ImportError as ie:
85 raise ImportError(
86 f"The PyJWT library is required for {self.__class__.__name__}.",
87 ) from ie
88 self._value = token
89 self._decoded = jwt.decode(
90 self._value,
91 options={"verify_signature": False},
92 algorithms=[jwt.get_unverified_header(self._value).get("alg")],
93 )
94 self._validate_token()
95
96 def is_expired(self) -> bool:
97 exp = self._decoded["exp"]
98 if exp == -1:
99 return False
100
101 return (
102 self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
103 )
104
105 def ttl(self) -> float:
106 exp = self._decoded["exp"]
107 if exp == -1:
108 return -1
109
110 return (
111 self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
112 )
113
114 def try_get(self, key: str) -> str:
115 return self._decoded.get(key)
116
117 def get_value(self) -> str:
118 return self._value
119
120 def get_expires_at_ms(self) -> float:
121 return float(self._decoded["exp"] * 1000)
122
123 def get_received_at_ms(self) -> float:
124 return datetime.now(timezone.utc).timestamp() * 1000
125
126 def _validate_token(self):
127 actual_fields = {x for x in self._decoded.keys()}
128
129 if len(self.REQUIRED_FIELDS - actual_fields) != 0:
130 raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)