1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""AsyncIO helpers for :mod:`grpc` supporting 3.7+.
16
17Please combine more detailed docstring in grpc_helpers.py to use following
18functions. This module is implementing the same surface with AsyncIO semantics.
19"""
20
21import asyncio
22import functools
23import warnings
24
25from typing import AsyncGenerator, Generic, Iterator, Optional, TypeVar
26
27import grpc
28from grpc import aio
29
30from google.api_core import exceptions, general_helpers, grpc_helpers
31
32# denotes the proto response type for grpc calls
33P = TypeVar("P")
34
35# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
36# automatic patching for us. But that means the overhead of creating an
37# extra Python function spreads to every single send and receive.
38
39
40class _WrappedCall(aio.Call):
41 def __init__(self):
42 self._call = None
43
44 def with_call(self, call):
45 """Supplies the call object separately to keep __init__ clean."""
46 self._call = call
47 return self
48
49 async def initial_metadata(self):
50 return await self._call.initial_metadata()
51
52 async def trailing_metadata(self):
53 return await self._call.trailing_metadata()
54
55 async def code(self):
56 return await self._call.code()
57
58 async def details(self):
59 return await self._call.details()
60
61 def cancelled(self):
62 return self._call.cancelled()
63
64 def done(self):
65 return self._call.done()
66
67 def time_remaining(self):
68 return self._call.time_remaining()
69
70 def cancel(self):
71 return self._call.cancel()
72
73 def add_done_callback(self, callback):
74 self._call.add_done_callback(callback)
75
76 async def wait_for_connection(self):
77 try:
78 await self._call.wait_for_connection()
79 except grpc.RpcError as rpc_error:
80 raise exceptions.from_grpc_error(rpc_error) from rpc_error
81
82
83class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall):
84 def __await__(self) -> Iterator[P]:
85 try:
86 response = yield from self._call.__await__()
87 return response
88 except grpc.RpcError as rpc_error:
89 raise exceptions.from_grpc_error(rpc_error) from rpc_error
90
91
92class _WrappedStreamResponseMixin(Generic[P], _WrappedCall):
93 def __init__(self):
94 self._wrapped_async_generator = None
95
96 async def read(self) -> P:
97 try:
98 return await self._call.read()
99 except grpc.RpcError as rpc_error:
100 raise exceptions.from_grpc_error(rpc_error) from rpc_error
101
102 async def _wrapped_aiter(self) -> AsyncGenerator[P, None]:
103 try:
104 # NOTE(lidiz) coverage doesn't understand the exception raised from
105 # __anext__ method. It is covered by test case:
106 # test_wrap_stream_errors_aiter_non_rpc_error
107 async for response in self._call: # pragma: no branch
108 yield response
109 except grpc.RpcError as rpc_error:
110 raise exceptions.from_grpc_error(rpc_error) from rpc_error
111
112 def __aiter__(self) -> AsyncGenerator[P, None]:
113 if not self._wrapped_async_generator:
114 self._wrapped_async_generator = self._wrapped_aiter()
115 return self._wrapped_async_generator
116
117
118class _WrappedStreamRequestMixin(_WrappedCall):
119 async def write(self, request):
120 try:
121 await self._call.write(request)
122 except grpc.RpcError as rpc_error:
123 raise exceptions.from_grpc_error(rpc_error) from rpc_error
124
125 async def done_writing(self):
126 try:
127 await self._call.done_writing()
128 except grpc.RpcError as rpc_error:
129 raise exceptions.from_grpc_error(rpc_error) from rpc_error
130
131
132# NOTE(lidiz) Implementing each individual class separately, so we don't
133# expose any API that should not be seen. E.g., __aiter__ in unary-unary
134# RPC, or __await__ in stream-stream RPC.
135class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall):
136 """Wrapped UnaryUnaryCall to map exceptions."""
137
138
139class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall):
140 """Wrapped UnaryStreamCall to map exceptions."""
141
142
143class _WrappedStreamUnaryCall(
144 _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall
145):
146 """Wrapped StreamUnaryCall to map exceptions."""
147
148
149class _WrappedStreamStreamCall(
150 _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall
151):
152 """Wrapped StreamStreamCall to map exceptions."""
153
154
155# public type alias denoting the return type of async streaming gapic calls
156GrpcAsyncStream = _WrappedStreamResponseMixin
157# public type alias denoting the return type of unary gapic calls
158AwaitableGrpcCall = _WrappedUnaryResponseMixin
159
160
161def _wrap_unary_errors(callable_):
162 """Map errors for Unary-Unary async callables."""
163
164 @functools.wraps(callable_)
165 def error_remapped_callable(*args, **kwargs):
166 call = callable_(*args, **kwargs)
167 return _WrappedUnaryUnaryCall().with_call(call)
168
169 return error_remapped_callable
170
171
172def _wrap_stream_errors(callable_, wrapper_type):
173 """Map errors for streaming RPC async callables."""
174
175 @functools.wraps(callable_)
176 async def error_remapped_callable(*args, **kwargs):
177 call = callable_(*args, **kwargs)
178 call = wrapper_type().with_call(call)
179 await call.wait_for_connection()
180 return call
181
182 return error_remapped_callable
183
184
185def wrap_errors(callable_):
186 """Wrap a gRPC async callable and map :class:`grpc.RpcErrors` to
187 friendly error classes.
188
189 Errors raised by the gRPC callable are mapped to the appropriate
190 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses. The
191 original `grpc.RpcError` (which is usually also a `grpc.Call`) is
192 available from the ``response`` property on the mapped exception. This
193 is useful for extracting metadata from the original error.
194
195 Args:
196 callable_ (Callable): A gRPC callable.
197
198 Returns: Callable: The wrapped gRPC callable.
199 """
200 grpc_helpers._patch_callable_name(callable_)
201
202 if isinstance(callable_, aio.UnaryStreamMultiCallable):
203 return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall)
204 elif isinstance(callable_, aio.StreamUnaryMultiCallable):
205 return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall)
206 elif isinstance(callable_, aio.StreamStreamMultiCallable):
207 return _wrap_stream_errors(callable_, _WrappedStreamStreamCall)
208 else:
209 return _wrap_unary_errors(callable_)
210
211
212def create_channel(
213 target,
214 credentials=None,
215 scopes=None,
216 ssl_credentials=None,
217 credentials_file=None,
218 quota_project_id=None,
219 default_scopes=None,
220 default_host=None,
221 compression=None,
222 attempt_direct_path: Optional[bool] = False,
223 **kwargs
224):
225 """Create an AsyncIO secure channel with credentials.
226
227 Args:
228 target (str): The target service address in the format 'hostname:port'.
229 credentials (google.auth.credentials.Credentials): The credentials. If
230 not specified, then this function will attempt to ascertain the
231 credentials from the environment using :func:`google.auth.default`.
232 scopes (Sequence[str]): A optional list of scopes needed for this
233 service. These are only used when credentials are not specified and
234 are passed to :func:`google.auth.default`.
235 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
236 credentials. This can be used to specify different certificates.
237 credentials_file (str): Deprecated. A file with credentials that can be loaded with
238 :func:`google.auth.load_credentials_from_file`. This argument is
239 mutually exclusive with credentials. This argument will be
240 removed in the next major version of `google-api-core`.
241
242 .. warning::
243 Important: If you accept a credential configuration (credential JSON/File/Stream)
244 from an external source for authentication to Google Cloud Platform, you must
245 validate it before providing it to any Google API or client library. Providing an
246 unvalidated credential configuration to Google APIs or libraries can compromise
247 the security of your systems and data. For more information, refer to
248 `Validate credential configurations from external sources`_.
249
250 .. _Validate credential configurations from external sources:
251
252 https://cloud.google.com/docs/authentication/external/externally-sourced-credentials
253 quota_project_id (str): An optional project to use for billing and quota.
254 default_scopes (Sequence[str]): Default scopes passed by a Google client
255 library. Use 'scopes' for user-defined scopes.
256 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
257 compression (grpc.Compression): An optional value indicating the
258 compression method to be used over the lifetime of the channel.
259 attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted
260 when the request is made. Direct Path is only available within a Google
261 Compute Engine (GCE) environment and provides a proxyless connection
262 which increases the available throughput, reduces latency, and increases
263 reliability. Note:
264
265 - This argument should only be set in a GCE environment and for Services
266 that are known to support Direct Path.
267 - If this argument is set outside of GCE, then this request will fail
268 unless the back-end service happens to have configured fall-back to DNS.
269 - If the request causes a `ServiceUnavailable` response, it is recommended
270 that the client repeat the request with `attempt_direct_path` set to
271 `False` as the Service may not support Direct Path.
272 - Using `ssl_credentials` with `attempt_direct_path` set to `True` will
273 result in `ValueError` as this combination is not yet supported.
274
275 kwargs: Additional key-word args passed to :func:`aio.secure_channel`.
276
277 Returns:
278 aio.Channel: The created channel.
279
280 Raises:
281 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
282 ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`.
283 """
284
285 if credentials_file is not None:
286 warnings.warn(general_helpers._CREDENTIALS_FILE_WARNING, DeprecationWarning)
287
288 # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`,
289 # raise ValueError as this is not yet supported.
290 # See https://github.com/googleapis/python-api-core/issues/590
291 if ssl_credentials and attempt_direct_path:
292 raise ValueError("Using ssl_credentials with Direct Path is not supported")
293
294 composite_credentials = grpc_helpers._create_composite_credentials(
295 credentials=credentials,
296 credentials_file=credentials_file,
297 scopes=scopes,
298 default_scopes=default_scopes,
299 ssl_credentials=ssl_credentials,
300 quota_project_id=quota_project_id,
301 default_host=default_host,
302 )
303
304 if attempt_direct_path:
305 target = grpc_helpers._modify_target_for_direct_path(target)
306
307 return aio.secure_channel(
308 target, composite_credentials, compression=compression, **kwargs
309 )
310
311
312class FakeUnaryUnaryCall(_WrappedUnaryUnaryCall):
313 """Fake implementation for unary-unary RPCs.
314
315 It is a dummy object for response message. Supply the intended response
316 upon the initialization, and the coroutine will return the exact response
317 message.
318 """
319
320 def __init__(self, response=object()):
321 self.response = response
322 self._future = asyncio.get_event_loop().create_future()
323 self._future.set_result(self.response)
324
325 def __await__(self):
326 response = yield from self._future.__await__()
327 return response
328
329
330class FakeStreamUnaryCall(_WrappedStreamUnaryCall):
331 """Fake implementation for stream-unary RPCs.
332
333 It is a dummy object for response message. Supply the intended response
334 upon the initialization, and the coroutine will return the exact response
335 message.
336 """
337
338 def __init__(self, response=object()):
339 self.response = response
340 self._future = asyncio.get_event_loop().create_future()
341 self._future.set_result(self.response)
342
343 def __await__(self):
344 response = yield from self._future.__await__()
345 return response
346
347 async def wait_for_connection(self):
348 pass