1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`grpc`."""
16from typing import Generic, Iterator, Optional, TypeVar
17
18import collections
19import functools
20import warnings
21
22import grpc
23
24from google.api_core import exceptions
25import google.auth
26import google.auth.credentials
27import google.auth.transport.grpc
28import google.auth.transport.requests
29import google.protobuf
30
31PROTOBUF_VERSION = google.protobuf.__version__
32
33# The grpcio-gcp package only has support for protobuf < 4
34if PROTOBUF_VERSION[0:2] == "3.": # pragma: NO COVER
35 try:
36 import grpc_gcp
37
38 warnings.warn(
39 """Support for grpcio-gcp is deprecated. This feature will be
40 removed from `google-api-core` after January 1, 2024. If you need to
41 continue to use this feature, please pin to a specific version of
42 `google-api-core`.""",
43 DeprecationWarning,
44 )
45 HAS_GRPC_GCP = True
46 except ImportError:
47 HAS_GRPC_GCP = False
48else:
49 HAS_GRPC_GCP = False
50
51
52# The list of gRPC Callable interfaces that return iterators.
53_STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable)
54
55# denotes the proto response type for grpc calls
56P = TypeVar("P")
57
58
59def _patch_callable_name(callable_):
60 """Fix-up gRPC callable attributes.
61
62 gRPC callable lack the ``__name__`` attribute which causes
63 :func:`functools.wraps` to error. This adds the attribute if needed.
64 """
65 if not hasattr(callable_, "__name__"):
66 callable_.__name__ = callable_.__class__.__name__
67
68
69def _wrap_unary_errors(callable_):
70 """Map errors for Unary-Unary and Stream-Unary gRPC callables."""
71 _patch_callable_name(callable_)
72
73 @functools.wraps(callable_)
74 def error_remapped_callable(*args, **kwargs):
75 try:
76 return callable_(*args, **kwargs)
77 except grpc.RpcError as exc:
78 raise exceptions.from_grpc_error(exc) from exc
79
80 return error_remapped_callable
81
82
83class _StreamingResponseIterator(Generic[P], grpc.Call):
84 def __init__(self, wrapped, prefetch_first_result=True):
85 self._wrapped = wrapped
86
87 # This iterator is used in a retry context, and returned outside after init.
88 # gRPC will not throw an exception until the stream is consumed, so we need
89 # to retrieve the first result, in order to fail, in order to trigger a retry.
90 try:
91 if prefetch_first_result:
92 self._stored_first_result = next(self._wrapped)
93 except TypeError:
94 # It is possible the wrapped method isn't an iterable (a grpc.Call
95 # for instance). If this happens don't store the first result.
96 pass
97 except StopIteration:
98 # ignore stop iteration at this time. This should be handled outside of retry.
99 pass
100
101 def __iter__(self) -> Iterator[P]:
102 """This iterator is also an iterable that returns itself."""
103 return self
104
105 def __next__(self) -> P:
106 """Get the next response from the stream.
107
108 Returns:
109 protobuf.Message: A single response from the stream.
110 """
111 try:
112 if hasattr(self, "_stored_first_result"):
113 result = self._stored_first_result
114 del self._stored_first_result
115 return result
116 return next(self._wrapped)
117 except grpc.RpcError as exc:
118 # If the stream has already returned data, we cannot recover here.
119 raise exceptions.from_grpc_error(exc) from exc
120
121 # grpc.Call & grpc.RpcContext interface
122
123 def add_callback(self, callback):
124 return self._wrapped.add_callback(callback)
125
126 def cancel(self):
127 return self._wrapped.cancel()
128
129 def code(self):
130 return self._wrapped.code()
131
132 def details(self):
133 return self._wrapped.details()
134
135 def initial_metadata(self):
136 return self._wrapped.initial_metadata()
137
138 def is_active(self):
139 return self._wrapped.is_active()
140
141 def time_remaining(self):
142 return self._wrapped.time_remaining()
143
144 def trailing_metadata(self):
145 return self._wrapped.trailing_metadata()
146
147
148# public type alias denoting the return type of streaming gapic calls
149GrpcStream = _StreamingResponseIterator[P]
150
151
152def _wrap_stream_errors(callable_):
153 """Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
154
155 The callables that return iterators require a bit more logic to re-map
156 errors when iterating. This wraps both the initial invocation and the
157 iterator of the return value to re-map errors.
158 """
159 _patch_callable_name(callable_)
160
161 @functools.wraps(callable_)
162 def error_remapped_callable(*args, **kwargs):
163 try:
164 result = callable_(*args, **kwargs)
165 # Auto-fetching the first result causes PubSub client's streaming pull
166 # to hang when re-opening the stream, thus we need examine the hacky
167 # hidden flag to see if pre-fetching is disabled.
168 # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257
169 prefetch_first = getattr(callable_, "_prefetch_first_result_", True)
170 return _StreamingResponseIterator(
171 result, prefetch_first_result=prefetch_first
172 )
173 except grpc.RpcError as exc:
174 raise exceptions.from_grpc_error(exc) from exc
175
176 return error_remapped_callable
177
178
179def wrap_errors(callable_):
180 """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error
181 classes.
182
183 Errors raised by the gRPC callable are mapped to the appropriate
184 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses.
185 The original `grpc.RpcError` (which is usually also a `grpc.Call`) is
186 available from the ``response`` property on the mapped exception. This
187 is useful for extracting metadata from the original error.
188
189 Args:
190 callable_ (Callable): A gRPC callable.
191
192 Returns:
193 Callable: The wrapped gRPC callable.
194 """
195 if isinstance(callable_, _STREAM_WRAP_CLASSES):
196 return _wrap_stream_errors(callable_)
197 else:
198 return _wrap_unary_errors(callable_)
199
200
201def _create_composite_credentials(
202 credentials=None,
203 credentials_file=None,
204 default_scopes=None,
205 scopes=None,
206 ssl_credentials=None,
207 quota_project_id=None,
208 default_host=None,
209):
210 """Create the composite credentials for secure channels.
211
212 Args:
213 credentials (google.auth.credentials.Credentials): The credentials. If
214 not specified, then this function will attempt to ascertain the
215 credentials from the environment using :func:`google.auth.default`.
216 credentials_file (str): A file with credentials that can be loaded with
217 :func:`google.auth.load_credentials_from_file`. This argument is
218 mutually exclusive with credentials.
219 default_scopes (Sequence[str]): A optional list of scopes needed for this
220 service. These are only used when credentials are not specified and
221 are passed to :func:`google.auth.default`.
222 scopes (Sequence[str]): A optional list of scopes needed for this
223 service. These are only used when credentials are not specified and
224 are passed to :func:`google.auth.default`.
225 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
226 credentials. This can be used to specify different certificates.
227 quota_project_id (str): An optional project to use for billing and quota.
228 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
229
230 Returns:
231 grpc.ChannelCredentials: The composed channel credentials object.
232
233 Raises:
234 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
235 """
236 if credentials and credentials_file:
237 raise exceptions.DuplicateCredentialArgs(
238 "'credentials' and 'credentials_file' are mutually exclusive."
239 )
240
241 if credentials_file:
242 credentials, _ = google.auth.load_credentials_from_file(
243 credentials_file, scopes=scopes, default_scopes=default_scopes
244 )
245 elif credentials:
246 credentials = google.auth.credentials.with_scopes_if_required(
247 credentials, scopes=scopes, default_scopes=default_scopes
248 )
249 else:
250 credentials, _ = google.auth.default(
251 scopes=scopes, default_scopes=default_scopes
252 )
253
254 if quota_project_id and isinstance(
255 credentials, google.auth.credentials.CredentialsWithQuotaProject
256 ):
257 credentials = credentials.with_quota_project(quota_project_id)
258
259 request = google.auth.transport.requests.Request()
260
261 # Create the metadata plugin for inserting the authorization header.
262 metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin(
263 credentials,
264 request,
265 default_host=default_host,
266 )
267
268 # Create a set of grpc.CallCredentials using the metadata plugin.
269 google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
270
271 # if `ssl_credentials` is set, use `grpc.composite_channel_credentials` instead of
272 # `grpc.compute_engine_channel_credentials` as the former supports passing
273 # `ssl_credentials` via `channel_credentials` which is needed for mTLS.
274 if ssl_credentials:
275 # Combine the ssl credentials and the authorization credentials.
276 # See https://grpc.github.io/grpc/python/grpc.html#grpc.composite_channel_credentials
277 return grpc.composite_channel_credentials(
278 ssl_credentials, google_auth_credentials
279 )
280 else:
281 # Use grpc.compute_engine_channel_credentials in order to support Direct Path.
282 # See https://grpc.github.io/grpc/python/grpc.html#grpc.compute_engine_channel_credentials
283 # TODO(https://github.com/googleapis/python-api-core/issues/598):
284 # Although `grpc.compute_engine_channel_credentials` returns channel credentials
285 # outside of a Google Compute Engine environment (GCE), we should determine if
286 # there is a way to reliably detect a GCE environment so that
287 # `grpc.compute_engine_channel_credentials` is not called outside of GCE.
288 return grpc.compute_engine_channel_credentials(google_auth_credentials)
289
290
291def create_channel(
292 target,
293 credentials=None,
294 scopes=None,
295 ssl_credentials=None,
296 credentials_file=None,
297 quota_project_id=None,
298 default_scopes=None,
299 default_host=None,
300 compression=None,
301 attempt_direct_path: Optional[bool] = False,
302 **kwargs,
303):
304 """Create a secure channel with credentials.
305
306 Args:
307 target (str): The target service address in the format 'hostname:port'.
308 credentials (google.auth.credentials.Credentials): The credentials. If
309 not specified, then this function will attempt to ascertain the
310 credentials from the environment using :func:`google.auth.default`.
311 scopes (Sequence[str]): A optional list of scopes needed for this
312 service. These are only used when credentials are not specified and
313 are passed to :func:`google.auth.default`.
314 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
315 credentials. This can be used to specify different certificates.
316 credentials_file (str): A file with credentials that can be loaded with
317 :func:`google.auth.load_credentials_from_file`. This argument is
318 mutually exclusive with credentials.
319 quota_project_id (str): An optional project to use for billing and quota.
320 default_scopes (Sequence[str]): Default scopes passed by a Google client
321 library. Use 'scopes' for user-defined scopes.
322 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
323 compression (grpc.Compression): An optional value indicating the
324 compression method to be used over the lifetime of the channel.
325 attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted
326 when the request is made. Direct Path is only available within a Google
327 Compute Engine (GCE) environment and provides a proxyless connection
328 which increases the available throughput, reduces latency, and increases
329 reliability. Note:
330
331 - This argument should only be set in a GCE environment and for Services
332 that are known to support Direct Path.
333 - If this argument is set outside of GCE, then this request will fail
334 unless the back-end service happens to have configured fall-back to DNS.
335 - If the request causes a `ServiceUnavailable` response, it is recommended
336 that the client repeat the request with `attempt_direct_path` set to
337 `False` as the Service may not support Direct Path.
338 - Using `ssl_credentials` with `attempt_direct_path` set to `True` will
339 result in `ValueError` as this combination is not yet supported.
340
341 kwargs: Additional key-word args passed to
342 :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`.
343 Note: `grpc_gcp` is only supported in environments with protobuf < 4.0.0.
344
345 Returns:
346 grpc.Channel: The created channel.
347
348 Raises:
349 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
350 ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`.
351 """
352
353 # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`,
354 # raise ValueError as this is not yet supported.
355 # See https://github.com/googleapis/python-api-core/issues/590
356 if ssl_credentials and attempt_direct_path:
357 raise ValueError("Using ssl_credentials with Direct Path is not supported")
358
359 composite_credentials = _create_composite_credentials(
360 credentials=credentials,
361 credentials_file=credentials_file,
362 default_scopes=default_scopes,
363 scopes=scopes,
364 ssl_credentials=ssl_credentials,
365 quota_project_id=quota_project_id,
366 default_host=default_host,
367 )
368
369 # Note that grpcio-gcp is deprecated
370 if HAS_GRPC_GCP: # pragma: NO COVER
371 if compression is not None and compression != grpc.Compression.NoCompression:
372 warnings.warn(
373 "The `compression` argument is ignored for grpc_gcp.secure_channel creation.",
374 DeprecationWarning,
375 )
376 if attempt_direct_path:
377 warnings.warn(
378 """The `attempt_direct_path` argument is ignored for grpc_gcp.secure_channel creation.""",
379 DeprecationWarning,
380 )
381 return grpc_gcp.secure_channel(target, composite_credentials, **kwargs)
382
383 if attempt_direct_path:
384 target = _modify_target_for_direct_path(target)
385
386 return grpc.secure_channel(
387 target, composite_credentials, compression=compression, **kwargs
388 )
389
390
391def _modify_target_for_direct_path(target: str) -> str:
392 """
393 Given a target, return a modified version which is compatible with Direct Path.
394
395 Args:
396 target (str): The target service address in the format 'hostname[:port]' or
397 'dns://hostname[:port]'.
398
399 Returns:
400 target (str): The target service address which is converted into a format compatible with Direct Path.
401 If the target contains `dns:///` or does not contain `:///`, the target will be converted in
402 a format compatible with Direct Path; otherwise the original target will be returned as the
403 original target may already denote Direct Path.
404 """
405
406 # A DNS prefix may be included with the target to indicate the endpoint is living in the Internet,
407 # outside of Google Cloud Platform.
408 dns_prefix = "dns:///"
409 # Remove "dns:///" if `attempt_direct_path` is set to True as
410 # the Direct Path prefix `google-c2p:///` will be used instead.
411 target = target.replace(dns_prefix, "")
412
413 direct_path_separator = ":///"
414 if direct_path_separator not in target:
415 target_without_port = target.split(":")[0]
416 # Modify the target to use Direct Path by adding the `google-c2p:///` prefix
417 target = f"google-c2p{direct_path_separator}{target_without_port}"
418 return target
419
420
421_MethodCall = collections.namedtuple(
422 "_MethodCall", ("request", "timeout", "metadata", "credentials", "compression")
423)
424
425_ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request"))
426
427
428class _CallableStub(object):
429 """Stub for the grpc.*MultiCallable interfaces."""
430
431 def __init__(self, method, channel):
432 self._method = method
433 self._channel = channel
434 self.response = None
435 """Union[protobuf.Message, Callable[protobuf.Message], exception]:
436 The response to give when invoking this callable. If this is a
437 callable, it will be invoked with the request protobuf. If it's an
438 exception, the exception will be raised when this is invoked.
439 """
440 self.responses = None
441 """Iterator[
442 Union[protobuf.Message, Callable[protobuf.Message], exception]]:
443 An iterator of responses. If specified, self.response will be populated
444 on each invocation by calling ``next(self.responses)``."""
445 self.requests = []
446 """List[protobuf.Message]: All requests sent to this callable."""
447 self.calls = []
448 """List[Tuple]: All invocations of this callable. Each tuple is the
449 request, timeout, metadata, compression, and credentials."""
450
451 def __call__(
452 self, request, timeout=None, metadata=None, credentials=None, compression=None
453 ):
454 self._channel.requests.append(_ChannelRequest(self._method, request))
455 self.calls.append(
456 _MethodCall(request, timeout, metadata, credentials, compression)
457 )
458 self.requests.append(request)
459
460 response = self.response
461 if self.responses is not None:
462 if response is None:
463 response = next(self.responses)
464 else:
465 raise ValueError(
466 "{method}.response and {method}.responses are mutually "
467 "exclusive.".format(method=self._method)
468 )
469
470 if callable(response):
471 return response(request)
472
473 if isinstance(response, Exception):
474 raise response
475
476 if response is not None:
477 return response
478
479 raise ValueError('Method stub for "{}" has no response.'.format(self._method))
480
481
482def _simplify_method_name(method):
483 """Simplifies a gRPC method name.
484
485 When gRPC invokes the channel to create a callable, it gives a full
486 method name like "/google.pubsub.v1.Publisher/CreateTopic". This
487 returns just the name of the method, in this case "CreateTopic".
488
489 Args:
490 method (str): The name of the method.
491
492 Returns:
493 str: The simplified name of the method.
494 """
495 return method.rsplit("/", 1).pop()
496
497
498class ChannelStub(grpc.Channel):
499 """A testing stub for the grpc.Channel interface.
500
501 This can be used to test any client that eventually uses a gRPC channel
502 to communicate. By passing in a channel stub, you can configure which
503 responses are returned and track which requests are made.
504
505 For example:
506
507 .. code-block:: python
508
509 channel_stub = grpc_helpers.ChannelStub()
510 client = FooClient(channel=channel_stub)
511
512 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
513
514 foo = client.get_foo(labels=['baz'])
515
516 assert foo.name == 'bar'
517 assert channel_stub.GetFoo.requests[0].labels = ['baz']
518
519 Each method on the stub can be accessed and configured on the channel.
520 Here's some examples of various configurations:
521
522 .. code-block:: python
523
524 # Return a basic response:
525
526 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
527 assert client.get_foo().name == 'bar'
528
529 # Raise an exception:
530 channel_stub.GetFoo.response = NotFound('...')
531
532 with pytest.raises(NotFound):
533 client.get_foo()
534
535 # Use a sequence of responses:
536 channel_stub.GetFoo.responses = iter([
537 foo_pb2.Foo(name='bar'),
538 foo_pb2.Foo(name='baz'),
539 ])
540
541 assert client.get_foo().name == 'bar'
542 assert client.get_foo().name == 'baz'
543
544 # Use a callable
545
546 def on_get_foo(request):
547 return foo_pb2.Foo(name='bar' + request.id)
548
549 channel_stub.GetFoo.response = on_get_foo
550
551 assert client.get_foo(id='123').name == 'bar123'
552 """
553
554 def __init__(self, responses=[]):
555 self.requests = []
556 """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
557 on this channel in order. The tuple is of method name, request
558 message."""
559 self._method_stubs = {}
560
561 def _stub_for_method(self, method):
562 method = _simplify_method_name(method)
563 self._method_stubs[method] = _CallableStub(method, self)
564 return self._method_stubs[method]
565
566 def __getattr__(self, key):
567 try:
568 return self._method_stubs[key]
569 except KeyError:
570 raise AttributeError
571
572 def unary_unary(
573 self,
574 method,
575 request_serializer=None,
576 response_deserializer=None,
577 _registered_method=False,
578 ):
579 """grpc.Channel.unary_unary implementation."""
580 return self._stub_for_method(method)
581
582 def unary_stream(
583 self,
584 method,
585 request_serializer=None,
586 response_deserializer=None,
587 _registered_method=False,
588 ):
589 """grpc.Channel.unary_stream implementation."""
590 return self._stub_for_method(method)
591
592 def stream_unary(
593 self,
594 method,
595 request_serializer=None,
596 response_deserializer=None,
597 _registered_method=False,
598 ):
599 """grpc.Channel.stream_unary implementation."""
600 return self._stub_for_method(method)
601
602 def stream_stream(
603 self,
604 method,
605 request_serializer=None,
606 response_deserializer=None,
607 _registered_method=False,
608 ):
609 """grpc.Channel.stream_stream implementation."""
610 return self._stub_for_method(method)
611
612 def subscribe(self, callback, try_to_connect=False):
613 """grpc.Channel.subscribe implementation."""
614 pass
615
616 def unsubscribe(self, callback):
617 """grpc.Channel.unsubscribe implementation."""
618 pass
619
620 def close(self):
621 """grpc.Channel.close implementation."""
622 pass