Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/grpc/aio/_call.py: 35%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
16import asyncio
17import enum
18from functools import partial
19import inspect
20import logging
21import traceback
22from typing import (
23 Any,
24 AsyncIterator,
25 Generator,
26 Generic,
27 Optional,
28 Tuple,
29 Union,
30)
32import grpc
33from grpc import _common
34from grpc._cython import cygrpc
36from . import _base_call
37from ._metadata import Metadata
38from ._typing import DeserializingFunction
39from ._typing import DoneCallbackType
40from ._typing import EOFType
41from ._typing import MetadatumType
42from ._typing import RequestIterableType
43from ._typing import RequestType
44from ._typing import ResponseType
45from ._typing import SerializingFunction
47__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall"
49_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
50_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!"
51_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished."
52_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
53_API_STYLE_ERROR = (
54 "The iterator and read/write APIs may not be mixed on a single RPC."
55)
57_OK_CALL_REPRESENTATION = (
58 '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>'
59)
61_NON_OK_CALL_REPRESENTATION = (
62 "<{} of RPC that terminated with:\n"
63 "\tstatus = {}\n"
64 '\tdetails = "{}"\n'
65 '\tdebug_error_string = "{}"\n'
66 ">"
67)
69_LOGGER = logging.getLogger(__name__)
72class AioRpcError(grpc.RpcError):
73 """An implementation of RpcError to be used by the asynchronous API.
75 Raised RpcError is a snapshot of the final status of the RPC, values are
76 determined. Hence, its methods no longer needs to be coroutines.
77 """
79 _code: grpc.StatusCode
80 _details: Optional[str]
81 _initial_metadata: Optional[Metadata]
82 _trailing_metadata: Optional[Metadata]
83 _debug_error_string: Optional[str]
85 def __init__(
86 self,
87 code: grpc.StatusCode,
88 initial_metadata: Metadata,
89 trailing_metadata: Metadata,
90 details: Optional[str] = None,
91 debug_error_string: Optional[str] = None,
92 ) -> None:
93 """Constructor.
95 Args:
96 code: The status code with which the RPC has been finalized.
97 details: Optional details explaining the reason of the error.
98 initial_metadata: Optional initial metadata that could be sent by the
99 Server.
100 trailing_metadata: Optional metadata that could be sent by the Server.
101 """
103 super().__init__()
104 self._code = code
105 self._details = details
106 self._initial_metadata = initial_metadata
107 self._trailing_metadata = trailing_metadata
108 self._debug_error_string = debug_error_string
110 def code(self) -> grpc.StatusCode:
111 """Accesses the status code sent by the server.
113 Returns:
114 The `grpc.StatusCode` status code.
115 """
116 return self._code
118 def details(self) -> Optional[str]:
119 """Accesses the details sent by the server.
121 Returns:
122 The description of the error.
123 """
124 return self._details
126 def initial_metadata(self) -> Metadata:
127 """Accesses the initial metadata sent by the server.
129 Returns:
130 The initial metadata received.
131 """
132 return self._initial_metadata
134 def trailing_metadata(self) -> Metadata:
135 """Accesses the trailing metadata sent by the server.
137 Returns:
138 The trailing metadata received.
139 """
140 return self._trailing_metadata
142 def debug_error_string(self) -> str:
143 """Accesses the debug error string sent by the server.
145 Returns:
146 The debug error string received.
147 """
148 return self._debug_error_string
150 def _repr(self) -> str:
151 """Assembles the error string for the RPC error."""
152 return _NON_OK_CALL_REPRESENTATION.format(
153 self.__class__.__name__,
154 self._code,
155 self._details,
156 self._debug_error_string,
157 )
159 def __repr__(self) -> str:
160 return self._repr()
162 def __str__(self) -> str:
163 return self._repr()
165 def __reduce__(self):
166 return (
167 type(self),
168 (
169 self._code,
170 self._initial_metadata,
171 self._trailing_metadata,
172 self._details,
173 self._debug_error_string,
174 ),
175 )
178def _create_rpc_error(
179 initial_metadata: Metadata, status: cygrpc.AioRpcStatus
180) -> AioRpcError:
181 return AioRpcError(
182 _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
183 Metadata.from_tuple(initial_metadata),
184 Metadata.from_tuple(status.trailing_metadata()),
185 details=status.details(),
186 debug_error_string=status.debug_error_string(),
187 )
190class Call:
191 """Base implementation of client RPC Call object.
193 Implements logic around final status, metadata and cancellation.
194 """
196 _loop: asyncio.AbstractEventLoop
197 _code: grpc.StatusCode
198 _cython_call: cygrpc._AioCall
199 _metadata: Tuple[MetadatumType, ...]
200 _request_serializer: SerializingFunction
201 _response_deserializer: DeserializingFunction
203 def __init__(
204 self,
205 cython_call: cygrpc._AioCall,
206 metadata: Metadata,
207 request_serializer: SerializingFunction,
208 response_deserializer: DeserializingFunction,
209 loop: asyncio.AbstractEventLoop,
210 ) -> None:
211 self._loop = loop
212 self._cython_call = cython_call
213 self._metadata = tuple(metadata)
214 self._request_serializer = request_serializer
215 self._response_deserializer = response_deserializer
217 def __del__(self) -> None:
218 # The '_cython_call' object might be destructed before Call object
219 if hasattr(self, "_cython_call"):
220 if not self._cython_call.done():
221 self._cancel(_GC_CANCELLATION_DETAILS)
223 def cancelled(self) -> bool:
224 return self._cython_call.cancelled()
226 def _cancel(self, details: str) -> bool:
227 """Forwards the application cancellation reasoning."""
228 if not self._cython_call.done():
229 self._cython_call.cancel(details)
230 return True
231 else:
232 return False
234 def cancel(self) -> bool:
235 return self._cancel(_LOCAL_CANCELLATION_DETAILS)
237 def done(self) -> bool:
238 return self._cython_call.done()
240 def add_done_callback(self, callback: DoneCallbackType) -> None:
241 cb = partial(callback, self)
242 self._cython_call.add_done_callback(cb)
244 def time_remaining(self) -> Optional[float]:
245 return self._cython_call.time_remaining()
247 async def initial_metadata(self) -> Metadata:
248 raw_metadata_tuple = await self._cython_call.initial_metadata()
249 return Metadata.from_tuple(raw_metadata_tuple)
251 async def trailing_metadata(self) -> Metadata:
252 raw_metadata_tuple = (
253 await self._cython_call.status()
254 ).trailing_metadata()
255 return Metadata.from_tuple(raw_metadata_tuple)
257 async def code(self) -> grpc.StatusCode:
258 cygrpc_code = (await self._cython_call.status()).code()
259 return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
261 async def details(self) -> str:
262 return (await self._cython_call.status()).details()
264 async def debug_error_string(self) -> str:
265 return (await self._cython_call.status()).debug_error_string()
267 async def _raise_for_status(self) -> None:
268 if self._cython_call.is_locally_cancelled():
269 raise asyncio.CancelledError()
270 code = await self.code()
271 if code != grpc.StatusCode.OK:
272 raise _create_rpc_error(
273 await self.initial_metadata(), await self._cython_call.status()
274 )
276 def _repr(self) -> str:
277 return repr(self._cython_call)
279 def __repr__(self) -> str:
280 return self._repr()
282 def __str__(self) -> str:
283 return self._repr()
286class _APIStyle(enum.IntEnum):
287 UNKNOWN = 0
288 ASYNC_GENERATOR = 1
289 READER_WRITER = 2
292class _UnaryResponseMixin(Call, Generic[ResponseType]):
293 _call_response: asyncio.Task
295 def _init_unary_response_mixin(self, response_task: asyncio.Task):
296 self._call_response = response_task
298 def cancel(self) -> bool:
299 if super().cancel():
300 self._call_response.cancel()
301 return True
302 else:
303 return False
305 def __await__(self) -> Generator[Any, None, ResponseType]:
306 """Wait till the ongoing RPC request finishes."""
307 try:
308 response = yield from self._call_response
309 except asyncio.CancelledError:
310 # Even if we caught all other CancelledError, there is still
311 # this corner case. If the application cancels immediately after
312 # the Call object is created, we will observe this
313 # `CancelledError`.
314 if not self.cancelled():
315 self.cancel()
316 raise
318 # NOTE(lidiz) If we raise RpcError in the task, and users doesn't
319 # 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
320 # Instead, if we move the exception raising here, the spam stops.
321 # Unfortunately, there can only be one 'yield from' in '__await__'. So,
322 # we need to access the private instance variable.
323 if response is cygrpc.EOF:
324 if self._cython_call.is_locally_cancelled():
325 raise asyncio.CancelledError()
326 else:
327 raise _create_rpc_error(
328 self._cython_call._initial_metadata,
329 self._cython_call._status,
330 )
331 else:
332 return response
335class _StreamResponseMixin(Call):
336 _message_aiter: AsyncIterator[ResponseType]
337 _preparation: asyncio.Task
338 _response_style: _APIStyle
340 def _init_stream_response_mixin(self, preparation: asyncio.Task):
341 self._message_aiter = None
342 self._preparation = preparation
343 self._response_style = _APIStyle.UNKNOWN
345 def _update_response_style(self, style: _APIStyle):
346 if self._response_style is _APIStyle.UNKNOWN:
347 self._response_style = style
348 elif self._response_style is not style:
349 raise cygrpc.UsageError(_API_STYLE_ERROR)
351 def cancel(self) -> bool:
352 if super().cancel():
353 self._preparation.cancel()
354 return True
355 else:
356 return False
358 async def _fetch_stream_responses(self) -> ResponseType:
359 message = await self._read()
360 while message is not cygrpc.EOF:
361 yield message
362 message = await self._read()
364 # If the read operation failed, Core should explain why.
365 await self._raise_for_status()
367 def __aiter__(self) -> AsyncIterator[ResponseType]:
368 self._update_response_style(_APIStyle.ASYNC_GENERATOR)
369 if self._message_aiter is None:
370 self._message_aiter = self._fetch_stream_responses()
371 return self._message_aiter
373 async def _read(self) -> ResponseType:
374 # Wait for the request being sent
375 await self._preparation
377 # Reads response message from Core
378 try:
379 raw_response = await self._cython_call.receive_serialized_message()
380 except asyncio.CancelledError:
381 if not self.cancelled():
382 self.cancel()
383 raise
385 if raw_response is cygrpc.EOF:
386 return cygrpc.EOF
387 else:
388 return _common.deserialize(
389 raw_response, self._response_deserializer
390 )
392 async def read(self) -> Union[EOFType, ResponseType]:
393 if self.done():
394 await self._raise_for_status()
395 return cygrpc.EOF
396 self._update_response_style(_APIStyle.READER_WRITER)
398 response_message = await self._read()
400 if response_message is cygrpc.EOF:
401 # If the read operation failed, Core should explain why.
402 await self._raise_for_status()
403 return response_message
406class _StreamRequestMixin(Call):
407 _metadata_sent: asyncio.Event
408 _done_writing_flag: bool
409 _async_request_poller: Optional[asyncio.Task]
410 _request_style: _APIStyle
412 def _init_stream_request_mixin(
413 self, request_iterator: Optional[RequestIterableType]
414 ):
415 self._metadata_sent = asyncio.Event()
416 self._done_writing_flag = False
418 # If user passes in an async iterator, create a consumer Task.
419 if request_iterator is not None:
420 self._async_request_poller = self._loop.create_task(
421 self._consume_request_iterator(request_iterator)
422 )
423 self._request_style = _APIStyle.ASYNC_GENERATOR
424 else:
425 self._async_request_poller = None
426 self._request_style = _APIStyle.READER_WRITER
428 def _raise_for_different_style(self, style: _APIStyle):
429 if self._request_style is not style:
430 raise cygrpc.UsageError(_API_STYLE_ERROR)
432 def cancel(self) -> bool:
433 if super().cancel():
434 if self._async_request_poller is not None:
435 self._async_request_poller.cancel()
436 return True
437 else:
438 return False
440 def _metadata_sent_observer(self):
441 self._metadata_sent.set()
443 async def _consume_request_iterator(
444 self, request_iterator: RequestIterableType
445 ) -> None:
446 try:
447 if inspect.isasyncgen(request_iterator) or hasattr(
448 request_iterator, "__aiter__"
449 ):
450 async for request in request_iterator:
451 try:
452 await self._write(request)
453 except AioRpcError as rpc_error:
454 _LOGGER.debug(
455 (
456 "Exception while consuming the"
457 " request_iterator: %s"
458 ),
459 rpc_error,
460 )
461 return
462 else:
463 for request in request_iterator:
464 try:
465 await self._write(request)
466 except AioRpcError as rpc_error:
467 _LOGGER.debug(
468 (
469 "Exception while consuming the"
470 " request_iterator: %s"
471 ),
472 rpc_error,
473 )
474 return
476 await self._done_writing()
477 except: # pylint: disable=bare-except
478 # Client iterators can raise exceptions, which we should handle by
479 # cancelling the RPC and logging the client's error. No exceptions
480 # should escape this function.
481 _LOGGER.debug(
482 "Client request_iterator raised exception:\n%s",
483 traceback.format_exc(),
484 )
485 self.cancel()
487 async def _write(self, request: RequestType) -> None:
488 if self.done():
489 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
490 if self._done_writing_flag:
491 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
492 if not self._metadata_sent.is_set():
493 await self._metadata_sent.wait()
494 if self.done():
495 await self._raise_for_status()
497 serialized_request = _common.serialize(
498 request, self._request_serializer
499 )
500 try:
501 await self._cython_call.send_serialized_message(serialized_request)
502 except cygrpc.InternalError as err:
503 self._cython_call.set_internal_error(str(err))
504 await self._raise_for_status()
505 except asyncio.CancelledError:
506 if not self.cancelled():
507 self.cancel()
508 raise
510 async def _done_writing(self) -> None:
511 if self.done():
512 # If the RPC is finished, do nothing.
513 return
514 if not self._done_writing_flag:
515 # If the done writing is not sent before, try to send it.
516 self._done_writing_flag = True
517 try:
518 await self._cython_call.send_receive_close()
519 except asyncio.CancelledError:
520 if not self.cancelled():
521 self.cancel()
522 raise
524 async def write(self, request: RequestType) -> None:
525 self._raise_for_different_style(_APIStyle.READER_WRITER)
526 await self._write(request)
528 async def done_writing(self) -> None:
529 """Signal peer that client is done writing.
531 This method is idempotent.
532 """
533 self._raise_for_different_style(_APIStyle.READER_WRITER)
534 await self._done_writing()
536 async def wait_for_connection(self) -> None:
537 await self._metadata_sent.wait()
538 if self.done():
539 await self._raise_for_status()
542class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
543 """Object for managing unary-unary RPC calls.
545 Returned when an instance of `UnaryUnaryMultiCallable` object is called.
546 """
548 _request: RequestType
549 _invocation_task: asyncio.Task
551 # pylint: disable=too-many-arguments
552 def __init__(
553 self,
554 request: RequestType,
555 deadline: Optional[float],
556 metadata: Metadata,
557 credentials: Optional[grpc.CallCredentials],
558 wait_for_ready: Optional[bool],
559 channel: cygrpc.AioChannel,
560 method: bytes,
561 request_serializer: SerializingFunction,
562 response_deserializer: DeserializingFunction,
563 loop: asyncio.AbstractEventLoop,
564 ) -> None:
565 super().__init__(
566 channel.call(method, deadline, credentials, wait_for_ready),
567 metadata,
568 request_serializer,
569 response_deserializer,
570 loop,
571 )
572 self._request = request
573 self._context = cygrpc.build_census_context()
574 self._invocation_task = loop.create_task(self._invoke())
575 self._init_unary_response_mixin(self._invocation_task)
577 async def _invoke(self) -> ResponseType:
578 serialized_request = _common.serialize(
579 self._request, self._request_serializer
580 )
582 # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
583 # because the asyncio.Task class do not cache the exception object.
584 # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
585 try:
586 serialized_response = await self._cython_call.unary_unary(
587 serialized_request, self._metadata, self._context
588 )
589 except asyncio.CancelledError:
590 if not self.cancelled():
591 self.cancel()
593 if self._cython_call.is_ok():
594 return _common.deserialize(
595 serialized_response, self._response_deserializer
596 )
597 else:
598 return cygrpc.EOF
600 async def wait_for_connection(self) -> None:
601 await self._invocation_task
602 if self.done():
603 await self._raise_for_status()
606class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
607 """Object for managing unary-stream RPC calls.
609 Returned when an instance of `UnaryStreamMultiCallable` object is called.
610 """
612 _request: RequestType
613 _send_unary_request_task: asyncio.Task
615 # pylint: disable=too-many-arguments
616 def __init__(
617 self,
618 request: RequestType,
619 deadline: Optional[float],
620 metadata: Metadata,
621 credentials: Optional[grpc.CallCredentials],
622 wait_for_ready: Optional[bool],
623 channel: cygrpc.AioChannel,
624 method: bytes,
625 request_serializer: SerializingFunction,
626 response_deserializer: DeserializingFunction,
627 loop: asyncio.AbstractEventLoop,
628 ) -> None:
629 super().__init__(
630 channel.call(method, deadline, credentials, wait_for_ready),
631 metadata,
632 request_serializer,
633 response_deserializer,
634 loop,
635 )
636 self._request = request
637 self._context = cygrpc.build_census_context()
638 self._send_unary_request_task = loop.create_task(
639 self._send_unary_request()
640 )
641 self._init_stream_response_mixin(self._send_unary_request_task)
643 async def _send_unary_request(self) -> ResponseType:
644 serialized_request = _common.serialize(
645 self._request, self._request_serializer
646 )
647 try:
648 await self._cython_call.initiate_unary_stream(
649 serialized_request, self._metadata, self._context
650 )
651 except asyncio.CancelledError:
652 if not self.cancelled():
653 self.cancel()
654 raise
656 async def wait_for_connection(self) -> None:
657 await self._send_unary_request_task
658 if self.done():
659 await self._raise_for_status()
662# pylint: disable=too-many-ancestors
663class StreamUnaryCall(
664 _StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall
665):
666 """Object for managing stream-unary RPC calls.
668 Returned when an instance of `StreamUnaryMultiCallable` object is called.
669 """
671 # pylint: disable=too-many-arguments
672 def __init__(
673 self,
674 request_iterator: Optional[RequestIterableType],
675 deadline: Optional[float],
676 metadata: Metadata,
677 credentials: Optional[grpc.CallCredentials],
678 wait_for_ready: Optional[bool],
679 channel: cygrpc.AioChannel,
680 method: bytes,
681 request_serializer: SerializingFunction,
682 response_deserializer: DeserializingFunction,
683 loop: asyncio.AbstractEventLoop,
684 ) -> None:
685 super().__init__(
686 channel.call(method, deadline, credentials, wait_for_ready),
687 metadata,
688 request_serializer,
689 response_deserializer,
690 loop,
691 )
693 self._context = cygrpc.build_census_context()
694 self._init_stream_request_mixin(request_iterator)
695 self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
697 async def _conduct_rpc(self) -> ResponseType:
698 try:
699 serialized_response = await self._cython_call.stream_unary(
700 self._metadata, self._metadata_sent_observer, self._context
701 )
702 except asyncio.CancelledError:
703 if not self.cancelled():
704 self.cancel()
705 raise
707 if self._cython_call.is_ok():
708 return _common.deserialize(
709 serialized_response, self._response_deserializer
710 )
711 else:
712 return cygrpc.EOF
715class StreamStreamCall(
716 _StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall
717):
718 """Object for managing stream-stream RPC calls.
720 Returned when an instance of `StreamStreamMultiCallable` object is called.
721 """
723 _initializer: asyncio.Task
725 # pylint: disable=too-many-arguments
726 def __init__(
727 self,
728 request_iterator: Optional[RequestIterableType],
729 deadline: Optional[float],
730 metadata: Metadata,
731 credentials: Optional[grpc.CallCredentials],
732 wait_for_ready: Optional[bool],
733 channel: cygrpc.AioChannel,
734 method: bytes,
735 request_serializer: SerializingFunction,
736 response_deserializer: DeserializingFunction,
737 loop: asyncio.AbstractEventLoop,
738 ) -> None:
739 super().__init__(
740 channel.call(method, deadline, credentials, wait_for_ready),
741 metadata,
742 request_serializer,
743 response_deserializer,
744 loop,
745 )
746 self._context = cygrpc.build_census_context()
747 self._initializer = self._loop.create_task(self._prepare_rpc())
748 self._init_stream_request_mixin(request_iterator)
749 self._init_stream_response_mixin(self._initializer)
751 async def _prepare_rpc(self):
752 """This method prepares the RPC for receiving/sending messages.
754 All other operations around the stream should only happen after the
755 completion of this method.
756 """
757 try:
758 await self._cython_call.initiate_stream_stream(
759 self._metadata, self._metadata_sent_observer, self._context
760 )
761 except asyncio.CancelledError:
762 if not self.cancelled():
763 self.cancel()
764 # No need to raise RpcError here, because no one will `await` this task.