1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`grpc`."""
16from typing import Generic, Iterator, Optional, TypeVar
17
18import collections
19import functools
20import warnings
21
22import grpc
23
24from google.api_core import exceptions
25import google.auth
26import google.auth.credentials
27import google.auth.transport.grpc
28import google.auth.transport.requests
29import google.protobuf
30
31PROTOBUF_VERSION = google.protobuf.__version__
32
33# The grpcio-gcp package only has support for protobuf < 4
34if PROTOBUF_VERSION[0:2] == "3.": # pragma: NO COVER
35 try:
36 import grpc_gcp
37
38 warnings.warn(
39 """Support for grpcio-gcp is deprecated. This feature will be
40 removed from `google-api-core` after January 1, 2024. If you need to
41 continue to use this feature, please pin to a specific version of
42 `google-api-core`.""",
43 DeprecationWarning,
44 )
45 HAS_GRPC_GCP = True
46 except ImportError:
47 HAS_GRPC_GCP = False
48else:
49 HAS_GRPC_GCP = False
50
51
52# The list of gRPC Callable interfaces that return iterators.
53_STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable)
54
55# denotes the proto response type for grpc calls
56P = TypeVar("P")
57
58
59def _patch_callable_name(callable_):
60 """Fix-up gRPC callable attributes.
61
62 gRPC callable lack the ``__name__`` attribute which causes
63 :func:`functools.wraps` to error. This adds the attribute if needed.
64 """
65 if not hasattr(callable_, "__name__"):
66 callable_.__name__ = callable_.__class__.__name__
67
68
69def _wrap_unary_errors(callable_):
70 """Map errors for Unary-Unary and Stream-Unary gRPC callables."""
71 _patch_callable_name(callable_)
72
73 @functools.wraps(callable_)
74 def error_remapped_callable(*args, **kwargs):
75 try:
76 return callable_(*args, **kwargs)
77 except grpc.RpcError as exc:
78 raise exceptions.from_grpc_error(exc) from exc
79
80 return error_remapped_callable
81
82
83class _StreamingResponseIterator(Generic[P], grpc.Call):
84 def __init__(self, wrapped, prefetch_first_result=True):
85 self._wrapped = wrapped
86
87 # This iterator is used in a retry context, and returned outside after init.
88 # gRPC will not throw an exception until the stream is consumed, so we need
89 # to retrieve the first result, in order to fail, in order to trigger a retry.
90 try:
91 if prefetch_first_result:
92 self._stored_first_result = next(self._wrapped)
93 except TypeError:
94 # It is possible the wrapped method isn't an iterable (a grpc.Call
95 # for instance). If this happens don't store the first result.
96 pass
97 except StopIteration:
98 # ignore stop iteration at this time. This should be handled outside of retry.
99 pass
100
101 def __iter__(self) -> Iterator[P]:
102 """This iterator is also an iterable that returns itself."""
103 return self
104
105 def __next__(self) -> P:
106 """Get the next response from the stream.
107
108 Returns:
109 protobuf.Message: A single response from the stream.
110 """
111 try:
112 if hasattr(self, "_stored_first_result"):
113 result = self._stored_first_result
114 del self._stored_first_result
115 return result
116 return next(self._wrapped)
117 except grpc.RpcError as exc:
118 # If the stream has already returned data, we cannot recover here.
119 raise exceptions.from_grpc_error(exc) from exc
120
121 # grpc.Call & grpc.RpcContext interface
122
123 def add_callback(self, callback):
124 return self._wrapped.add_callback(callback)
125
126 def cancel(self):
127 return self._wrapped.cancel()
128
129 def code(self):
130 return self._wrapped.code()
131
132 def details(self):
133 return self._wrapped.details()
134
135 def initial_metadata(self):
136 return self._wrapped.initial_metadata()
137
138 def is_active(self):
139 return self._wrapped.is_active()
140
141 def time_remaining(self):
142 return self._wrapped.time_remaining()
143
144 def trailing_metadata(self):
145 return self._wrapped.trailing_metadata()
146
147
148# public type alias denoting the return type of streaming gapic calls
149GrpcStream = _StreamingResponseIterator[P]
150
151
152def _wrap_stream_errors(callable_):
153 """Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
154
155 The callables that return iterators require a bit more logic to re-map
156 errors when iterating. This wraps both the initial invocation and the
157 iterator of the return value to re-map errors.
158 """
159 _patch_callable_name(callable_)
160
161 @functools.wraps(callable_)
162 def error_remapped_callable(*args, **kwargs):
163 try:
164 result = callable_(*args, **kwargs)
165 # Auto-fetching the first result causes PubSub client's streaming pull
166 # to hang when re-opening the stream, thus we need examine the hacky
167 # hidden flag to see if pre-fetching is disabled.
168 # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257
169 prefetch_first = getattr(callable_, "_prefetch_first_result_", True)
170 return _StreamingResponseIterator(
171 result, prefetch_first_result=prefetch_first
172 )
173 except grpc.RpcError as exc:
174 raise exceptions.from_grpc_error(exc) from exc
175
176 return error_remapped_callable
177
178
179def wrap_errors(callable_):
180 """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error
181 classes.
182
183 Errors raised by the gRPC callable are mapped to the appropriate
184 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses.
185 The original `grpc.RpcError` (which is usually also a `grpc.Call`) is
186 available from the ``response`` property on the mapped exception. This
187 is useful for extracting metadata from the original error.
188
189 Args:
190 callable_ (Callable): A gRPC callable.
191
192 Returns:
193 Callable: The wrapped gRPC callable.
194 """
195 if isinstance(callable_, _STREAM_WRAP_CLASSES):
196 return _wrap_stream_errors(callable_)
197 else:
198 return _wrap_unary_errors(callable_)
199
200
201def _create_composite_credentials(
202 credentials=None,
203 credentials_file=None,
204 default_scopes=None,
205 scopes=None,
206 ssl_credentials=None,
207 quota_project_id=None,
208 default_host=None,
209):
210 """Create the composite credentials for secure channels.
211
212 Args:
213 credentials (google.auth.credentials.Credentials): The credentials. If
214 not specified, then this function will attempt to ascertain the
215 credentials from the environment using :func:`google.auth.default`.
216 credentials_file (str): A file with credentials that can be loaded with
217 :func:`google.auth.load_credentials_from_file`. This argument is
218 mutually exclusive with credentials.
219
220 .. warning::
221 Important: If you accept a credential configuration (credential JSON/File/Stream)
222 from an external source for authentication to Google Cloud Platform, you must
223 validate it before providing it to any Google API or client library. Providing an
224 unvalidated credential configuration to Google APIs or libraries can compromise
225 the security of your systems and data. For more information, refer to
226 `Validate credential configurations from external sources`_.
227
228 .. _Validate credential configurations from external sources:
229
230 https://cloud.google.com/docs/authentication/external/externally-sourced-credentials
231 default_scopes (Sequence[str]): A optional list of scopes needed for this
232 service. These are only used when credentials are not specified and
233 are passed to :func:`google.auth.default`.
234 scopes (Sequence[str]): A optional list of scopes needed for this
235 service. These are only used when credentials are not specified and
236 are passed to :func:`google.auth.default`.
237 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
238 credentials. This can be used to specify different certificates.
239 quota_project_id (str): An optional project to use for billing and quota.
240 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
241
242 Returns:
243 grpc.ChannelCredentials: The composed channel credentials object.
244
245 Raises:
246 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
247 """
248 if credentials and credentials_file:
249 raise exceptions.DuplicateCredentialArgs(
250 "'credentials' and 'credentials_file' are mutually exclusive."
251 )
252
253 if credentials_file:
254 credentials, _ = google.auth.load_credentials_from_file(
255 credentials_file, scopes=scopes, default_scopes=default_scopes
256 )
257 elif credentials:
258 credentials = google.auth.credentials.with_scopes_if_required(
259 credentials, scopes=scopes, default_scopes=default_scopes
260 )
261 else:
262 credentials, _ = google.auth.default(
263 scopes=scopes, default_scopes=default_scopes
264 )
265
266 if quota_project_id and isinstance(
267 credentials, google.auth.credentials.CredentialsWithQuotaProject
268 ):
269 credentials = credentials.with_quota_project(quota_project_id)
270
271 request = google.auth.transport.requests.Request()
272
273 # Create the metadata plugin for inserting the authorization header.
274 metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin(
275 credentials,
276 request,
277 default_host=default_host,
278 )
279
280 # Create a set of grpc.CallCredentials using the metadata plugin.
281 google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
282
283 # if `ssl_credentials` is set, use `grpc.composite_channel_credentials` instead of
284 # `grpc.compute_engine_channel_credentials` as the former supports passing
285 # `ssl_credentials` via `channel_credentials` which is needed for mTLS.
286 if ssl_credentials:
287 # Combine the ssl credentials and the authorization credentials.
288 # See https://grpc.github.io/grpc/python/grpc.html#grpc.composite_channel_credentials
289 return grpc.composite_channel_credentials(
290 ssl_credentials, google_auth_credentials
291 )
292 else:
293 # Use grpc.compute_engine_channel_credentials in order to support Direct Path.
294 # See https://grpc.github.io/grpc/python/grpc.html#grpc.compute_engine_channel_credentials
295 # TODO(https://github.com/googleapis/python-api-core/issues/598):
296 # Although `grpc.compute_engine_channel_credentials` returns channel credentials
297 # outside of a Google Compute Engine environment (GCE), we should determine if
298 # there is a way to reliably detect a GCE environment so that
299 # `grpc.compute_engine_channel_credentials` is not called outside of GCE.
300 return grpc.compute_engine_channel_credentials(google_auth_credentials)
301
302
303def create_channel(
304 target,
305 credentials=None,
306 scopes=None,
307 ssl_credentials=None,
308 credentials_file=None,
309 quota_project_id=None,
310 default_scopes=None,
311 default_host=None,
312 compression=None,
313 attempt_direct_path: Optional[bool] = False,
314 **kwargs,
315):
316 """Create a secure channel with credentials.
317
318 Args:
319 target (str): The target service address in the format 'hostname:port'.
320 credentials (google.auth.credentials.Credentials): The credentials. If
321 not specified, then this function will attempt to ascertain the
322 credentials from the environment using :func:`google.auth.default`.
323 scopes (Sequence[str]): A optional list of scopes needed for this
324 service. These are only used when credentials are not specified and
325 are passed to :func:`google.auth.default`.
326 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
327 credentials. This can be used to specify different certificates.
328 credentials_file (str): A file with credentials that can be loaded with
329 :func:`google.auth.load_credentials_from_file`. This argument is
330 mutually exclusive with credentials.
331
332 .. warning::
333 Important: If you accept a credential configuration (credential JSON/File/Stream)
334 from an external source for authentication to Google Cloud Platform, you must
335 validate it before providing it to any Google API or client library. Providing an
336 unvalidated credential configuration to Google APIs or libraries can compromise
337 the security of your systems and data. For more information, refer to
338 `Validate credential configurations from external sources`_.
339
340 .. _Validate credential configurations from external sources:
341
342 https://cloud.google.com/docs/authentication/external/externally-sourced-credentials
343 quota_project_id (str): An optional project to use for billing and quota.
344 default_scopes (Sequence[str]): Default scopes passed by a Google client
345 library. Use 'scopes' for user-defined scopes.
346 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
347 compression (grpc.Compression): An optional value indicating the
348 compression method to be used over the lifetime of the channel.
349 attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted
350 when the request is made. Direct Path is only available within a Google
351 Compute Engine (GCE) environment and provides a proxyless connection
352 which increases the available throughput, reduces latency, and increases
353 reliability. Note:
354
355 - This argument should only be set in a GCE environment and for Services
356 that are known to support Direct Path.
357 - If this argument is set outside of GCE, then this request will fail
358 unless the back-end service happens to have configured fall-back to DNS.
359 - If the request causes a `ServiceUnavailable` response, it is recommended
360 that the client repeat the request with `attempt_direct_path` set to
361 `False` as the Service may not support Direct Path.
362 - Using `ssl_credentials` with `attempt_direct_path` set to `True` will
363 result in `ValueError` as this combination is not yet supported.
364
365 kwargs: Additional key-word args passed to
366 :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`.
367 Note: `grpc_gcp` is only supported in environments with protobuf < 4.0.0.
368
369 Returns:
370 grpc.Channel: The created channel.
371
372 Raises:
373 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
374 ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`.
375 """
376
377 # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`,
378 # raise ValueError as this is not yet supported.
379 # See https://github.com/googleapis/python-api-core/issues/590
380 if ssl_credentials and attempt_direct_path:
381 raise ValueError("Using ssl_credentials with Direct Path is not supported")
382
383 composite_credentials = _create_composite_credentials(
384 credentials=credentials,
385 credentials_file=credentials_file,
386 default_scopes=default_scopes,
387 scopes=scopes,
388 ssl_credentials=ssl_credentials,
389 quota_project_id=quota_project_id,
390 default_host=default_host,
391 )
392
393 # Note that grpcio-gcp is deprecated
394 if HAS_GRPC_GCP: # pragma: NO COVER
395 if compression is not None and compression != grpc.Compression.NoCompression:
396 warnings.warn(
397 "The `compression` argument is ignored for grpc_gcp.secure_channel creation.",
398 DeprecationWarning,
399 )
400 if attempt_direct_path:
401 warnings.warn(
402 """The `attempt_direct_path` argument is ignored for grpc_gcp.secure_channel creation.""",
403 DeprecationWarning,
404 )
405 return grpc_gcp.secure_channel(target, composite_credentials, **kwargs)
406
407 if attempt_direct_path:
408 target = _modify_target_for_direct_path(target)
409
410 return grpc.secure_channel(
411 target, composite_credentials, compression=compression, **kwargs
412 )
413
414
415def _modify_target_for_direct_path(target: str) -> str:
416 """
417 Given a target, return a modified version which is compatible with Direct Path.
418
419 Args:
420 target (str): The target service address in the format 'hostname[:port]' or
421 'dns://hostname[:port]'.
422
423 Returns:
424 target (str): The target service address which is converted into a format compatible with Direct Path.
425 If the target contains `dns:///` or does not contain `:///`, the target will be converted in
426 a format compatible with Direct Path; otherwise the original target will be returned as the
427 original target may already denote Direct Path.
428 """
429
430 # A DNS prefix may be included with the target to indicate the endpoint is living in the Internet,
431 # outside of Google Cloud Platform.
432 dns_prefix = "dns:///"
433 # Remove "dns:///" if `attempt_direct_path` is set to True as
434 # the Direct Path prefix `google-c2p:///` will be used instead.
435 target = target.replace(dns_prefix, "")
436
437 direct_path_separator = ":///"
438 if direct_path_separator not in target:
439 target_without_port = target.split(":")[0]
440 # Modify the target to use Direct Path by adding the `google-c2p:///` prefix
441 target = f"google-c2p{direct_path_separator}{target_without_port}"
442 return target
443
444
445_MethodCall = collections.namedtuple(
446 "_MethodCall", ("request", "timeout", "metadata", "credentials", "compression")
447)
448
449_ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request"))
450
451
452class _CallableStub(object):
453 """Stub for the grpc.*MultiCallable interfaces."""
454
455 def __init__(self, method, channel):
456 self._method = method
457 self._channel = channel
458 self.response = None
459 """Union[protobuf.Message, Callable[protobuf.Message], exception]:
460 The response to give when invoking this callable. If this is a
461 callable, it will be invoked with the request protobuf. If it's an
462 exception, the exception will be raised when this is invoked.
463 """
464 self.responses = None
465 """Iterator[
466 Union[protobuf.Message, Callable[protobuf.Message], exception]]:
467 An iterator of responses. If specified, self.response will be populated
468 on each invocation by calling ``next(self.responses)``."""
469 self.requests = []
470 """List[protobuf.Message]: All requests sent to this callable."""
471 self.calls = []
472 """List[Tuple]: All invocations of this callable. Each tuple is the
473 request, timeout, metadata, compression, and credentials."""
474
475 def __call__(
476 self, request, timeout=None, metadata=None, credentials=None, compression=None
477 ):
478 self._channel.requests.append(_ChannelRequest(self._method, request))
479 self.calls.append(
480 _MethodCall(request, timeout, metadata, credentials, compression)
481 )
482 self.requests.append(request)
483
484 response = self.response
485 if self.responses is not None:
486 if response is None:
487 response = next(self.responses)
488 else:
489 raise ValueError(
490 "{method}.response and {method}.responses are mutually "
491 "exclusive.".format(method=self._method)
492 )
493
494 if callable(response):
495 return response(request)
496
497 if isinstance(response, Exception):
498 raise response
499
500 if response is not None:
501 return response
502
503 raise ValueError('Method stub for "{}" has no response.'.format(self._method))
504
505
506def _simplify_method_name(method):
507 """Simplifies a gRPC method name.
508
509 When gRPC invokes the channel to create a callable, it gives a full
510 method name like "/google.pubsub.v1.Publisher/CreateTopic". This
511 returns just the name of the method, in this case "CreateTopic".
512
513 Args:
514 method (str): The name of the method.
515
516 Returns:
517 str: The simplified name of the method.
518 """
519 return method.rsplit("/", 1).pop()
520
521
522class ChannelStub(grpc.Channel):
523 """A testing stub for the grpc.Channel interface.
524
525 This can be used to test any client that eventually uses a gRPC channel
526 to communicate. By passing in a channel stub, you can configure which
527 responses are returned and track which requests are made.
528
529 For example:
530
531 .. code-block:: python
532
533 channel_stub = grpc_helpers.ChannelStub()
534 client = FooClient(channel=channel_stub)
535
536 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
537
538 foo = client.get_foo(labels=['baz'])
539
540 assert foo.name == 'bar'
541 assert channel_stub.GetFoo.requests[0].labels = ['baz']
542
543 Each method on the stub can be accessed and configured on the channel.
544 Here's some examples of various configurations:
545
546 .. code-block:: python
547
548 # Return a basic response:
549
550 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
551 assert client.get_foo().name == 'bar'
552
553 # Raise an exception:
554 channel_stub.GetFoo.response = NotFound('...')
555
556 with pytest.raises(NotFound):
557 client.get_foo()
558
559 # Use a sequence of responses:
560 channel_stub.GetFoo.responses = iter([
561 foo_pb2.Foo(name='bar'),
562 foo_pb2.Foo(name='baz'),
563 ])
564
565 assert client.get_foo().name == 'bar'
566 assert client.get_foo().name == 'baz'
567
568 # Use a callable
569
570 def on_get_foo(request):
571 return foo_pb2.Foo(name='bar' + request.id)
572
573 channel_stub.GetFoo.response = on_get_foo
574
575 assert client.get_foo(id='123').name == 'bar123'
576 """
577
578 def __init__(self, responses=[]):
579 self.requests = []
580 """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
581 on this channel in order. The tuple is of method name, request
582 message."""
583 self._method_stubs = {}
584
585 def _stub_for_method(self, method):
586 method = _simplify_method_name(method)
587 self._method_stubs[method] = _CallableStub(method, self)
588 return self._method_stubs[method]
589
590 def __getattr__(self, key):
591 try:
592 return self._method_stubs[key]
593 except KeyError:
594 raise AttributeError
595
596 def unary_unary(
597 self,
598 method,
599 request_serializer=None,
600 response_deserializer=None,
601 _registered_method=False,
602 ):
603 """grpc.Channel.unary_unary implementation."""
604 return self._stub_for_method(method)
605
606 def unary_stream(
607 self,
608 method,
609 request_serializer=None,
610 response_deserializer=None,
611 _registered_method=False,
612 ):
613 """grpc.Channel.unary_stream implementation."""
614 return self._stub_for_method(method)
615
616 def stream_unary(
617 self,
618 method,
619 request_serializer=None,
620 response_deserializer=None,
621 _registered_method=False,
622 ):
623 """grpc.Channel.stream_unary implementation."""
624 return self._stub_for_method(method)
625
626 def stream_stream(
627 self,
628 method,
629 request_serializer=None,
630 response_deserializer=None,
631 _registered_method=False,
632 ):
633 """grpc.Channel.stream_stream implementation."""
634 return self._stub_for_method(method)
635
636 def subscribe(self, callback, try_to_connect=False):
637 """grpc.Channel.subscribe implementation."""
638 pass
639
640 def unsubscribe(self, callback):
641 """grpc.Channel.unsubscribe implementation."""
642 pass
643
644 def close(self):
645 """grpc.Channel.close implementation."""
646 pass