1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""
16Generator wrapper for retryable async streaming RPCs.
17"""
18from __future__ import annotations
19
20from typing import (
21 cast,
22 Any,
23 Callable,
24 Iterable,
25 AsyncIterator,
26 AsyncIterable,
27 Awaitable,
28 TypeVar,
29 AsyncGenerator,
30 TYPE_CHECKING,
31)
32
33import asyncio
34import time
35import sys
36import functools
37
38from google.api_core.retry.retry_base import _BaseRetry
39from google.api_core.retry.retry_base import _retry_error_helper
40from google.api_core.retry import exponential_sleep_generator
41from google.api_core.retry import build_retry_error
42from google.api_core.retry import RetryFailureReason
43
44
45if TYPE_CHECKING:
46 if sys.version_info >= (3, 10):
47 from typing import ParamSpec
48 else:
49 from typing_extensions import ParamSpec
50
51 _P = ParamSpec("_P") # target function call parameters
52 _Y = TypeVar("_Y") # yielded values
53
54
55async def retry_target_stream(
56 target: Callable[_P, AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]]],
57 predicate: Callable[[Exception], bool],
58 sleep_generator: Iterable[float],
59 timeout: float | None = None,
60 on_error: Callable[[Exception], None] | None = None,
61 exception_factory: Callable[
62 [list[Exception], RetryFailureReason, float | None],
63 tuple[Exception, Exception | None],
64 ] = build_retry_error,
65 init_args: tuple = (),
66 init_kwargs: dict = {},
67 **kwargs,
68) -> AsyncGenerator[_Y, None]:
69 """Create a generator wrapper that retries the wrapped stream if it fails.
70
71 This is the lowest-level retry helper. Generally, you'll use the
72 higher-level retry helper :class:`AsyncRetry`.
73
74 Args:
75 target: The generator function to call and retry.
76 predicate: A callable used to determine if an
77 exception raised by the target should be considered retryable.
78 It should return True to retry or False otherwise.
79 sleep_generator: An infinite iterator that determines
80 how long to sleep between retries.
81 timeout: How long to keep retrying the target.
82 Note: timeout is only checked before initiating a retry, so the target may
83 run past the timeout value as long as it is healthy.
84 on_error: If given, the on_error callback will be called with each
85 retryable exception raised by the target. Any error raised by this
86 function will *not* be caught.
87 exception_factory: A function that is called when the retryable reaches
88 a terminal failure state, used to construct an exception to be raised.
89 It takes a list of all exceptions encountered, a retry.RetryFailureReason
90 enum indicating the failure cause, and the original timeout value
91 as arguments. It should return a tuple of the exception to be raised,
92 along with the cause exception if any. The default implementation will raise
93 a RetryError on timeout, or the last exception encountered otherwise.
94 init_args: Positional arguments to pass to the target function.
95 init_kwargs: Keyword arguments to pass to the target function.
96
97 Returns:
98 AsyncGenerator: A retryable generator that wraps the target generator function.
99
100 Raises:
101 ValueError: If the sleep generator stops yielding values.
102 Exception: a custom exception specified by the exception_factory if provided.
103 If no exception_factory is provided:
104 google.api_core.RetryError: If the timeout is exceeded while retrying.
105 Exception: If the target raises an error that isn't retryable.
106 """
107 target_iterator: AsyncIterator[_Y] | None = None
108 timeout = kwargs.get("deadline", timeout)
109 deadline = time.monotonic() + timeout if timeout else None
110 # keep track of retryable exceptions we encounter to pass in to exception_factory
111 error_list: list[Exception] = []
112 sleep_iter = iter(sleep_generator)
113 target_is_generator: bool | None = None
114
115 # continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper
116 # TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535
117 while True:
118 # Start a new retry loop
119 try:
120 # Note: in the future, we can add a ResumptionStrategy object
121 # to generate new args between calls. For now, use the same args
122 # for each attempt.
123 target_output: AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]] = target(
124 *init_args, **init_kwargs
125 )
126 try:
127 # gapic functions return the generator behind an awaitable
128 # unwrap the awaitable so we can work with the generator directly
129 target_output = await target_output # type: ignore
130 except TypeError:
131 # was not awaitable, continue
132 pass
133 target_iterator = cast(AsyncIterable["_Y"], target_output).__aiter__()
134
135 if target_is_generator is None:
136 # Check if target supports generator features (asend, athrow, aclose)
137 target_is_generator = bool(getattr(target_iterator, "asend", None))
138
139 sent_in = None
140 while True:
141 ## Read from target_iterator
142 # If the target is a generator, we will advance it with `asend`
143 # otherwise, we will use `anext`
144 if target_is_generator:
145 next_value = await target_iterator.asend(sent_in) # type: ignore
146 else:
147 next_value = await target_iterator.__anext__()
148 ## Yield from Wrapper to caller
149 try:
150 # yield latest value from target
151 # exceptions from `athrow` and `aclose` are injected here
152 sent_in = yield next_value
153 except GeneratorExit:
154 # if wrapper received `aclose` while waiting on yield,
155 # it will raise GeneratorExit here
156 if target_is_generator:
157 # pass to inner target_iterator for handling
158 await cast(AsyncGenerator["_Y", None], target_iterator).aclose()
159 else:
160 raise
161 return
162 except: # noqa: E722
163 # bare except catches any exception passed to `athrow`
164 if target_is_generator:
165 # delegate error handling to target_iterator
166 await cast(AsyncGenerator["_Y", None], target_iterator).athrow(
167 cast(BaseException, sys.exc_info()[1])
168 )
169 else:
170 raise
171 return
172 except StopAsyncIteration:
173 # if iterator exhausted, return
174 return
175 # handle exceptions raised by the target_iterator
176 # pylint: disable=broad-except
177 # This function explicitly must deal with broad exceptions.
178 except Exception as exc:
179 # defer to shared logic for handling errors
180 next_sleep = _retry_error_helper(
181 exc,
182 deadline,
183 sleep_iter,
184 error_list,
185 predicate,
186 on_error,
187 exception_factory,
188 timeout,
189 )
190 # if exception not raised, sleep before next attempt
191 await asyncio.sleep(next_sleep)
192
193 finally:
194 if target_is_generator and target_iterator is not None:
195 await cast(AsyncGenerator["_Y", None], target_iterator).aclose()
196
197
198class AsyncStreamingRetry(_BaseRetry):
199 """Exponential retry decorator for async streaming rpcs.
200
201 This class returns an AsyncGenerator when called, which wraps the target
202 stream in retry logic. If any exception is raised by the target, the
203 entire stream will be retried within the wrapper.
204
205 Although the default behavior is to retry transient API errors, a
206 different predicate can be provided to retry other exceptions.
207
208 Important Note: when a stream is encounters a retryable error, it will
209 silently construct a fresh iterator instance in the background
210 and continue yielding (likely duplicate) values as if no error occurred.
211 This is the most general way to retry a stream, but it often is not the
212 desired behavior. Example: iter([1, 2, 1/0]) -> [1, 2, 1, 2, ...]
213
214 There are two ways to build more advanced retry logic for streams:
215
216 1. Wrap the target
217 Use a ``target`` that maintains state between retries, and creates a
218 different generator on each retry call. For example, you can wrap a
219 grpc call in a function that modifies the request based on what has
220 already been returned:
221
222 .. code-block:: python
223
224 async def attempt_with_modified_request(target, request, seen_items=[]):
225 # remove seen items from request on each attempt
226 new_request = modify_request(request, seen_items)
227 new_generator = await target(new_request)
228 async for item in new_generator:
229 yield item
230 seen_items.append(item)
231
232 retry_wrapped = AsyncRetry(is_stream=True,...)(attempt_with_modified_request, target, request, [])
233
234 2. Wrap the retry generator
235 Alternatively, you can wrap the retryable generator itself before
236 passing it to the end-user to add a filter on the stream. For
237 example, you can keep track of the items that were successfully yielded
238 in previous retry attempts, and only yield new items when the
239 new attempt surpasses the previous ones:
240
241 .. code-block:: python
242
243 async def retryable_with_filter(target):
244 stream_idx = 0
245 # reset stream_idx when the stream is retried
246 def on_error(e):
247 nonlocal stream_idx
248 stream_idx = 0
249 # build retryable
250 retryable_gen = AsyncRetry(is_stream=True, ...)(target)
251 # keep track of what has been yielded out of filter
252 seen_items = []
253 async for item in retryable_gen:
254 if stream_idx >= len(seen_items):
255 yield item
256 seen_items.append(item)
257 elif item != previous_stream[stream_idx]:
258 raise ValueError("Stream differs from last attempt")"
259 stream_idx += 1
260
261 filter_retry_wrapped = retryable_with_filter(target)
262
263 Args:
264 predicate (Callable[Exception]): A callable that should return ``True``
265 if the given exception is retryable.
266 initial (float): The minimum amount of time to delay in seconds. This
267 must be greater than 0.
268 maximum (float): The maximum amount of time to delay in seconds.
269 multiplier (float): The multiplier applied to the delay.
270 timeout (Optional[float]): How long to keep retrying in seconds.
271 Note: timeout is only checked before initiating a retry, so the target may
272 run past the timeout value as long as it is healthy.
273 on_error (Optional[Callable[Exception]]): A function to call while processing
274 a retryable exception. Any error raised by this function will
275 *not* be caught.
276 is_stream (bool): Indicates whether the input function
277 should be treated as a stream function (i.e. an AsyncGenerator,
278 or function or coroutine that returns an AsyncIterable).
279 If True, the iterable will be wrapped with retry logic, and any
280 failed outputs will restart the stream. If False, only the input
281 function call itself will be retried. Defaults to False.
282 To avoid duplicate values, retryable streams should typically be
283 wrapped in additional filter logic before use.
284 deadline (float): DEPRECATED use ``timeout`` instead. If set it will
285 override ``timeout`` parameter.
286 """
287
288 def __call__(
289 self,
290 func: Callable[..., AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]]],
291 on_error: Callable[[Exception], Any] | None = None,
292 ) -> Callable[_P, Awaitable[AsyncGenerator[_Y, None]]]:
293 """Wrap a callable with retry behavior.
294
295 Args:
296 func (Callable): The callable or stream to add retry behavior to.
297 on_error (Optional[Callable[Exception]]): If given, the
298 on_error callback will be called with each retryable exception
299 raised by the wrapped function. Any error raised by this
300 function will *not* be caught. If on_error was specified in the
301 constructor, this value will be ignored.
302
303 Returns:
304 Callable: A callable that will invoke ``func`` with retry
305 behavior.
306 """
307 if self._on_error is not None:
308 on_error = self._on_error
309
310 @functools.wraps(func)
311 async def retry_wrapped_func(
312 *args: _P.args, **kwargs: _P.kwargs
313 ) -> AsyncGenerator[_Y, None]:
314 """A wrapper that calls target function with retry."""
315 sleep_generator = exponential_sleep_generator(
316 self._initial, self._maximum, multiplier=self._multiplier
317 )
318 return retry_target_stream(
319 func,
320 self._predicate,
321 sleep_generator,
322 self._timeout,
323 on_error,
324 init_args=args,
325 init_kwargs=kwargs,
326 )
327
328 return retry_wrapped_func