1# Copyright 2023 Google LLC 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15""" 
    16Generator wrapper for retryable async streaming RPCs. 
    17""" 
    18from __future__ import annotations 
    19 
    20from typing import ( 
    21    cast, 
    22    Any, 
    23    Callable, 
    24    Iterable, 
    25    AsyncIterator, 
    26    AsyncIterable, 
    27    Awaitable, 
    28    TypeVar, 
    29    AsyncGenerator, 
    30    TYPE_CHECKING, 
    31) 
    32 
    33import asyncio 
    34import time 
    35import sys 
    36import functools 
    37 
    38from google.api_core.retry.retry_base import _BaseRetry 
    39from google.api_core.retry.retry_base import _retry_error_helper 
    40from google.api_core.retry import exponential_sleep_generator 
    41from google.api_core.retry import build_retry_error 
    42from google.api_core.retry import RetryFailureReason 
    43 
    44 
    45if TYPE_CHECKING: 
    46    if sys.version_info >= (3, 10): 
    47        from typing import ParamSpec 
    48    else: 
    49        from typing_extensions import ParamSpec 
    50 
    51    _P = ParamSpec("_P")  # target function call parameters 
    52    _Y = TypeVar("_Y")  # yielded values 
    53 
    54 
    55async def retry_target_stream( 
    56    target: Callable[_P, AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]]], 
    57    predicate: Callable[[Exception], bool], 
    58    sleep_generator: Iterable[float], 
    59    timeout: float | None = None, 
    60    on_error: Callable[[Exception], None] | None = None, 
    61    exception_factory: Callable[ 
    62        [list[Exception], RetryFailureReason, float | None], 
    63        tuple[Exception, Exception | None], 
    64    ] = build_retry_error, 
    65    init_args: tuple = (), 
    66    init_kwargs: dict = {}, 
    67    **kwargs, 
    68) -> AsyncGenerator[_Y, None]: 
    69    """Create a generator wrapper that retries the wrapped stream if it fails. 
    70 
    71    This is the lowest-level retry helper. Generally, you'll use the 
    72    higher-level retry helper :class:`AsyncRetry`. 
    73 
    74    Args: 
    75        target: The generator function to call and retry. 
    76        predicate: A callable used to determine if an 
    77            exception raised by the target should be considered retryable. 
    78            It should return True to retry or False otherwise. 
    79        sleep_generator: An infinite iterator that determines 
    80            how long to sleep between retries. 
    81        timeout: How long to keep retrying the target. 
    82            Note: timeout is only checked before initiating a retry, so the target may 
    83            run past the timeout value as long as it is healthy. 
    84        on_error: If given, the on_error callback will be called with each 
    85            retryable exception raised by the target. Any error raised by this 
    86            function will *not* be caught. 
    87        exception_factory: A function that is called when the retryable reaches 
    88            a terminal failure state, used to construct an exception to be raised. 
    89            It takes a list of all exceptions encountered, a retry.RetryFailureReason 
    90            enum indicating the failure cause, and the original timeout value 
    91            as arguments. It should return a tuple of the exception to be raised, 
    92            along with the cause exception if any. The default implementation will raise 
    93            a RetryError on timeout, or the last exception encountered otherwise. 
    94        init_args: Positional arguments to pass to the target function. 
    95        init_kwargs: Keyword arguments to pass to the target function. 
    96 
    97    Returns: 
    98        AsyncGenerator: A retryable generator that wraps the target generator function. 
    99 
    100    Raises: 
    101        ValueError: If the sleep generator stops yielding values. 
    102        Exception: a custom exception specified by the exception_factory if provided. 
    103            If no exception_factory is provided: 
    104                google.api_core.RetryError: If the timeout is exceeded while retrying. 
    105                Exception: If the target raises an error that isn't retryable. 
    106    """ 
    107    target_iterator: AsyncIterator[_Y] | None = None 
    108    timeout = kwargs.get("deadline", timeout) 
    109    deadline = time.monotonic() + timeout if timeout else None 
    110    # keep track of retryable exceptions we encounter to pass in to exception_factory 
    111    error_list: list[Exception] = [] 
    112    sleep_iter = iter(sleep_generator) 
    113    target_is_generator: bool | None = None 
    114 
    115    # continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper 
    116    # TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535 
    117    while True: 
    118        # Start a new retry loop 
    119        try: 
    120            # Note: in the future, we can add a ResumptionStrategy object 
    121            # to generate new args between calls. For now, use the same args 
    122            # for each attempt. 
    123            target_output: AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]] = target( 
    124                *init_args, **init_kwargs 
    125            ) 
    126            try: 
    127                # gapic functions return the generator behind an awaitable 
    128                # unwrap the awaitable so we can work with the generator directly 
    129                target_output = await target_output  # type: ignore 
    130            except TypeError: 
    131                # was not awaitable, continue 
    132                pass 
    133            target_iterator = cast(AsyncIterable["_Y"], target_output).__aiter__() 
    134 
    135            if target_is_generator is None: 
    136                # Check if target supports generator features (asend, athrow, aclose) 
    137                target_is_generator = bool(getattr(target_iterator, "asend", None)) 
    138 
    139            sent_in = None 
    140            while True: 
    141                ## Read from target_iterator 
    142                # If the target is a generator, we will advance it with `asend` 
    143                # otherwise, we will use `anext` 
    144                if target_is_generator: 
    145                    next_value = await target_iterator.asend(sent_in)  # type: ignore 
    146                else: 
    147                    next_value = await target_iterator.__anext__() 
    148                ## Yield from Wrapper to caller 
    149                try: 
    150                    # yield latest value from target 
    151                    # exceptions from `athrow` and `aclose` are injected here 
    152                    sent_in = yield next_value 
    153                except GeneratorExit: 
    154                    # if wrapper received `aclose` while waiting on yield, 
    155                    # it will raise GeneratorExit here 
    156                    if target_is_generator: 
    157                        # pass to inner target_iterator for handling 
    158                        await cast(AsyncGenerator["_Y", None], target_iterator).aclose() 
    159                    else: 
    160                        raise 
    161                    return 
    162                except:  # noqa: E722 
    163                    # bare except catches any exception passed to `athrow` 
    164                    if target_is_generator: 
    165                        # delegate error handling to target_iterator 
    166                        await cast(AsyncGenerator["_Y", None], target_iterator).athrow( 
    167                            cast(BaseException, sys.exc_info()[1]) 
    168                        ) 
    169                    else: 
    170                        raise 
    171            return 
    172        except StopAsyncIteration: 
    173            # if iterator exhausted, return 
    174            return 
    175        # handle exceptions raised by the target_iterator 
    176        # pylint: disable=broad-except 
    177        # This function explicitly must deal with broad exceptions. 
    178        except Exception as exc: 
    179            # defer to shared logic for handling errors 
    180            next_sleep = _retry_error_helper( 
    181                exc, 
    182                deadline, 
    183                sleep_iter, 
    184                error_list, 
    185                predicate, 
    186                on_error, 
    187                exception_factory, 
    188                timeout, 
    189            ) 
    190            # if exception not raised, sleep before next attempt 
    191            await asyncio.sleep(next_sleep) 
    192 
    193        finally: 
    194            if target_is_generator and target_iterator is not None: 
    195                await cast(AsyncGenerator["_Y", None], target_iterator).aclose() 
    196 
    197 
    198class AsyncStreamingRetry(_BaseRetry): 
    199    """Exponential retry decorator for async streaming rpcs. 
    200 
    201    This class returns an AsyncGenerator when called, which wraps the target 
    202    stream in retry logic. If any exception is raised by the target, the 
    203    entire stream will be retried within the wrapper. 
    204 
    205    Although the default behavior is to retry transient API errors, a 
    206    different predicate can be provided to retry other exceptions. 
    207 
    208    Important Note: when a stream is encounters a retryable error, it will 
    209    silently construct a fresh iterator instance in the background 
    210    and continue yielding (likely duplicate) values as if no error occurred. 
    211    This is the most general way to retry a stream, but it often is not the 
    212    desired behavior. Example: iter([1, 2, 1/0]) -> [1, 2, 1, 2, ...] 
    213 
    214    There are two ways to build more advanced retry logic for streams: 
    215 
    216    1. Wrap the target 
    217        Use a ``target`` that maintains state between retries, and creates a 
    218        different generator on each retry call. For example, you can wrap a 
    219        grpc call in a function that modifies the request based on what has 
    220        already been returned: 
    221 
    222        .. code-block:: python 
    223 
    224            async def attempt_with_modified_request(target, request, seen_items=[]): 
    225                # remove seen items from request on each attempt 
    226                new_request = modify_request(request, seen_items) 
    227                new_generator = await target(new_request) 
    228                async for item in new_generator: 
    229                    yield item 
    230                    seen_items.append(item) 
    231 
    232            retry_wrapped = AsyncRetry(is_stream=True,...)(attempt_with_modified_request, target, request, []) 
    233 
    234        2. Wrap the retry generator 
    235            Alternatively, you can wrap the retryable generator itself before 
    236            passing it to the end-user to add a filter on the stream. For 
    237            example, you can keep track of the items that were successfully yielded 
    238            in previous retry attempts, and only yield new items when the 
    239            new attempt surpasses the previous ones: 
    240 
    241            .. code-block:: python 
    242 
    243                async def retryable_with_filter(target): 
    244                    stream_idx = 0 
    245                    # reset stream_idx when the stream is retried 
    246                    def on_error(e): 
    247                        nonlocal stream_idx 
    248                        stream_idx = 0 
    249                    # build retryable 
    250                    retryable_gen = AsyncRetry(is_stream=True, ...)(target) 
    251                    # keep track of what has been yielded out of filter 
    252                    seen_items = [] 
    253                    async for item in retryable_gen: 
    254                        if stream_idx >= len(seen_items): 
    255                            yield item 
    256                            seen_items.append(item) 
    257                        elif item != previous_stream[stream_idx]: 
    258                            raise ValueError("Stream differs from last attempt")" 
    259                        stream_idx += 1 
    260 
    261                filter_retry_wrapped = retryable_with_filter(target) 
    262 
    263    Args: 
    264        predicate (Callable[Exception]): A callable that should return ``True`` 
    265            if the given exception is retryable. 
    266        initial (float): The minimum amount of time to delay in seconds. This 
    267            must be greater than 0. 
    268        maximum (float): The maximum amount of time to delay in seconds. 
    269        multiplier (float): The multiplier applied to the delay. 
    270        timeout (Optional[float]): How long to keep retrying in seconds. 
    271            Note: timeout is only checked before initiating a retry, so the target may 
    272            run past the timeout value as long as it is healthy. 
    273        on_error (Optional[Callable[Exception]]): A function to call while processing 
    274            a retryable exception. Any error raised by this function will 
    275            *not* be caught. 
    276        is_stream (bool): Indicates whether the input function 
    277            should be treated as a stream function (i.e. an AsyncGenerator, 
    278            or function or coroutine that returns an AsyncIterable). 
    279            If True, the iterable will be wrapped with retry logic, and any 
    280            failed outputs will restart the stream. If False, only the input 
    281            function call itself will be retried. Defaults to False. 
    282            To avoid duplicate values, retryable streams should typically be 
    283            wrapped in additional filter logic before use. 
    284        deadline (float): DEPRECATED use ``timeout`` instead. If set it will 
    285        override ``timeout`` parameter. 
    286    """ 
    287 
    288    def __call__( 
    289        self, 
    290        func: Callable[..., AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]]], 
    291        on_error: Callable[[Exception], Any] | None = None, 
    292    ) -> Callable[_P, Awaitable[AsyncGenerator[_Y, None]]]: 
    293        """Wrap a callable with retry behavior. 
    294 
    295        Args: 
    296            func (Callable): The callable or stream to add retry behavior to. 
    297            on_error (Optional[Callable[Exception]]): If given, the 
    298                on_error callback will be called with each retryable exception 
    299                raised by the wrapped function. Any error raised by this 
    300                function will *not* be caught. If on_error was specified in the 
    301                constructor, this value will be ignored. 
    302 
    303        Returns: 
    304            Callable: A callable that will invoke ``func`` with retry 
    305                behavior. 
    306        """ 
    307        if self._on_error is not None: 
    308            on_error = self._on_error 
    309 
    310        @functools.wraps(func) 
    311        async def retry_wrapped_func( 
    312            *args: _P.args, **kwargs: _P.kwargs 
    313        ) -> AsyncGenerator[_Y, None]: 
    314            """A wrapper that calls target function with retry.""" 
    315            sleep_generator = exponential_sleep_generator( 
    316                self._initial, self._maximum, multiplier=self._multiplier 
    317            ) 
    318            return retry_target_stream( 
    319                func, 
    320                self._predicate, 
    321                sleep_generator, 
    322                self._timeout, 
    323                on_error, 
    324                init_args=args, 
    325                init_kwargs=kwargs, 
    326            ) 
    327 
    328        return retry_wrapped_func