Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pyrate_limiter/bucket.py: 38%

127 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:51 +0000

1""" Implement this class to create 

2a workable bucket for Limiter to use 

3""" 

4from abc import ABC 

5from abc import abstractmethod 

6from queue import Queue 

7from threading import RLock 

8from typing import List 

9from typing import Tuple 

10 

11from .exceptions import InvalidParams 

12 

13 

14class AbstractBucket(ABC): 

15 """Base bucket interface""" 

16 

17 def __init__(self, maxsize: int = 0, **_kwargs): 

18 self._maxsize = maxsize 

19 

20 def maxsize(self) -> int: 

21 """Return the maximum size of the bucket, 

22 ie the maximum number of item this bucket can hold 

23 """ 

24 return self._maxsize 

25 

26 @abstractmethod 

27 def size(self) -> int: 

28 """Return the current size of the bucket, 

29 ie the count of all items currently in the bucket 

30 """ 

31 

32 @abstractmethod 

33 def put(self, item: float) -> int: 

34 """Put an item (typically the current time) in the bucket 

35 Return 1 if successful, else 0 

36 """ 

37 

38 @abstractmethod 

39 def get(self, number: int) -> int: 

40 """Get items, remove them from the bucket in the FIFO order, and return the number of items 

41 that have been removed 

42 """ 

43 

44 @abstractmethod 

45 def all_items(self) -> List[float]: 

46 """Return a list as copies of all items in the bucket""" 

47 

48 @abstractmethod 

49 def flush(self) -> None: 

50 """Flush/reset bucket""" 

51 

52 def inspect_expired_items(self, time: float) -> Tuple[int, float]: 

53 """Find how many items in bucket that have slipped out of the time-window 

54 

55 Returns: 

56 The number of unexpired items, and the time until the next item will expire 

57 """ 

58 volume = self.size() 

59 item_count, remaining_time = 0, 0.0 

60 

61 for log_idx, log_item in enumerate(self.all_items()): 

62 if log_item > time: 

63 item_count = volume - log_idx 

64 remaining_time = round(log_item - time, 3) 

65 break 

66 

67 return item_count, remaining_time 

68 

69 def lock_acquire(self): 

70 """Acquire a lock prior to beginning a new transaction, if needed""" 

71 

72 def lock_release(self): 

73 """Release lock following a transaction, if needed""" 

74 

75 

76class MemoryQueueBucket(AbstractBucket): 

77 """A bucket that resides in memory using python's built-in Queue class""" 

78 

79 def __init__(self, maxsize: int = 0, **_kwargs): 

80 super().__init__() 

81 self._q: Queue = Queue(maxsize=maxsize) 

82 

83 def size(self) -> int: 

84 return self._q.qsize() 

85 

86 def put(self, item: float): 

87 return self._q.put(item) 

88 

89 def get(self, number: int) -> int: 

90 counter = 0 

91 for _ in range(number): 

92 self._q.get() 

93 counter += 1 

94 

95 return counter 

96 

97 def all_items(self) -> List[float]: 

98 return list(self._q.queue) 

99 

100 def flush(self): 

101 while not self._q.empty(): 

102 self._q.get() 

103 

104 

105class MemoryListBucket(AbstractBucket): 

106 """A bucket that resides in memory using python's List""" 

107 

108 def __init__(self, maxsize: int = 0, **_kwargs): 

109 super().__init__(maxsize=maxsize) 

110 self._q: List[float] = [] 

111 self._lock = RLock() 

112 

113 def size(self) -> int: 

114 return len(self._q) 

115 

116 def put(self, item: float): 

117 with self._lock: 

118 if self.size() < self.maxsize(): 

119 self._q.append(item) 

120 return 1 

121 return 0 

122 

123 def get(self, number: int) -> int: 

124 with self._lock: 

125 counter = 0 

126 for _ in range(number): 

127 self._q.pop(0) 

128 counter += 1 

129 

130 return counter 

131 

132 def all_items(self) -> List[float]: 

133 return self._q.copy() 

134 

135 def flush(self): 

136 self._q = list() 

137 

138 

139class RedisBucket(AbstractBucket): 

140 """A bucket backed by a Redis instance""" 

141 

142 def __init__( 

143 self, 

144 maxsize=0, 

145 redis_pool=None, 

146 bucket_name: str = None, 

147 identity: str = None, 

148 expire_time: int = None, 

149 **_kwargs, 

150 ): 

151 super().__init__(maxsize=maxsize) 

152 

153 if not redis_pool: 

154 raise InvalidParams("Missing Redis connection pool") 

155 

156 if not isinstance(bucket_name, str): 

157 msg = "keyword argument `bucket-name` is missing: a distict name is required" 

158 raise InvalidParams(msg) 

159 

160 self._pool = redis_pool 

161 self._bucket_name = f"{bucket_name}___{identity}" 

162 self._expire_time = expire_time 

163 

164 def get_connection(self): 

165 """Obtain a connection from redis pool""" 

166 from redis import Redis # type: ignore 

167 

168 return Redis(connection_pool=self._pool) 

169 

170 def get_pipeline(self): 

171 """Using redis pipeline for batch operation""" 

172 conn = self.get_connection() 

173 pipeline = conn.pipeline() 

174 return pipeline 

175 

176 def size(self) -> int: 

177 conn = self.get_connection() 

178 return conn.llen(self._bucket_name) 

179 

180 def put(self, item: float): 

181 conn = self.get_connection() 

182 current_size = conn.llen(self._bucket_name) 

183 

184 if current_size < self.maxsize(): 

185 pipeline = self.get_pipeline() 

186 pipeline.rpush(self._bucket_name, item) 

187 

188 if self._expire_time is not None: 

189 pipeline.expire(self._bucket_name, self._expire_time) 

190 

191 pipeline.execute() 

192 return 1 

193 

194 return 0 

195 

196 def get(self, number: int) -> int: 

197 pipeline = self.get_pipeline() 

198 counter = 0 

199 

200 for _ in range(number): 

201 pipeline.lpop(self._bucket_name) 

202 counter += 1 

203 

204 pipeline.execute() 

205 return counter 

206 

207 def all_items(self) -> List[float]: 

208 conn = self.get_connection() 

209 items = conn.lrange(self._bucket_name, 0, -1) 

210 return sorted([float(i.decode("utf-8")) for i in items]) 

211 

212 def flush(self): 

213 conn = self.get_connection() 

214 conn.delete(self._bucket_name) 

215 

216 

217class RedisClusterBucket(RedisBucket): 

218 """A bucket backed by a Redis cluster""" 

219 

220 def get_connection(self): 

221 """Obtain a connection from redis pool""" 

222 from rediscluster import RedisCluster # pylint: disable=import-outside-toplevel 

223 

224 return RedisCluster(connection_pool=self._pool)