1""" Implement this class to create
2a workable bucket for Limiter to use
3"""
4from abc import ABC
5from abc import abstractmethod
6from queue import Queue
7from threading import RLock
8from typing import List
9from typing import Tuple
10
11from .exceptions import InvalidParams
12
13
14class AbstractBucket(ABC):
15 """Base bucket interface"""
16
17 def __init__(self, maxsize: int = 0, **_kwargs):
18 self._maxsize = maxsize
19
20 def maxsize(self) -> int:
21 """Return the maximum size of the bucket,
22 ie the maximum number of item this bucket can hold
23 """
24 return self._maxsize
25
26 @abstractmethod
27 def size(self) -> int:
28 """Return the current size of the bucket,
29 ie the count of all items currently in the bucket
30 """
31
32 @abstractmethod
33 def put(self, item: float) -> int:
34 """Put an item (typically the current time) in the bucket
35 Return 1 if successful, else 0
36 """
37
38 @abstractmethod
39 def get(self, number: int) -> int:
40 """Get items, remove them from the bucket in the FIFO order, and return the number of items
41 that have been removed
42 """
43
44 @abstractmethod
45 def all_items(self) -> List[float]:
46 """Return a list as copies of all items in the bucket"""
47
48 @abstractmethod
49 def flush(self) -> None:
50 """Flush/reset bucket"""
51
52 def inspect_expired_items(self, time: float) -> Tuple[int, float]:
53 """Find how many items in bucket that have slipped out of the time-window
54
55 Returns:
56 The number of unexpired items, and the time until the next item will expire
57 """
58 volume = self.size()
59 item_count, remaining_time = 0, 0.0
60
61 for log_idx, log_item in enumerate(self.all_items()):
62 if log_item > time:
63 item_count = volume - log_idx
64 remaining_time = round(log_item - time, 3)
65 break
66
67 return item_count, remaining_time
68
69 def lock_acquire(self):
70 """Acquire a lock prior to beginning a new transaction, if needed"""
71
72 def lock_release(self):
73 """Release lock following a transaction, if needed"""
74
75
76class MemoryQueueBucket(AbstractBucket):
77 """A bucket that resides in memory using python's built-in Queue class"""
78
79 def __init__(self, maxsize: int = 0, **_kwargs):
80 super().__init__()
81 self._q: Queue = Queue(maxsize=maxsize)
82
83 def size(self) -> int:
84 return self._q.qsize()
85
86 def put(self, item: float):
87 return self._q.put(item)
88
89 def get(self, number: int) -> int:
90 counter = 0
91 for _ in range(number):
92 self._q.get()
93 counter += 1
94
95 return counter
96
97 def all_items(self) -> List[float]:
98 return list(self._q.queue)
99
100 def flush(self):
101 while not self._q.empty():
102 self._q.get()
103
104
105class MemoryListBucket(AbstractBucket):
106 """A bucket that resides in memory using python's List"""
107
108 def __init__(self, maxsize: int = 0, **_kwargs):
109 super().__init__(maxsize=maxsize)
110 self._q: List[float] = []
111 self._lock = RLock()
112
113 def size(self) -> int:
114 return len(self._q)
115
116 def put(self, item: float):
117 with self._lock:
118 if self.size() < self.maxsize():
119 self._q.append(item)
120 return 1
121 return 0
122
123 def get(self, number: int) -> int:
124 with self._lock:
125 counter = 0
126 for _ in range(number):
127 self._q.pop(0)
128 counter += 1
129
130 return counter
131
132 def all_items(self) -> List[float]:
133 return self._q.copy()
134
135 def flush(self):
136 self._q = list()
137
138
139class RedisBucket(AbstractBucket):
140 """A bucket backed by a Redis instance"""
141
142 def __init__(
143 self,
144 maxsize=0,
145 redis_pool=None,
146 bucket_name: str = None,
147 identity: str = None,
148 expire_time: int = None,
149 **_kwargs,
150 ):
151 super().__init__(maxsize=maxsize)
152
153 if not redis_pool:
154 raise InvalidParams("Missing Redis connection pool")
155
156 if not isinstance(bucket_name, str):
157 msg = "keyword argument `bucket-name` is missing: a distict name is required"
158 raise InvalidParams(msg)
159
160 self._pool = redis_pool
161 self._bucket_name = f"{bucket_name}___{identity}"
162 self._expire_time = expire_time
163
164 def get_connection(self):
165 """Obtain a connection from redis pool"""
166 from redis import Redis # type: ignore
167
168 return Redis(connection_pool=self._pool)
169
170 def get_pipeline(self):
171 """Using redis pipeline for batch operation"""
172 conn = self.get_connection()
173 pipeline = conn.pipeline()
174 return pipeline
175
176 def size(self) -> int:
177 conn = self.get_connection()
178 return conn.llen(self._bucket_name)
179
180 def put(self, item: float):
181 conn = self.get_connection()
182 current_size = conn.llen(self._bucket_name)
183
184 if current_size < self.maxsize():
185 pipeline = self.get_pipeline()
186 pipeline.rpush(self._bucket_name, item)
187
188 if self._expire_time is not None:
189 pipeline.expire(self._bucket_name, self._expire_time)
190
191 pipeline.execute()
192 return 1
193
194 return 0
195
196 def get(self, number: int) -> int:
197 pipeline = self.get_pipeline()
198 counter = 0
199
200 for _ in range(number):
201 pipeline.lpop(self._bucket_name)
202 counter += 1
203
204 pipeline.execute()
205 return counter
206
207 def all_items(self) -> List[float]:
208 conn = self.get_connection()
209 items = conn.lrange(self._bucket_name, 0, -1)
210 return sorted([float(i.decode("utf-8")) for i in items])
211
212 def flush(self):
213 conn = self.get_connection()
214 conn.delete(self._bucket_name)
215
216
217class RedisClusterBucket(RedisBucket):
218 """A bucket backed by a Redis cluster"""
219
220 def get_connection(self):
221 """Obtain a connection from redis pool"""
222 from rediscluster import RedisCluster # pylint: disable=import-outside-toplevel
223
224 return RedisCluster(connection_pool=self._pool)