1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
15
16import asyncio
17import sys
18from typing import Any, Iterable, List, Optional, Sequence
19
20import grpc
21from grpc import _common
22from grpc import _compression
23from grpc import _grpcio_metadata
24from grpc._cython import cygrpc
25
26from . import _base_call
27from . import _base_channel
28from ._call import StreamStreamCall
29from ._call import StreamUnaryCall
30from ._call import UnaryStreamCall
31from ._call import UnaryUnaryCall
32from ._interceptor import ClientInterceptor
33from ._interceptor import InterceptedStreamStreamCall
34from ._interceptor import InterceptedStreamUnaryCall
35from ._interceptor import InterceptedUnaryStreamCall
36from ._interceptor import InterceptedUnaryUnaryCall
37from ._interceptor import StreamStreamClientInterceptor
38from ._interceptor import StreamUnaryClientInterceptor
39from ._interceptor import UnaryStreamClientInterceptor
40from ._interceptor import UnaryUnaryClientInterceptor
41from ._metadata import Metadata
42from ._typing import ChannelArgumentType
43from ._typing import DeserializingFunction
44from ._typing import MetadataType
45from ._typing import RequestIterableType
46from ._typing import RequestType
47from ._typing import ResponseType
48from ._typing import SerializingFunction
49from ._utils import _timeout_to_deadline
50
51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)
52
53if sys.version_info[1] < 7:
54
55 def _all_tasks() -> Iterable[asyncio.Task]:
56 return asyncio.Task.all_tasks() # pylint: disable=no-member
57
58else:
59
60 def _all_tasks() -> Iterable[asyncio.Task]:
61 return asyncio.all_tasks()
62
63
64def _augment_channel_arguments(
65 base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
66):
67 compression_channel_argument = _compression.create_channel_option(
68 compression
69 )
70 user_agent_channel_argument = (
71 (
72 cygrpc.ChannelArgKey.primary_user_agent_string,
73 _USER_AGENT,
74 ),
75 )
76 return (
77 tuple(base_options)
78 + compression_channel_argument
79 + user_agent_channel_argument
80 )
81
82
83class _BaseMultiCallable:
84 """Base class of all multi callable objects.
85
86 Handles the initialization logic and stores common attributes.
87 """
88
89 _loop: asyncio.AbstractEventLoop
90 _channel: cygrpc.AioChannel
91 _method: bytes
92 _request_serializer: Optional[SerializingFunction]
93 _response_deserializer: Optional[DeserializingFunction]
94 _interceptors: Optional[Sequence[ClientInterceptor]]
95 _references: List[Any]
96 _loop: asyncio.AbstractEventLoop
97
98 # pylint: disable=too-many-arguments
99 def __init__(
100 self,
101 channel: cygrpc.AioChannel,
102 method: bytes,
103 request_serializer: Optional[SerializingFunction],
104 response_deserializer: Optional[DeserializingFunction],
105 interceptors: Optional[Sequence[ClientInterceptor]],
106 references: List[Any],
107 loop: asyncio.AbstractEventLoop,
108 ) -> None:
109 self._loop = loop
110 self._channel = channel
111 self._method = method
112 self._request_serializer = request_serializer
113 self._response_deserializer = response_deserializer
114 self._interceptors = interceptors
115 self._references = references
116
117 @staticmethod
118 def _init_metadata(
119 metadata: Optional[MetadataType] = None,
120 compression: Optional[grpc.Compression] = None,
121 ) -> Metadata:
122 """Based on the provided values for <metadata> or <compression> initialise the final
123 metadata, as it should be used for the current call.
124 """
125 metadata = metadata or Metadata()
126 if not isinstance(metadata, Metadata) and isinstance(
127 metadata, Sequence
128 ):
129 metadata = Metadata.from_tuple(tuple(metadata))
130 if compression:
131 metadata = Metadata(
132 *_compression.augment_metadata(metadata, compression)
133 )
134 return metadata
135
136
137class UnaryUnaryMultiCallable(
138 _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
139):
140 def __call__(
141 self,
142 request: RequestType,
143 *,
144 timeout: Optional[float] = None,
145 metadata: Optional[MetadataType] = None,
146 credentials: Optional[grpc.CallCredentials] = None,
147 wait_for_ready: Optional[bool] = None,
148 compression: Optional[grpc.Compression] = None,
149 ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
150 metadata = self._init_metadata(metadata, compression)
151 if not self._interceptors:
152 call = UnaryUnaryCall(
153 request,
154 _timeout_to_deadline(timeout),
155 metadata,
156 credentials,
157 wait_for_ready,
158 self._channel,
159 self._method,
160 self._request_serializer,
161 self._response_deserializer,
162 self._loop,
163 )
164 else:
165 call = InterceptedUnaryUnaryCall(
166 self._interceptors,
167 request,
168 timeout,
169 metadata,
170 credentials,
171 wait_for_ready,
172 self._channel,
173 self._method,
174 self._request_serializer,
175 self._response_deserializer,
176 self._loop,
177 )
178
179 return call
180
181
182class UnaryStreamMultiCallable(
183 _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
184):
185 def __call__(
186 self,
187 request: RequestType,
188 *,
189 timeout: Optional[float] = None,
190 metadata: Optional[MetadataType] = None,
191 credentials: Optional[grpc.CallCredentials] = None,
192 wait_for_ready: Optional[bool] = None,
193 compression: Optional[grpc.Compression] = None,
194 ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
195 metadata = self._init_metadata(metadata, compression)
196
197 if not self._interceptors:
198 call = UnaryStreamCall(
199 request,
200 _timeout_to_deadline(timeout),
201 metadata,
202 credentials,
203 wait_for_ready,
204 self._channel,
205 self._method,
206 self._request_serializer,
207 self._response_deserializer,
208 self._loop,
209 )
210 else:
211 call = InterceptedUnaryStreamCall(
212 self._interceptors,
213 request,
214 timeout,
215 metadata,
216 credentials,
217 wait_for_ready,
218 self._channel,
219 self._method,
220 self._request_serializer,
221 self._response_deserializer,
222 self._loop,
223 )
224
225 return call
226
227
228class StreamUnaryMultiCallable(
229 _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
230):
231 def __call__(
232 self,
233 request_iterator: Optional[RequestIterableType] = None,
234 timeout: Optional[float] = None,
235 metadata: Optional[MetadataType] = None,
236 credentials: Optional[grpc.CallCredentials] = None,
237 wait_for_ready: Optional[bool] = None,
238 compression: Optional[grpc.Compression] = None,
239 ) -> _base_call.StreamUnaryCall:
240 metadata = self._init_metadata(metadata, compression)
241
242 if not self._interceptors:
243 call = StreamUnaryCall(
244 request_iterator,
245 _timeout_to_deadline(timeout),
246 metadata,
247 credentials,
248 wait_for_ready,
249 self._channel,
250 self._method,
251 self._request_serializer,
252 self._response_deserializer,
253 self._loop,
254 )
255 else:
256 call = InterceptedStreamUnaryCall(
257 self._interceptors,
258 request_iterator,
259 timeout,
260 metadata,
261 credentials,
262 wait_for_ready,
263 self._channel,
264 self._method,
265 self._request_serializer,
266 self._response_deserializer,
267 self._loop,
268 )
269
270 return call
271
272
273class StreamStreamMultiCallable(
274 _BaseMultiCallable, _base_channel.StreamStreamMultiCallable
275):
276 def __call__(
277 self,
278 request_iterator: Optional[RequestIterableType] = None,
279 timeout: Optional[float] = None,
280 metadata: Optional[MetadataType] = None,
281 credentials: Optional[grpc.CallCredentials] = None,
282 wait_for_ready: Optional[bool] = None,
283 compression: Optional[grpc.Compression] = None,
284 ) -> _base_call.StreamStreamCall:
285 metadata = self._init_metadata(metadata, compression)
286
287 if not self._interceptors:
288 call = StreamStreamCall(
289 request_iterator,
290 _timeout_to_deadline(timeout),
291 metadata,
292 credentials,
293 wait_for_ready,
294 self._channel,
295 self._method,
296 self._request_serializer,
297 self._response_deserializer,
298 self._loop,
299 )
300 else:
301 call = InterceptedStreamStreamCall(
302 self._interceptors,
303 request_iterator,
304 timeout,
305 metadata,
306 credentials,
307 wait_for_ready,
308 self._channel,
309 self._method,
310 self._request_serializer,
311 self._response_deserializer,
312 self._loop,
313 )
314
315 return call
316
317
318class Channel(_base_channel.Channel):
319 _loop: asyncio.AbstractEventLoop
320 _channel: cygrpc.AioChannel
321 _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
322 _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
323 _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
324 _stream_stream_interceptors: List[StreamStreamClientInterceptor]
325
326 def __init__(
327 self,
328 target: str,
329 options: ChannelArgumentType,
330 credentials: Optional[grpc.ChannelCredentials],
331 compression: Optional[grpc.Compression],
332 interceptors: Optional[Sequence[ClientInterceptor]],
333 ):
334 """Constructor.
335
336 Args:
337 target: The target to which to connect.
338 options: Configuration options for the channel.
339 credentials: A cygrpc.ChannelCredentials or None.
340 compression: An optional value indicating the compression method to be
341 used over the lifetime of the channel.
342 interceptors: An optional list of interceptors that would be used for
343 intercepting any RPC executed with that channel.
344 """
345 self._unary_unary_interceptors = []
346 self._unary_stream_interceptors = []
347 self._stream_unary_interceptors = []
348 self._stream_stream_interceptors = []
349
350 if interceptors is not None:
351 for interceptor in interceptors:
352 if isinstance(interceptor, UnaryUnaryClientInterceptor):
353 self._unary_unary_interceptors.append(interceptor)
354 elif isinstance(interceptor, UnaryStreamClientInterceptor):
355 self._unary_stream_interceptors.append(interceptor)
356 elif isinstance(interceptor, StreamUnaryClientInterceptor):
357 self._stream_unary_interceptors.append(interceptor)
358 elif isinstance(interceptor, StreamStreamClientInterceptor):
359 self._stream_stream_interceptors.append(interceptor)
360 else:
361 raise ValueError(
362 "Interceptor {} must be ".format(interceptor)
363 + "{} or ".format(UnaryUnaryClientInterceptor.__name__)
364 + "{} or ".format(UnaryStreamClientInterceptor.__name__)
365 + "{} or ".format(StreamUnaryClientInterceptor.__name__)
366 + "{}. ".format(StreamStreamClientInterceptor.__name__)
367 )
368
369 self._loop = cygrpc.get_working_loop()
370 self._channel = cygrpc.AioChannel(
371 _common.encode(target),
372 _augment_channel_arguments(options, compression),
373 credentials,
374 self._loop,
375 )
376
377 async def __aenter__(self):
378 return self
379
380 async def __aexit__(self, exc_type, exc_val, exc_tb):
381 await self._close(None)
382
383 async def _close(self, grace): # pylint: disable=too-many-branches
384 if self._channel.closed():
385 return
386
387 # No new calls will be accepted by the Cython channel.
388 self._channel.closing()
389
390 # Iterate through running tasks
391 tasks = _all_tasks()
392 calls = []
393 call_tasks = []
394 for task in tasks:
395 try:
396 stack = task.get_stack(limit=1)
397 except AttributeError as attribute_error:
398 # NOTE(lidiz) tl;dr: If the Task is created with a CPython
399 # object, it will trigger AttributeError.
400 #
401 # In the global finalizer, the event loop schedules
402 # a CPython PyAsyncGenAThrow object.
403 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
404 #
405 # However, the PyAsyncGenAThrow object is written in C and
406 # failed to include the normal Python frame objects. Hence,
407 # this exception is a false negative, and it is safe to ignore
408 # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
409 # but not available until 3.9 or 3.8.3. So, we have to keep it
410 # for a while.
411 # TODO(lidiz) drop this hack after 3.8 deprecation
412 if "frame" in str(attribute_error):
413 continue
414 else:
415 raise
416
417 # If the Task is created by a C-extension, the stack will be empty.
418 if not stack:
419 continue
420
421 # Locate ones created by `aio.Call`.
422 frame = stack[0]
423 candidate = frame.f_locals.get("self")
424 # Explicitly check for a non-null candidate instead of the more pythonic 'if candidate:'
425 # because doing 'if candidate:' assumes that the coroutine implements '__bool__' which
426 # might not always be the case.
427 if candidate is not None:
428 if isinstance(candidate, _base_call.Call):
429 if hasattr(candidate, "_channel"):
430 # For intercepted Call object
431 if candidate._channel is not self._channel:
432 continue
433 elif hasattr(candidate, "_cython_call"):
434 # For normal Call object
435 if candidate._cython_call._channel is not self._channel:
436 continue
437 else:
438 # Unidentified Call object
439 error_msg = f"Unrecognized call object: {candidate}"
440 raise cygrpc.InternalError(error_msg)
441
442 calls.append(candidate)
443 call_tasks.append(task)
444
445 # If needed, try to wait for them to finish.
446 # Call objects are not always awaitables.
447 if grace and call_tasks:
448 await asyncio.wait(call_tasks, timeout=grace)
449
450 # Time to cancel existing calls.
451 for call in calls:
452 call.cancel()
453
454 # Destroy the channel
455 self._channel.close()
456
457 async def close(self, grace: Optional[float] = None):
458 await self._close(grace)
459
460 def __del__(self):
461 if hasattr(self, "_channel"):
462 if not self._channel.closed():
463 self._channel.close()
464
465 def get_state(
466 self, try_to_connect: bool = False
467 ) -> grpc.ChannelConnectivity:
468 result = self._channel.check_connectivity_state(try_to_connect)
469 return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
470
471 async def wait_for_state_change(
472 self,
473 last_observed_state: grpc.ChannelConnectivity,
474 ) -> None:
475 assert await self._channel.watch_connectivity_state(
476 last_observed_state.value[0], None
477 )
478
479 async def channel_ready(self) -> None:
480 state = self.get_state(try_to_connect=True)
481 while state != grpc.ChannelConnectivity.READY:
482 await self.wait_for_state_change(state)
483 state = self.get_state(try_to_connect=True)
484
485 # TODO(xuanwn): Implement this method after we have
486 # observability for Asyncio.
487 def _get_registered_call_handle(self, method: str) -> int:
488 pass
489
490 # TODO(xuanwn): Implement _registered_method after we have
491 # observability for Asyncio.
492 # pylint: disable=arguments-differ,unused-argument
493 def unary_unary(
494 self,
495 method: str,
496 request_serializer: Optional[SerializingFunction] = None,
497 response_deserializer: Optional[DeserializingFunction] = None,
498 _registered_method: Optional[bool] = False,
499 ) -> UnaryUnaryMultiCallable:
500 return UnaryUnaryMultiCallable(
501 self._channel,
502 _common.encode(method),
503 request_serializer,
504 response_deserializer,
505 self._unary_unary_interceptors,
506 [self],
507 self._loop,
508 )
509
510 # TODO(xuanwn): Implement _registered_method after we have
511 # observability for Asyncio.
512 # pylint: disable=arguments-differ,unused-argument
513 def unary_stream(
514 self,
515 method: str,
516 request_serializer: Optional[SerializingFunction] = None,
517 response_deserializer: Optional[DeserializingFunction] = None,
518 _registered_method: Optional[bool] = False,
519 ) -> UnaryStreamMultiCallable:
520 return UnaryStreamMultiCallable(
521 self._channel,
522 _common.encode(method),
523 request_serializer,
524 response_deserializer,
525 self._unary_stream_interceptors,
526 [self],
527 self._loop,
528 )
529
530 # TODO(xuanwn): Implement _registered_method after we have
531 # observability for Asyncio.
532 # pylint: disable=arguments-differ,unused-argument
533 def stream_unary(
534 self,
535 method: str,
536 request_serializer: Optional[SerializingFunction] = None,
537 response_deserializer: Optional[DeserializingFunction] = None,
538 _registered_method: Optional[bool] = False,
539 ) -> StreamUnaryMultiCallable:
540 return StreamUnaryMultiCallable(
541 self._channel,
542 _common.encode(method),
543 request_serializer,
544 response_deserializer,
545 self._stream_unary_interceptors,
546 [self],
547 self._loop,
548 )
549
550 # TODO(xuanwn): Implement _registered_method after we have
551 # observability for Asyncio.
552 # pylint: disable=arguments-differ,unused-argument
553 def stream_stream(
554 self,
555 method: str,
556 request_serializer: Optional[SerializingFunction] = None,
557 response_deserializer: Optional[DeserializingFunction] = None,
558 _registered_method: Optional[bool] = False,
559 ) -> StreamStreamMultiCallable:
560 return StreamStreamMultiCallable(
561 self._channel,
562 _common.encode(method),
563 request_serializer,
564 response_deserializer,
565 self._stream_stream_interceptors,
566 [self],
567 self._loop,
568 )
569
570
571def insecure_channel(
572 target: str,
573 options: Optional[ChannelArgumentType] = None,
574 compression: Optional[grpc.Compression] = None,
575 interceptors: Optional[Sequence[ClientInterceptor]] = None,
576):
577 """Creates an insecure asynchronous Channel to a server.
578
579 Args:
580 target: The server address
581 options: An optional list of key-value pairs (:term:`channel_arguments`
582 in gRPC Core runtime) to configure the channel.
583 compression: An optional value indicating the compression method to be
584 used over the lifetime of the channel.
585 interceptors: An optional sequence of interceptors that will be executed for
586 any call executed with this channel.
587
588 Returns:
589 A Channel.
590 """
591 return Channel(
592 target,
593 () if options is None else options,
594 None,
595 compression,
596 interceptors,
597 )
598
599
600def secure_channel(
601 target: str,
602 credentials: grpc.ChannelCredentials,
603 options: Optional[ChannelArgumentType] = None,
604 compression: Optional[grpc.Compression] = None,
605 interceptors: Optional[Sequence[ClientInterceptor]] = None,
606):
607 """Creates a secure asynchronous Channel to a server.
608
609 Args:
610 target: The server address.
611 credentials: A ChannelCredentials instance.
612 options: An optional list of key-value pairs (:term:`channel_arguments`
613 in gRPC Core runtime) to configure the channel.
614 compression: An optional value indicating the compression method to be
615 used over the lifetime of the channel.
616 interceptors: An optional sequence of interceptors that will be executed for
617 any call executed with this channel.
618
619 Returns:
620 An aio.Channel.
621 """
622 return Channel(
623 target,
624 () if options is None else options,
625 credentials._credentials,
626 compression,
627 interceptors,
628 )