1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""AsyncIO helpers for :mod:`grpc` supporting 3.7+.
16
17Please combine more detailed docstring in grpc_helpers.py to use following
18functions. This module is implementing the same surface with AsyncIO semantics.
19"""
20
21import asyncio
22import functools
23
24from typing import AsyncGenerator, Generic, Iterator, Optional, TypeVar
25
26import grpc
27from grpc import aio
28
29from google.api_core import exceptions, grpc_helpers
30
31# denotes the proto response type for grpc calls
32P = TypeVar("P")
33
34# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
35# automatic patching for us. But that means the overhead of creating an
36# extra Python function spreads to every single send and receive.
37
38
39class _WrappedCall(aio.Call):
40 def __init__(self):
41 self._call = None
42
43 def with_call(self, call):
44 """Supplies the call object separately to keep __init__ clean."""
45 self._call = call
46 return self
47
48 async def initial_metadata(self):
49 return await self._call.initial_metadata()
50
51 async def trailing_metadata(self):
52 return await self._call.trailing_metadata()
53
54 async def code(self):
55 return await self._call.code()
56
57 async def details(self):
58 return await self._call.details()
59
60 def cancelled(self):
61 return self._call.cancelled()
62
63 def done(self):
64 return self._call.done()
65
66 def time_remaining(self):
67 return self._call.time_remaining()
68
69 def cancel(self):
70 return self._call.cancel()
71
72 def add_done_callback(self, callback):
73 self._call.add_done_callback(callback)
74
75 async def wait_for_connection(self):
76 try:
77 await self._call.wait_for_connection()
78 except grpc.RpcError as rpc_error:
79 raise exceptions.from_grpc_error(rpc_error) from rpc_error
80
81
82class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall):
83 def __await__(self) -> Iterator[P]:
84 try:
85 response = yield from self._call.__await__()
86 return response
87 except grpc.RpcError as rpc_error:
88 raise exceptions.from_grpc_error(rpc_error) from rpc_error
89
90
91class _WrappedStreamResponseMixin(Generic[P], _WrappedCall):
92 def __init__(self):
93 self._wrapped_async_generator = None
94
95 async def read(self) -> P:
96 try:
97 return await self._call.read()
98 except grpc.RpcError as rpc_error:
99 raise exceptions.from_grpc_error(rpc_error) from rpc_error
100
101 async def _wrapped_aiter(self) -> AsyncGenerator[P, None]:
102 try:
103 # NOTE(lidiz) coverage doesn't understand the exception raised from
104 # __anext__ method. It is covered by test case:
105 # test_wrap_stream_errors_aiter_non_rpc_error
106 async for response in self._call: # pragma: no branch
107 yield response
108 except grpc.RpcError as rpc_error:
109 raise exceptions.from_grpc_error(rpc_error) from rpc_error
110
111 def __aiter__(self) -> AsyncGenerator[P, None]:
112 if not self._wrapped_async_generator:
113 self._wrapped_async_generator = self._wrapped_aiter()
114 return self._wrapped_async_generator
115
116
117class _WrappedStreamRequestMixin(_WrappedCall):
118 async def write(self, request):
119 try:
120 await self._call.write(request)
121 except grpc.RpcError as rpc_error:
122 raise exceptions.from_grpc_error(rpc_error) from rpc_error
123
124 async def done_writing(self):
125 try:
126 await self._call.done_writing()
127 except grpc.RpcError as rpc_error:
128 raise exceptions.from_grpc_error(rpc_error) from rpc_error
129
130
131# NOTE(lidiz) Implementing each individual class separately, so we don't
132# expose any API that should not be seen. E.g., __aiter__ in unary-unary
133# RPC, or __await__ in stream-stream RPC.
134class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall):
135 """Wrapped UnaryUnaryCall to map exceptions."""
136
137
138class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall):
139 """Wrapped UnaryStreamCall to map exceptions."""
140
141
142class _WrappedStreamUnaryCall(
143 _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall
144):
145 """Wrapped StreamUnaryCall to map exceptions."""
146
147
148class _WrappedStreamStreamCall(
149 _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall
150):
151 """Wrapped StreamStreamCall to map exceptions."""
152
153
154# public type alias denoting the return type of async streaming gapic calls
155GrpcAsyncStream = _WrappedStreamResponseMixin[P]
156# public type alias denoting the return type of unary gapic calls
157AwaitableGrpcCall = _WrappedUnaryResponseMixin[P]
158
159
160def _wrap_unary_errors(callable_):
161 """Map errors for Unary-Unary async callables."""
162 grpc_helpers._patch_callable_name(callable_)
163
164 @functools.wraps(callable_)
165 def error_remapped_callable(*args, **kwargs):
166 call = callable_(*args, **kwargs)
167 return _WrappedUnaryUnaryCall().with_call(call)
168
169 return error_remapped_callable
170
171
172def _wrap_stream_errors(callable_):
173 """Map errors for streaming RPC async callables."""
174 grpc_helpers._patch_callable_name(callable_)
175
176 @functools.wraps(callable_)
177 async def error_remapped_callable(*args, **kwargs):
178 call = callable_(*args, **kwargs)
179
180 if isinstance(call, aio.UnaryStreamCall):
181 call = _WrappedUnaryStreamCall().with_call(call)
182 elif isinstance(call, aio.StreamUnaryCall):
183 call = _WrappedStreamUnaryCall().with_call(call)
184 elif isinstance(call, aio.StreamStreamCall):
185 call = _WrappedStreamStreamCall().with_call(call)
186 else:
187 raise TypeError("Unexpected type of call %s" % type(call))
188
189 await call.wait_for_connection()
190 return call
191
192 return error_remapped_callable
193
194
195def wrap_errors(callable_):
196 """Wrap a gRPC async callable and map :class:`grpc.RpcErrors` to
197 friendly error classes.
198
199 Errors raised by the gRPC callable are mapped to the appropriate
200 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses. The
201 original `grpc.RpcError` (which is usually also a `grpc.Call`) is
202 available from the ``response`` property on the mapped exception. This
203 is useful for extracting metadata from the original error.
204
205 Args:
206 callable_ (Callable): A gRPC callable.
207
208 Returns: Callable: The wrapped gRPC callable.
209 """
210 if isinstance(callable_, aio.UnaryUnaryMultiCallable):
211 return _wrap_unary_errors(callable_)
212 else:
213 return _wrap_stream_errors(callable_)
214
215
216def create_channel(
217 target,
218 credentials=None,
219 scopes=None,
220 ssl_credentials=None,
221 credentials_file=None,
222 quota_project_id=None,
223 default_scopes=None,
224 default_host=None,
225 compression=None,
226 attempt_direct_path: Optional[bool] = False,
227 **kwargs
228):
229 """Create an AsyncIO secure channel with credentials.
230
231 Args:
232 target (str): The target service address in the format 'hostname:port'.
233 credentials (google.auth.credentials.Credentials): The credentials. If
234 not specified, then this function will attempt to ascertain the
235 credentials from the environment using :func:`google.auth.default`.
236 scopes (Sequence[str]): A optional list of scopes needed for this
237 service. These are only used when credentials are not specified and
238 are passed to :func:`google.auth.default`.
239 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
240 credentials. This can be used to specify different certificates.
241 credentials_file (str): A file with credentials that can be loaded with
242 :func:`google.auth.load_credentials_from_file`. This argument is
243 mutually exclusive with credentials.
244 quota_project_id (str): An optional project to use for billing and quota.
245 default_scopes (Sequence[str]): Default scopes passed by a Google client
246 library. Use 'scopes' for user-defined scopes.
247 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
248 compression (grpc.Compression): An optional value indicating the
249 compression method to be used over the lifetime of the channel.
250 attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted
251 when the request is made. Direct Path is only available within a Google
252 Compute Engine (GCE) environment and provides a proxyless connection
253 which increases the available throughput, reduces latency, and increases
254 reliability. Note:
255
256 - This argument should only be set in a GCE environment and for Services
257 that are known to support Direct Path.
258 - If this argument is set outside of GCE, then this request will fail
259 unless the back-end service happens to have configured fall-back to DNS.
260 - If the request causes a `ServiceUnavailable` response, it is recommended
261 that the client repeat the request with `attempt_direct_path` set to
262 `False` as the Service may not support Direct Path.
263 - Using `ssl_credentials` with `attempt_direct_path` set to `True` will
264 result in `ValueError` as this combination is not yet supported.
265
266 kwargs: Additional key-word args passed to :func:`aio.secure_channel`.
267
268 Returns:
269 aio.Channel: The created channel.
270
271 Raises:
272 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
273 ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`.
274 """
275
276 # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`,
277 # raise ValueError as this is not yet supported.
278 # See https://github.com/googleapis/python-api-core/issues/590
279 if ssl_credentials and attempt_direct_path:
280 raise ValueError("Using ssl_credentials with Direct Path is not supported")
281
282 composite_credentials = grpc_helpers._create_composite_credentials(
283 credentials=credentials,
284 credentials_file=credentials_file,
285 scopes=scopes,
286 default_scopes=default_scopes,
287 ssl_credentials=ssl_credentials,
288 quota_project_id=quota_project_id,
289 default_host=default_host,
290 )
291
292 if attempt_direct_path:
293 target = grpc_helpers._modify_target_for_direct_path(target)
294
295 return aio.secure_channel(
296 target, composite_credentials, compression=compression, **kwargs
297 )
298
299
300class FakeUnaryUnaryCall(_WrappedUnaryUnaryCall):
301 """Fake implementation for unary-unary RPCs.
302
303 It is a dummy object for response message. Supply the intended response
304 upon the initialization, and the coroutine will return the exact response
305 message.
306 """
307
308 def __init__(self, response=object()):
309 self.response = response
310 self._future = asyncio.get_event_loop().create_future()
311 self._future.set_result(self.response)
312
313 def __await__(self):
314 response = yield from self._future.__await__()
315 return response
316
317
318class FakeStreamUnaryCall(_WrappedStreamUnaryCall):
319 """Fake implementation for stream-unary RPCs.
320
321 It is a dummy object for response message. Supply the intended response
322 upon the initialization, and the coroutine will return the exact response
323 message.
324 """
325
326 def __init__(self, response=object()):
327 self.response = response
328 self._future = asyncio.get_event_loop().create_future()
329 self._future.set_result(self.response)
330
331 def __await__(self):
332 response = yield from self._future.__await__()
333 return response
334
335 async def wait_for_connection(self):
336 pass