1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
15
16import asyncio
17import sys
18from typing import Any, Iterable, List, Optional, Sequence
19
20import grpc
21from grpc import _common
22from grpc import _compression
23from grpc import _grpcio_metadata
24from grpc._cython import cygrpc
25
26from . import _base_call
27from . import _base_channel
28from ._call import StreamStreamCall
29from ._call import StreamUnaryCall
30from ._call import UnaryStreamCall
31from ._call import UnaryUnaryCall
32from ._interceptor import ClientInterceptor
33from ._interceptor import InterceptedStreamStreamCall
34from ._interceptor import InterceptedStreamUnaryCall
35from ._interceptor import InterceptedUnaryStreamCall
36from ._interceptor import InterceptedUnaryUnaryCall
37from ._interceptor import StreamStreamClientInterceptor
38from ._interceptor import StreamUnaryClientInterceptor
39from ._interceptor import UnaryStreamClientInterceptor
40from ._interceptor import UnaryUnaryClientInterceptor
41from ._metadata import Metadata
42from ._typing import ChannelArgumentType
43from ._typing import DeserializingFunction
44from ._typing import MetadataType
45from ._typing import RequestIterableType
46from ._typing import RequestType
47from ._typing import ResponseType
48from ._typing import SerializingFunction
49from ._utils import _timeout_to_deadline
50
51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)
52
53if sys.version_info[1] < 7:
54
55 def _all_tasks() -> Iterable[asyncio.Task]:
56 return asyncio.Task.all_tasks() # pylint: disable=no-member
57
58else:
59
60 def _all_tasks() -> Iterable[asyncio.Task]:
61 return asyncio.all_tasks()
62
63
64def _augment_channel_arguments(
65 base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
66):
67 compression_channel_argument = _compression.create_channel_option(
68 compression
69 )
70 user_agent_channel_argument = (
71 (
72 cygrpc.ChannelArgKey.primary_user_agent_string,
73 _USER_AGENT,
74 ),
75 )
76 return (
77 tuple(base_options)
78 + compression_channel_argument
79 + user_agent_channel_argument
80 )
81
82
83class _BaseMultiCallable:
84 """Base class of all multi callable objects.
85
86 Handles the initialization logic and stores common attributes.
87 """
88
89 _loop: asyncio.AbstractEventLoop
90 _channel: cygrpc.AioChannel
91 _method: bytes
92 _request_serializer: SerializingFunction
93 _response_deserializer: DeserializingFunction
94 _interceptors: Optional[Sequence[ClientInterceptor]]
95 _references: List[Any]
96 _loop: asyncio.AbstractEventLoop
97
98 # pylint: disable=too-many-arguments
99 def __init__(
100 self,
101 channel: cygrpc.AioChannel,
102 method: bytes,
103 request_serializer: SerializingFunction,
104 response_deserializer: DeserializingFunction,
105 interceptors: Optional[Sequence[ClientInterceptor]],
106 references: List[Any],
107 loop: asyncio.AbstractEventLoop,
108 ) -> None:
109 self._loop = loop
110 self._channel = channel
111 self._method = method
112 self._request_serializer = request_serializer
113 self._response_deserializer = response_deserializer
114 self._interceptors = interceptors
115 self._references = references
116
117 @staticmethod
118 def _init_metadata(
119 metadata: Optional[MetadataType] = None,
120 compression: Optional[grpc.Compression] = None,
121 ) -> Metadata:
122 """Based on the provided values for <metadata> or <compression> initialise the final
123 metadata, as it should be used for the current call.
124 """
125 metadata = metadata or Metadata()
126 if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
127 metadata = Metadata.from_tuple(metadata)
128 if compression:
129 metadata = Metadata(
130 *_compression.augment_metadata(metadata, compression)
131 )
132 return metadata
133
134
135class UnaryUnaryMultiCallable(
136 _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
137):
138 def __call__(
139 self,
140 request: RequestType,
141 *,
142 timeout: Optional[float] = None,
143 metadata: Optional[MetadataType] = None,
144 credentials: Optional[grpc.CallCredentials] = None,
145 wait_for_ready: Optional[bool] = None,
146 compression: Optional[grpc.Compression] = None,
147 ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
148 metadata = self._init_metadata(metadata, compression)
149 if not self._interceptors:
150 call = UnaryUnaryCall(
151 request,
152 _timeout_to_deadline(timeout),
153 metadata,
154 credentials,
155 wait_for_ready,
156 self._channel,
157 self._method,
158 self._request_serializer,
159 self._response_deserializer,
160 self._loop,
161 )
162 else:
163 call = InterceptedUnaryUnaryCall(
164 self._interceptors,
165 request,
166 timeout,
167 metadata,
168 credentials,
169 wait_for_ready,
170 self._channel,
171 self._method,
172 self._request_serializer,
173 self._response_deserializer,
174 self._loop,
175 )
176
177 return call
178
179
180class UnaryStreamMultiCallable(
181 _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
182):
183 def __call__(
184 self,
185 request: RequestType,
186 *,
187 timeout: Optional[float] = None,
188 metadata: Optional[MetadataType] = None,
189 credentials: Optional[grpc.CallCredentials] = None,
190 wait_for_ready: Optional[bool] = None,
191 compression: Optional[grpc.Compression] = None,
192 ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
193 metadata = self._init_metadata(metadata, compression)
194
195 if not self._interceptors:
196 call = UnaryStreamCall(
197 request,
198 _timeout_to_deadline(timeout),
199 metadata,
200 credentials,
201 wait_for_ready,
202 self._channel,
203 self._method,
204 self._request_serializer,
205 self._response_deserializer,
206 self._loop,
207 )
208 else:
209 call = InterceptedUnaryStreamCall(
210 self._interceptors,
211 request,
212 timeout,
213 metadata,
214 credentials,
215 wait_for_ready,
216 self._channel,
217 self._method,
218 self._request_serializer,
219 self._response_deserializer,
220 self._loop,
221 )
222
223 return call
224
225
226class StreamUnaryMultiCallable(
227 _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
228):
229 def __call__(
230 self,
231 request_iterator: Optional[RequestIterableType] = None,
232 timeout: Optional[float] = None,
233 metadata: Optional[MetadataType] = None,
234 credentials: Optional[grpc.CallCredentials] = None,
235 wait_for_ready: Optional[bool] = None,
236 compression: Optional[grpc.Compression] = None,
237 ) -> _base_call.StreamUnaryCall:
238 metadata = self._init_metadata(metadata, compression)
239
240 if not self._interceptors:
241 call = StreamUnaryCall(
242 request_iterator,
243 _timeout_to_deadline(timeout),
244 metadata,
245 credentials,
246 wait_for_ready,
247 self._channel,
248 self._method,
249 self._request_serializer,
250 self._response_deserializer,
251 self._loop,
252 )
253 else:
254 call = InterceptedStreamUnaryCall(
255 self._interceptors,
256 request_iterator,
257 timeout,
258 metadata,
259 credentials,
260 wait_for_ready,
261 self._channel,
262 self._method,
263 self._request_serializer,
264 self._response_deserializer,
265 self._loop,
266 )
267
268 return call
269
270
271class StreamStreamMultiCallable(
272 _BaseMultiCallable, _base_channel.StreamStreamMultiCallable
273):
274 def __call__(
275 self,
276 request_iterator: Optional[RequestIterableType] = None,
277 timeout: Optional[float] = None,
278 metadata: Optional[MetadataType] = None,
279 credentials: Optional[grpc.CallCredentials] = None,
280 wait_for_ready: Optional[bool] = None,
281 compression: Optional[grpc.Compression] = None,
282 ) -> _base_call.StreamStreamCall:
283 metadata = self._init_metadata(metadata, compression)
284
285 if not self._interceptors:
286 call = StreamStreamCall(
287 request_iterator,
288 _timeout_to_deadline(timeout),
289 metadata,
290 credentials,
291 wait_for_ready,
292 self._channel,
293 self._method,
294 self._request_serializer,
295 self._response_deserializer,
296 self._loop,
297 )
298 else:
299 call = InterceptedStreamStreamCall(
300 self._interceptors,
301 request_iterator,
302 timeout,
303 metadata,
304 credentials,
305 wait_for_ready,
306 self._channel,
307 self._method,
308 self._request_serializer,
309 self._response_deserializer,
310 self._loop,
311 )
312
313 return call
314
315
316class Channel(_base_channel.Channel):
317 _loop: asyncio.AbstractEventLoop
318 _channel: cygrpc.AioChannel
319 _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
320 _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
321 _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
322 _stream_stream_interceptors: List[StreamStreamClientInterceptor]
323
324 def __init__(
325 self,
326 target: str,
327 options: ChannelArgumentType,
328 credentials: Optional[grpc.ChannelCredentials],
329 compression: Optional[grpc.Compression],
330 interceptors: Optional[Sequence[ClientInterceptor]],
331 ):
332 """Constructor.
333
334 Args:
335 target: The target to which to connect.
336 options: Configuration options for the channel.
337 credentials: A cygrpc.ChannelCredentials or None.
338 compression: An optional value indicating the compression method to be
339 used over the lifetime of the channel.
340 interceptors: An optional list of interceptors that would be used for
341 intercepting any RPC executed with that channel.
342 """
343 self._unary_unary_interceptors = []
344 self._unary_stream_interceptors = []
345 self._stream_unary_interceptors = []
346 self._stream_stream_interceptors = []
347
348 if interceptors is not None:
349 for interceptor in interceptors:
350 if isinstance(interceptor, UnaryUnaryClientInterceptor):
351 self._unary_unary_interceptors.append(interceptor)
352 elif isinstance(interceptor, UnaryStreamClientInterceptor):
353 self._unary_stream_interceptors.append(interceptor)
354 elif isinstance(interceptor, StreamUnaryClientInterceptor):
355 self._stream_unary_interceptors.append(interceptor)
356 elif isinstance(interceptor, StreamStreamClientInterceptor):
357 self._stream_stream_interceptors.append(interceptor)
358 else:
359 raise ValueError(
360 "Interceptor {} must be ".format(interceptor)
361 + "{} or ".format(UnaryUnaryClientInterceptor.__name__)
362 + "{} or ".format(UnaryStreamClientInterceptor.__name__)
363 + "{} or ".format(StreamUnaryClientInterceptor.__name__)
364 + "{}. ".format(StreamStreamClientInterceptor.__name__)
365 )
366
367 self._loop = cygrpc.get_working_loop()
368 self._channel = cygrpc.AioChannel(
369 _common.encode(target),
370 _augment_channel_arguments(options, compression),
371 credentials,
372 self._loop,
373 )
374
375 async def __aenter__(self):
376 return self
377
378 async def __aexit__(self, exc_type, exc_val, exc_tb):
379 await self._close(None)
380
381 async def _close(self, grace): # pylint: disable=too-many-branches
382 if self._channel.closed():
383 return
384
385 # No new calls will be accepted by the Cython channel.
386 self._channel.closing()
387
388 # Iterate through running tasks
389 tasks = _all_tasks()
390 calls = []
391 call_tasks = []
392 for task in tasks:
393 try:
394 stack = task.get_stack(limit=1)
395 except AttributeError as attribute_error:
396 # NOTE(lidiz) tl;dr: If the Task is created with a CPython
397 # object, it will trigger AttributeError.
398 #
399 # In the global finalizer, the event loop schedules
400 # a CPython PyAsyncGenAThrow object.
401 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
402 #
403 # However, the PyAsyncGenAThrow object is written in C and
404 # failed to include the normal Python frame objects. Hence,
405 # this exception is a false negative, and it is safe to ignore
406 # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
407 # but not available until 3.9 or 3.8.3. So, we have to keep it
408 # for a while.
409 # TODO(lidiz) drop this hack after 3.8 deprecation
410 if "frame" in str(attribute_error):
411 continue
412 else:
413 raise
414
415 # If the Task is created by a C-extension, the stack will be empty.
416 if not stack:
417 continue
418
419 # Locate ones created by `aio.Call`.
420 frame = stack[0]
421 candidate = frame.f_locals.get("self")
422 if candidate:
423 if isinstance(candidate, _base_call.Call):
424 if hasattr(candidate, "_channel"):
425 # For intercepted Call object
426 if candidate._channel is not self._channel:
427 continue
428 elif hasattr(candidate, "_cython_call"):
429 # For normal Call object
430 if candidate._cython_call._channel is not self._channel:
431 continue
432 else:
433 # Unidentified Call object
434 raise cygrpc.InternalError(
435 f"Unrecognized call object: {candidate}"
436 )
437
438 calls.append(candidate)
439 call_tasks.append(task)
440
441 # If needed, try to wait for them to finish.
442 # Call objects are not always awaitables.
443 if grace and call_tasks:
444 await asyncio.wait(call_tasks, timeout=grace)
445
446 # Time to cancel existing calls.
447 for call in calls:
448 call.cancel()
449
450 # Destroy the channel
451 self._channel.close()
452
453 async def close(self, grace: Optional[float] = None):
454 await self._close(grace)
455
456 def __del__(self):
457 if hasattr(self, "_channel"):
458 if not self._channel.closed():
459 self._channel.close()
460
461 def get_state(
462 self, try_to_connect: bool = False
463 ) -> grpc.ChannelConnectivity:
464 result = self._channel.check_connectivity_state(try_to_connect)
465 return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
466
467 async def wait_for_state_change(
468 self,
469 last_observed_state: grpc.ChannelConnectivity,
470 ) -> None:
471 assert await self._channel.watch_connectivity_state(
472 last_observed_state.value[0], None
473 )
474
475 async def channel_ready(self) -> None:
476 state = self.get_state(try_to_connect=True)
477 while state != grpc.ChannelConnectivity.READY:
478 await self.wait_for_state_change(state)
479 state = self.get_state(try_to_connect=True)
480
481 # TODO(xuanwn): Implement this method after we have
482 # observability for Asyncio.
483 def _get_registered_call_handle(self, method: str) -> int:
484 pass
485
486 # TODO(xuanwn): Implement _registered_method after we have
487 # observability for Asyncio.
488 # pylint: disable=arguments-differ,unused-argument
489 def unary_unary(
490 self,
491 method: str,
492 request_serializer: Optional[SerializingFunction] = None,
493 response_deserializer: Optional[DeserializingFunction] = None,
494 _registered_method: Optional[bool] = False,
495 ) -> UnaryUnaryMultiCallable:
496 return UnaryUnaryMultiCallable(
497 self._channel,
498 _common.encode(method),
499 request_serializer,
500 response_deserializer,
501 self._unary_unary_interceptors,
502 [self],
503 self._loop,
504 )
505
506 # TODO(xuanwn): Implement _registered_method after we have
507 # observability for Asyncio.
508 # pylint: disable=arguments-differ,unused-argument
509 def unary_stream(
510 self,
511 method: str,
512 request_serializer: Optional[SerializingFunction] = None,
513 response_deserializer: Optional[DeserializingFunction] = None,
514 _registered_method: Optional[bool] = False,
515 ) -> UnaryStreamMultiCallable:
516 return UnaryStreamMultiCallable(
517 self._channel,
518 _common.encode(method),
519 request_serializer,
520 response_deserializer,
521 self._unary_stream_interceptors,
522 [self],
523 self._loop,
524 )
525
526 # TODO(xuanwn): Implement _registered_method after we have
527 # observability for Asyncio.
528 # pylint: disable=arguments-differ,unused-argument
529 def stream_unary(
530 self,
531 method: str,
532 request_serializer: Optional[SerializingFunction] = None,
533 response_deserializer: Optional[DeserializingFunction] = None,
534 _registered_method: Optional[bool] = False,
535 ) -> StreamUnaryMultiCallable:
536 return StreamUnaryMultiCallable(
537 self._channel,
538 _common.encode(method),
539 request_serializer,
540 response_deserializer,
541 self._stream_unary_interceptors,
542 [self],
543 self._loop,
544 )
545
546 # TODO(xuanwn): Implement _registered_method after we have
547 # observability for Asyncio.
548 # pylint: disable=arguments-differ,unused-argument
549 def stream_stream(
550 self,
551 method: str,
552 request_serializer: Optional[SerializingFunction] = None,
553 response_deserializer: Optional[DeserializingFunction] = None,
554 _registered_method: Optional[bool] = False,
555 ) -> StreamStreamMultiCallable:
556 return StreamStreamMultiCallable(
557 self._channel,
558 _common.encode(method),
559 request_serializer,
560 response_deserializer,
561 self._stream_stream_interceptors,
562 [self],
563 self._loop,
564 )
565
566
567def insecure_channel(
568 target: str,
569 options: Optional[ChannelArgumentType] = None,
570 compression: Optional[grpc.Compression] = None,
571 interceptors: Optional[Sequence[ClientInterceptor]] = None,
572):
573 """Creates an insecure asynchronous Channel to a server.
574
575 Args:
576 target: The server address
577 options: An optional list of key-value pairs (:term:`channel_arguments`
578 in gRPC Core runtime) to configure the channel.
579 compression: An optional value indicating the compression method to be
580 used over the lifetime of the channel.
581 interceptors: An optional sequence of interceptors that will be executed for
582 any call executed with this channel.
583
584 Returns:
585 A Channel.
586 """
587 return Channel(
588 target,
589 () if options is None else options,
590 None,
591 compression,
592 interceptors,
593 )
594
595
596def secure_channel(
597 target: str,
598 credentials: grpc.ChannelCredentials,
599 options: Optional[ChannelArgumentType] = None,
600 compression: Optional[grpc.Compression] = None,
601 interceptors: Optional[Sequence[ClientInterceptor]] = None,
602):
603 """Creates a secure asynchronous Channel to a server.
604
605 Args:
606 target: The server address.
607 credentials: A ChannelCredentials instance.
608 options: An optional list of key-value pairs (:term:`channel_arguments`
609 in gRPC Core runtime) to configure the channel.
610 compression: An optional value indicating the compression method to be
611 used over the lifetime of the channel.
612 interceptors: An optional sequence of interceptors that will be executed for
613 any call executed with this channel.
614
615 Returns:
616 An aio.Channel.
617 """
618 return Channel(
619 target,
620 () if options is None else options,
621 credentials._credentials,
622 compression,
623 interceptors,
624 )