1import time
2from typing import Optional
3
4from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
5
6
7class JWKSetCache:
8 def __init__(self, lifespan: int) -> None:
9 self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
10 self.lifespan = lifespan
11
12 def put(self, jwk_set: PyJWKSet) -> None:
13 if jwk_set is not None:
14 self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
15 else:
16 # clear cache
17 self.jwk_set_with_timestamp = None
18
19 def get(self) -> Optional[PyJWKSet]:
20 if self.jwk_set_with_timestamp is None or self.is_expired():
21 return None
22
23 return self.jwk_set_with_timestamp.get_jwk_set()
24
25 def is_expired(self) -> bool:
26 return (
27 self.jwk_set_with_timestamp is not None
28 and self.lifespan > -1
29 and time.monotonic()
30 > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
31 )