1from __future__ import annotations
2
3import time
4from typing import TYPE_CHECKING
5
6from packaging.version import Version
7
8from limits.typing import Optional, RedisClient, ScriptP, Tuple, Type, Union
9
10from ..util import get_package_data
11from .base import MovingWindowSupport, Storage
12
13if TYPE_CHECKING:
14 import redis
15
16
17class RedisInteractor:
18 RES_DIR = "resources/redis/lua_scripts"
19
20 SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
21 SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
22 f"{RES_DIR}/acquire_moving_window.lua"
23 )
24 SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
25 SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
26
27 lua_moving_window: ScriptP[Tuple[int, int]]
28 lua_acquire_window: ScriptP[bool]
29
30 PREFIX = "LIMITS"
31
32 def prefixed_key(self, key: str) -> str:
33 return f"{self.PREFIX}:{key}"
34
35 def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]:
36 """
37 returns the starting point and the number of entries in the moving
38 window
39
40 :param key: rate limit key
41 :param expiry: expiry of entry
42 :return: (start of window, number of acquired entries)
43 """
44 key = self.prefixed_key(key)
45 timestamp = time.time()
46 window = self.lua_moving_window([key], [int(timestamp - expiry), limit])
47
48 return window or (int(timestamp), 0)
49
50 def _incr(
51 self,
52 key: str,
53 expiry: int,
54 connection: RedisClient,
55 elastic_expiry: bool = False,
56 amount: int = 1,
57 ) -> int:
58 """
59 increments the counter for a given rate limit key
60
61 :param connection: Redis connection
62 :param key: the key to increment
63 :param expiry: amount in seconds for the key to expire in
64 :param amount: the number to increment by
65 """
66 key = self.prefixed_key(key)
67 value = connection.incrby(key, amount)
68
69 if elastic_expiry or value == amount:
70 connection.expire(key, expiry)
71
72 return value
73
74 def _get(self, key: str, connection: RedisClient) -> int:
75 """
76 :param connection: Redis connection
77 :param key: the key to get the counter value for
78 """
79
80 key = self.prefixed_key(key)
81 return int(connection.get(key) or 0)
82
83 def _clear(self, key: str, connection: RedisClient) -> None:
84 """
85 :param key: the key to clear rate limits for
86 :param connection: Redis connection
87 """
88 key = self.prefixed_key(key)
89 connection.delete(key)
90
91 def _acquire_entry(
92 self,
93 key: str,
94 limit: int,
95 expiry: int,
96 connection: RedisClient,
97 amount: int = 1,
98 ) -> bool:
99 """
100 :param key: rate limit key to acquire an entry in
101 :param limit: amount of entries allowed
102 :param expiry: expiry of the entry
103 :param connection: Redis connection
104 :param amount: the number of entries to acquire
105 """
106 key = self.prefixed_key(key)
107 timestamp = time.time()
108 acquired = self.lua_acquire_window([key], [timestamp, limit, expiry, amount])
109
110 return bool(acquired)
111
112 def _get_expiry(self, key: str, connection: RedisClient) -> int:
113 """
114 :param key: the key to get the expiry for
115 :param connection: Redis connection
116 """
117
118 key = self.prefixed_key(key)
119 return int(max(connection.ttl(key), 0) + time.time())
120
121 def _check(self, connection: RedisClient) -> bool:
122 """
123 :param connection: Redis connection
124 check if storage is healthy
125 """
126 try:
127 return connection.ping()
128 except: # noqa
129 return False
130
131
132class RedisStorage(RedisInteractor, Storage, MovingWindowSupport):
133 """
134 Rate limit storage with redis as backend.
135
136 Depends on :pypi:`redis`.
137 """
138
139 STORAGE_SCHEME = ["redis", "rediss", "redis+unix"]
140 """The storage scheme for redis"""
141
142 DEPENDENCIES = {"redis": Version("3.0")}
143
144 def __init__(
145 self,
146 uri: str,
147 connection_pool: Optional[redis.connection.ConnectionPool] = None,
148 wrap_exceptions: bool = False,
149 **options: Union[float, str, bool],
150 ) -> None:
151 """
152 :param uri: uri of the form ``redis://[:password]@host:port``,
153 ``redis://[:password]@host:port/db``,
154 ``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
155 This uri is passed directly to :func:`redis.from_url` except for the
156 case of ``redis+unix://`` where it is replaced with ``unix://``.
157 :param connection_pool: if provided, the redis client is initialized with
158 the connection pool and any other params passed as :paramref:`options`
159 :param wrap_exceptions: Whether to wrap storage exceptions in
160 :exc:`limits.errors.StorageError` before raising it.
161 :param options: all remaining keyword arguments are passed
162 directly to the constructor of :class:`redis.Redis`
163 :raise ConfigurationError: when the :pypi:`redis` library is not available
164 """
165 super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
166 self.dependency = self.dependencies["redis"].module
167
168 uri = uri.replace("redis+unix", "unix")
169
170 if not connection_pool:
171 self.storage = self.dependency.from_url(uri, **options)
172 else:
173 self.storage = self.dependency.Redis(
174 connection_pool=connection_pool, **options
175 )
176 self.initialize_storage(uri)
177
178 @property
179 def base_exceptions(
180 self,
181 ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
182 return self.dependency.RedisError # type: ignore[no-any-return]
183
184 def initialize_storage(self, _uri: str) -> None:
185 self.lua_moving_window = self.storage.register_script(self.SCRIPT_MOVING_WINDOW)
186 self.lua_acquire_window = self.storage.register_script(
187 self.SCRIPT_ACQUIRE_MOVING_WINDOW
188 )
189 self.lua_clear_keys = self.storage.register_script(self.SCRIPT_CLEAR_KEYS)
190 self.lua_incr_expire = self.storage.register_script(
191 RedisStorage.SCRIPT_INCR_EXPIRE
192 )
193
194 def incr(
195 self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
196 ) -> int:
197 """
198 increments the counter for a given rate limit key
199
200 :param key: the key to increment
201 :param expiry: amount in seconds for the key to expire in
202 :param amount: the number to increment by
203 """
204
205 if elastic_expiry:
206 return super()._incr(key, expiry, self.storage, elastic_expiry, amount)
207 else:
208 key = self.prefixed_key(key)
209 return int(self.lua_incr_expire([key], [expiry, amount]))
210
211 def get(self, key: str) -> int:
212 """
213 :param key: the key to get the counter value for
214 """
215
216 return super()._get(key, self.storage)
217
218 def clear(self, key: str) -> None:
219 """
220 :param key: the key to clear rate limits for
221 """
222
223 return super()._clear(key, self.storage)
224
225 def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
226 """
227 :param key: rate limit key to acquire an entry in
228 :param limit: amount of entries allowed
229 :param expiry: expiry of the entry
230 :param amount: the number to increment by
231 """
232
233 return super()._acquire_entry(key, limit, expiry, self.storage, amount)
234
235 def get_expiry(self, key: str) -> int:
236 """
237 :param key: the key to get the expiry for
238 """
239
240 return super()._get_expiry(key, self.storage)
241
242 def check(self) -> bool:
243 """
244 check if storage is healthy
245 """
246
247 return super()._check(self.storage)
248
249 def reset(self) -> Optional[int]:
250 """
251 This function calls a Lua Script to delete keys prefixed with
252 ``self.PREFIX`` in blocks of 5000.
253
254 .. warning::
255 This operation was designed to be fast, but was not tested
256 on a large production based system. Be careful with its usage as it
257 could be slow on very large data sets.
258
259 """
260
261 prefix = self.prefixed_key("*")
262 return int(self.lua_clear_keys([prefix]))