1# Copyright 2020 Google LLC 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""AsyncIO helpers for :mod:`grpc` supporting 3.7+. 
    16 
    17Please combine more detailed docstring in grpc_helpers.py to use following 
    18functions. This module is implementing the same surface with AsyncIO semantics. 
    19""" 
    20 
    21import asyncio 
    22import functools 
    23import warnings 
    24 
    25from typing import AsyncGenerator, Generic, Iterator, Optional, TypeVar 
    26 
    27import grpc 
    28from grpc import aio 
    29 
    30from google.api_core import exceptions, general_helpers, grpc_helpers 
    31 
    32# denotes the proto response type for grpc calls 
    33P = TypeVar("P") 
    34 
    35# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform 
    36# automatic patching for us. But that means the overhead of creating an 
    37# extra Python function spreads to every single send and receive. 
    38 
    39 
    40class _WrappedCall(aio.Call): 
    41    def __init__(self): 
    42        self._call = None 
    43 
    44    def with_call(self, call): 
    45        """Supplies the call object separately to keep __init__ clean.""" 
    46        self._call = call 
    47        return self 
    48 
    49    async def initial_metadata(self): 
    50        return await self._call.initial_metadata() 
    51 
    52    async def trailing_metadata(self): 
    53        return await self._call.trailing_metadata() 
    54 
    55    async def code(self): 
    56        return await self._call.code() 
    57 
    58    async def details(self): 
    59        return await self._call.details() 
    60 
    61    def cancelled(self): 
    62        return self._call.cancelled() 
    63 
    64    def done(self): 
    65        return self._call.done() 
    66 
    67    def time_remaining(self): 
    68        return self._call.time_remaining() 
    69 
    70    def cancel(self): 
    71        return self._call.cancel() 
    72 
    73    def add_done_callback(self, callback): 
    74        self._call.add_done_callback(callback) 
    75 
    76    async def wait_for_connection(self): 
    77        try: 
    78            await self._call.wait_for_connection() 
    79        except grpc.RpcError as rpc_error: 
    80            raise exceptions.from_grpc_error(rpc_error) from rpc_error 
    81 
    82 
    83class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall): 
    84    def __await__(self) -> Iterator[P]: 
    85        try: 
    86            response = yield from self._call.__await__() 
    87            return response 
    88        except grpc.RpcError as rpc_error: 
    89            raise exceptions.from_grpc_error(rpc_error) from rpc_error 
    90 
    91 
    92class _WrappedStreamResponseMixin(Generic[P], _WrappedCall): 
    93    def __init__(self): 
    94        self._wrapped_async_generator = None 
    95 
    96    async def read(self) -> P: 
    97        try: 
    98            return await self._call.read() 
    99        except grpc.RpcError as rpc_error: 
    100            raise exceptions.from_grpc_error(rpc_error) from rpc_error 
    101 
    102    async def _wrapped_aiter(self) -> AsyncGenerator[P, None]: 
    103        try: 
    104            # NOTE(lidiz) coverage doesn't understand the exception raised from 
    105            # __anext__ method. It is covered by test case: 
    106            #     test_wrap_stream_errors_aiter_non_rpc_error 
    107            async for response in self._call:  # pragma: no branch 
    108                yield response 
    109        except grpc.RpcError as rpc_error: 
    110            raise exceptions.from_grpc_error(rpc_error) from rpc_error 
    111 
    112    def __aiter__(self) -> AsyncGenerator[P, None]: 
    113        if not self._wrapped_async_generator: 
    114            self._wrapped_async_generator = self._wrapped_aiter() 
    115        return self._wrapped_async_generator 
    116 
    117 
    118class _WrappedStreamRequestMixin(_WrappedCall): 
    119    async def write(self, request): 
    120        try: 
    121            await self._call.write(request) 
    122        except grpc.RpcError as rpc_error: 
    123            raise exceptions.from_grpc_error(rpc_error) from rpc_error 
    124 
    125    async def done_writing(self): 
    126        try: 
    127            await self._call.done_writing() 
    128        except grpc.RpcError as rpc_error: 
    129            raise exceptions.from_grpc_error(rpc_error) from rpc_error 
    130 
    131 
    132# NOTE(lidiz) Implementing each individual class separately, so we don't 
    133# expose any API that should not be seen. E.g., __aiter__ in unary-unary 
    134# RPC, or __await__ in stream-stream RPC. 
    135class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall): 
    136    """Wrapped UnaryUnaryCall to map exceptions.""" 
    137 
    138 
    139class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall): 
    140    """Wrapped UnaryStreamCall to map exceptions.""" 
    141 
    142 
    143class _WrappedStreamUnaryCall( 
    144    _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall 
    145): 
    146    """Wrapped StreamUnaryCall to map exceptions.""" 
    147 
    148 
    149class _WrappedStreamStreamCall( 
    150    _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall 
    151): 
    152    """Wrapped StreamStreamCall to map exceptions.""" 
    153 
    154 
    155# public type alias denoting the return type of async streaming gapic calls 
    156GrpcAsyncStream = _WrappedStreamResponseMixin 
    157# public type alias denoting the return type of unary gapic calls 
    158AwaitableGrpcCall = _WrappedUnaryResponseMixin 
    159 
    160 
    161def _wrap_unary_errors(callable_): 
    162    """Map errors for Unary-Unary async callables.""" 
    163 
    164    @functools.wraps(callable_) 
    165    def error_remapped_callable(*args, **kwargs): 
    166        call = callable_(*args, **kwargs) 
    167        return _WrappedUnaryUnaryCall().with_call(call) 
    168 
    169    return error_remapped_callable 
    170 
    171 
    172def _wrap_stream_errors(callable_, wrapper_type): 
    173    """Map errors for streaming RPC async callables.""" 
    174 
    175    @functools.wraps(callable_) 
    176    async def error_remapped_callable(*args, **kwargs): 
    177        call = callable_(*args, **kwargs) 
    178        call = wrapper_type().with_call(call) 
    179        await call.wait_for_connection() 
    180        return call 
    181 
    182    return error_remapped_callable 
    183 
    184 
    185def wrap_errors(callable_): 
    186    """Wrap a gRPC async callable and map :class:`grpc.RpcErrors` to 
    187    friendly error classes. 
    188 
    189    Errors raised by the gRPC callable are mapped to the appropriate 
    190    :class:`google.api_core.exceptions.GoogleAPICallError` subclasses. The 
    191    original `grpc.RpcError` (which is usually also a `grpc.Call`) is 
    192    available from the ``response`` property on the mapped exception. This 
    193    is useful for extracting metadata from the original error. 
    194 
    195    Args: 
    196        callable_ (Callable): A gRPC callable. 
    197 
    198    Returns: Callable: The wrapped gRPC callable. 
    199    """ 
    200    grpc_helpers._patch_callable_name(callable_) 
    201 
    202    if isinstance(callable_, aio.UnaryStreamMultiCallable): 
    203        return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall) 
    204    elif isinstance(callable_, aio.StreamUnaryMultiCallable): 
    205        return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall) 
    206    elif isinstance(callable_, aio.StreamStreamMultiCallable): 
    207        return _wrap_stream_errors(callable_, _WrappedStreamStreamCall) 
    208    else: 
    209        return _wrap_unary_errors(callable_) 
    210 
    211 
    212def create_channel( 
    213    target, 
    214    credentials=None, 
    215    scopes=None, 
    216    ssl_credentials=None, 
    217    credentials_file=None, 
    218    quota_project_id=None, 
    219    default_scopes=None, 
    220    default_host=None, 
    221    compression=None, 
    222    attempt_direct_path: Optional[bool] = False, 
    223    **kwargs 
    224): 
    225    """Create an AsyncIO secure channel with credentials. 
    226 
    227    Args: 
    228        target (str): The target service address in the format 'hostname:port'. 
    229        credentials (google.auth.credentials.Credentials): The credentials. If 
    230            not specified, then this function will attempt to ascertain the 
    231            credentials from the environment using :func:`google.auth.default`. 
    232        scopes (Sequence[str]): A optional list of scopes needed for this 
    233            service. These are only used when credentials are not specified and 
    234            are passed to :func:`google.auth.default`. 
    235        ssl_credentials (grpc.ChannelCredentials): Optional SSL channel 
    236            credentials. This can be used to specify different certificates. 
    237        credentials_file (str): Deprecated. A file with credentials that can be loaded with 
    238            :func:`google.auth.load_credentials_from_file`. This argument is 
    239            mutually exclusive with credentials. This argument will be 
    240            removed in the next major version of `google-api-core`. 
    241 
    242            .. warning:: 
    243                Important: If you accept a credential configuration (credential JSON/File/Stream) 
    244                from an external source for authentication to Google Cloud Platform, you must 
    245                validate it before providing it to any Google API or client library. Providing an 
    246                unvalidated credential configuration to Google APIs or libraries can compromise 
    247                the security of your systems and data. For more information, refer to 
    248                `Validate credential configurations from external sources`_. 
    249 
    250            .. _Validate credential configurations from external sources: 
    251 
    252            https://cloud.google.com/docs/authentication/external/externally-sourced-credentials 
    253        quota_project_id (str): An optional project to use for billing and quota. 
    254        default_scopes (Sequence[str]): Default scopes passed by a Google client 
    255            library. Use 'scopes' for user-defined scopes. 
    256        default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". 
    257        compression (grpc.Compression): An optional value indicating the 
    258            compression method to be used over the lifetime of the channel. 
    259        attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted 
    260            when the request is made. Direct Path is only available within a Google 
    261            Compute Engine (GCE) environment and provides a proxyless connection 
    262            which increases the available throughput, reduces latency, and increases 
    263            reliability. Note: 
    264 
    265            - This argument should only be set in a GCE environment and for Services 
    266              that are known to support Direct Path. 
    267            - If this argument is set outside of GCE, then this request will fail 
    268              unless the back-end service happens to have configured fall-back to DNS. 
    269            - If the request causes a `ServiceUnavailable` response, it is recommended 
    270              that the client repeat the request with `attempt_direct_path` set to 
    271              `False` as the Service may not support Direct Path. 
    272            - Using `ssl_credentials` with `attempt_direct_path` set to `True` will 
    273              result in `ValueError` as this combination  is not yet supported. 
    274 
    275        kwargs: Additional key-word args passed to :func:`aio.secure_channel`. 
    276 
    277    Returns: 
    278        aio.Channel: The created channel. 
    279 
    280    Raises: 
    281        google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. 
    282        ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`. 
    283    """ 
    284 
    285    if credentials_file is not None: 
    286        warnings.warn(general_helpers._CREDENTIALS_FILE_WARNING, DeprecationWarning) 
    287 
    288    # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`, 
    289    # raise ValueError as this is not yet supported. 
    290    # See https://github.com/googleapis/python-api-core/issues/590 
    291    if ssl_credentials and attempt_direct_path: 
    292        raise ValueError("Using ssl_credentials with Direct Path is not supported") 
    293 
    294    composite_credentials = grpc_helpers._create_composite_credentials( 
    295        credentials=credentials, 
    296        credentials_file=credentials_file, 
    297        scopes=scopes, 
    298        default_scopes=default_scopes, 
    299        ssl_credentials=ssl_credentials, 
    300        quota_project_id=quota_project_id, 
    301        default_host=default_host, 
    302    ) 
    303 
    304    if attempt_direct_path: 
    305        target = grpc_helpers._modify_target_for_direct_path(target) 
    306 
    307    return aio.secure_channel( 
    308        target, composite_credentials, compression=compression, **kwargs 
    309    ) 
    310 
    311 
    312class FakeUnaryUnaryCall(_WrappedUnaryUnaryCall): 
    313    """Fake implementation for unary-unary RPCs. 
    314 
    315    It is a dummy object for response message. Supply the intended response 
    316    upon the initialization, and the coroutine will return the exact response 
    317    message. 
    318    """ 
    319 
    320    def __init__(self, response=object()): 
    321        self.response = response 
    322        self._future = asyncio.get_event_loop().create_future() 
    323        self._future.set_result(self.response) 
    324 
    325    def __await__(self): 
    326        response = yield from self._future.__await__() 
    327        return response 
    328 
    329 
    330class FakeStreamUnaryCall(_WrappedStreamUnaryCall): 
    331    """Fake implementation for stream-unary RPCs. 
    332 
    333    It is a dummy object for response message. Supply the intended response 
    334    upon the initialization, and the coroutine will return the exact response 
    335    message. 
    336    """ 
    337 
    338    def __init__(self, response=object()): 
    339        self.response = response 
    340        self._future = asyncio.get_event_loop().create_future() 
    341        self._future.set_result(self.response) 
    342 
    343    def __await__(self): 
    344        response = yield from self._future.__await__() 
    345        return response 
    346 
    347    async def wait_for_connection(self): 
    348        pass