Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/grpc/aio/_call.py: 36%
344 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-16 06:17 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-16 06:17 +0000
1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
16import asyncio
17import enum
18from functools import partial
19import inspect
20import logging
21import traceback
22from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple
24import grpc
25from grpc import _common
26from grpc._cython import cygrpc
28from . import _base_call
29from ._metadata import Metadata
30from ._typing import DeserializingFunction
31from ._typing import DoneCallbackType
32from ._typing import MetadatumType
33from ._typing import RequestIterableType
34from ._typing import RequestType
35from ._typing import ResponseType
36from ._typing import SerializingFunction
38__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall"
40_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
41_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!"
42_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished."
43_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
44_API_STYLE_ERROR = (
45 "The iterator and read/write APIs may not be mixed on a single RPC."
46)
48_OK_CALL_REPRESENTATION = (
49 '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>'
50)
52_NON_OK_CALL_REPRESENTATION = (
53 "<{} of RPC that terminated with:\n"
54 "\tstatus = {}\n"
55 '\tdetails = "{}"\n'
56 '\tdebug_error_string = "{}"\n'
57 ">"
58)
60_LOGGER = logging.getLogger(__name__)
63class AioRpcError(grpc.RpcError):
64 """An implementation of RpcError to be used by the asynchronous API.
66 Raised RpcError is a snapshot of the final status of the RPC, values are
67 determined. Hence, its methods no longer needs to be coroutines.
68 """
70 _code: grpc.StatusCode
71 _details: Optional[str]
72 _initial_metadata: Optional[Metadata]
73 _trailing_metadata: Optional[Metadata]
74 _debug_error_string: Optional[str]
76 def __init__(
77 self,
78 code: grpc.StatusCode,
79 initial_metadata: Metadata,
80 trailing_metadata: Metadata,
81 details: Optional[str] = None,
82 debug_error_string: Optional[str] = None,
83 ) -> None:
84 """Constructor.
86 Args:
87 code: The status code with which the RPC has been finalized.
88 details: Optional details explaining the reason of the error.
89 initial_metadata: Optional initial metadata that could be sent by the
90 Server.
91 trailing_metadata: Optional metadata that could be sent by the Server.
92 """
94 super().__init__()
95 self._code = code
96 self._details = details
97 self._initial_metadata = initial_metadata
98 self._trailing_metadata = trailing_metadata
99 self._debug_error_string = debug_error_string
101 def code(self) -> grpc.StatusCode:
102 """Accesses the status code sent by the server.
104 Returns:
105 The `grpc.StatusCode` status code.
106 """
107 return self._code
109 def details(self) -> Optional[str]:
110 """Accesses the details sent by the server.
112 Returns:
113 The description of the error.
114 """
115 return self._details
117 def initial_metadata(self) -> Metadata:
118 """Accesses the initial metadata sent by the server.
120 Returns:
121 The initial metadata received.
122 """
123 return self._initial_metadata
125 def trailing_metadata(self) -> Metadata:
126 """Accesses the trailing metadata sent by the server.
128 Returns:
129 The trailing metadata received.
130 """
131 return self._trailing_metadata
133 def debug_error_string(self) -> str:
134 """Accesses the debug error string sent by the server.
136 Returns:
137 The debug error string received.
138 """
139 return self._debug_error_string
141 def _repr(self) -> str:
142 """Assembles the error string for the RPC error."""
143 return _NON_OK_CALL_REPRESENTATION.format(
144 self.__class__.__name__,
145 self._code,
146 self._details,
147 self._debug_error_string,
148 )
150 def __repr__(self) -> str:
151 return self._repr()
153 def __str__(self) -> str:
154 return self._repr()
157def _create_rpc_error(
158 initial_metadata: Metadata, status: cygrpc.AioRpcStatus
159) -> AioRpcError:
160 return AioRpcError(
161 _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
162 Metadata.from_tuple(initial_metadata),
163 Metadata.from_tuple(status.trailing_metadata()),
164 details=status.details(),
165 debug_error_string=status.debug_error_string(),
166 )
169class Call:
170 """Base implementation of client RPC Call object.
172 Implements logic around final status, metadata and cancellation.
173 """
175 _loop: asyncio.AbstractEventLoop
176 _code: grpc.StatusCode
177 _cython_call: cygrpc._AioCall
178 _metadata: Tuple[MetadatumType, ...]
179 _request_serializer: SerializingFunction
180 _response_deserializer: DeserializingFunction
182 def __init__(
183 self,
184 cython_call: cygrpc._AioCall,
185 metadata: Metadata,
186 request_serializer: SerializingFunction,
187 response_deserializer: DeserializingFunction,
188 loop: asyncio.AbstractEventLoop,
189 ) -> None:
190 self._loop = loop
191 self._cython_call = cython_call
192 self._metadata = tuple(metadata)
193 self._request_serializer = request_serializer
194 self._response_deserializer = response_deserializer
196 def __del__(self) -> None:
197 # The '_cython_call' object might be destructed before Call object
198 if hasattr(self, "_cython_call"):
199 if not self._cython_call.done():
200 self._cancel(_GC_CANCELLATION_DETAILS)
202 def cancelled(self) -> bool:
203 return self._cython_call.cancelled()
205 def _cancel(self, details: str) -> bool:
206 """Forwards the application cancellation reasoning."""
207 if not self._cython_call.done():
208 self._cython_call.cancel(details)
209 return True
210 else:
211 return False
213 def cancel(self) -> bool:
214 return self._cancel(_LOCAL_CANCELLATION_DETAILS)
216 def done(self) -> bool:
217 return self._cython_call.done()
219 def add_done_callback(self, callback: DoneCallbackType) -> None:
220 cb = partial(callback, self)
221 self._cython_call.add_done_callback(cb)
223 def time_remaining(self) -> Optional[float]:
224 return self._cython_call.time_remaining()
226 async def initial_metadata(self) -> Metadata:
227 raw_metadata_tuple = await self._cython_call.initial_metadata()
228 return Metadata.from_tuple(raw_metadata_tuple)
230 async def trailing_metadata(self) -> Metadata:
231 raw_metadata_tuple = (
232 await self._cython_call.status()
233 ).trailing_metadata()
234 return Metadata.from_tuple(raw_metadata_tuple)
236 async def code(self) -> grpc.StatusCode:
237 cygrpc_code = (await self._cython_call.status()).code()
238 return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
240 async def details(self) -> str:
241 return (await self._cython_call.status()).details()
243 async def debug_error_string(self) -> str:
244 return (await self._cython_call.status()).debug_error_string()
246 async def _raise_for_status(self) -> None:
247 if self._cython_call.is_locally_cancelled():
248 raise asyncio.CancelledError()
249 code = await self.code()
250 if code != grpc.StatusCode.OK:
251 raise _create_rpc_error(
252 await self.initial_metadata(), await self._cython_call.status()
253 )
255 def _repr(self) -> str:
256 return repr(self._cython_call)
258 def __repr__(self) -> str:
259 return self._repr()
261 def __str__(self) -> str:
262 return self._repr()
265class _APIStyle(enum.IntEnum):
266 UNKNOWN = 0
267 ASYNC_GENERATOR = 1
268 READER_WRITER = 2
271class _UnaryResponseMixin(Call, Generic[ResponseType]):
272 _call_response: asyncio.Task
274 def _init_unary_response_mixin(self, response_task: asyncio.Task):
275 self._call_response = response_task
277 def cancel(self) -> bool:
278 if super().cancel():
279 self._call_response.cancel()
280 return True
281 else:
282 return False
284 def __await__(self) -> Generator[Any, None, ResponseType]:
285 """Wait till the ongoing RPC request finishes."""
286 try:
287 response = yield from self._call_response
288 except asyncio.CancelledError:
289 # Even if we caught all other CancelledError, there is still
290 # this corner case. If the application cancels immediately after
291 # the Call object is created, we will observe this
292 # `CancelledError`.
293 if not self.cancelled():
294 self.cancel()
295 raise
297 # NOTE(lidiz) If we raise RpcError in the task, and users doesn't
298 # 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
299 # Instead, if we move the exception raising here, the spam stops.
300 # Unfortunately, there can only be one 'yield from' in '__await__'. So,
301 # we need to access the private instance variable.
302 if response is cygrpc.EOF:
303 if self._cython_call.is_locally_cancelled():
304 raise asyncio.CancelledError()
305 else:
306 raise _create_rpc_error(
307 self._cython_call._initial_metadata,
308 self._cython_call._status,
309 )
310 else:
311 return response
314class _StreamResponseMixin(Call):
315 _message_aiter: AsyncIterator[ResponseType]
316 _preparation: asyncio.Task
317 _response_style: _APIStyle
319 def _init_stream_response_mixin(self, preparation: asyncio.Task):
320 self._message_aiter = None
321 self._preparation = preparation
322 self._response_style = _APIStyle.UNKNOWN
324 def _update_response_style(self, style: _APIStyle):
325 if self._response_style is _APIStyle.UNKNOWN:
326 self._response_style = style
327 elif self._response_style is not style:
328 raise cygrpc.UsageError(_API_STYLE_ERROR)
330 def cancel(self) -> bool:
331 if super().cancel():
332 self._preparation.cancel()
333 return True
334 else:
335 return False
337 async def _fetch_stream_responses(self) -> ResponseType:
338 message = await self._read()
339 while message is not cygrpc.EOF:
340 yield message
341 message = await self._read()
343 # If the read operation failed, Core should explain why.
344 await self._raise_for_status()
346 def __aiter__(self) -> AsyncIterator[ResponseType]:
347 self._update_response_style(_APIStyle.ASYNC_GENERATOR)
348 if self._message_aiter is None:
349 self._message_aiter = self._fetch_stream_responses()
350 return self._message_aiter
352 async def _read(self) -> ResponseType:
353 # Wait for the request being sent
354 await self._preparation
356 # Reads response message from Core
357 try:
358 raw_response = await self._cython_call.receive_serialized_message()
359 except asyncio.CancelledError:
360 if not self.cancelled():
361 self.cancel()
362 raise
364 if raw_response is cygrpc.EOF:
365 return cygrpc.EOF
366 else:
367 return _common.deserialize(
368 raw_response, self._response_deserializer
369 )
371 async def read(self) -> ResponseType:
372 if self.done():
373 await self._raise_for_status()
374 return cygrpc.EOF
375 self._update_response_style(_APIStyle.READER_WRITER)
377 response_message = await self._read()
379 if response_message is cygrpc.EOF:
380 # If the read operation failed, Core should explain why.
381 await self._raise_for_status()
382 return response_message
385class _StreamRequestMixin(Call):
386 _metadata_sent: asyncio.Event
387 _done_writing_flag: bool
388 _async_request_poller: Optional[asyncio.Task]
389 _request_style: _APIStyle
391 def _init_stream_request_mixin(
392 self, request_iterator: Optional[RequestIterableType]
393 ):
394 self._metadata_sent = asyncio.Event()
395 self._done_writing_flag = False
397 # If user passes in an async iterator, create a consumer Task.
398 if request_iterator is not None:
399 self._async_request_poller = self._loop.create_task(
400 self._consume_request_iterator(request_iterator)
401 )
402 self._request_style = _APIStyle.ASYNC_GENERATOR
403 else:
404 self._async_request_poller = None
405 self._request_style = _APIStyle.READER_WRITER
407 def _raise_for_different_style(self, style: _APIStyle):
408 if self._request_style is not style:
409 raise cygrpc.UsageError(_API_STYLE_ERROR)
411 def cancel(self) -> bool:
412 if super().cancel():
413 if self._async_request_poller is not None:
414 self._async_request_poller.cancel()
415 return True
416 else:
417 return False
419 def _metadata_sent_observer(self):
420 self._metadata_sent.set()
422 async def _consume_request_iterator(
423 self, request_iterator: RequestIterableType
424 ) -> None:
425 try:
426 if inspect.isasyncgen(request_iterator) or hasattr(
427 request_iterator, "__aiter__"
428 ):
429 async for request in request_iterator:
430 try:
431 await self._write(request)
432 except AioRpcError as rpc_error:
433 _LOGGER.debug(
434 (
435 "Exception while consuming the"
436 " request_iterator: %s"
437 ),
438 rpc_error,
439 )
440 return
441 else:
442 for request in request_iterator:
443 try:
444 await self._write(request)
445 except AioRpcError as rpc_error:
446 _LOGGER.debug(
447 (
448 "Exception while consuming the"
449 " request_iterator: %s"
450 ),
451 rpc_error,
452 )
453 return
455 await self._done_writing()
456 except: # pylint: disable=bare-except
457 # Client iterators can raise exceptions, which we should handle by
458 # cancelling the RPC and logging the client's error. No exceptions
459 # should escape this function.
460 _LOGGER.debug(
461 "Client request_iterator raised exception:\n%s",
462 traceback.format_exc(),
463 )
464 self.cancel()
466 async def _write(self, request: RequestType) -> None:
467 if self.done():
468 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
469 if self._done_writing_flag:
470 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
471 if not self._metadata_sent.is_set():
472 await self._metadata_sent.wait()
473 if self.done():
474 await self._raise_for_status()
476 serialized_request = _common.serialize(
477 request, self._request_serializer
478 )
479 try:
480 await self._cython_call.send_serialized_message(serialized_request)
481 except cygrpc.InternalError:
482 await self._raise_for_status()
483 except asyncio.CancelledError:
484 if not self.cancelled():
485 self.cancel()
486 raise
488 async def _done_writing(self) -> None:
489 if self.done():
490 # If the RPC is finished, do nothing.
491 return
492 if not self._done_writing_flag:
493 # If the done writing is not sent before, try to send it.
494 self._done_writing_flag = True
495 try:
496 await self._cython_call.send_receive_close()
497 except asyncio.CancelledError:
498 if not self.cancelled():
499 self.cancel()
500 raise
502 async def write(self, request: RequestType) -> None:
503 self._raise_for_different_style(_APIStyle.READER_WRITER)
504 await self._write(request)
506 async def done_writing(self) -> None:
507 """Signal peer that client is done writing.
509 This method is idempotent.
510 """
511 self._raise_for_different_style(_APIStyle.READER_WRITER)
512 await self._done_writing()
514 async def wait_for_connection(self) -> None:
515 await self._metadata_sent.wait()
516 if self.done():
517 await self._raise_for_status()
520class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
521 """Object for managing unary-unary RPC calls.
523 Returned when an instance of `UnaryUnaryMultiCallable` object is called.
524 """
526 _request: RequestType
527 _invocation_task: asyncio.Task
529 # pylint: disable=too-many-arguments
530 def __init__(
531 self,
532 request: RequestType,
533 deadline: Optional[float],
534 metadata: Metadata,
535 credentials: Optional[grpc.CallCredentials],
536 wait_for_ready: Optional[bool],
537 channel: cygrpc.AioChannel,
538 method: bytes,
539 request_serializer: SerializingFunction,
540 response_deserializer: DeserializingFunction,
541 loop: asyncio.AbstractEventLoop,
542 ) -> None:
543 super().__init__(
544 channel.call(method, deadline, credentials, wait_for_ready),
545 metadata,
546 request_serializer,
547 response_deserializer,
548 loop,
549 )
550 self._request = request
551 self._invocation_task = loop.create_task(self._invoke())
552 self._init_unary_response_mixin(self._invocation_task)
554 async def _invoke(self) -> ResponseType:
555 serialized_request = _common.serialize(
556 self._request, self._request_serializer
557 )
559 # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
560 # because the asyncio.Task class do not cache the exception object.
561 # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
562 try:
563 serialized_response = await self._cython_call.unary_unary(
564 serialized_request, self._metadata
565 )
566 except asyncio.CancelledError:
567 if not self.cancelled():
568 self.cancel()
570 if self._cython_call.is_ok():
571 return _common.deserialize(
572 serialized_response, self._response_deserializer
573 )
574 else:
575 return cygrpc.EOF
577 async def wait_for_connection(self) -> None:
578 await self._invocation_task
579 if self.done():
580 await self._raise_for_status()
583class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
584 """Object for managing unary-stream RPC calls.
586 Returned when an instance of `UnaryStreamMultiCallable` object is called.
587 """
589 _request: RequestType
590 _send_unary_request_task: asyncio.Task
592 # pylint: disable=too-many-arguments
593 def __init__(
594 self,
595 request: RequestType,
596 deadline: Optional[float],
597 metadata: Metadata,
598 credentials: Optional[grpc.CallCredentials],
599 wait_for_ready: Optional[bool],
600 channel: cygrpc.AioChannel,
601 method: bytes,
602 request_serializer: SerializingFunction,
603 response_deserializer: DeserializingFunction,
604 loop: asyncio.AbstractEventLoop,
605 ) -> None:
606 super().__init__(
607 channel.call(method, deadline, credentials, wait_for_ready),
608 metadata,
609 request_serializer,
610 response_deserializer,
611 loop,
612 )
613 self._request = request
614 self._send_unary_request_task = loop.create_task(
615 self._send_unary_request()
616 )
617 self._init_stream_response_mixin(self._send_unary_request_task)
619 async def _send_unary_request(self) -> ResponseType:
620 serialized_request = _common.serialize(
621 self._request, self._request_serializer
622 )
623 try:
624 await self._cython_call.initiate_unary_stream(
625 serialized_request, self._metadata
626 )
627 except asyncio.CancelledError:
628 if not self.cancelled():
629 self.cancel()
630 raise
632 async def wait_for_connection(self) -> None:
633 await self._send_unary_request_task
634 if self.done():
635 await self._raise_for_status()
638# pylint: disable=too-many-ancestors
639class StreamUnaryCall(
640 _StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall
641):
642 """Object for managing stream-unary RPC calls.
644 Returned when an instance of `StreamUnaryMultiCallable` object is called.
645 """
647 # pylint: disable=too-many-arguments
648 def __init__(
649 self,
650 request_iterator: Optional[RequestIterableType],
651 deadline: Optional[float],
652 metadata: Metadata,
653 credentials: Optional[grpc.CallCredentials],
654 wait_for_ready: Optional[bool],
655 channel: cygrpc.AioChannel,
656 method: bytes,
657 request_serializer: SerializingFunction,
658 response_deserializer: DeserializingFunction,
659 loop: asyncio.AbstractEventLoop,
660 ) -> None:
661 super().__init__(
662 channel.call(method, deadline, credentials, wait_for_ready),
663 metadata,
664 request_serializer,
665 response_deserializer,
666 loop,
667 )
669 self._init_stream_request_mixin(request_iterator)
670 self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
672 async def _conduct_rpc(self) -> ResponseType:
673 try:
674 serialized_response = await self._cython_call.stream_unary(
675 self._metadata, self._metadata_sent_observer
676 )
677 except asyncio.CancelledError:
678 if not self.cancelled():
679 self.cancel()
680 raise
682 if self._cython_call.is_ok():
683 return _common.deserialize(
684 serialized_response, self._response_deserializer
685 )
686 else:
687 return cygrpc.EOF
690class StreamStreamCall(
691 _StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall
692):
693 """Object for managing stream-stream RPC calls.
695 Returned when an instance of `StreamStreamMultiCallable` object is called.
696 """
698 _initializer: asyncio.Task
700 # pylint: disable=too-many-arguments
701 def __init__(
702 self,
703 request_iterator: Optional[RequestIterableType],
704 deadline: Optional[float],
705 metadata: Metadata,
706 credentials: Optional[grpc.CallCredentials],
707 wait_for_ready: Optional[bool],
708 channel: cygrpc.AioChannel,
709 method: bytes,
710 request_serializer: SerializingFunction,
711 response_deserializer: DeserializingFunction,
712 loop: asyncio.AbstractEventLoop,
713 ) -> None:
714 super().__init__(
715 channel.call(method, deadline, credentials, wait_for_ready),
716 metadata,
717 request_serializer,
718 response_deserializer,
719 loop,
720 )
721 self._initializer = self._loop.create_task(self._prepare_rpc())
722 self._init_stream_request_mixin(request_iterator)
723 self._init_stream_response_mixin(self._initializer)
725 async def _prepare_rpc(self):
726 """This method prepares the RPC for receiving/sending messages.
728 All other operations around the stream should only happen after the
729 completion of this method.
730 """
731 try:
732 await self._cython_call.initiate_stream_stream(
733 self._metadata, self._metadata_sent_observer
734 )
735 except asyncio.CancelledError:
736 if not self.cancelled():
737 self.cancel()
738 # No need to raise RpcError here, because no one will `await` this task.