Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/grpc/aio/_interceptor.py: 38%
394 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:45 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:45 +0000
1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Interceptors implementation of gRPC Asyncio Python."""
15from abc import ABCMeta
16from abc import abstractmethod
17import asyncio
18import collections
19import functools
20from typing import (
21 AsyncIterable,
22 Awaitable,
23 Callable,
24 Iterator,
25 List,
26 Optional,
27 Sequence,
28 Union,
29)
31import grpc
32from grpc._cython import cygrpc
34from . import _base_call
35from ._call import AioRpcError
36from ._call import StreamStreamCall
37from ._call import StreamUnaryCall
38from ._call import UnaryStreamCall
39from ._call import UnaryUnaryCall
40from ._call import _API_STYLE_ERROR
41from ._call import _RPC_ALREADY_FINISHED_DETAILS
42from ._call import _RPC_HALF_CLOSED_DETAILS
43from ._metadata import Metadata
44from ._typing import DeserializingFunction
45from ._typing import DoneCallbackType
46from ._typing import RequestIterableType
47from ._typing import RequestType
48from ._typing import ResponseIterableType
49from ._typing import ResponseType
50from ._typing import SerializingFunction
51from ._utils import _timeout_to_deadline
53_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
56class ServerInterceptor(metaclass=ABCMeta):
57 """Affords intercepting incoming RPCs on the service-side.
59 This is an EXPERIMENTAL API.
60 """
62 @abstractmethod
63 async def intercept_service(
64 self,
65 continuation: Callable[
66 [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]
67 ],
68 handler_call_details: grpc.HandlerCallDetails,
69 ) -> grpc.RpcMethodHandler:
70 """Intercepts incoming RPCs before handing them over to a handler.
72 State can be passed from an interceptor to downstream interceptors
73 via contextvars. The first interceptor is called from an empty
74 contextvars.Context, and the same Context is used for downstream
75 interceptors and for the final handler call. Note that there are no
76 guarantees that interceptors and handlers will be called from the
77 same thread.
79 Args:
80 continuation: A function that takes a HandlerCallDetails and
81 proceeds to invoke the next interceptor in the chain, if any,
82 or the RPC handler lookup logic, with the call details passed
83 as an argument, and returns an RpcMethodHandler instance if
84 the RPC is considered serviced, or None otherwise.
85 handler_call_details: A HandlerCallDetails describing the RPC.
87 Returns:
88 An RpcMethodHandler with which the RPC may be serviced if the
89 interceptor chooses to service this RPC, or None otherwise.
90 """
93class ClientCallDetails(
94 collections.namedtuple(
95 "ClientCallDetails",
96 ("method", "timeout", "metadata", "credentials", "wait_for_ready"),
97 ),
98 grpc.ClientCallDetails,
99):
100 """Describes an RPC to be invoked.
102 This is an EXPERIMENTAL API.
104 Args:
105 method: The method name of the RPC.
106 timeout: An optional duration of time in seconds to allow for the RPC.
107 metadata: Optional metadata to be transmitted to the service-side of
108 the RPC.
109 credentials: An optional CallCredentials for the RPC.
110 wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism.
111 """
113 method: str
114 timeout: Optional[float]
115 metadata: Optional[Metadata]
116 credentials: Optional[grpc.CallCredentials]
117 wait_for_ready: Optional[bool]
120class ClientInterceptor(metaclass=ABCMeta):
121 """Base class used for all Aio Client Interceptor classes"""
124class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
125 """Affords intercepting unary-unary invocations."""
127 @abstractmethod
128 async def intercept_unary_unary(
129 self,
130 continuation: Callable[
131 [ClientCallDetails, RequestType], UnaryUnaryCall
132 ],
133 client_call_details: ClientCallDetails,
134 request: RequestType,
135 ) -> Union[UnaryUnaryCall, ResponseType]:
136 """Intercepts a unary-unary invocation asynchronously.
138 Args:
139 continuation: A coroutine that proceeds with the invocation by
140 executing the next interceptor in the chain or invoking the
141 actual RPC on the underlying Channel. It is the interceptor's
142 responsibility to call it if it decides to move the RPC forward.
143 The interceptor can use
144 `call = await continuation(client_call_details, request)`
145 to continue with the RPC. `continuation` returns the call to the
146 RPC.
147 client_call_details: A ClientCallDetails object describing the
148 outgoing RPC.
149 request: The request value for the RPC.
151 Returns:
152 An object with the RPC response.
154 Raises:
155 AioRpcError: Indicating that the RPC terminated with non-OK status.
156 asyncio.CancelledError: Indicating that the RPC was canceled.
157 """
160class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
161 """Affords intercepting unary-stream invocations."""
163 @abstractmethod
164 async def intercept_unary_stream(
165 self,
166 continuation: Callable[
167 [ClientCallDetails, RequestType], UnaryStreamCall
168 ],
169 client_call_details: ClientCallDetails,
170 request: RequestType,
171 ) -> Union[ResponseIterableType, UnaryStreamCall]:
172 """Intercepts a unary-stream invocation asynchronously.
174 The function could return the call object or an asynchronous
175 iterator, in case of being an asyncrhonous iterator this will
176 become the source of the reads done by the caller.
178 Args:
179 continuation: A coroutine that proceeds with the invocation by
180 executing the next interceptor in the chain or invoking the
181 actual RPC on the underlying Channel. It is the interceptor's
182 responsibility to call it if it decides to move the RPC forward.
183 The interceptor can use
184 `call = await continuation(client_call_details, request)`
185 to continue with the RPC. `continuation` returns the call to the
186 RPC.
187 client_call_details: A ClientCallDetails object describing the
188 outgoing RPC.
189 request: The request value for the RPC.
191 Returns:
192 The RPC Call or an asynchronous iterator.
194 Raises:
195 AioRpcError: Indicating that the RPC terminated with non-OK status.
196 asyncio.CancelledError: Indicating that the RPC was canceled.
197 """
200class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
201 """Affords intercepting stream-unary invocations."""
203 @abstractmethod
204 async def intercept_stream_unary(
205 self,
206 continuation: Callable[
207 [ClientCallDetails, RequestType], StreamUnaryCall
208 ],
209 client_call_details: ClientCallDetails,
210 request_iterator: RequestIterableType,
211 ) -> StreamUnaryCall:
212 """Intercepts a stream-unary invocation asynchronously.
214 Within the interceptor the usage of the call methods like `write` or
215 even awaiting the call should be done carefully, since the caller
216 could be expecting an untouched call, for example for start writing
217 messages to it.
219 Args:
220 continuation: A coroutine that proceeds with the invocation by
221 executing the next interceptor in the chain or invoking the
222 actual RPC on the underlying Channel. It is the interceptor's
223 responsibility to call it if it decides to move the RPC forward.
224 The interceptor can use
225 `call = await continuation(client_call_details, request_iterator)`
226 to continue with the RPC. `continuation` returns the call to the
227 RPC.
228 client_call_details: A ClientCallDetails object describing the
229 outgoing RPC.
230 request_iterator: The request iterator that will produce requests
231 for the RPC.
233 Returns:
234 The RPC Call.
236 Raises:
237 AioRpcError: Indicating that the RPC terminated with non-OK status.
238 asyncio.CancelledError: Indicating that the RPC was canceled.
239 """
242class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
243 """Affords intercepting stream-stream invocations."""
245 @abstractmethod
246 async def intercept_stream_stream(
247 self,
248 continuation: Callable[
249 [ClientCallDetails, RequestType], StreamStreamCall
250 ],
251 client_call_details: ClientCallDetails,
252 request_iterator: RequestIterableType,
253 ) -> Union[ResponseIterableType, StreamStreamCall]:
254 """Intercepts a stream-stream invocation asynchronously.
256 Within the interceptor the usage of the call methods like `write` or
257 even awaiting the call should be done carefully, since the caller
258 could be expecting an untouched call, for example for start writing
259 messages to it.
261 The function could return the call object or an asynchronous
262 iterator, in case of being an asyncrhonous iterator this will
263 become the source of the reads done by the caller.
265 Args:
266 continuation: A coroutine that proceeds with the invocation by
267 executing the next interceptor in the chain or invoking the
268 actual RPC on the underlying Channel. It is the interceptor's
269 responsibility to call it if it decides to move the RPC forward.
270 The interceptor can use
271 `call = await continuation(client_call_details, request_iterator)`
272 to continue with the RPC. `continuation` returns the call to the
273 RPC.
274 client_call_details: A ClientCallDetails object describing the
275 outgoing RPC.
276 request_iterator: The request iterator that will produce requests
277 for the RPC.
279 Returns:
280 The RPC Call or an asynchronous iterator.
282 Raises:
283 AioRpcError: Indicating that the RPC terminated with non-OK status.
284 asyncio.CancelledError: Indicating that the RPC was canceled.
285 """
288class InterceptedCall:
289 """Base implementation for all intercepted call arities.
291 Interceptors might have some work to do before the RPC invocation with
292 the capacity of changing the invocation parameters, and some work to do
293 after the RPC invocation with the capacity for accessing to the wrapped
294 `UnaryUnaryCall`.
296 It handles also early and later cancellations, when the RPC has not even
297 started and the execution is still held by the interceptors or when the
298 RPC has finished but again the execution is still held by the interceptors.
300 Once the RPC is finally executed, all methods are finally done against the
301 intercepted call, being at the same time the same call returned to the
302 interceptors.
304 As a base class for all of the interceptors implements the logic around
305 final status, metadata and cancellation.
306 """
308 _interceptors_task: asyncio.Task
309 _pending_add_done_callbacks: Sequence[DoneCallbackType]
311 def __init__(self, interceptors_task: asyncio.Task) -> None:
312 self._interceptors_task = interceptors_task
313 self._pending_add_done_callbacks = []
314 self._interceptors_task.add_done_callback(
315 self._fire_or_add_pending_done_callbacks
316 )
318 def __del__(self):
319 self.cancel()
321 def _fire_or_add_pending_done_callbacks(
322 self, interceptors_task: asyncio.Task
323 ) -> None:
324 if not self._pending_add_done_callbacks:
325 return
327 call_completed = False
329 try:
330 call = interceptors_task.result()
331 if call.done():
332 call_completed = True
333 except (AioRpcError, asyncio.CancelledError):
334 call_completed = True
336 if call_completed:
337 for callback in self._pending_add_done_callbacks:
338 callback(self)
339 else:
340 for callback in self._pending_add_done_callbacks:
341 callback = functools.partial(
342 self._wrap_add_done_callback, callback
343 )
344 call.add_done_callback(callback)
346 self._pending_add_done_callbacks = []
348 def _wrap_add_done_callback(
349 self, callback: DoneCallbackType, unused_call: _base_call.Call
350 ) -> None:
351 callback(self)
353 def cancel(self) -> bool:
354 if not self._interceptors_task.done():
355 # There is no yet the intercepted call available,
356 # Trying to cancel it by using the generic Asyncio
357 # cancellation method.
358 return self._interceptors_task.cancel()
360 try:
361 call = self._interceptors_task.result()
362 except AioRpcError:
363 return False
364 except asyncio.CancelledError:
365 return False
367 return call.cancel()
369 def cancelled(self) -> bool:
370 if not self._interceptors_task.done():
371 return False
373 try:
374 call = self._interceptors_task.result()
375 except AioRpcError as err:
376 return err.code() == grpc.StatusCode.CANCELLED
377 except asyncio.CancelledError:
378 return True
380 return call.cancelled()
382 def done(self) -> bool:
383 if not self._interceptors_task.done():
384 return False
386 try:
387 call = self._interceptors_task.result()
388 except (AioRpcError, asyncio.CancelledError):
389 return True
391 return call.done()
393 def add_done_callback(self, callback: DoneCallbackType) -> None:
394 if not self._interceptors_task.done():
395 self._pending_add_done_callbacks.append(callback)
396 return
398 try:
399 call = self._interceptors_task.result()
400 except (AioRpcError, asyncio.CancelledError):
401 callback(self)
402 return
404 if call.done():
405 callback(self)
406 else:
407 callback = functools.partial(self._wrap_add_done_callback, callback)
408 call.add_done_callback(callback)
410 def time_remaining(self) -> Optional[float]:
411 raise NotImplementedError()
413 async def initial_metadata(self) -> Optional[Metadata]:
414 try:
415 call = await self._interceptors_task
416 except AioRpcError as err:
417 return err.initial_metadata()
418 except asyncio.CancelledError:
419 return None
421 return await call.initial_metadata()
423 async def trailing_metadata(self) -> Optional[Metadata]:
424 try:
425 call = await self._interceptors_task
426 except AioRpcError as err:
427 return err.trailing_metadata()
428 except asyncio.CancelledError:
429 return None
431 return await call.trailing_metadata()
433 async def code(self) -> grpc.StatusCode:
434 try:
435 call = await self._interceptors_task
436 except AioRpcError as err:
437 return err.code()
438 except asyncio.CancelledError:
439 return grpc.StatusCode.CANCELLED
441 return await call.code()
443 async def details(self) -> str:
444 try:
445 call = await self._interceptors_task
446 except AioRpcError as err:
447 return err.details()
448 except asyncio.CancelledError:
449 return _LOCAL_CANCELLATION_DETAILS
451 return await call.details()
453 async def debug_error_string(self) -> Optional[str]:
454 try:
455 call = await self._interceptors_task
456 except AioRpcError as err:
457 return err.debug_error_string()
458 except asyncio.CancelledError:
459 return ""
461 return await call.debug_error_string()
463 async def wait_for_connection(self) -> None:
464 call = await self._interceptors_task
465 return await call.wait_for_connection()
468class _InterceptedUnaryResponseMixin:
469 def __await__(self):
470 call = yield from self._interceptors_task.__await__()
471 response = yield from call.__await__()
472 return response
475class _InterceptedStreamResponseMixin:
476 _response_aiter: Optional[AsyncIterable[ResponseType]]
478 def _init_stream_response_mixin(self) -> None:
479 # Is initalized later, otherwise if the iterator is not finnally
480 # consumed a logging warning is emmited by Asyncio.
481 self._response_aiter = None
483 async def _wait_for_interceptor_task_response_iterator(
484 self,
485 ) -> ResponseType:
486 call = await self._interceptors_task
487 async for response in call:
488 yield response
490 def __aiter__(self) -> AsyncIterable[ResponseType]:
491 if self._response_aiter is None:
492 self._response_aiter = (
493 self._wait_for_interceptor_task_response_iterator()
494 )
495 return self._response_aiter
497 async def read(self) -> ResponseType:
498 if self._response_aiter is None:
499 self._response_aiter = (
500 self._wait_for_interceptor_task_response_iterator()
501 )
502 return await self._response_aiter.asend(None)
505class _InterceptedStreamRequestMixin:
506 _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
507 _write_to_iterator_queue: Optional[asyncio.Queue]
508 _status_code_task: Optional[asyncio.Task]
510 _FINISH_ITERATOR_SENTINEL = object()
512 def _init_stream_request_mixin(
513 self, request_iterator: Optional[RequestIterableType]
514 ) -> RequestIterableType:
515 if request_iterator is None:
516 # We provide our own request iterator which is a proxy
517 # of the futures writes that will be done by the caller.
518 self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
519 self._write_to_iterator_async_gen = (
520 self._proxy_writes_as_request_iterator()
521 )
522 self._status_code_task = None
523 request_iterator = self._write_to_iterator_async_gen
524 else:
525 self._write_to_iterator_queue = None
527 return request_iterator
529 async def _proxy_writes_as_request_iterator(self):
530 await self._interceptors_task
532 while True:
533 value = await self._write_to_iterator_queue.get()
534 if (
535 value
536 is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL
537 ):
538 break
539 yield value
541 async def _write_to_iterator_queue_interruptible(
542 self, request: RequestType, call: InterceptedCall
543 ):
544 # Write the specified 'request' to the request iterator queue using the
545 # specified 'call' to allow for interruption of the write in the case
546 # of abrupt termination of the call.
547 if self._status_code_task is None:
548 self._status_code_task = self._loop.create_task(call.code())
550 await asyncio.wait(
551 (
552 self._loop.create_task(
553 self._write_to_iterator_queue.put(request)
554 ),
555 self._status_code_task,
556 ),
557 return_when=asyncio.FIRST_COMPLETED,
558 )
560 async def write(self, request: RequestType) -> None:
561 # If no queue was created it means that requests
562 # should be expected through an iterators provided
563 # by the caller.
564 if self._write_to_iterator_queue is None:
565 raise cygrpc.UsageError(_API_STYLE_ERROR)
567 try:
568 call = await self._interceptors_task
569 except (asyncio.CancelledError, AioRpcError):
570 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
572 if call.done():
573 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
574 elif call._done_writing_flag:
575 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
577 await self._write_to_iterator_queue_interruptible(request, call)
579 if call.done():
580 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
582 async def done_writing(self) -> None:
583 """Signal peer that client is done writing.
585 This method is idempotent.
586 """
587 # If no queue was created it means that requests
588 # should be expected through an iterators provided
589 # by the caller.
590 if self._write_to_iterator_queue is None:
591 raise cygrpc.UsageError(_API_STYLE_ERROR)
593 try:
594 call = await self._interceptors_task
595 except asyncio.CancelledError:
596 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
598 await self._write_to_iterator_queue_interruptible(
599 _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call
600 )
603class InterceptedUnaryUnaryCall(
604 _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall
605):
606 """Used for running a `UnaryUnaryCall` wrapped by interceptors.
608 For the `__await__` method is it is proxied to the intercepted call only when
609 the interceptor task is finished.
610 """
612 _loop: asyncio.AbstractEventLoop
613 _channel: cygrpc.AioChannel
615 # pylint: disable=too-many-arguments
616 def __init__(
617 self,
618 interceptors: Sequence[UnaryUnaryClientInterceptor],
619 request: RequestType,
620 timeout: Optional[float],
621 metadata: Metadata,
622 credentials: Optional[grpc.CallCredentials],
623 wait_for_ready: Optional[bool],
624 channel: cygrpc.AioChannel,
625 method: bytes,
626 request_serializer: SerializingFunction,
627 response_deserializer: DeserializingFunction,
628 loop: asyncio.AbstractEventLoop,
629 ) -> None:
630 self._loop = loop
631 self._channel = channel
632 interceptors_task = loop.create_task(
633 self._invoke(
634 interceptors,
635 method,
636 timeout,
637 metadata,
638 credentials,
639 wait_for_ready,
640 request,
641 request_serializer,
642 response_deserializer,
643 )
644 )
645 super().__init__(interceptors_task)
647 # pylint: disable=too-many-arguments
648 async def _invoke(
649 self,
650 interceptors: Sequence[UnaryUnaryClientInterceptor],
651 method: bytes,
652 timeout: Optional[float],
653 metadata: Optional[Metadata],
654 credentials: Optional[grpc.CallCredentials],
655 wait_for_ready: Optional[bool],
656 request: RequestType,
657 request_serializer: SerializingFunction,
658 response_deserializer: DeserializingFunction,
659 ) -> UnaryUnaryCall:
660 """Run the RPC call wrapped in interceptors"""
662 async def _run_interceptor(
663 interceptors: List[UnaryUnaryClientInterceptor],
664 client_call_details: ClientCallDetails,
665 request: RequestType,
666 ) -> _base_call.UnaryUnaryCall:
667 if interceptors:
668 continuation = functools.partial(
669 _run_interceptor, interceptors[1:]
670 )
671 call_or_response = await interceptors[0].intercept_unary_unary(
672 continuation, client_call_details, request
673 )
675 if isinstance(call_or_response, _base_call.UnaryUnaryCall):
676 return call_or_response
677 else:
678 return UnaryUnaryCallResponse(call_or_response)
680 else:
681 return UnaryUnaryCall(
682 request,
683 _timeout_to_deadline(client_call_details.timeout),
684 client_call_details.metadata,
685 client_call_details.credentials,
686 client_call_details.wait_for_ready,
687 self._channel,
688 client_call_details.method,
689 request_serializer,
690 response_deserializer,
691 self._loop,
692 )
694 client_call_details = ClientCallDetails(
695 method, timeout, metadata, credentials, wait_for_ready
696 )
697 return await _run_interceptor(
698 list(interceptors), client_call_details, request
699 )
701 def time_remaining(self) -> Optional[float]:
702 raise NotImplementedError()
705class InterceptedUnaryStreamCall(
706 _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall
707):
708 """Used for running a `UnaryStreamCall` wrapped by interceptors."""
710 _loop: asyncio.AbstractEventLoop
711 _channel: cygrpc.AioChannel
712 _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
714 # pylint: disable=too-many-arguments
715 def __init__(
716 self,
717 interceptors: Sequence[UnaryStreamClientInterceptor],
718 request: RequestType,
719 timeout: Optional[float],
720 metadata: Metadata,
721 credentials: Optional[grpc.CallCredentials],
722 wait_for_ready: Optional[bool],
723 channel: cygrpc.AioChannel,
724 method: bytes,
725 request_serializer: SerializingFunction,
726 response_deserializer: DeserializingFunction,
727 loop: asyncio.AbstractEventLoop,
728 ) -> None:
729 self._loop = loop
730 self._channel = channel
731 self._init_stream_response_mixin()
732 self._last_returned_call_from_interceptors = None
733 interceptors_task = loop.create_task(
734 self._invoke(
735 interceptors,
736 method,
737 timeout,
738 metadata,
739 credentials,
740 wait_for_ready,
741 request,
742 request_serializer,
743 response_deserializer,
744 )
745 )
746 super().__init__(interceptors_task)
748 # pylint: disable=too-many-arguments
749 async def _invoke(
750 self,
751 interceptors: Sequence[UnaryStreamClientInterceptor],
752 method: bytes,
753 timeout: Optional[float],
754 metadata: Optional[Metadata],
755 credentials: Optional[grpc.CallCredentials],
756 wait_for_ready: Optional[bool],
757 request: RequestType,
758 request_serializer: SerializingFunction,
759 response_deserializer: DeserializingFunction,
760 ) -> UnaryStreamCall:
761 """Run the RPC call wrapped in interceptors"""
763 async def _run_interceptor(
764 interceptors: List[UnaryStreamClientInterceptor],
765 client_call_details: ClientCallDetails,
766 request: RequestType,
767 ) -> _base_call.UnaryStreamCall:
768 if interceptors:
769 continuation = functools.partial(
770 _run_interceptor, interceptors[1:]
771 )
773 call_or_response_iterator = await interceptors[
774 0
775 ].intercept_unary_stream(
776 continuation, client_call_details, request
777 )
779 if isinstance(
780 call_or_response_iterator, _base_call.UnaryStreamCall
781 ):
782 self._last_returned_call_from_interceptors = (
783 call_or_response_iterator
784 )
785 else:
786 self._last_returned_call_from_interceptors = (
787 UnaryStreamCallResponseIterator(
788 self._last_returned_call_from_interceptors,
789 call_or_response_iterator,
790 )
791 )
792 return self._last_returned_call_from_interceptors
793 else:
794 self._last_returned_call_from_interceptors = UnaryStreamCall(
795 request,
796 _timeout_to_deadline(client_call_details.timeout),
797 client_call_details.metadata,
798 client_call_details.credentials,
799 client_call_details.wait_for_ready,
800 self._channel,
801 client_call_details.method,
802 request_serializer,
803 response_deserializer,
804 self._loop,
805 )
807 return self._last_returned_call_from_interceptors
809 client_call_details = ClientCallDetails(
810 method, timeout, metadata, credentials, wait_for_ready
811 )
812 return await _run_interceptor(
813 list(interceptors), client_call_details, request
814 )
816 def time_remaining(self) -> Optional[float]:
817 raise NotImplementedError()
820class InterceptedStreamUnaryCall(
821 _InterceptedUnaryResponseMixin,
822 _InterceptedStreamRequestMixin,
823 InterceptedCall,
824 _base_call.StreamUnaryCall,
825):
826 """Used for running a `StreamUnaryCall` wrapped by interceptors.
828 For the `__await__` method is it is proxied to the intercepted call only when
829 the interceptor task is finished.
830 """
832 _loop: asyncio.AbstractEventLoop
833 _channel: cygrpc.AioChannel
835 # pylint: disable=too-many-arguments
836 def __init__(
837 self,
838 interceptors: Sequence[StreamUnaryClientInterceptor],
839 request_iterator: Optional[RequestIterableType],
840 timeout: Optional[float],
841 metadata: Metadata,
842 credentials: Optional[grpc.CallCredentials],
843 wait_for_ready: Optional[bool],
844 channel: cygrpc.AioChannel,
845 method: bytes,
846 request_serializer: SerializingFunction,
847 response_deserializer: DeserializingFunction,
848 loop: asyncio.AbstractEventLoop,
849 ) -> None:
850 self._loop = loop
851 self._channel = channel
852 request_iterator = self._init_stream_request_mixin(request_iterator)
853 interceptors_task = loop.create_task(
854 self._invoke(
855 interceptors,
856 method,
857 timeout,
858 metadata,
859 credentials,
860 wait_for_ready,
861 request_iterator,
862 request_serializer,
863 response_deserializer,
864 )
865 )
866 super().__init__(interceptors_task)
868 # pylint: disable=too-many-arguments
869 async def _invoke(
870 self,
871 interceptors: Sequence[StreamUnaryClientInterceptor],
872 method: bytes,
873 timeout: Optional[float],
874 metadata: Optional[Metadata],
875 credentials: Optional[grpc.CallCredentials],
876 wait_for_ready: Optional[bool],
877 request_iterator: RequestIterableType,
878 request_serializer: SerializingFunction,
879 response_deserializer: DeserializingFunction,
880 ) -> StreamUnaryCall:
881 """Run the RPC call wrapped in interceptors"""
883 async def _run_interceptor(
884 interceptors: Iterator[StreamUnaryClientInterceptor],
885 client_call_details: ClientCallDetails,
886 request_iterator: RequestIterableType,
887 ) -> _base_call.StreamUnaryCall:
888 if interceptors:
889 continuation = functools.partial(
890 _run_interceptor, interceptors[1:]
891 )
893 return await interceptors[0].intercept_stream_unary(
894 continuation, client_call_details, request_iterator
895 )
896 else:
897 return StreamUnaryCall(
898 request_iterator,
899 _timeout_to_deadline(client_call_details.timeout),
900 client_call_details.metadata,
901 client_call_details.credentials,
902 client_call_details.wait_for_ready,
903 self._channel,
904 client_call_details.method,
905 request_serializer,
906 response_deserializer,
907 self._loop,
908 )
910 client_call_details = ClientCallDetails(
911 method, timeout, metadata, credentials, wait_for_ready
912 )
913 return await _run_interceptor(
914 list(interceptors), client_call_details, request_iterator
915 )
917 def time_remaining(self) -> Optional[float]:
918 raise NotImplementedError()
921class InterceptedStreamStreamCall(
922 _InterceptedStreamResponseMixin,
923 _InterceptedStreamRequestMixin,
924 InterceptedCall,
925 _base_call.StreamStreamCall,
926):
927 """Used for running a `StreamStreamCall` wrapped by interceptors."""
929 _loop: asyncio.AbstractEventLoop
930 _channel: cygrpc.AioChannel
931 _last_returned_call_from_interceptors = Optional[
932 _base_call.StreamStreamCall
933 ]
935 # pylint: disable=too-many-arguments
936 def __init__(
937 self,
938 interceptors: Sequence[StreamStreamClientInterceptor],
939 request_iterator: Optional[RequestIterableType],
940 timeout: Optional[float],
941 metadata: Metadata,
942 credentials: Optional[grpc.CallCredentials],
943 wait_for_ready: Optional[bool],
944 channel: cygrpc.AioChannel,
945 method: bytes,
946 request_serializer: SerializingFunction,
947 response_deserializer: DeserializingFunction,
948 loop: asyncio.AbstractEventLoop,
949 ) -> None:
950 self._loop = loop
951 self._channel = channel
952 self._init_stream_response_mixin()
953 request_iterator = self._init_stream_request_mixin(request_iterator)
954 self._last_returned_call_from_interceptors = None
955 interceptors_task = loop.create_task(
956 self._invoke(
957 interceptors,
958 method,
959 timeout,
960 metadata,
961 credentials,
962 wait_for_ready,
963 request_iterator,
964 request_serializer,
965 response_deserializer,
966 )
967 )
968 super().__init__(interceptors_task)
970 # pylint: disable=too-many-arguments
971 async def _invoke(
972 self,
973 interceptors: Sequence[StreamStreamClientInterceptor],
974 method: bytes,
975 timeout: Optional[float],
976 metadata: Optional[Metadata],
977 credentials: Optional[grpc.CallCredentials],
978 wait_for_ready: Optional[bool],
979 request_iterator: RequestIterableType,
980 request_serializer: SerializingFunction,
981 response_deserializer: DeserializingFunction,
982 ) -> StreamStreamCall:
983 """Run the RPC call wrapped in interceptors"""
985 async def _run_interceptor(
986 interceptors: List[StreamStreamClientInterceptor],
987 client_call_details: ClientCallDetails,
988 request_iterator: RequestIterableType,
989 ) -> _base_call.StreamStreamCall:
990 if interceptors:
991 continuation = functools.partial(
992 _run_interceptor, interceptors[1:]
993 )
995 call_or_response_iterator = await interceptors[
996 0
997 ].intercept_stream_stream(
998 continuation, client_call_details, request_iterator
999 )
1001 if isinstance(
1002 call_or_response_iterator, _base_call.StreamStreamCall
1003 ):
1004 self._last_returned_call_from_interceptors = (
1005 call_or_response_iterator
1006 )
1007 else:
1008 self._last_returned_call_from_interceptors = (
1009 StreamStreamCallResponseIterator(
1010 self._last_returned_call_from_interceptors,
1011 call_or_response_iterator,
1012 )
1013 )
1014 return self._last_returned_call_from_interceptors
1015 else:
1016 self._last_returned_call_from_interceptors = StreamStreamCall(
1017 request_iterator,
1018 _timeout_to_deadline(client_call_details.timeout),
1019 client_call_details.metadata,
1020 client_call_details.credentials,
1021 client_call_details.wait_for_ready,
1022 self._channel,
1023 client_call_details.method,
1024 request_serializer,
1025 response_deserializer,
1026 self._loop,
1027 )
1028 return self._last_returned_call_from_interceptors
1030 client_call_details = ClientCallDetails(
1031 method, timeout, metadata, credentials, wait_for_ready
1032 )
1033 return await _run_interceptor(
1034 list(interceptors), client_call_details, request_iterator
1035 )
1037 def time_remaining(self) -> Optional[float]:
1038 raise NotImplementedError()
1041class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
1042 """Final UnaryUnaryCall class finished with a response."""
1044 _response: ResponseType
1046 def __init__(self, response: ResponseType) -> None:
1047 self._response = response
1049 def cancel(self) -> bool:
1050 return False
1052 def cancelled(self) -> bool:
1053 return False
1055 def done(self) -> bool:
1056 return True
1058 def add_done_callback(self, unused_callback) -> None:
1059 raise NotImplementedError()
1061 def time_remaining(self) -> Optional[float]:
1062 raise NotImplementedError()
1064 async def initial_metadata(self) -> Optional[Metadata]:
1065 return None
1067 async def trailing_metadata(self) -> Optional[Metadata]:
1068 return None
1070 async def code(self) -> grpc.StatusCode:
1071 return grpc.StatusCode.OK
1073 async def details(self) -> str:
1074 return ""
1076 async def debug_error_string(self) -> Optional[str]:
1077 return None
1079 def __await__(self):
1080 if False: # pylint: disable=using-constant-test
1081 # This code path is never used, but a yield statement is needed
1082 # for telling the interpreter that __await__ is a generator.
1083 yield None
1084 return self._response
1086 async def wait_for_connection(self) -> None:
1087 pass
1090class _StreamCallResponseIterator:
1091 _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
1092 _response_iterator: AsyncIterable[ResponseType]
1094 def __init__(
1095 self,
1096 call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall],
1097 response_iterator: AsyncIterable[ResponseType],
1098 ) -> None:
1099 self._response_iterator = response_iterator
1100 self._call = call
1102 def cancel(self) -> bool:
1103 return self._call.cancel()
1105 def cancelled(self) -> bool:
1106 return self._call.cancelled()
1108 def done(self) -> bool:
1109 return self._call.done()
1111 def add_done_callback(self, callback) -> None:
1112 self._call.add_done_callback(callback)
1114 def time_remaining(self) -> Optional[float]:
1115 return self._call.time_remaining()
1117 async def initial_metadata(self) -> Optional[Metadata]:
1118 return await self._call.initial_metadata()
1120 async def trailing_metadata(self) -> Optional[Metadata]:
1121 return await self._call.trailing_metadata()
1123 async def code(self) -> grpc.StatusCode:
1124 return await self._call.code()
1126 async def details(self) -> str:
1127 return await self._call.details()
1129 async def debug_error_string(self) -> Optional[str]:
1130 return await self._call.debug_error_string()
1132 def __aiter__(self):
1133 return self._response_iterator.__aiter__()
1135 async def wait_for_connection(self) -> None:
1136 return await self._call.wait_for_connection()
1139class UnaryStreamCallResponseIterator(
1140 _StreamCallResponseIterator, _base_call.UnaryStreamCall
1141):
1142 """UnaryStreamCall class wich uses an alternative response iterator."""
1144 async def read(self) -> ResponseType:
1145 # Behind the scenes everyting goes through the
1146 # async iterator. So this path should not be reached.
1147 raise NotImplementedError()
1150class StreamStreamCallResponseIterator(
1151 _StreamCallResponseIterator, _base_call.StreamStreamCall
1152):
1153 """StreamStreamCall class wich uses an alternative response iterator."""
1155 async def read(self) -> ResponseType:
1156 # Behind the scenes everyting goes through the
1157 # async iterator. So this path should not be reached.
1158 raise NotImplementedError()
1160 async def write(self, request: RequestType) -> None:
1161 # Behind the scenes everyting goes through the
1162 # async iterator provided by the InterceptedStreamStreamCall.
1163 # So this path should not be reached.
1164 raise NotImplementedError()
1166 async def done_writing(self) -> None:
1167 # Behind the scenes everyting goes through the
1168 # async iterator provided by the InterceptedStreamStreamCall.
1169 # So this path should not be reached.
1170 raise NotImplementedError()
1172 @property
1173 def _done_writing_flag(self) -> bool:
1174 return self._call._done_writing_flag