1import time
2import urllib
3from typing import TYPE_CHECKING, cast
4
5from deprecated.sphinx import versionadded
6from packaging.version import Version
7
8from limits.aio.storage.base import MovingWindowSupport, Storage
9from limits.errors import ConfigurationError
10from limits.typing import AsyncRedisClient, Dict, Optional, Tuple, Type, Union
11from limits.util import get_package_data
12
13if TYPE_CHECKING:
14 import coredis
15 import coredis.commands
16
17
18class RedisInteractor:
19 RES_DIR = "resources/redis/lua_scripts"
20
21 SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
22 SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
23 f"{RES_DIR}/acquire_moving_window.lua"
24 )
25 SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
26 SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
27
28 lua_moving_window: "coredis.commands.Script[bytes]"
29 lua_acquire_window: "coredis.commands.Script[bytes]"
30 lua_clear_keys: "coredis.commands.Script[bytes]"
31 lua_incr_expire: "coredis.commands.Script[bytes]"
32
33 PREFIX = "LIMITS"
34
35 def prefixed_key(self, key: str) -> str:
36 return f"{self.PREFIX}:{key}"
37
38 async def _incr(
39 self,
40 key: str,
41 expiry: int,
42 connection: AsyncRedisClient,
43 elastic_expiry: bool = False,
44 amount: int = 1,
45 ) -> int:
46 """
47 increments the counter for a given rate limit key
48
49 :param connection: Redis connection
50 :param key: the key to increment
51 :param expiry: amount in seconds for the key to expire in
52 :param amount: the number to increment by
53 """
54 key = self.prefixed_key(key)
55 value = await connection.incrby(key, amount)
56
57 if elastic_expiry or value == amount:
58 await connection.expire(key, expiry)
59
60 return value
61
62 async def _get(self, key: str, connection: AsyncRedisClient) -> int:
63 """
64 :param connection: Redis connection
65 :param key: the key to get the counter value for
66 """
67
68 key = self.prefixed_key(key)
69 return int(await connection.get(key) or 0)
70
71 async def _clear(self, key: str, connection: AsyncRedisClient) -> None:
72 """
73 :param key: the key to clear rate limits for
74 :param connection: Redis connection
75 """
76 key = self.prefixed_key(key)
77 await connection.delete([key])
78
79 async def get_moving_window(
80 self, key: str, limit: int, expiry: int
81 ) -> Tuple[int, int]:
82 """
83 returns the starting point and the number of entries in the moving
84 window
85
86 :param key: rate limit key
87 :param expiry: expiry of entry
88 :return: (start of window, number of acquired entries)
89 """
90 key = self.prefixed_key(key)
91 timestamp = int(time.time())
92 window = await self.lua_moving_window.execute(
93 [key], [int(timestamp - expiry), limit]
94 )
95 if window:
96 return tuple(window) # type: ignore
97 return timestamp, 0
98
99 async def _acquire_entry(
100 self,
101 key: str,
102 limit: int,
103 expiry: int,
104 connection: AsyncRedisClient,
105 amount: int = 1,
106 ) -> bool:
107 """
108 :param key: rate limit key to acquire an entry in
109 :param limit: amount of entries allowed
110 :param expiry: expiry of the entry
111 :param connection: Redis connection
112 """
113 key = self.prefixed_key(key)
114 timestamp = time.time()
115 acquired = await self.lua_acquire_window.execute(
116 [key], [timestamp, limit, expiry, amount]
117 )
118
119 return bool(acquired)
120
121 async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> int:
122 """
123 :param key: the key to get the expiry for
124 :param connection: Redis connection
125 """
126
127 key = self.prefixed_key(key)
128 return int(max(await connection.ttl(key), 0) + time.time())
129
130 async def _check(self, connection: AsyncRedisClient) -> bool:
131 """
132 check if storage is healthy
133
134 :param connection: Redis connection
135 """
136 try:
137 await connection.ping()
138
139 return True
140 except: # noqa
141 return False
142
143
144@versionadded(version="2.1")
145class RedisStorage(RedisInteractor, Storage, MovingWindowSupport):
146 """
147 Rate limit storage with redis as backend.
148
149 Depends on :pypi:`coredis`
150 """
151
152 STORAGE_SCHEME = ["async+redis", "async+rediss", "async+redis+unix"]
153 """
154 The storage schemes for redis to be used in an async context
155 """
156 DEPENDENCIES = {"coredis": Version("3.4.0")}
157
158 def __init__(
159 self,
160 uri: str,
161 connection_pool: Optional["coredis.ConnectionPool"] = None,
162 wrap_exceptions: bool = False,
163 **options: Union[float, str, bool],
164 ) -> None:
165 """
166 :param uri: uri of the form:
167
168 - ``async+redis://[:password]@host:port``
169 - ``async+redis://[:password]@host:port/db``
170 - ``async+rediss://[:password]@host:port``
171 - ``async+redis+unix:///path/to/sock?db=0`` etc...
172
173 This uri is passed directly to :meth:`coredis.Redis.from_url` with
174 the initial ``async`` removed, except for the case of ``async+redis+unix``
175 where it is replaced with ``unix``.
176 :param connection_pool: if provided, the redis client is initialized with
177 the connection pool and any other params passed as :paramref:`options`
178 :param wrap_exceptions: Whether to wrap storage exceptions in
179 :exc:`limits.errors.StorageError` before raising it.
180 :param options: all remaining keyword arguments are passed
181 directly to the constructor of :class:`coredis.Redis`
182 :raise ConfigurationError: when the redis library is not available
183 """
184 uri = uri.replace("async+redis", "redis", 1)
185 uri = uri.replace("redis+unix", "unix")
186
187 super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
188
189 self.dependency = self.dependencies["coredis"].module
190
191 if connection_pool:
192 self.storage = self.dependency.Redis(
193 connection_pool=connection_pool, **options
194 )
195 else:
196 self.storage = self.dependency.Redis.from_url(uri, **options)
197
198 self.initialize_storage(uri)
199
200 @property
201 def base_exceptions(
202 self,
203 ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
204 return self.dependency.exceptions.RedisError # type: ignore[no-any-return]
205
206 def initialize_storage(self, _uri: str) -> None:
207 # all these methods are coroutines, so must be called with await
208 self.lua_moving_window = self.storage.register_script(self.SCRIPT_MOVING_WINDOW)
209 self.lua_acquire_window = self.storage.register_script(
210 self.SCRIPT_ACQUIRE_MOVING_WINDOW
211 )
212 self.lua_clear_keys = self.storage.register_script(self.SCRIPT_CLEAR_KEYS)
213 self.lua_incr_expire = self.storage.register_script(
214 RedisStorage.SCRIPT_INCR_EXPIRE
215 )
216
217 async def incr(
218 self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
219 ) -> int:
220 """
221 increments the counter for a given rate limit key
222
223 :param key: the key to increment
224 :param expiry: amount in seconds for the key to expire in
225 :param amount: the number to increment by
226 """
227
228 if elastic_expiry:
229 return await super()._incr(
230 key, expiry, self.storage, elastic_expiry, amount
231 )
232 else:
233 key = self.prefixed_key(key)
234 return cast(
235 int, await self.lua_incr_expire.execute([key], [expiry, amount])
236 )
237
238 async def get(self, key: str) -> int:
239 """
240 :param key: the key to get the counter value for
241 """
242
243 return await super()._get(key, self.storage)
244
245 async def clear(self, key: str) -> None:
246 """
247 :param key: the key to clear rate limits for
248 """
249
250 return await super()._clear(key, self.storage)
251
252 async def acquire_entry(
253 self, key: str, limit: int, expiry: int, amount: int = 1
254 ) -> bool:
255 """
256 :param key: rate limit key to acquire an entry in
257 :param limit: amount of entries allowed
258 :param expiry: expiry of the entry
259 :param amount: the number of entries to acquire
260 """
261
262 return await super()._acquire_entry(key, limit, expiry, self.storage, amount)
263
264 async def get_expiry(self, key: str) -> int:
265 """
266 :param key: the key to get the expiry for
267 """
268
269 return await super()._get_expiry(key, self.storage)
270
271 async def check(self) -> bool:
272 """
273 Check if storage is healthy by calling :meth:`coredis.Redis.ping`
274 """
275
276 return await super()._check(self.storage)
277
278 async def reset(self) -> Optional[int]:
279 """
280 This function calls a Lua Script to delete keys prefixed with
281 ``self.PREFIX`` in blocks of 5000.
282
283 .. warning:: This operation was designed to be fast, but was not tested
284 on a large production based system. Be careful with its usage as it
285 could be slow on very large data sets.
286 """
287
288 prefix = self.prefixed_key("*")
289 return cast(int, await self.lua_clear_keys.execute([prefix]))
290
291
292@versionadded(version="2.1")
293class RedisClusterStorage(RedisStorage):
294 """
295 Rate limit storage with redis cluster as backend
296
297 Depends on :pypi:`coredis`
298 """
299
300 STORAGE_SCHEME = ["async+redis+cluster"]
301 """
302 The storage schemes for redis cluster to be used in an async context
303 """
304
305 DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = {
306 "max_connections": 1000,
307 }
308 "Default options passed to :class:`coredis.RedisCluster`"
309
310 def __init__(
311 self,
312 uri: str,
313 wrap_exceptions: bool = False,
314 **options: Union[float, str, bool],
315 ) -> None:
316 """
317 :param uri: url of the form
318 ``async+redis+cluster://[:password]@host:port,host:port``
319 :param options: all remaining keyword arguments are passed
320 directly to the constructor of :class:`coredis.RedisCluster`
321 :raise ConfigurationError: when the coredis library is not
322 available or if the redis host cannot be pinged.
323 """
324 parsed = urllib.parse.urlparse(uri)
325 parsed_auth: Dict[str, Union[float, str, bool]] = {}
326
327 if parsed.username:
328 parsed_auth["username"] = parsed.username
329 if parsed.password:
330 parsed_auth["password"] = parsed.password
331
332 sep = parsed.netloc.find("@") + 1
333 cluster_hosts = []
334
335 for loc in parsed.netloc[sep:].split(","):
336 host, port = loc.split(":")
337 cluster_hosts.append({"host": host, "port": int(port)})
338
339 super(RedisStorage, self).__init__(
340 uri, wrap_exceptions=wrap_exceptions, **options
341 )
342
343 self.dependency = self.dependencies["coredis"].module
344
345 self.storage: "coredis.RedisCluster[str]" = self.dependency.RedisCluster(
346 startup_nodes=cluster_hosts,
347 **{**self.DEFAULT_OPTIONS, **parsed_auth, **options},
348 )
349 self.initialize_storage(uri)
350
351 async def reset(self) -> Optional[int]:
352 """
353 Redis Clusters are sharded and deleting across shards
354 can't be done atomically. Because of this, this reset loops over all
355 keys that are prefixed with ``self.PREFIX`` and calls delete on them,
356 one at a time.
357
358 .. warning:: This operation was not tested with extremely large data sets.
359 On a large production based system, care should be taken with its
360 usage as it could be slow on very large data sets
361 """
362
363 prefix = self.prefixed_key("*")
364 keys = await self.storage.keys(prefix)
365 count = 0
366 for key in keys:
367 count += await self.storage.delete([key])
368 return count
369
370
371@versionadded(version="2.1")
372class RedisSentinelStorage(RedisStorage):
373 """
374 Rate limit storage with redis sentinel as backend
375
376 Depends on :pypi:`coredis`
377 """
378
379 STORAGE_SCHEME = ["async+redis+sentinel"]
380 """The storage scheme for redis accessed via a redis sentinel installation"""
381
382 DEPENDENCIES = {"coredis.sentinel": Version("3.4.0")}
383
384 def __init__(
385 self,
386 uri: str,
387 service_name: Optional[str] = None,
388 use_replicas: bool = True,
389 sentinel_kwargs: Optional[Dict[str, Union[float, str, bool]]] = None,
390 **options: Union[float, str, bool],
391 ):
392 """
393 :param uri: url of the form
394 ``async+redis+sentinel://host:port,host:port/service_name``
395 :param service_name, optional: sentinel service name
396 (if not provided in `uri`)
397 :param use_replicas: Whether to use replicas for read only operations
398 :param sentinel_kwargs, optional: kwargs to pass as
399 ``sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel`
400 :param options: all remaining keyword arguments are passed
401 directly to the constructor of :class:`coredis.sentinel.Sentinel`
402 :raise ConfigurationError: when the coredis library is not available
403 or if the redis primary host cannot be pinged.
404 """
405
406 parsed = urllib.parse.urlparse(uri)
407 sentinel_configuration = []
408 connection_options = options.copy()
409 sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
410 parsed_auth: Dict[str, Union[float, str, bool]] = {}
411
412 if parsed.username:
413 parsed_auth["username"] = parsed.username
414
415 if parsed.password:
416 parsed_auth["password"] = parsed.password
417
418 sep = parsed.netloc.find("@") + 1
419
420 for loc in parsed.netloc[sep:].split(","):
421 host, port = loc.split(":")
422 sentinel_configuration.append((host, int(port)))
423 self.service_name = (
424 parsed.path.replace("/", "") if parsed.path else service_name
425 )
426
427 if self.service_name is None:
428 raise ConfigurationError("'service_name' not provided")
429
430 super(RedisStorage, self).__init__()
431
432 self.dependency = self.dependencies["coredis.sentinel"].module
433
434 self.sentinel = self.dependency.Sentinel(
435 sentinel_configuration,
436 sentinel_kwargs={**parsed_auth, **sentinel_options},
437 **{**parsed_auth, **connection_options},
438 )
439 self.storage = self.sentinel.primary_for(self.service_name)
440 self.storage_replica = self.sentinel.replica_for(self.service_name)
441 self.use_replicas = use_replicas
442 self.initialize_storage(uri)
443
444 async def get(self, key: str) -> int:
445 """
446 :param key: the key to get the counter value for
447 """
448
449 return await super()._get(
450 key, self.storage_replica if self.use_replicas else self.storage
451 )
452
453 async def get_expiry(self, key: str) -> int:
454 """
455 :param key: the key to get the expiry for
456 """
457
458 return await super()._get_expiry(
459 key, self.storage_replica if self.use_replicas else self.storage
460 )
461
462 async def check(self) -> bool:
463 """
464 Check if storage is healthy by calling :meth:`coredis.Redis.ping`
465 on the replica.
466 """
467
468 return await super()._check(
469 self.storage_replica if self.use_replicas else self.storage
470 )