1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`grpc`."""
16import collections
17import functools
18from typing import Generic, Iterator, Optional, TypeVar
19import warnings
20
21import google.auth
22import google.auth.credentials
23import google.auth.transport.grpc
24import google.auth.transport.requests
25import google.protobuf
26import grpc
27
28from google.api_core import exceptions, general_helpers
29
30
31# The list of gRPC Callable interfaces that return iterators.
32_STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable)
33
34# denotes the proto response type for grpc calls
35P = TypeVar("P")
36
37
38def _patch_callable_name(callable_):
39 """Fix-up gRPC callable attributes.
40
41 gRPC callable lack the ``__name__`` attribute which causes
42 :func:`functools.wraps` to error. This adds the attribute if needed.
43 """
44 if not hasattr(callable_, "__name__"):
45 callable_.__name__ = callable_.__class__.__name__
46
47
48def _wrap_unary_errors(callable_):
49 """Map errors for Unary-Unary and Stream-Unary gRPC callables."""
50 _patch_callable_name(callable_)
51
52 @functools.wraps(callable_)
53 def error_remapped_callable(*args, **kwargs):
54 try:
55 return callable_(*args, **kwargs)
56 except grpc.RpcError as exc:
57 raise exceptions.from_grpc_error(exc) from exc
58
59 return error_remapped_callable
60
61
62class _StreamingResponseIterator(Generic[P], grpc.Call):
63 def __init__(self, wrapped, prefetch_first_result=True):
64 self._wrapped = wrapped
65
66 # This iterator is used in a retry context, and returned outside after init.
67 # gRPC will not throw an exception until the stream is consumed, so we need
68 # to retrieve the first result, in order to fail, in order to trigger a retry.
69 try:
70 if prefetch_first_result:
71 self._stored_first_result = next(self._wrapped)
72 except TypeError:
73 # It is possible the wrapped method isn't an iterable (a grpc.Call
74 # for instance). If this happens don't store the first result.
75 pass
76 except StopIteration:
77 # ignore stop iteration at this time. This should be handled outside of retry.
78 pass
79
80 def __iter__(self) -> Iterator[P]:
81 """This iterator is also an iterable that returns itself."""
82 return self
83
84 def __next__(self) -> P:
85 """Get the next response from the stream.
86
87 Returns:
88 protobuf.Message: A single response from the stream.
89 """
90 try:
91 if hasattr(self, "_stored_first_result"):
92 result = self._stored_first_result
93 del self._stored_first_result
94 return result
95 return next(self._wrapped)
96 except grpc.RpcError as exc:
97 # If the stream has already returned data, we cannot recover here.
98 raise exceptions.from_grpc_error(exc) from exc
99
100 # grpc.Call & grpc.RpcContext interface
101
102 def add_callback(self, callback):
103 return self._wrapped.add_callback(callback)
104
105 def cancel(self):
106 return self._wrapped.cancel()
107
108 def code(self):
109 return self._wrapped.code()
110
111 def details(self):
112 return self._wrapped.details()
113
114 def initial_metadata(self):
115 return self._wrapped.initial_metadata()
116
117 def is_active(self):
118 return self._wrapped.is_active()
119
120 def time_remaining(self):
121 return self._wrapped.time_remaining()
122
123 def trailing_metadata(self):
124 return self._wrapped.trailing_metadata()
125
126
127# public type alias denoting the return type of streaming gapic calls
128GrpcStream = _StreamingResponseIterator[P]
129
130
131def _wrap_stream_errors(callable_):
132 """Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
133
134 The callables that return iterators require a bit more logic to re-map
135 errors when iterating. This wraps both the initial invocation and the
136 iterator of the return value to re-map errors.
137 """
138 _patch_callable_name(callable_)
139
140 @functools.wraps(callable_)
141 def error_remapped_callable(*args, **kwargs):
142 try:
143 result = callable_(*args, **kwargs)
144 # Auto-fetching the first result causes PubSub client's streaming pull
145 # to hang when re-opening the stream, thus we need examine the hacky
146 # hidden flag to see if pre-fetching is disabled.
147 # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257
148 prefetch_first = getattr(callable_, "_prefetch_first_result_", True)
149 return _StreamingResponseIterator(
150 result, prefetch_first_result=prefetch_first
151 )
152 except grpc.RpcError as exc:
153 raise exceptions.from_grpc_error(exc) from exc
154
155 return error_remapped_callable
156
157
158def wrap_errors(callable_):
159 """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error
160 classes.
161
162 Errors raised by the gRPC callable are mapped to the appropriate
163 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses.
164 The original `grpc.RpcError` (which is usually also a `grpc.Call`) is
165 available from the ``response`` property on the mapped exception. This
166 is useful for extracting metadata from the original error.
167
168 Args:
169 callable_ (Callable): A gRPC callable.
170
171 Returns:
172 Callable: The wrapped gRPC callable.
173 """
174 if isinstance(callable_, _STREAM_WRAP_CLASSES):
175 return _wrap_stream_errors(callable_)
176 else:
177 return _wrap_unary_errors(callable_)
178
179
180def _create_composite_credentials(
181 credentials=None,
182 credentials_file=None,
183 default_scopes=None,
184 scopes=None,
185 ssl_credentials=None,
186 quota_project_id=None,
187 default_host=None,
188):
189 """Create the composite credentials for secure channels.
190
191 Args:
192 credentials (google.auth.credentials.Credentials): The credentials. If
193 not specified, then this function will attempt to ascertain the
194 credentials from the environment using :func:`google.auth.default`.
195 credentials_file (str): Deprecated. A file with credentials that can be loaded with
196 :func:`google.auth.load_credentials_from_file`. This argument is
197 mutually exclusive with credentials. This argument will be
198 removed in the next major version of `google-api-core`.
199
200 .. warning::
201 Important: If you accept a credential configuration (credential JSON/File/Stream)
202 from an external source for authentication to Google Cloud Platform, you must
203 validate it before providing it to any Google API or client library. Providing an
204 unvalidated credential configuration to Google APIs or libraries can compromise
205 the security of your systems and data. For more information, refer to
206 `Validate credential configurations from external sources`_.
207
208 .. _Validate credential configurations from external sources:
209
210 https://cloud.google.com/docs/authentication/external/externally-sourced-credentials
211 default_scopes (Sequence[str]): A optional list of scopes needed for this
212 service. These are only used when credentials are not specified and
213 are passed to :func:`google.auth.default`.
214 scopes (Sequence[str]): A optional list of scopes needed for this
215 service. These are only used when credentials are not specified and
216 are passed to :func:`google.auth.default`.
217 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
218 credentials. This can be used to specify different certificates.
219 quota_project_id (str): An optional project to use for billing and quota.
220 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
221
222 Returns:
223 grpc.ChannelCredentials: The composed channel credentials object.
224
225 Raises:
226 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
227 """
228 if credentials_file is not None:
229 warnings.warn(general_helpers._CREDENTIALS_FILE_WARNING, DeprecationWarning)
230
231 if credentials and credentials_file:
232 raise exceptions.DuplicateCredentialArgs(
233 "'credentials' and 'credentials_file' are mutually exclusive."
234 )
235
236 if credentials_file:
237 credentials, _ = google.auth.load_credentials_from_file(
238 credentials_file, scopes=scopes, default_scopes=default_scopes
239 )
240 elif credentials:
241 credentials = google.auth.credentials.with_scopes_if_required(
242 credentials, scopes=scopes, default_scopes=default_scopes
243 )
244 else:
245 credentials, _ = google.auth.default(
246 scopes=scopes, default_scopes=default_scopes
247 )
248
249 if quota_project_id and isinstance(
250 credentials, google.auth.credentials.CredentialsWithQuotaProject
251 ):
252 credentials = credentials.with_quota_project(quota_project_id)
253
254 request = google.auth.transport.requests.Request()
255
256 # Create the metadata plugin for inserting the authorization header.
257 metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin(
258 credentials,
259 request,
260 default_host=default_host,
261 )
262
263 # Create a set of grpc.CallCredentials using the metadata plugin.
264 google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
265
266 # if `ssl_credentials` is set, use `grpc.composite_channel_credentials` instead of
267 # `grpc.compute_engine_channel_credentials` as the former supports passing
268 # `ssl_credentials` via `channel_credentials` which is needed for mTLS.
269 if ssl_credentials:
270 # Combine the ssl credentials and the authorization credentials.
271 # See https://grpc.github.io/grpc/python/grpc.html#grpc.composite_channel_credentials
272 return grpc.composite_channel_credentials(
273 ssl_credentials, google_auth_credentials
274 )
275 else:
276 # Use grpc.compute_engine_channel_credentials in order to support Direct Path.
277 # See https://grpc.github.io/grpc/python/grpc.html#grpc.compute_engine_channel_credentials
278 # TODO(https://github.com/googleapis/python-api-core/issues/598):
279 # Although `grpc.compute_engine_channel_credentials` returns channel credentials
280 # outside of a Google Compute Engine environment (GCE), we should determine if
281 # there is a way to reliably detect a GCE environment so that
282 # `grpc.compute_engine_channel_credentials` is not called outside of GCE.
283 return grpc.compute_engine_channel_credentials(google_auth_credentials)
284
285
286def create_channel(
287 target,
288 credentials=None,
289 scopes=None,
290 ssl_credentials=None,
291 credentials_file=None,
292 quota_project_id=None,
293 default_scopes=None,
294 default_host=None,
295 compression=None,
296 attempt_direct_path: Optional[bool] = False,
297 **kwargs,
298):
299 """Create a secure channel with credentials.
300
301 Args:
302 target (str): The target service address in the format 'hostname:port'.
303 credentials (google.auth.credentials.Credentials): The credentials. If
304 not specified, then this function will attempt to ascertain the
305 credentials from the environment using :func:`google.auth.default`.
306 scopes (Sequence[str]): A optional list of scopes needed for this
307 service. These are only used when credentials are not specified and
308 are passed to :func:`google.auth.default`.
309 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
310 credentials. This can be used to specify different certificates.
311 credentials_file (str): A file with credentials that can be loaded with
312 :func:`google.auth.load_credentials_from_file`. This argument is
313 mutually exclusive with credentials.
314
315 .. warning::
316 Important: If you accept a credential configuration (credential JSON/File/Stream)
317 from an external source for authentication to Google Cloud Platform, you must
318 validate it before providing it to any Google API or client library. Providing an
319 unvalidated credential configuration to Google APIs or libraries can compromise
320 the security of your systems and data. For more information, refer to
321 `Validate credential configurations from external sources`_.
322
323 .. _Validate credential configurations from external sources:
324
325 https://cloud.google.com/docs/authentication/external/externally-sourced-credentials
326 quota_project_id (str): An optional project to use for billing and quota.
327 default_scopes (Sequence[str]): Default scopes passed by a Google client
328 library. Use 'scopes' for user-defined scopes.
329 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
330 compression (grpc.Compression): An optional value indicating the
331 compression method to be used over the lifetime of the channel.
332 attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted
333 when the request is made. Direct Path is only available within a Google
334 Compute Engine (GCE) environment and provides a proxyless connection
335 which increases the available throughput, reduces latency, and increases
336 reliability. Note:
337
338 - This argument should only be set in a GCE environment and for Services
339 that are known to support Direct Path.
340 - If this argument is set outside of GCE, then this request will fail
341 unless the back-end service happens to have configured fall-back to DNS.
342 - If the request causes a `ServiceUnavailable` response, it is recommended
343 that the client repeat the request with `attempt_direct_path` set to
344 `False` as the Service may not support Direct Path.
345 - Using `ssl_credentials` with `attempt_direct_path` set to `True` will
346 result in `ValueError` as this combination is not yet supported.
347
348 kwargs: Additional key-word args passed to
349 :func:`grpc.secure_channel`.
350
351 Returns:
352 grpc.Channel: The created channel.
353
354 Raises:
355 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
356 ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`.
357 """
358
359 # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`,
360 # raise ValueError as this is not yet supported.
361 # See https://github.com/googleapis/python-api-core/issues/590
362 if ssl_credentials and attempt_direct_path:
363 raise ValueError("Using ssl_credentials with Direct Path is not supported")
364
365 composite_credentials = _create_composite_credentials(
366 credentials=credentials,
367 credentials_file=credentials_file,
368 default_scopes=default_scopes,
369 scopes=scopes,
370 ssl_credentials=ssl_credentials,
371 quota_project_id=quota_project_id,
372 default_host=default_host,
373 )
374
375 if attempt_direct_path:
376 target = _modify_target_for_direct_path(target)
377
378 return grpc.secure_channel(
379 target, composite_credentials, compression=compression, **kwargs
380 )
381
382
383def _modify_target_for_direct_path(target: str) -> str:
384 """
385 Given a target, return a modified version which is compatible with Direct Path.
386
387 Args:
388 target (str): The target service address in the format 'hostname[:port]' or
389 'dns://hostname[:port]'.
390
391 Returns:
392 target (str): The target service address which is converted into a format compatible with Direct Path.
393 If the target contains `dns:///` or does not contain `:///`, the target will be converted in
394 a format compatible with Direct Path; otherwise the original target will be returned as the
395 original target may already denote Direct Path.
396 """
397
398 # A DNS prefix may be included with the target to indicate the endpoint is living in the Internet,
399 # outside of Google Cloud Platform.
400 dns_prefix = "dns:///"
401 # Remove "dns:///" if `attempt_direct_path` is set to True as
402 # the Direct Path prefix `google-c2p:///` will be used instead.
403 target = target.replace(dns_prefix, "")
404
405 direct_path_separator = ":///"
406 if direct_path_separator not in target:
407 target_without_port = target.split(":")[0]
408 # Modify the target to use Direct Path by adding the `google-c2p:///` prefix
409 target = f"google-c2p{direct_path_separator}{target_without_port}"
410 return target
411
412
413_MethodCall = collections.namedtuple(
414 "_MethodCall", ("request", "timeout", "metadata", "credentials", "compression")
415)
416
417_ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request"))
418
419
420class _CallableStub(object):
421 """Stub for the grpc.*MultiCallable interfaces."""
422
423 def __init__(self, method, channel):
424 self._method = method
425 self._channel = channel
426 self.response = None
427 """Union[protobuf.Message, Callable[protobuf.Message], exception]:
428 The response to give when invoking this callable. If this is a
429 callable, it will be invoked with the request protobuf. If it's an
430 exception, the exception will be raised when this is invoked.
431 """
432 self.responses = None
433 """Iterator[
434 Union[protobuf.Message, Callable[protobuf.Message], exception]]:
435 An iterator of responses. If specified, self.response will be populated
436 on each invocation by calling ``next(self.responses)``."""
437 self.requests = []
438 """List[protobuf.Message]: All requests sent to this callable."""
439 self.calls = []
440 """List[Tuple]: All invocations of this callable. Each tuple is the
441 request, timeout, metadata, compression, and credentials."""
442
443 def __call__(
444 self, request, timeout=None, metadata=None, credentials=None, compression=None
445 ):
446 self._channel.requests.append(_ChannelRequest(self._method, request))
447 self.calls.append(
448 _MethodCall(request, timeout, metadata, credentials, compression)
449 )
450 self.requests.append(request)
451
452 response = self.response
453 if self.responses is not None:
454 if response is None:
455 response = next(self.responses)
456 else:
457 raise ValueError(
458 "{method}.response and {method}.responses are mutually "
459 "exclusive.".format(method=self._method)
460 )
461
462 if callable(response):
463 return response(request)
464
465 if isinstance(response, Exception):
466 raise response
467
468 if response is not None:
469 return response
470
471 raise ValueError('Method stub for "{}" has no response.'.format(self._method))
472
473
474def _simplify_method_name(method):
475 """Simplifies a gRPC method name.
476
477 When gRPC invokes the channel to create a callable, it gives a full
478 method name like "/google.pubsub.v1.Publisher/CreateTopic". This
479 returns just the name of the method, in this case "CreateTopic".
480
481 Args:
482 method (str): The name of the method.
483
484 Returns:
485 str: The simplified name of the method.
486 """
487 return method.rsplit("/", 1).pop()
488
489
490class ChannelStub(grpc.Channel):
491 """A testing stub for the grpc.Channel interface.
492
493 This can be used to test any client that eventually uses a gRPC channel
494 to communicate. By passing in a channel stub, you can configure which
495 responses are returned and track which requests are made.
496
497 For example:
498
499 .. code-block:: python
500
501 channel_stub = grpc_helpers.ChannelStub()
502 client = FooClient(channel=channel_stub)
503
504 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
505
506 foo = client.get_foo(labels=['baz'])
507
508 assert foo.name == 'bar'
509 assert channel_stub.GetFoo.requests[0].labels = ['baz']
510
511 Each method on the stub can be accessed and configured on the channel.
512 Here's some examples of various configurations:
513
514 .. code-block:: python
515
516 # Return a basic response:
517
518 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
519 assert client.get_foo().name == 'bar'
520
521 # Raise an exception:
522 channel_stub.GetFoo.response = NotFound('...')
523
524 with pytest.raises(NotFound):
525 client.get_foo()
526
527 # Use a sequence of responses:
528 channel_stub.GetFoo.responses = iter([
529 foo_pb2.Foo(name='bar'),
530 foo_pb2.Foo(name='baz'),
531 ])
532
533 assert client.get_foo().name == 'bar'
534 assert client.get_foo().name == 'baz'
535
536 # Use a callable
537
538 def on_get_foo(request):
539 return foo_pb2.Foo(name='bar' + request.id)
540
541 channel_stub.GetFoo.response = on_get_foo
542
543 assert client.get_foo(id='123').name == 'bar123'
544 """
545
546 def __init__(self, responses=[]):
547 self.requests = []
548 """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
549 on this channel in order. The tuple is of method name, request
550 message."""
551 self._method_stubs = {}
552
553 def _stub_for_method(self, method):
554 method = _simplify_method_name(method)
555 self._method_stubs[method] = _CallableStub(method, self)
556 return self._method_stubs[method]
557
558 def __getattr__(self, key):
559 try:
560 return self._method_stubs[key]
561 except KeyError:
562 raise AttributeError
563
564 def unary_unary(
565 self,
566 method,
567 request_serializer=None,
568 response_deserializer=None,
569 _registered_method=False,
570 ):
571 """grpc.Channel.unary_unary implementation."""
572 return self._stub_for_method(method)
573
574 def unary_stream(
575 self,
576 method,
577 request_serializer=None,
578 response_deserializer=None,
579 _registered_method=False,
580 ):
581 """grpc.Channel.unary_stream implementation."""
582 return self._stub_for_method(method)
583
584 def stream_unary(
585 self,
586 method,
587 request_serializer=None,
588 response_deserializer=None,
589 _registered_method=False,
590 ):
591 """grpc.Channel.stream_unary implementation."""
592 return self._stub_for_method(method)
593
594 def stream_stream(
595 self,
596 method,
597 request_serializer=None,
598 response_deserializer=None,
599 _registered_method=False,
600 ):
601 """grpc.Channel.stream_stream implementation."""
602 return self._stub_for_method(method)
603
604 def subscribe(self, callback, try_to_connect=False):
605 """grpc.Channel.subscribe implementation."""
606 pass
607
608 def unsubscribe(self, callback):
609 """grpc.Channel.unsubscribe implementation."""
610 pass
611
612 def close(self):
613 """grpc.Channel.close implementation."""
614 pass