Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/limits/storage/redis.py: 3%

73 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1from __future__ import annotations 

2 

3import time 

4from typing import TYPE_CHECKING 

5 

6from packaging.version import Version 

7 

8from limits.typing import Optional, RedisClient, ScriptP, Tuple, Union 

9 

10from ..util import get_package_data 

11from .base import MovingWindowSupport, Storage 

12 

13if TYPE_CHECKING: 

14 import redis 

15 

16 

17class RedisInteractor: 

18 RES_DIR = "resources/redis/lua_scripts" 

19 

20 SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua") 

21 SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data( 

22 f"{RES_DIR}/acquire_moving_window.lua" 

23 ) 

24 SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua") 

25 SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua") 

26 

27 lua_moving_window: ScriptP[Tuple[int, int]] 

28 lua_acquire_window: ScriptP[bool] 

29 

30 def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]: 

31 """ 

32 returns the starting point and the number of entries in the moving 

33 window 

34 

35 :param key: rate limit key 

36 :param expiry: expiry of entry 

37 :return: (start of window, number of acquired entries) 

38 """ 

39 timestamp = time.time() 

40 window = self.lua_moving_window([key], [int(timestamp - expiry), limit]) 

41 

42 return window or (int(timestamp), 0) 

43 

44 def _incr( 

45 self, 

46 key: str, 

47 expiry: int, 

48 connection: RedisClient, 

49 elastic_expiry: bool = False, 

50 amount: int = 1, 

51 ) -> int: 

52 """ 

53 increments the counter for a given rate limit key 

54 

55 :param connection: Redis connection 

56 :param key: the key to increment 

57 :param expiry: amount in seconds for the key to expire in 

58 :param amount: the number to increment by 

59 """ 

60 value = connection.incrby(key, amount) 

61 

62 if elastic_expiry or value == amount: 

63 connection.expire(key, expiry) 

64 

65 return value 

66 

67 def _get(self, key: str, connection: RedisClient) -> int: 

68 """ 

69 :param connection: Redis connection 

70 :param key: the key to get the counter value for 

71 """ 

72 

73 return int(connection.get(key) or 0) 

74 

75 def _clear(self, key: str, connection: RedisClient) -> None: 

76 """ 

77 :param key: the key to clear rate limits for 

78 :param connection: Redis connection 

79 """ 

80 connection.delete(key) 

81 

82 def _acquire_entry( 

83 self, 

84 key: str, 

85 limit: int, 

86 expiry: int, 

87 connection: RedisClient, 

88 amount: int = 1, 

89 ) -> bool: 

90 """ 

91 :param key: rate limit key to acquire an entry in 

92 :param limit: amount of entries allowed 

93 :param expiry: expiry of the entry 

94 :param connection: Redis connection 

95 :param amount: the number of entries to acquire 

96 """ 

97 timestamp = time.time() 

98 acquired = self.lua_acquire_window([key], [timestamp, limit, expiry, amount]) 

99 

100 return bool(acquired) 

101 

102 def _get_expiry(self, key: str, connection: RedisClient) -> int: 

103 """ 

104 :param key: the key to get the expiry for 

105 :param connection: Redis connection 

106 """ 

107 

108 return int(max(connection.ttl(key), 0) + time.time()) 

109 

110 def _check(self, connection: RedisClient) -> bool: 

111 """ 

112 :param connection: Redis connection 

113 check if storage is healthy 

114 """ 

115 try: 

116 return connection.ping() 

117 except: # noqa 

118 return False 

119 

120 

121class RedisStorage(RedisInteractor, Storage, MovingWindowSupport): 

122 """ 

123 Rate limit storage with redis as backend. 

124 

125 Depends on :pypi:`redis`. 

126 """ 

127 

128 STORAGE_SCHEME = ["redis", "rediss", "redis+unix"] 

129 """The storage scheme for redis""" 

130 

131 DEPENDENCIES = {"redis": Version("3.0")} 

132 

133 def __init__( 

134 self, 

135 uri: str, 

136 connection_pool: Optional[redis.connection.ConnectionPool] = None, 

137 **options: Union[float, str, bool], 

138 ) -> None: 

139 """ 

140 :param uri: uri of the form ``redis://[:password]@host:port``, 

141 ``redis://[:password]@host:port/db``, 

142 ``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc. 

143 This uri is passed directly to :func:`redis.from_url` except for the 

144 case of ``redis+unix://`` where it is replaced with ``unix://``. 

145 :param connection_pool: if provided, the redis client is initialized with 

146 the connection pool and any other params passed as :paramref:`options` 

147 :param options: all remaining keyword arguments are passed 

148 directly to the constructor of :class:`redis.Redis` 

149 :raise ConfigurationError: when the :pypi:`redis` library is not available 

150 """ 

151 super().__init__(uri, **options) 

152 redis = self.dependencies["redis"].module 

153 

154 uri = uri.replace("redis+unix", "unix") 

155 

156 if not connection_pool: 

157 self.storage = redis.from_url(uri, **options) 

158 else: 

159 self.storage = redis.Redis(connection_pool=connection_pool, **options) 

160 self.initialize_storage(uri) 

161 

162 def initialize_storage(self, _uri: str) -> None: 

163 self.lua_moving_window = self.storage.register_script(self.SCRIPT_MOVING_WINDOW) 

164 self.lua_acquire_window = self.storage.register_script( 

165 self.SCRIPT_ACQUIRE_MOVING_WINDOW 

166 ) 

167 self.lua_clear_keys = self.storage.register_script(self.SCRIPT_CLEAR_KEYS) 

168 self.lua_incr_expire = self.storage.register_script( 

169 RedisStorage.SCRIPT_INCR_EXPIRE 

170 ) 

171 

172 def incr( 

173 self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1 

174 ) -> int: 

175 """ 

176 increments the counter for a given rate limit key 

177 

178 :param key: the key to increment 

179 :param expiry: amount in seconds for the key to expire in 

180 :param amount: the number to increment by 

181 """ 

182 

183 if elastic_expiry: 

184 return super()._incr(key, expiry, self.storage, elastic_expiry, amount) 

185 else: 

186 return int(self.lua_incr_expire([key], [expiry, amount])) 

187 

188 def get(self, key: str) -> int: 

189 """ 

190 :param key: the key to get the counter value for 

191 """ 

192 

193 return super()._get(key, self.storage) 

194 

195 def clear(self, key: str) -> None: 

196 """ 

197 :param key: the key to clear rate limits for 

198 """ 

199 

200 return super()._clear(key, self.storage) 

201 

202 def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool: 

203 """ 

204 :param key: rate limit key to acquire an entry in 

205 :param limit: amount of entries allowed 

206 :param expiry: expiry of the entry 

207 :param amount: the number to increment by 

208 """ 

209 

210 return super()._acquire_entry(key, limit, expiry, self.storage, amount) 

211 

212 def get_expiry(self, key: str) -> int: 

213 """ 

214 :param key: the key to get the expiry for 

215 """ 

216 

217 return super()._get_expiry(key, self.storage) 

218 

219 def check(self) -> bool: 

220 """ 

221 check if storage is healthy 

222 """ 

223 

224 return super()._check(self.storage) 

225 

226 def reset(self) -> Optional[int]: 

227 """ 

228 This function calls a Lua Script to delete keys prefixed with 'LIMITER' 

229 in block of 5000. 

230 

231 .. warning:: 

232 This operation was designed to be fast, but was not tested 

233 on a large production based system. Be careful with its usage as it 

234 could be slow on very large data sets. 

235 

236 """ 

237 

238 return int(self.lua_clear_keys(["LIMITER*"]))