Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/limits/storage/redis.py: 3%
73 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1from __future__ import annotations
3import time
4from typing import TYPE_CHECKING
6from packaging.version import Version
8from limits.typing import Optional, RedisClient, ScriptP, Tuple, Union
10from ..util import get_package_data
11from .base import MovingWindowSupport, Storage
13if TYPE_CHECKING:
14 import redis
17class RedisInteractor:
18 RES_DIR = "resources/redis/lua_scripts"
20 SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
21 SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
22 f"{RES_DIR}/acquire_moving_window.lua"
23 )
24 SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
25 SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
27 lua_moving_window: ScriptP[Tuple[int, int]]
28 lua_acquire_window: ScriptP[bool]
30 def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]:
31 """
32 returns the starting point and the number of entries in the moving
33 window
35 :param key: rate limit key
36 :param expiry: expiry of entry
37 :return: (start of window, number of acquired entries)
38 """
39 timestamp = time.time()
40 window = self.lua_moving_window([key], [int(timestamp - expiry), limit])
42 return window or (int(timestamp), 0)
44 def _incr(
45 self,
46 key: str,
47 expiry: int,
48 connection: RedisClient,
49 elastic_expiry: bool = False,
50 amount: int = 1,
51 ) -> int:
52 """
53 increments the counter for a given rate limit key
55 :param connection: Redis connection
56 :param key: the key to increment
57 :param expiry: amount in seconds for the key to expire in
58 :param amount: the number to increment by
59 """
60 value = connection.incrby(key, amount)
62 if elastic_expiry or value == amount:
63 connection.expire(key, expiry)
65 return value
67 def _get(self, key: str, connection: RedisClient) -> int:
68 """
69 :param connection: Redis connection
70 :param key: the key to get the counter value for
71 """
73 return int(connection.get(key) or 0)
75 def _clear(self, key: str, connection: RedisClient) -> None:
76 """
77 :param key: the key to clear rate limits for
78 :param connection: Redis connection
79 """
80 connection.delete(key)
82 def _acquire_entry(
83 self,
84 key: str,
85 limit: int,
86 expiry: int,
87 connection: RedisClient,
88 amount: int = 1,
89 ) -> bool:
90 """
91 :param key: rate limit key to acquire an entry in
92 :param limit: amount of entries allowed
93 :param expiry: expiry of the entry
94 :param connection: Redis connection
95 :param amount: the number of entries to acquire
96 """
97 timestamp = time.time()
98 acquired = self.lua_acquire_window([key], [timestamp, limit, expiry, amount])
100 return bool(acquired)
102 def _get_expiry(self, key: str, connection: RedisClient) -> int:
103 """
104 :param key: the key to get the expiry for
105 :param connection: Redis connection
106 """
108 return int(max(connection.ttl(key), 0) + time.time())
110 def _check(self, connection: RedisClient) -> bool:
111 """
112 :param connection: Redis connection
113 check if storage is healthy
114 """
115 try:
116 return connection.ping()
117 except: # noqa
118 return False
121class RedisStorage(RedisInteractor, Storage, MovingWindowSupport):
122 """
123 Rate limit storage with redis as backend.
125 Depends on :pypi:`redis`.
126 """
128 STORAGE_SCHEME = ["redis", "rediss", "redis+unix"]
129 """The storage scheme for redis"""
131 DEPENDENCIES = {"redis": Version("3.0")}
133 def __init__(
134 self,
135 uri: str,
136 connection_pool: Optional[redis.connection.ConnectionPool] = None,
137 **options: Union[float, str, bool],
138 ) -> None:
139 """
140 :param uri: uri of the form ``redis://[:password]@host:port``,
141 ``redis://[:password]@host:port/db``,
142 ``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
143 This uri is passed directly to :func:`redis.from_url` except for the
144 case of ``redis+unix://`` where it is replaced with ``unix://``.
145 :param connection_pool: if provided, the redis client is initialized with
146 the connection pool and any other params passed as :paramref:`options`
147 :param options: all remaining keyword arguments are passed
148 directly to the constructor of :class:`redis.Redis`
149 :raise ConfigurationError: when the :pypi:`redis` library is not available
150 """
151 super().__init__(uri, **options)
152 redis = self.dependencies["redis"].module
154 uri = uri.replace("redis+unix", "unix")
156 if not connection_pool:
157 self.storage = redis.from_url(uri, **options)
158 else:
159 self.storage = redis.Redis(connection_pool=connection_pool, **options)
160 self.initialize_storage(uri)
162 def initialize_storage(self, _uri: str) -> None:
163 self.lua_moving_window = self.storage.register_script(self.SCRIPT_MOVING_WINDOW)
164 self.lua_acquire_window = self.storage.register_script(
165 self.SCRIPT_ACQUIRE_MOVING_WINDOW
166 )
167 self.lua_clear_keys = self.storage.register_script(self.SCRIPT_CLEAR_KEYS)
168 self.lua_incr_expire = self.storage.register_script(
169 RedisStorage.SCRIPT_INCR_EXPIRE
170 )
172 def incr(
173 self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
174 ) -> int:
175 """
176 increments the counter for a given rate limit key
178 :param key: the key to increment
179 :param expiry: amount in seconds for the key to expire in
180 :param amount: the number to increment by
181 """
183 if elastic_expiry:
184 return super()._incr(key, expiry, self.storage, elastic_expiry, amount)
185 else:
186 return int(self.lua_incr_expire([key], [expiry, amount]))
188 def get(self, key: str) -> int:
189 """
190 :param key: the key to get the counter value for
191 """
193 return super()._get(key, self.storage)
195 def clear(self, key: str) -> None:
196 """
197 :param key: the key to clear rate limits for
198 """
200 return super()._clear(key, self.storage)
202 def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
203 """
204 :param key: rate limit key to acquire an entry in
205 :param limit: amount of entries allowed
206 :param expiry: expiry of the entry
207 :param amount: the number to increment by
208 """
210 return super()._acquire_entry(key, limit, expiry, self.storage, amount)
212 def get_expiry(self, key: str) -> int:
213 """
214 :param key: the key to get the expiry for
215 """
217 return super()._get_expiry(key, self.storage)
219 def check(self) -> bool:
220 """
221 check if storage is healthy
222 """
224 return super()._check(self.storage)
226 def reset(self) -> Optional[int]:
227 """
228 This function calls a Lua Script to delete keys prefixed with 'LIMITER'
229 in block of 5000.
231 .. warning::
232 This operation was designed to be fast, but was not tested
233 on a large production based system. Be careful with its usage as it
234 could be slow on very large data sets.
236 """
238 return int(self.lua_clear_keys(["LIMITER*"]))