1from asyncio import sleep
2from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
3
4from redis.exceptions import ConnectionError, RedisError, TimeoutError
5from redis.retry import AbstractRetry
6
7T = TypeVar("T")
8
9if TYPE_CHECKING:
10 from redis.backoff import AbstractBackoff
11
12
13class Retry(AbstractRetry[RedisError]):
14 __hash__ = AbstractRetry.__hash__
15
16 def __init__(
17 self,
18 backoff: "AbstractBackoff",
19 retries: int,
20 supported_errors: Tuple[Type[RedisError], ...] = (
21 ConnectionError,
22 TimeoutError,
23 ),
24 ):
25 super().__init__(backoff, retries, supported_errors)
26
27 def __eq__(self, other: Any) -> bool:
28 if not isinstance(other, Retry):
29 return NotImplemented
30
31 return (
32 self._backoff == other._backoff
33 and self._retries == other._retries
34 and set(self._supported_errors) == set(other._supported_errors)
35 )
36
37 async def call_with_retry(
38 self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
39 ) -> T:
40 """
41 Execute an operation that might fail and returns its result, or
42 raise the exception that was thrown depending on the `Backoff` object.
43 `do`: the operation to call. Expects no argument.
44 `fail`: the failure handler, expects the last error that was thrown
45 """
46 self._backoff.reset()
47 failures = 0
48 while True:
49 try:
50 return await do()
51 except self._supported_errors as error:
52 failures += 1
53 await fail(error)
54 if self._retries >= 0 and failures > self._retries:
55 raise error
56 backoff = self._backoff.compute(failures)
57 if backoff > 0:
58 await sleep(backoff)