1from asyncio import sleep
2from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
3
4from redis.exceptions import ConnectionError, RedisError, TimeoutError
5
6if TYPE_CHECKING:
7 from redis.backoff import AbstractBackoff
8
9
10T = TypeVar("T")
11
12
13class Retry:
14 """Retry a specific number of times after a failure"""
15
16 __slots__ = "_backoff", "_retries", "_supported_errors"
17
18 def __init__(
19 self,
20 backoff: "AbstractBackoff",
21 retries: int,
22 supported_errors: Tuple[Type[RedisError], ...] = (
23 ConnectionError,
24 TimeoutError,
25 ),
26 ):
27 """
28 Initialize a `Retry` object with a `Backoff` object
29 that retries a maximum of `retries` times.
30 `retries` can be negative to retry forever.
31 You can specify the types of supported errors which trigger
32 a retry with the `supported_errors` parameter.
33 """
34 self._backoff = backoff
35 self._retries = retries
36 self._supported_errors = supported_errors
37
38 def update_supported_errors(self, specified_errors: list):
39 """
40 Updates the supported errors with the specified error types
41 """
42 self._supported_errors = tuple(
43 set(self._supported_errors + tuple(specified_errors))
44 )
45
46 def get_retries(self) -> int:
47 """
48 Get the number of retries.
49 """
50 return self._retries
51
52 def update_retries(self, value: int) -> None:
53 """
54 Set the number of retries.
55 """
56 self._retries = value
57
58 async def call_with_retry(
59 self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
60 ) -> T:
61 """
62 Execute an operation that might fail and returns its result, or
63 raise the exception that was thrown depending on the `Backoff` object.
64 `do`: the operation to call. Expects no argument.
65 `fail`: the failure handler, expects the last error that was thrown
66 """
67 self._backoff.reset()
68 failures = 0
69 while True:
70 try:
71 return await do()
72 except self._supported_errors as error:
73 failures += 1
74 await fail(error)
75 if self._retries >= 0 and failures > self._retries:
76 raise error
77 backoff = self._backoff.compute(failures)
78 if backoff > 0:
79 await sleep(backoff)