Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pyrate_limiter/limiter.py: 30%

74 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:51 +0000

1from time import monotonic 

2from typing import Any 

3from typing import Callable 

4from typing import Dict 

5from typing import Type 

6from typing import Union 

7 

8from .bucket import AbstractBucket 

9from .bucket import MemoryQueueBucket 

10from .exceptions import BucketFullException 

11from .exceptions import InvalidParams 

12from .limit_context_decorator import LimitContextDecorator 

13from .request_rate import RequestRate 

14 

15 

16class Limiter: 

17 """Main rate-limiter class 

18 

19 Args: 

20 rates: Request rate definitions 

21 bucket_class: Bucket backend to use; may be any subclass of :py:class:`.AbstractBucket`. 

22 See :py:mod`pyrate_limiter.bucket` for available bucket classes. 

23 bucket_kwargs: Extra keyword arguments to pass to the bucket class constructor. 

24 time_function: Time function that returns the current time as a float, in seconds 

25 """ 

26 

27 def __init__( 

28 self, 

29 *rates: RequestRate, 

30 bucket_class: Type[AbstractBucket] = MemoryQueueBucket, 

31 bucket_kwargs: Dict[str, Any] = None, 

32 time_function: Callable[[], float] = None, 

33 ): 

34 self._validate_rate_list(rates) 

35 

36 self._rates = rates 

37 self._bkclass = bucket_class 

38 self._bucket_args = bucket_kwargs or {} 

39 self._validate_bucket() 

40 

41 self.bucket_group: Dict[str, AbstractBucket] = {} 

42 self.time_function = monotonic 

43 if time_function is not None: 

44 self.time_function = time_function 

45 # Call for time_function to make an anchor if required. 

46 self.time_function() 

47 

48 def _validate_rate_list(self, rates): # pylint: disable=no-self-use 

49 """Raise exception if rates are incorrectly ordered.""" 

50 if not rates: 

51 raise InvalidParams("Rate(s) must be provided") 

52 

53 for idx, rate in enumerate(rates[1:]): 

54 prev_rate = rates[idx] 

55 invalid = rate.limit <= prev_rate.limit or rate.interval <= prev_rate.interval 

56 if invalid: 

57 msg = f"{prev_rate} cannot come before {rate}" 

58 raise InvalidParams(msg) 

59 

60 def _validate_bucket(self): 

61 """Try initialize a bucket to check if ok""" 

62 bucket = self._bkclass(maxsize=self._rates[-1].limit, identity="_", **self._bucket_args) 

63 del bucket 

64 

65 def _init_buckets(self, identities) -> None: 

66 """Initialize a bucket for each identity, if needed. 

67 The bucket's maxsize equals the max limit of request-rates. 

68 """ 

69 maxsize = self._rates[-1].limit 

70 for item_id in sorted(identities): 

71 if not self.bucket_group.get(item_id): 

72 self.bucket_group[item_id] = self._bkclass( 

73 maxsize=maxsize, 

74 identity=item_id, 

75 **self._bucket_args, 

76 ) 

77 self.bucket_group[item_id].lock_acquire() 

78 

79 def _release_buckets(self, identities) -> None: 

80 """Release locks after bucket transactions, if applicable""" 

81 for item_id in sorted(identities): 

82 self.bucket_group[item_id].lock_release() 

83 

84 def try_acquire(self, *identities: str) -> None: 

85 """Attempt to acquire an item, or raise an error if a rate limit has been exceeded. 

86 

87 Args: 

88 identities: One or more identities to acquire. Typically this is the name of a service 

89 or resource that is being rate-limited. 

90 

91 Raises: 

92 :py:exc:`BucketFullException`: If the bucket is full and the item cannot be acquired 

93 """ 

94 self._init_buckets(identities) 

95 now = round(self.time_function(), 3) 

96 

97 for rate in self._rates: 

98 for item_id in identities: 

99 bucket = self.bucket_group[item_id] 

100 volume = bucket.size() 

101 

102 if volume < rate.limit: 

103 continue 

104 

105 # Determine rate's starting point, and check requests made during its time window 

106 item_count, remaining_time = bucket.inspect_expired_items(now - rate.interval) 

107 if item_count >= rate.limit: 

108 self._release_buckets(identities) 

109 raise BucketFullException(item_id, rate, remaining_time) 

110 

111 # Remove expired bucket items beyond the last (maximum) rate limit, 

112 if rate is self._rates[-1]: 

113 bucket.get(volume - item_count) 

114 

115 # If no buckets are full, add another item to each bucket representing the next request 

116 for item_id in identities: 

117 self.bucket_group[item_id].put(now) 

118 self._release_buckets(identities) 

119 

120 def ratelimit( 

121 self, 

122 *identities: str, 

123 delay: bool = False, 

124 max_delay: Union[int, float] = None, 

125 ): 

126 """A decorator and contextmanager that applies rate-limiting, with async support. 

127 Depending on arguments, calls that exceed the rate limit will either raise an exception, or 

128 sleep until space is available in the bucket. 

129 

130 Args: 

131 identities: One or more identities to acquire. Typically this is the name of a service 

132 or resource that is being rate-limited. 

133 delay: Delay until the next request instead of raising an exception 

134 max_delay: The maximum allowed delay time (in seconds); anything over this will raise 

135 an exception 

136 

137 Raises: 

138 :py:exc:`BucketFullException`: If the rate limit is reached, and ``delay=False`` or the 

139 delay exceeds ``max_delay`` 

140 """ 

141 return LimitContextDecorator(self, *identities, delay=delay, max_delay=max_delay) 

142 

143 def get_current_volume(self, identity) -> int: 

144 """Get current bucket volume for a specific identity""" 

145 bucket = self.bucket_group[identity] 

146 return bucket.size() 

147 

148 def flush_all(self) -> int: 

149 cnt = 0 

150 

151 for _, bucket in self.bucket_group.items(): 

152 bucket.flush() 

153 cnt += 1 

154 

155 return cnt