1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`grpc`."""
16import collections
17import functools
18from typing import Generic, Iterator, Optional, TypeVar
19import warnings
20
21import google.auth
22import google.auth.credentials
23import google.auth.transport.grpc
24import google.auth.transport.requests
25import google.protobuf
26import grpc
27
28from google.api_core import exceptions, general_helpers
29
30PROTOBUF_VERSION = google.protobuf.__version__
31
32# The grpcio-gcp package only has support for protobuf < 4
33if PROTOBUF_VERSION[0:2] == "3.": # pragma: NO COVER
34 try:
35 import grpc_gcp
36
37 warnings.warn(
38 """Support for grpcio-gcp is deprecated. This feature will be
39 removed from `google-api-core` after January 1, 2024. If you need to
40 continue to use this feature, please pin to a specific version of
41 `google-api-core`.""",
42 DeprecationWarning,
43 )
44 HAS_GRPC_GCP = True
45 except ImportError:
46 HAS_GRPC_GCP = False
47else:
48 HAS_GRPC_GCP = False
49
50
51# The list of gRPC Callable interfaces that return iterators.
52_STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable)
53
54# denotes the proto response type for grpc calls
55P = TypeVar("P")
56
57
58def _patch_callable_name(callable_):
59 """Fix-up gRPC callable attributes.
60
61 gRPC callable lack the ``__name__`` attribute which causes
62 :func:`functools.wraps` to error. This adds the attribute if needed.
63 """
64 if not hasattr(callable_, "__name__"):
65 callable_.__name__ = callable_.__class__.__name__
66
67
68def _wrap_unary_errors(callable_):
69 """Map errors for Unary-Unary and Stream-Unary gRPC callables."""
70 _patch_callable_name(callable_)
71
72 @functools.wraps(callable_)
73 def error_remapped_callable(*args, **kwargs):
74 try:
75 return callable_(*args, **kwargs)
76 except grpc.RpcError as exc:
77 raise exceptions.from_grpc_error(exc) from exc
78
79 return error_remapped_callable
80
81
82class _StreamingResponseIterator(Generic[P], grpc.Call):
83 def __init__(self, wrapped, prefetch_first_result=True):
84 self._wrapped = wrapped
85
86 # This iterator is used in a retry context, and returned outside after init.
87 # gRPC will not throw an exception until the stream is consumed, so we need
88 # to retrieve the first result, in order to fail, in order to trigger a retry.
89 try:
90 if prefetch_first_result:
91 self._stored_first_result = next(self._wrapped)
92 except TypeError:
93 # It is possible the wrapped method isn't an iterable (a grpc.Call
94 # for instance). If this happens don't store the first result.
95 pass
96 except StopIteration:
97 # ignore stop iteration at this time. This should be handled outside of retry.
98 pass
99
100 def __iter__(self) -> Iterator[P]:
101 """This iterator is also an iterable that returns itself."""
102 return self
103
104 def __next__(self) -> P:
105 """Get the next response from the stream.
106
107 Returns:
108 protobuf.Message: A single response from the stream.
109 """
110 try:
111 if hasattr(self, "_stored_first_result"):
112 result = self._stored_first_result
113 del self._stored_first_result
114 return result
115 return next(self._wrapped)
116 except grpc.RpcError as exc:
117 # If the stream has already returned data, we cannot recover here.
118 raise exceptions.from_grpc_error(exc) from exc
119
120 # grpc.Call & grpc.RpcContext interface
121
122 def add_callback(self, callback):
123 return self._wrapped.add_callback(callback)
124
125 def cancel(self):
126 return self._wrapped.cancel()
127
128 def code(self):
129 return self._wrapped.code()
130
131 def details(self):
132 return self._wrapped.details()
133
134 def initial_metadata(self):
135 return self._wrapped.initial_metadata()
136
137 def is_active(self):
138 return self._wrapped.is_active()
139
140 def time_remaining(self):
141 return self._wrapped.time_remaining()
142
143 def trailing_metadata(self):
144 return self._wrapped.trailing_metadata()
145
146
147# public type alias denoting the return type of streaming gapic calls
148GrpcStream = _StreamingResponseIterator[P]
149
150
151def _wrap_stream_errors(callable_):
152 """Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
153
154 The callables that return iterators require a bit more logic to re-map
155 errors when iterating. This wraps both the initial invocation and the
156 iterator of the return value to re-map errors.
157 """
158 _patch_callable_name(callable_)
159
160 @functools.wraps(callable_)
161 def error_remapped_callable(*args, **kwargs):
162 try:
163 result = callable_(*args, **kwargs)
164 # Auto-fetching the first result causes PubSub client's streaming pull
165 # to hang when re-opening the stream, thus we need examine the hacky
166 # hidden flag to see if pre-fetching is disabled.
167 # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257
168 prefetch_first = getattr(callable_, "_prefetch_first_result_", True)
169 return _StreamingResponseIterator(
170 result, prefetch_first_result=prefetch_first
171 )
172 except grpc.RpcError as exc:
173 raise exceptions.from_grpc_error(exc) from exc
174
175 return error_remapped_callable
176
177
178def wrap_errors(callable_):
179 """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error
180 classes.
181
182 Errors raised by the gRPC callable are mapped to the appropriate
183 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses.
184 The original `grpc.RpcError` (which is usually also a `grpc.Call`) is
185 available from the ``response`` property on the mapped exception. This
186 is useful for extracting metadata from the original error.
187
188 Args:
189 callable_ (Callable): A gRPC callable.
190
191 Returns:
192 Callable: The wrapped gRPC callable.
193 """
194 if isinstance(callable_, _STREAM_WRAP_CLASSES):
195 return _wrap_stream_errors(callable_)
196 else:
197 return _wrap_unary_errors(callable_)
198
199
200def _create_composite_credentials(
201 credentials=None,
202 credentials_file=None,
203 default_scopes=None,
204 scopes=None,
205 ssl_credentials=None,
206 quota_project_id=None,
207 default_host=None,
208):
209 """Create the composite credentials for secure channels.
210
211 Args:
212 credentials (google.auth.credentials.Credentials): The credentials. If
213 not specified, then this function will attempt to ascertain the
214 credentials from the environment using :func:`google.auth.default`.
215 credentials_file (str): Deprecated. A file with credentials that can be loaded with
216 :func:`google.auth.load_credentials_from_file`. This argument is
217 mutually exclusive with credentials. This argument will be
218 removed in the next major version of `google-api-core`.
219
220 .. warning::
221 Important: If you accept a credential configuration (credential JSON/File/Stream)
222 from an external source for authentication to Google Cloud Platform, you must
223 validate it before providing it to any Google API or client library. Providing an
224 unvalidated credential configuration to Google APIs or libraries can compromise
225 the security of your systems and data. For more information, refer to
226 `Validate credential configurations from external sources`_.
227
228 .. _Validate credential configurations from external sources:
229
230 https://cloud.google.com/docs/authentication/external/externally-sourced-credentials
231 default_scopes (Sequence[str]): A optional list of scopes needed for this
232 service. These are only used when credentials are not specified and
233 are passed to :func:`google.auth.default`.
234 scopes (Sequence[str]): A optional list of scopes needed for this
235 service. These are only used when credentials are not specified and
236 are passed to :func:`google.auth.default`.
237 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
238 credentials. This can be used to specify different certificates.
239 quota_project_id (str): An optional project to use for billing and quota.
240 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
241
242 Returns:
243 grpc.ChannelCredentials: The composed channel credentials object.
244
245 Raises:
246 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
247 """
248 if credentials_file is not None:
249 warnings.warn(general_helpers._CREDENTIALS_FILE_WARNING, DeprecationWarning)
250
251 if credentials and credentials_file:
252 raise exceptions.DuplicateCredentialArgs(
253 "'credentials' and 'credentials_file' are mutually exclusive."
254 )
255
256 if credentials_file:
257 credentials, _ = google.auth.load_credentials_from_file(
258 credentials_file, scopes=scopes, default_scopes=default_scopes
259 )
260 elif credentials:
261 credentials = google.auth.credentials.with_scopes_if_required(
262 credentials, scopes=scopes, default_scopes=default_scopes
263 )
264 else:
265 credentials, _ = google.auth.default(
266 scopes=scopes, default_scopes=default_scopes
267 )
268
269 if quota_project_id and isinstance(
270 credentials, google.auth.credentials.CredentialsWithQuotaProject
271 ):
272 credentials = credentials.with_quota_project(quota_project_id)
273
274 request = google.auth.transport.requests.Request()
275
276 # Create the metadata plugin for inserting the authorization header.
277 metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin(
278 credentials,
279 request,
280 default_host=default_host,
281 )
282
283 # Create a set of grpc.CallCredentials using the metadata plugin.
284 google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
285
286 # if `ssl_credentials` is set, use `grpc.composite_channel_credentials` instead of
287 # `grpc.compute_engine_channel_credentials` as the former supports passing
288 # `ssl_credentials` via `channel_credentials` which is needed for mTLS.
289 if ssl_credentials:
290 # Combine the ssl credentials and the authorization credentials.
291 # See https://grpc.github.io/grpc/python/grpc.html#grpc.composite_channel_credentials
292 return grpc.composite_channel_credentials(
293 ssl_credentials, google_auth_credentials
294 )
295 else:
296 # Use grpc.compute_engine_channel_credentials in order to support Direct Path.
297 # See https://grpc.github.io/grpc/python/grpc.html#grpc.compute_engine_channel_credentials
298 # TODO(https://github.com/googleapis/python-api-core/issues/598):
299 # Although `grpc.compute_engine_channel_credentials` returns channel credentials
300 # outside of a Google Compute Engine environment (GCE), we should determine if
301 # there is a way to reliably detect a GCE environment so that
302 # `grpc.compute_engine_channel_credentials` is not called outside of GCE.
303 return grpc.compute_engine_channel_credentials(google_auth_credentials)
304
305
306def create_channel(
307 target,
308 credentials=None,
309 scopes=None,
310 ssl_credentials=None,
311 credentials_file=None,
312 quota_project_id=None,
313 default_scopes=None,
314 default_host=None,
315 compression=None,
316 attempt_direct_path: Optional[bool] = False,
317 **kwargs,
318):
319 """Create a secure channel with credentials.
320
321 Args:
322 target (str): The target service address in the format 'hostname:port'.
323 credentials (google.auth.credentials.Credentials): The credentials. If
324 not specified, then this function will attempt to ascertain the
325 credentials from the environment using :func:`google.auth.default`.
326 scopes (Sequence[str]): A optional list of scopes needed for this
327 service. These are only used when credentials are not specified and
328 are passed to :func:`google.auth.default`.
329 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
330 credentials. This can be used to specify different certificates.
331 credentials_file (str): A file with credentials that can be loaded with
332 :func:`google.auth.load_credentials_from_file`. This argument is
333 mutually exclusive with credentials.
334
335 .. warning::
336 Important: If you accept a credential configuration (credential JSON/File/Stream)
337 from an external source for authentication to Google Cloud Platform, you must
338 validate it before providing it to any Google API or client library. Providing an
339 unvalidated credential configuration to Google APIs or libraries can compromise
340 the security of your systems and data. For more information, refer to
341 `Validate credential configurations from external sources`_.
342
343 .. _Validate credential configurations from external sources:
344
345 https://cloud.google.com/docs/authentication/external/externally-sourced-credentials
346 quota_project_id (str): An optional project to use for billing and quota.
347 default_scopes (Sequence[str]): Default scopes passed by a Google client
348 library. Use 'scopes' for user-defined scopes.
349 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
350 compression (grpc.Compression): An optional value indicating the
351 compression method to be used over the lifetime of the channel.
352 attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted
353 when the request is made. Direct Path is only available within a Google
354 Compute Engine (GCE) environment and provides a proxyless connection
355 which increases the available throughput, reduces latency, and increases
356 reliability. Note:
357
358 - This argument should only be set in a GCE environment and for Services
359 that are known to support Direct Path.
360 - If this argument is set outside of GCE, then this request will fail
361 unless the back-end service happens to have configured fall-back to DNS.
362 - If the request causes a `ServiceUnavailable` response, it is recommended
363 that the client repeat the request with `attempt_direct_path` set to
364 `False` as the Service may not support Direct Path.
365 - Using `ssl_credentials` with `attempt_direct_path` set to `True` will
366 result in `ValueError` as this combination is not yet supported.
367
368 kwargs: Additional key-word args passed to
369 :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`.
370 Note: `grpc_gcp` is only supported in environments with protobuf < 4.0.0.
371
372 Returns:
373 grpc.Channel: The created channel.
374
375 Raises:
376 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
377 ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`.
378 """
379
380 # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`,
381 # raise ValueError as this is not yet supported.
382 # See https://github.com/googleapis/python-api-core/issues/590
383 if ssl_credentials and attempt_direct_path:
384 raise ValueError("Using ssl_credentials with Direct Path is not supported")
385
386 composite_credentials = _create_composite_credentials(
387 credentials=credentials,
388 credentials_file=credentials_file,
389 default_scopes=default_scopes,
390 scopes=scopes,
391 ssl_credentials=ssl_credentials,
392 quota_project_id=quota_project_id,
393 default_host=default_host,
394 )
395
396 # Note that grpcio-gcp is deprecated
397 if HAS_GRPC_GCP: # pragma: NO COVER
398 if compression is not None and compression != grpc.Compression.NoCompression:
399 warnings.warn(
400 "The `compression` argument is ignored for grpc_gcp.secure_channel creation.",
401 DeprecationWarning,
402 )
403 if attempt_direct_path:
404 warnings.warn(
405 """The `attempt_direct_path` argument is ignored for grpc_gcp.secure_channel creation.""",
406 DeprecationWarning,
407 )
408 return grpc_gcp.secure_channel(target, composite_credentials, **kwargs)
409
410 if attempt_direct_path:
411 target = _modify_target_for_direct_path(target)
412
413 return grpc.secure_channel(
414 target, composite_credentials, compression=compression, **kwargs
415 )
416
417
418def _modify_target_for_direct_path(target: str) -> str:
419 """
420 Given a target, return a modified version which is compatible with Direct Path.
421
422 Args:
423 target (str): The target service address in the format 'hostname[:port]' or
424 'dns://hostname[:port]'.
425
426 Returns:
427 target (str): The target service address which is converted into a format compatible with Direct Path.
428 If the target contains `dns:///` or does not contain `:///`, the target will be converted in
429 a format compatible with Direct Path; otherwise the original target will be returned as the
430 original target may already denote Direct Path.
431 """
432
433 # A DNS prefix may be included with the target to indicate the endpoint is living in the Internet,
434 # outside of Google Cloud Platform.
435 dns_prefix = "dns:///"
436 # Remove "dns:///" if `attempt_direct_path` is set to True as
437 # the Direct Path prefix `google-c2p:///` will be used instead.
438 target = target.replace(dns_prefix, "")
439
440 direct_path_separator = ":///"
441 if direct_path_separator not in target:
442 target_without_port = target.split(":")[0]
443 # Modify the target to use Direct Path by adding the `google-c2p:///` prefix
444 target = f"google-c2p{direct_path_separator}{target_without_port}"
445 return target
446
447
448_MethodCall = collections.namedtuple(
449 "_MethodCall", ("request", "timeout", "metadata", "credentials", "compression")
450)
451
452_ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request"))
453
454
455class _CallableStub(object):
456 """Stub for the grpc.*MultiCallable interfaces."""
457
458 def __init__(self, method, channel):
459 self._method = method
460 self._channel = channel
461 self.response = None
462 """Union[protobuf.Message, Callable[protobuf.Message], exception]:
463 The response to give when invoking this callable. If this is a
464 callable, it will be invoked with the request protobuf. If it's an
465 exception, the exception will be raised when this is invoked.
466 """
467 self.responses = None
468 """Iterator[
469 Union[protobuf.Message, Callable[protobuf.Message], exception]]:
470 An iterator of responses. If specified, self.response will be populated
471 on each invocation by calling ``next(self.responses)``."""
472 self.requests = []
473 """List[protobuf.Message]: All requests sent to this callable."""
474 self.calls = []
475 """List[Tuple]: All invocations of this callable. Each tuple is the
476 request, timeout, metadata, compression, and credentials."""
477
478 def __call__(
479 self, request, timeout=None, metadata=None, credentials=None, compression=None
480 ):
481 self._channel.requests.append(_ChannelRequest(self._method, request))
482 self.calls.append(
483 _MethodCall(request, timeout, metadata, credentials, compression)
484 )
485 self.requests.append(request)
486
487 response = self.response
488 if self.responses is not None:
489 if response is None:
490 response = next(self.responses)
491 else:
492 raise ValueError(
493 "{method}.response and {method}.responses are mutually "
494 "exclusive.".format(method=self._method)
495 )
496
497 if callable(response):
498 return response(request)
499
500 if isinstance(response, Exception):
501 raise response
502
503 if response is not None:
504 return response
505
506 raise ValueError('Method stub for "{}" has no response.'.format(self._method))
507
508
509def _simplify_method_name(method):
510 """Simplifies a gRPC method name.
511
512 When gRPC invokes the channel to create a callable, it gives a full
513 method name like "/google.pubsub.v1.Publisher/CreateTopic". This
514 returns just the name of the method, in this case "CreateTopic".
515
516 Args:
517 method (str): The name of the method.
518
519 Returns:
520 str: The simplified name of the method.
521 """
522 return method.rsplit("/", 1).pop()
523
524
525class ChannelStub(grpc.Channel):
526 """A testing stub for the grpc.Channel interface.
527
528 This can be used to test any client that eventually uses a gRPC channel
529 to communicate. By passing in a channel stub, you can configure which
530 responses are returned and track which requests are made.
531
532 For example:
533
534 .. code-block:: python
535
536 channel_stub = grpc_helpers.ChannelStub()
537 client = FooClient(channel=channel_stub)
538
539 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
540
541 foo = client.get_foo(labels=['baz'])
542
543 assert foo.name == 'bar'
544 assert channel_stub.GetFoo.requests[0].labels = ['baz']
545
546 Each method on the stub can be accessed and configured on the channel.
547 Here's some examples of various configurations:
548
549 .. code-block:: python
550
551 # Return a basic response:
552
553 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
554 assert client.get_foo().name == 'bar'
555
556 # Raise an exception:
557 channel_stub.GetFoo.response = NotFound('...')
558
559 with pytest.raises(NotFound):
560 client.get_foo()
561
562 # Use a sequence of responses:
563 channel_stub.GetFoo.responses = iter([
564 foo_pb2.Foo(name='bar'),
565 foo_pb2.Foo(name='baz'),
566 ])
567
568 assert client.get_foo().name == 'bar'
569 assert client.get_foo().name == 'baz'
570
571 # Use a callable
572
573 def on_get_foo(request):
574 return foo_pb2.Foo(name='bar' + request.id)
575
576 channel_stub.GetFoo.response = on_get_foo
577
578 assert client.get_foo(id='123').name == 'bar123'
579 """
580
581 def __init__(self, responses=[]):
582 self.requests = []
583 """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
584 on this channel in order. The tuple is of method name, request
585 message."""
586 self._method_stubs = {}
587
588 def _stub_for_method(self, method):
589 method = _simplify_method_name(method)
590 self._method_stubs[method] = _CallableStub(method, self)
591 return self._method_stubs[method]
592
593 def __getattr__(self, key):
594 try:
595 return self._method_stubs[key]
596 except KeyError:
597 raise AttributeError
598
599 def unary_unary(
600 self,
601 method,
602 request_serializer=None,
603 response_deserializer=None,
604 _registered_method=False,
605 ):
606 """grpc.Channel.unary_unary implementation."""
607 return self._stub_for_method(method)
608
609 def unary_stream(
610 self,
611 method,
612 request_serializer=None,
613 response_deserializer=None,
614 _registered_method=False,
615 ):
616 """grpc.Channel.unary_stream implementation."""
617 return self._stub_for_method(method)
618
619 def stream_unary(
620 self,
621 method,
622 request_serializer=None,
623 response_deserializer=None,
624 _registered_method=False,
625 ):
626 """grpc.Channel.stream_unary implementation."""
627 return self._stub_for_method(method)
628
629 def stream_stream(
630 self,
631 method,
632 request_serializer=None,
633 response_deserializer=None,
634 _registered_method=False,
635 ):
636 """grpc.Channel.stream_stream implementation."""
637 return self._stub_for_method(method)
638
639 def subscribe(self, callback, try_to_connect=False):
640 """grpc.Channel.subscribe implementation."""
641 pass
642
643 def unsubscribe(self, callback):
644 """grpc.Channel.unsubscribe implementation."""
645 pass
646
647 def close(self):
648 """grpc.Channel.close implementation."""
649 pass