1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
15
16import asyncio
17import sys
18from typing import Any, Iterable, List, Optional, Sequence
19
20import grpc
21from grpc import _common
22from grpc import _compression
23from grpc import _grpcio_metadata
24from grpc._cython import cygrpc
25
26from . import _base_call
27from . import _base_channel
28from ._call import StreamStreamCall
29from ._call import StreamUnaryCall
30from ._call import UnaryStreamCall
31from ._call import UnaryUnaryCall
32from ._interceptor import ClientInterceptor
33from ._interceptor import InterceptedStreamStreamCall
34from ._interceptor import InterceptedStreamUnaryCall
35from ._interceptor import InterceptedUnaryStreamCall
36from ._interceptor import InterceptedUnaryUnaryCall
37from ._interceptor import StreamStreamClientInterceptor
38from ._interceptor import StreamUnaryClientInterceptor
39from ._interceptor import UnaryStreamClientInterceptor
40from ._interceptor import UnaryUnaryClientInterceptor
41from ._metadata import Metadata
42from ._typing import ChannelArgumentType
43from ._typing import DeserializingFunction
44from ._typing import MetadataType
45from ._typing import RequestIterableType
46from ._typing import RequestType
47from ._typing import ResponseType
48from ._typing import SerializingFunction
49from ._utils import _timeout_to_deadline
50
51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)
52
53if sys.version_info[1] < 7:
54
55 def _all_tasks() -> Iterable[asyncio.Task]:
56 return asyncio.Task.all_tasks() # pylint: disable=no-member
57
58else:
59
60 def _all_tasks() -> Iterable[asyncio.Task]:
61 return asyncio.all_tasks()
62
63
64def _augment_channel_arguments(
65 base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
66):
67 compression_channel_argument = _compression.create_channel_option(
68 compression
69 )
70 user_agent_channel_argument = (
71 (
72 cygrpc.ChannelArgKey.primary_user_agent_string,
73 _USER_AGENT,
74 ),
75 )
76 return (
77 tuple(base_options)
78 + compression_channel_argument
79 + user_agent_channel_argument
80 )
81
82
83class _BaseMultiCallable:
84 """Base class of all multi callable objects.
85
86 Handles the initialization logic and stores common attributes.
87 """
88
89 _loop: asyncio.AbstractEventLoop
90 _channel: cygrpc.AioChannel
91 _method: bytes
92 _request_serializer: SerializingFunction
93 _response_deserializer: DeserializingFunction
94 _interceptors: Optional[Sequence[ClientInterceptor]]
95 _references: List[Any]
96 _loop: asyncio.AbstractEventLoop
97
98 # pylint: disable=too-many-arguments
99 def __init__(
100 self,
101 channel: cygrpc.AioChannel,
102 method: bytes,
103 request_serializer: SerializingFunction,
104 response_deserializer: DeserializingFunction,
105 interceptors: Optional[Sequence[ClientInterceptor]],
106 references: List[Any],
107 loop: asyncio.AbstractEventLoop,
108 ) -> None:
109 self._loop = loop
110 self._channel = channel
111 self._method = method
112 self._request_serializer = request_serializer
113 self._response_deserializer = response_deserializer
114 self._interceptors = interceptors
115 self._references = references
116
117 @staticmethod
118 def _init_metadata(
119 metadata: Optional[MetadataType] = None,
120 compression: Optional[grpc.Compression] = None,
121 ) -> Metadata:
122 """Based on the provided values for <metadata> or <compression> initialise the final
123 metadata, as it should be used for the current call.
124 """
125 metadata = metadata or Metadata()
126 if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
127 metadata = Metadata.from_tuple(metadata)
128 if compression:
129 metadata = Metadata(
130 *_compression.augment_metadata(metadata, compression)
131 )
132 return metadata
133
134
135class UnaryUnaryMultiCallable(
136 _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
137):
138 def __call__(
139 self,
140 request: RequestType,
141 *,
142 timeout: Optional[float] = None,
143 metadata: Optional[MetadataType] = None,
144 credentials: Optional[grpc.CallCredentials] = None,
145 wait_for_ready: Optional[bool] = None,
146 compression: Optional[grpc.Compression] = None,
147 ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
148 metadata = self._init_metadata(metadata, compression)
149 if not self._interceptors:
150 call = UnaryUnaryCall(
151 request,
152 _timeout_to_deadline(timeout),
153 metadata,
154 credentials,
155 wait_for_ready,
156 self._channel,
157 self._method,
158 self._request_serializer,
159 self._response_deserializer,
160 self._loop,
161 )
162 else:
163 call = InterceptedUnaryUnaryCall(
164 self._interceptors,
165 request,
166 timeout,
167 metadata,
168 credentials,
169 wait_for_ready,
170 self._channel,
171 self._method,
172 self._request_serializer,
173 self._response_deserializer,
174 self._loop,
175 )
176
177 return call
178
179
180class UnaryStreamMultiCallable(
181 _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
182):
183 def __call__(
184 self,
185 request: RequestType,
186 *,
187 timeout: Optional[float] = None,
188 metadata: Optional[MetadataType] = None,
189 credentials: Optional[grpc.CallCredentials] = None,
190 wait_for_ready: Optional[bool] = None,
191 compression: Optional[grpc.Compression] = None,
192 ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
193 metadata = self._init_metadata(metadata, compression)
194
195 if not self._interceptors:
196 call = UnaryStreamCall(
197 request,
198 _timeout_to_deadline(timeout),
199 metadata,
200 credentials,
201 wait_for_ready,
202 self._channel,
203 self._method,
204 self._request_serializer,
205 self._response_deserializer,
206 self._loop,
207 )
208 else:
209 call = InterceptedUnaryStreamCall(
210 self._interceptors,
211 request,
212 timeout,
213 metadata,
214 credentials,
215 wait_for_ready,
216 self._channel,
217 self._method,
218 self._request_serializer,
219 self._response_deserializer,
220 self._loop,
221 )
222
223 return call
224
225
226class StreamUnaryMultiCallable(
227 _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
228):
229 def __call__(
230 self,
231 request_iterator: Optional[RequestIterableType] = None,
232 timeout: Optional[float] = None,
233 metadata: Optional[MetadataType] = None,
234 credentials: Optional[grpc.CallCredentials] = None,
235 wait_for_ready: Optional[bool] = None,
236 compression: Optional[grpc.Compression] = None,
237 ) -> _base_call.StreamUnaryCall:
238 metadata = self._init_metadata(metadata, compression)
239
240 if not self._interceptors:
241 call = StreamUnaryCall(
242 request_iterator,
243 _timeout_to_deadline(timeout),
244 metadata,
245 credentials,
246 wait_for_ready,
247 self._channel,
248 self._method,
249 self._request_serializer,
250 self._response_deserializer,
251 self._loop,
252 )
253 else:
254 call = InterceptedStreamUnaryCall(
255 self._interceptors,
256 request_iterator,
257 timeout,
258 metadata,
259 credentials,
260 wait_for_ready,
261 self._channel,
262 self._method,
263 self._request_serializer,
264 self._response_deserializer,
265 self._loop,
266 )
267
268 return call
269
270
271class StreamStreamMultiCallable(
272 _BaseMultiCallable, _base_channel.StreamStreamMultiCallable
273):
274 def __call__(
275 self,
276 request_iterator: Optional[RequestIterableType] = None,
277 timeout: Optional[float] = None,
278 metadata: Optional[MetadataType] = None,
279 credentials: Optional[grpc.CallCredentials] = None,
280 wait_for_ready: Optional[bool] = None,
281 compression: Optional[grpc.Compression] = None,
282 ) -> _base_call.StreamStreamCall:
283 metadata = self._init_metadata(metadata, compression)
284
285 if not self._interceptors:
286 call = StreamStreamCall(
287 request_iterator,
288 _timeout_to_deadline(timeout),
289 metadata,
290 credentials,
291 wait_for_ready,
292 self._channel,
293 self._method,
294 self._request_serializer,
295 self._response_deserializer,
296 self._loop,
297 )
298 else:
299 call = InterceptedStreamStreamCall(
300 self._interceptors,
301 request_iterator,
302 timeout,
303 metadata,
304 credentials,
305 wait_for_ready,
306 self._channel,
307 self._method,
308 self._request_serializer,
309 self._response_deserializer,
310 self._loop,
311 )
312
313 return call
314
315
316class Channel(_base_channel.Channel):
317 _loop: asyncio.AbstractEventLoop
318 _channel: cygrpc.AioChannel
319 _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
320 _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
321 _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
322 _stream_stream_interceptors: List[StreamStreamClientInterceptor]
323
324 def __init__(
325 self,
326 target: str,
327 options: ChannelArgumentType,
328 credentials: Optional[grpc.ChannelCredentials],
329 compression: Optional[grpc.Compression],
330 interceptors: Optional[Sequence[ClientInterceptor]],
331 ):
332 """Constructor.
333
334 Args:
335 target: The target to which to connect.
336 options: Configuration options for the channel.
337 credentials: A cygrpc.ChannelCredentials or None.
338 compression: An optional value indicating the compression method to be
339 used over the lifetime of the channel.
340 interceptors: An optional list of interceptors that would be used for
341 intercepting any RPC executed with that channel.
342 """
343 self._unary_unary_interceptors = []
344 self._unary_stream_interceptors = []
345 self._stream_unary_interceptors = []
346 self._stream_stream_interceptors = []
347
348 if interceptors is not None:
349 for interceptor in interceptors:
350 if isinstance(interceptor, UnaryUnaryClientInterceptor):
351 self._unary_unary_interceptors.append(interceptor)
352 elif isinstance(interceptor, UnaryStreamClientInterceptor):
353 self._unary_stream_interceptors.append(interceptor)
354 elif isinstance(interceptor, StreamUnaryClientInterceptor):
355 self._stream_unary_interceptors.append(interceptor)
356 elif isinstance(interceptor, StreamStreamClientInterceptor):
357 self._stream_stream_interceptors.append(interceptor)
358 else:
359 raise ValueError(
360 "Interceptor {} must be ".format(interceptor)
361 + "{} or ".format(UnaryUnaryClientInterceptor.__name__)
362 + "{} or ".format(UnaryStreamClientInterceptor.__name__)
363 + "{} or ".format(StreamUnaryClientInterceptor.__name__)
364 + "{}. ".format(StreamStreamClientInterceptor.__name__)
365 )
366
367 self._loop = cygrpc.get_working_loop()
368 self._channel = cygrpc.AioChannel(
369 _common.encode(target),
370 _augment_channel_arguments(options, compression),
371 credentials,
372 self._loop,
373 )
374
375 async def __aenter__(self):
376 return self
377
378 async def __aexit__(self, exc_type, exc_val, exc_tb):
379 await self._close(None)
380
381 async def _close(self, grace): # pylint: disable=too-many-branches
382 if self._channel.closed():
383 return
384
385 # No new calls will be accepted by the Cython channel.
386 self._channel.closing()
387
388 # Iterate through running tasks
389 tasks = _all_tasks()
390 calls = []
391 call_tasks = []
392 for task in tasks:
393 try:
394 stack = task.get_stack(limit=1)
395 except AttributeError as attribute_error:
396 # NOTE(lidiz) tl;dr: If the Task is created with a CPython
397 # object, it will trigger AttributeError.
398 #
399 # In the global finalizer, the event loop schedules
400 # a CPython PyAsyncGenAThrow object.
401 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
402 #
403 # However, the PyAsyncGenAThrow object is written in C and
404 # failed to include the normal Python frame objects. Hence,
405 # this exception is a false negative, and it is safe to ignore
406 # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
407 # but not available until 3.9 or 3.8.3. So, we have to keep it
408 # for a while.
409 # TODO(lidiz) drop this hack after 3.8 deprecation
410 if "frame" in str(attribute_error):
411 continue
412 else:
413 raise
414
415 # If the Task is created by a C-extension, the stack will be empty.
416 if not stack:
417 continue
418
419 # Locate ones created by `aio.Call`.
420 frame = stack[0]
421 candidate = frame.f_locals.get("self")
422 # Explicitly check for a non-null candidate instead of the more pythonic 'if candidate:'
423 # because doing 'if candidate:' assumes that the coroutine implements '__bool__' which
424 # might not always be the case.
425 if candidate is not None:
426 if isinstance(candidate, _base_call.Call):
427 if hasattr(candidate, "_channel"):
428 # For intercepted Call object
429 if candidate._channel is not self._channel:
430 continue
431 elif hasattr(candidate, "_cython_call"):
432 # For normal Call object
433 if candidate._cython_call._channel is not self._channel:
434 continue
435 else:
436 # Unidentified Call object
437 raise cygrpc.InternalError(
438 f"Unrecognized call object: {candidate}"
439 )
440
441 calls.append(candidate)
442 call_tasks.append(task)
443
444 # If needed, try to wait for them to finish.
445 # Call objects are not always awaitables.
446 if grace and call_tasks:
447 await asyncio.wait(call_tasks, timeout=grace)
448
449 # Time to cancel existing calls.
450 for call in calls:
451 call.cancel()
452
453 # Destroy the channel
454 self._channel.close()
455
456 async def close(self, grace: Optional[float] = None):
457 await self._close(grace)
458
459 def __del__(self):
460 if hasattr(self, "_channel"):
461 if not self._channel.closed():
462 self._channel.close()
463
464 def get_state(
465 self, try_to_connect: bool = False
466 ) -> grpc.ChannelConnectivity:
467 result = self._channel.check_connectivity_state(try_to_connect)
468 return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
469
470 async def wait_for_state_change(
471 self,
472 last_observed_state: grpc.ChannelConnectivity,
473 ) -> None:
474 assert await self._channel.watch_connectivity_state(
475 last_observed_state.value[0], None
476 )
477
478 async def channel_ready(self) -> None:
479 state = self.get_state(try_to_connect=True)
480 while state != grpc.ChannelConnectivity.READY:
481 await self.wait_for_state_change(state)
482 state = self.get_state(try_to_connect=True)
483
484 # TODO(xuanwn): Implement this method after we have
485 # observability for Asyncio.
486 def _get_registered_call_handle(self, method: str) -> int:
487 pass
488
489 # TODO(xuanwn): Implement _registered_method after we have
490 # observability for Asyncio.
491 # pylint: disable=arguments-differ,unused-argument
492 def unary_unary(
493 self,
494 method: str,
495 request_serializer: Optional[SerializingFunction] = None,
496 response_deserializer: Optional[DeserializingFunction] = None,
497 _registered_method: Optional[bool] = False,
498 ) -> UnaryUnaryMultiCallable:
499 return UnaryUnaryMultiCallable(
500 self._channel,
501 _common.encode(method),
502 request_serializer,
503 response_deserializer,
504 self._unary_unary_interceptors,
505 [self],
506 self._loop,
507 )
508
509 # TODO(xuanwn): Implement _registered_method after we have
510 # observability for Asyncio.
511 # pylint: disable=arguments-differ,unused-argument
512 def unary_stream(
513 self,
514 method: str,
515 request_serializer: Optional[SerializingFunction] = None,
516 response_deserializer: Optional[DeserializingFunction] = None,
517 _registered_method: Optional[bool] = False,
518 ) -> UnaryStreamMultiCallable:
519 return UnaryStreamMultiCallable(
520 self._channel,
521 _common.encode(method),
522 request_serializer,
523 response_deserializer,
524 self._unary_stream_interceptors,
525 [self],
526 self._loop,
527 )
528
529 # TODO(xuanwn): Implement _registered_method after we have
530 # observability for Asyncio.
531 # pylint: disable=arguments-differ,unused-argument
532 def stream_unary(
533 self,
534 method: str,
535 request_serializer: Optional[SerializingFunction] = None,
536 response_deserializer: Optional[DeserializingFunction] = None,
537 _registered_method: Optional[bool] = False,
538 ) -> StreamUnaryMultiCallable:
539 return StreamUnaryMultiCallable(
540 self._channel,
541 _common.encode(method),
542 request_serializer,
543 response_deserializer,
544 self._stream_unary_interceptors,
545 [self],
546 self._loop,
547 )
548
549 # TODO(xuanwn): Implement _registered_method after we have
550 # observability for Asyncio.
551 # pylint: disable=arguments-differ,unused-argument
552 def stream_stream(
553 self,
554 method: str,
555 request_serializer: Optional[SerializingFunction] = None,
556 response_deserializer: Optional[DeserializingFunction] = None,
557 _registered_method: Optional[bool] = False,
558 ) -> StreamStreamMultiCallable:
559 return StreamStreamMultiCallable(
560 self._channel,
561 _common.encode(method),
562 request_serializer,
563 response_deserializer,
564 self._stream_stream_interceptors,
565 [self],
566 self._loop,
567 )
568
569
570def insecure_channel(
571 target: str,
572 options: Optional[ChannelArgumentType] = None,
573 compression: Optional[grpc.Compression] = None,
574 interceptors: Optional[Sequence[ClientInterceptor]] = None,
575):
576 """Creates an insecure asynchronous Channel to a server.
577
578 Args:
579 target: The server address
580 options: An optional list of key-value pairs (:term:`channel_arguments`
581 in gRPC Core runtime) to configure the channel.
582 compression: An optional value indicating the compression method to be
583 used over the lifetime of the channel.
584 interceptors: An optional sequence of interceptors that will be executed for
585 any call executed with this channel.
586
587 Returns:
588 A Channel.
589 """
590 return Channel(
591 target,
592 () if options is None else options,
593 None,
594 compression,
595 interceptors,
596 )
597
598
599def secure_channel(
600 target: str,
601 credentials: grpc.ChannelCredentials,
602 options: Optional[ChannelArgumentType] = None,
603 compression: Optional[grpc.Compression] = None,
604 interceptors: Optional[Sequence[ClientInterceptor]] = None,
605):
606 """Creates a secure asynchronous Channel to a server.
607
608 Args:
609 target: The server address.
610 credentials: A ChannelCredentials instance.
611 options: An optional list of key-value pairs (:term:`channel_arguments`
612 in gRPC Core runtime) to configure the channel.
613 compression: An optional value indicating the compression method to be
614 used over the lifetime of the channel.
615 interceptors: An optional sequence of interceptors that will be executed for
616 any call executed with this channel.
617
618 Returns:
619 An aio.Channel.
620 """
621 return Channel(
622 target,
623 () if options is None else options,
624 credentials._credentials,
625 compression,
626 interceptors,
627 )