Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pyrate_limiter/bucket.py: 38%
127 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:51 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:51 +0000
1""" Implement this class to create
2a workable bucket for Limiter to use
3"""
4from abc import ABC
5from abc import abstractmethod
6from queue import Queue
7from threading import RLock
8from typing import List
9from typing import Tuple
11from .exceptions import InvalidParams
14class AbstractBucket(ABC):
15 """Base bucket interface"""
17 def __init__(self, maxsize: int = 0, **_kwargs):
18 self._maxsize = maxsize
20 def maxsize(self) -> int:
21 """Return the maximum size of the bucket,
22 ie the maximum number of item this bucket can hold
23 """
24 return self._maxsize
26 @abstractmethod
27 def size(self) -> int:
28 """Return the current size of the bucket,
29 ie the count of all items currently in the bucket
30 """
32 @abstractmethod
33 def put(self, item: float) -> int:
34 """Put an item (typically the current time) in the bucket
35 Return 1 if successful, else 0
36 """
38 @abstractmethod
39 def get(self, number: int) -> int:
40 """Get items, remove them from the bucket in the FIFO order, and return the number of items
41 that have been removed
42 """
44 @abstractmethod
45 def all_items(self) -> List[float]:
46 """Return a list as copies of all items in the bucket"""
48 @abstractmethod
49 def flush(self) -> None:
50 """Flush/reset bucket"""
52 def inspect_expired_items(self, time: float) -> Tuple[int, float]:
53 """Find how many items in bucket that have slipped out of the time-window
55 Returns:
56 The number of unexpired items, and the time until the next item will expire
57 """
58 volume = self.size()
59 item_count, remaining_time = 0, 0.0
61 for log_idx, log_item in enumerate(self.all_items()):
62 if log_item > time:
63 item_count = volume - log_idx
64 remaining_time = round(log_item - time, 3)
65 break
67 return item_count, remaining_time
69 def lock_acquire(self):
70 """Acquire a lock prior to beginning a new transaction, if needed"""
72 def lock_release(self):
73 """Release lock following a transaction, if needed"""
76class MemoryQueueBucket(AbstractBucket):
77 """A bucket that resides in memory using python's built-in Queue class"""
79 def __init__(self, maxsize: int = 0, **_kwargs):
80 super().__init__()
81 self._q: Queue = Queue(maxsize=maxsize)
83 def size(self) -> int:
84 return self._q.qsize()
86 def put(self, item: float):
87 return self._q.put(item)
89 def get(self, number: int) -> int:
90 counter = 0
91 for _ in range(number):
92 self._q.get()
93 counter += 1
95 return counter
97 def all_items(self) -> List[float]:
98 return list(self._q.queue)
100 def flush(self):
101 while not self._q.empty():
102 self._q.get()
105class MemoryListBucket(AbstractBucket):
106 """A bucket that resides in memory using python's List"""
108 def __init__(self, maxsize: int = 0, **_kwargs):
109 super().__init__(maxsize=maxsize)
110 self._q: List[float] = []
111 self._lock = RLock()
113 def size(self) -> int:
114 return len(self._q)
116 def put(self, item: float):
117 with self._lock:
118 if self.size() < self.maxsize():
119 self._q.append(item)
120 return 1
121 return 0
123 def get(self, number: int) -> int:
124 with self._lock:
125 counter = 0
126 for _ in range(number):
127 self._q.pop(0)
128 counter += 1
130 return counter
132 def all_items(self) -> List[float]:
133 return self._q.copy()
135 def flush(self):
136 self._q = list()
139class RedisBucket(AbstractBucket):
140 """A bucket backed by a Redis instance"""
142 def __init__(
143 self,
144 maxsize=0,
145 redis_pool=None,
146 bucket_name: str = None,
147 identity: str = None,
148 expire_time: int = None,
149 **_kwargs,
150 ):
151 super().__init__(maxsize=maxsize)
153 if not redis_pool:
154 raise InvalidParams("Missing Redis connection pool")
156 if not isinstance(bucket_name, str):
157 msg = "keyword argument `bucket-name` is missing: a distict name is required"
158 raise InvalidParams(msg)
160 self._pool = redis_pool
161 self._bucket_name = f"{bucket_name}___{identity}"
162 self._expire_time = expire_time
164 def get_connection(self):
165 """Obtain a connection from redis pool"""
166 from redis import Redis # type: ignore
168 return Redis(connection_pool=self._pool)
170 def get_pipeline(self):
171 """Using redis pipeline for batch operation"""
172 conn = self.get_connection()
173 pipeline = conn.pipeline()
174 return pipeline
176 def size(self) -> int:
177 conn = self.get_connection()
178 return conn.llen(self._bucket_name)
180 def put(self, item: float):
181 conn = self.get_connection()
182 current_size = conn.llen(self._bucket_name)
184 if current_size < self.maxsize():
185 pipeline = self.get_pipeline()
186 pipeline.rpush(self._bucket_name, item)
188 if self._expire_time is not None:
189 pipeline.expire(self._bucket_name, self._expire_time)
191 pipeline.execute()
192 return 1
194 return 0
196 def get(self, number: int) -> int:
197 pipeline = self.get_pipeline()
198 counter = 0
200 for _ in range(number):
201 pipeline.lpop(self._bucket_name)
202 counter += 1
204 pipeline.execute()
205 return counter
207 def all_items(self) -> List[float]:
208 conn = self.get_connection()
209 items = conn.lrange(self._bucket_name, 0, -1)
210 return sorted([float(i.decode("utf-8")) for i in items])
212 def flush(self):
213 conn = self.get_connection()
214 conn.delete(self._bucket_name)
217class RedisClusterBucket(RedisBucket):
218 """A bucket backed by a Redis cluster"""
220 def get_connection(self):
221 """Obtain a connection from redis pool"""
222 from rediscluster import RedisCluster # pylint: disable=import-outside-toplevel
224 return RedisCluster(connection_pool=self._pool)