1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Interceptors implementation of gRPC Asyncio Python."""
15from abc import ABCMeta
16from abc import abstractmethod
17import asyncio
18import collections
19import functools
20from typing import (
21 AsyncIterable,
22 Awaitable,
23 Callable,
24 Iterator,
25 List,
26 Optional,
27 Sequence,
28 Union,
29)
30
31import grpc
32from grpc._cython import cygrpc
33
34from . import _base_call
35from ._call import AioRpcError
36from ._call import StreamStreamCall
37from ._call import StreamUnaryCall
38from ._call import UnaryStreamCall
39from ._call import UnaryUnaryCall
40from ._call import _API_STYLE_ERROR
41from ._call import _RPC_ALREADY_FINISHED_DETAILS
42from ._call import _RPC_HALF_CLOSED_DETAILS
43from ._metadata import Metadata
44from ._typing import DeserializingFunction
45from ._typing import DoneCallbackType
46from ._typing import EOFType
47from ._typing import RequestIterableType
48from ._typing import RequestType
49from ._typing import ResponseIterableType
50from ._typing import ResponseType
51from ._typing import SerializingFunction
52from ._utils import _timeout_to_deadline
53
54_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
55
56
57class ServerInterceptor(metaclass=ABCMeta):
58 """Affords intercepting incoming RPCs on the service-side.
59
60 This is an EXPERIMENTAL API.
61 """
62
63 @abstractmethod
64 async def intercept_service(
65 self,
66 continuation: Callable[
67 [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]
68 ],
69 handler_call_details: grpc.HandlerCallDetails,
70 ) -> grpc.RpcMethodHandler:
71 """Intercepts incoming RPCs before handing them over to a handler.
72
73 State can be passed from an interceptor to downstream interceptors
74 via contextvars. The first interceptor is called from an empty
75 contextvars.Context, and the same Context is used for downstream
76 interceptors and for the final handler call. Note that there are no
77 guarantees that interceptors and handlers will be called from the
78 same thread.
79
80 Args:
81 continuation: A function that takes a HandlerCallDetails and
82 proceeds to invoke the next interceptor in the chain, if any,
83 or the RPC handler lookup logic, with the call details passed
84 as an argument, and returns an RpcMethodHandler instance if
85 the RPC is considered serviced, or None otherwise.
86 handler_call_details: A HandlerCallDetails describing the RPC.
87
88 Returns:
89 An RpcMethodHandler with which the RPC may be serviced if the
90 interceptor chooses to service this RPC, or None otherwise.
91 """
92
93
94class ClientCallDetails(
95 collections.namedtuple(
96 "ClientCallDetails",
97 ("method", "timeout", "metadata", "credentials", "wait_for_ready"),
98 ),
99 grpc.ClientCallDetails,
100):
101 """Describes an RPC to be invoked.
102
103 This is an EXPERIMENTAL API.
104
105 Args:
106 method: The method name of the RPC.
107 timeout: An optional duration of time in seconds to allow for the RPC.
108 metadata: Optional metadata to be transmitted to the service-side of
109 the RPC.
110 credentials: An optional CallCredentials for the RPC.
111 wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism.
112 """
113
114 method: str
115 timeout: Optional[float]
116 metadata: Optional[Metadata]
117 credentials: Optional[grpc.CallCredentials]
118 wait_for_ready: Optional[bool]
119
120
121class ClientInterceptor(metaclass=ABCMeta):
122 """Base class used for all Aio Client Interceptor classes"""
123
124
125class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
126 """Affords intercepting unary-unary invocations."""
127
128 @abstractmethod
129 async def intercept_unary_unary(
130 self,
131 continuation: Callable[
132 [ClientCallDetails, RequestType], UnaryUnaryCall
133 ],
134 client_call_details: ClientCallDetails,
135 request: RequestType,
136 ) -> Union[UnaryUnaryCall, ResponseType]:
137 """Intercepts a unary-unary invocation asynchronously.
138
139 Args:
140 continuation: A coroutine that proceeds with the invocation by
141 executing the next interceptor in the chain or invoking the
142 actual RPC on the underlying Channel. It is the interceptor's
143 responsibility to call it if it decides to move the RPC forward.
144 The interceptor can use
145 `call = await continuation(client_call_details, request)`
146 to continue with the RPC. `continuation` returns the call to the
147 RPC.
148 client_call_details: A ClientCallDetails object describing the
149 outgoing RPC.
150 request: The request value for the RPC.
151
152 Returns:
153 An object with the RPC response.
154
155 Raises:
156 AioRpcError: Indicating that the RPC terminated with non-OK status.
157 asyncio.CancelledError: Indicating that the RPC was canceled.
158 """
159
160
161class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
162 """Affords intercepting unary-stream invocations."""
163
164 @abstractmethod
165 async def intercept_unary_stream(
166 self,
167 continuation: Callable[
168 [ClientCallDetails, RequestType], UnaryStreamCall
169 ],
170 client_call_details: ClientCallDetails,
171 request: RequestType,
172 ) -> Union[ResponseIterableType, UnaryStreamCall]:
173 """Intercepts a unary-stream invocation asynchronously.
174
175 The function could return the call object or an asynchronous
176 iterator, in case of being an asyncrhonous iterator this will
177 become the source of the reads done by the caller.
178
179 Args:
180 continuation: A coroutine that proceeds with the invocation by
181 executing the next interceptor in the chain or invoking the
182 actual RPC on the underlying Channel. It is the interceptor's
183 responsibility to call it if it decides to move the RPC forward.
184 The interceptor can use
185 `call = await continuation(client_call_details, request)`
186 to continue with the RPC. `continuation` returns the call to the
187 RPC.
188 client_call_details: A ClientCallDetails object describing the
189 outgoing RPC.
190 request: The request value for the RPC.
191
192 Returns:
193 The RPC Call or an asynchronous iterator.
194
195 Raises:
196 AioRpcError: Indicating that the RPC terminated with non-OK status.
197 asyncio.CancelledError: Indicating that the RPC was canceled.
198 """
199
200
201class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
202 """Affords intercepting stream-unary invocations."""
203
204 @abstractmethod
205 async def intercept_stream_unary(
206 self,
207 continuation: Callable[
208 [ClientCallDetails, RequestType], StreamUnaryCall
209 ],
210 client_call_details: ClientCallDetails,
211 request_iterator: RequestIterableType,
212 ) -> StreamUnaryCall:
213 """Intercepts a stream-unary invocation asynchronously.
214
215 Within the interceptor the usage of the call methods like `write` or
216 even awaiting the call should be done carefully, since the caller
217 could be expecting an untouched call, for example for start writing
218 messages to it.
219
220 Args:
221 continuation: A coroutine that proceeds with the invocation by
222 executing the next interceptor in the chain or invoking the
223 actual RPC on the underlying Channel. It is the interceptor's
224 responsibility to call it if it decides to move the RPC forward.
225 The interceptor can use
226 `call = await continuation(client_call_details, request_iterator)`
227 to continue with the RPC. `continuation` returns the call to the
228 RPC.
229 client_call_details: A ClientCallDetails object describing the
230 outgoing RPC.
231 request_iterator: The request iterator that will produce requests
232 for the RPC.
233
234 Returns:
235 The RPC Call.
236
237 Raises:
238 AioRpcError: Indicating that the RPC terminated with non-OK status.
239 asyncio.CancelledError: Indicating that the RPC was canceled.
240 """
241
242
243class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
244 """Affords intercepting stream-stream invocations."""
245
246 @abstractmethod
247 async def intercept_stream_stream(
248 self,
249 continuation: Callable[
250 [ClientCallDetails, RequestType], StreamStreamCall
251 ],
252 client_call_details: ClientCallDetails,
253 request_iterator: RequestIterableType,
254 ) -> Union[ResponseIterableType, StreamStreamCall]:
255 """Intercepts a stream-stream invocation asynchronously.
256
257 Within the interceptor the usage of the call methods like `write` or
258 even awaiting the call should be done carefully, since the caller
259 could be expecting an untouched call, for example for start writing
260 messages to it.
261
262 The function could return the call object or an asynchronous
263 iterator, in case of being an asyncrhonous iterator this will
264 become the source of the reads done by the caller.
265
266 Args:
267 continuation: A coroutine that proceeds with the invocation by
268 executing the next interceptor in the chain or invoking the
269 actual RPC on the underlying Channel. It is the interceptor's
270 responsibility to call it if it decides to move the RPC forward.
271 The interceptor can use
272 `call = await continuation(client_call_details, request_iterator)`
273 to continue with the RPC. `continuation` returns the call to the
274 RPC.
275 client_call_details: A ClientCallDetails object describing the
276 outgoing RPC.
277 request_iterator: The request iterator that will produce requests
278 for the RPC.
279
280 Returns:
281 The RPC Call or an asynchronous iterator.
282
283 Raises:
284 AioRpcError: Indicating that the RPC terminated with non-OK status.
285 asyncio.CancelledError: Indicating that the RPC was canceled.
286 """
287
288
289class InterceptedCall:
290 """Base implementation for all intercepted call arities.
291
292 Interceptors might have some work to do before the RPC invocation with
293 the capacity of changing the invocation parameters, and some work to do
294 after the RPC invocation with the capacity for accessing to the wrapped
295 `UnaryUnaryCall`.
296
297 It handles also early and later cancellations, when the RPC has not even
298 started and the execution is still held by the interceptors or when the
299 RPC has finished but again the execution is still held by the interceptors.
300
301 Once the RPC is finally executed, all methods are finally done against the
302 intercepted call, being at the same time the same call returned to the
303 interceptors.
304
305 As a base class for all of the interceptors implements the logic around
306 final status, metadata and cancellation.
307 """
308
309 _interceptors_task: asyncio.Task
310 _pending_add_done_callbacks: Sequence[DoneCallbackType]
311
312 def __init__(self, interceptors_task: asyncio.Task) -> None:
313 self._interceptors_task = interceptors_task
314 self._pending_add_done_callbacks = []
315 self._interceptors_task.add_done_callback(
316 self._fire_or_add_pending_done_callbacks
317 )
318
319 def __del__(self):
320 self.cancel()
321
322 def _fire_or_add_pending_done_callbacks(
323 self, interceptors_task: asyncio.Task
324 ) -> None:
325 if not self._pending_add_done_callbacks:
326 return
327
328 call_completed = False
329
330 try:
331 call = interceptors_task.result()
332 if call.done():
333 call_completed = True
334 except (AioRpcError, asyncio.CancelledError):
335 call_completed = True
336
337 if call_completed:
338 for callback in self._pending_add_done_callbacks:
339 callback(self)
340 else:
341 for callback in self._pending_add_done_callbacks:
342 callback = functools.partial(
343 self._wrap_add_done_callback, callback
344 )
345 call.add_done_callback(callback)
346
347 self._pending_add_done_callbacks = []
348
349 def _wrap_add_done_callback(
350 self, callback: DoneCallbackType, unused_call: _base_call.Call
351 ) -> None:
352 callback(self)
353
354 def cancel(self) -> bool:
355 if not self._interceptors_task.done():
356 # There is no yet the intercepted call available,
357 # Trying to cancel it by using the generic Asyncio
358 # cancellation method.
359 return self._interceptors_task.cancel()
360
361 try:
362 call = self._interceptors_task.result()
363 except AioRpcError:
364 return False
365 except asyncio.CancelledError:
366 return False
367
368 return call.cancel()
369
370 def cancelled(self) -> bool:
371 if not self._interceptors_task.done():
372 return False
373
374 try:
375 call = self._interceptors_task.result()
376 except AioRpcError as err:
377 return err.code() == grpc.StatusCode.CANCELLED
378 except asyncio.CancelledError:
379 return True
380
381 return call.cancelled()
382
383 def done(self) -> bool:
384 if not self._interceptors_task.done():
385 return False
386
387 try:
388 call = self._interceptors_task.result()
389 except (AioRpcError, asyncio.CancelledError):
390 return True
391
392 return call.done()
393
394 def add_done_callback(self, callback: DoneCallbackType) -> None:
395 if not self._interceptors_task.done():
396 self._pending_add_done_callbacks.append(callback)
397 return
398
399 try:
400 call = self._interceptors_task.result()
401 except (AioRpcError, asyncio.CancelledError):
402 callback(self)
403 return
404
405 if call.done():
406 callback(self)
407 else:
408 callback = functools.partial(self._wrap_add_done_callback, callback)
409 call.add_done_callback(callback)
410
411 def time_remaining(self) -> Optional[float]:
412 raise NotImplementedError()
413
414 async def initial_metadata(self) -> Optional[Metadata]:
415 try:
416 call = await self._interceptors_task
417 except AioRpcError as err:
418 return err.initial_metadata()
419 except asyncio.CancelledError:
420 return None
421
422 return await call.initial_metadata()
423
424 async def trailing_metadata(self) -> Optional[Metadata]:
425 try:
426 call = await self._interceptors_task
427 except AioRpcError as err:
428 return err.trailing_metadata()
429 except asyncio.CancelledError:
430 return None
431
432 return await call.trailing_metadata()
433
434 async def code(self) -> grpc.StatusCode:
435 try:
436 call = await self._interceptors_task
437 except AioRpcError as err:
438 return err.code()
439 except asyncio.CancelledError:
440 return grpc.StatusCode.CANCELLED
441
442 return await call.code()
443
444 async def details(self) -> str:
445 try:
446 call = await self._interceptors_task
447 except AioRpcError as err:
448 return err.details()
449 except asyncio.CancelledError:
450 return _LOCAL_CANCELLATION_DETAILS
451
452 return await call.details()
453
454 async def debug_error_string(self) -> Optional[str]:
455 try:
456 call = await self._interceptors_task
457 except AioRpcError as err:
458 return err.debug_error_string()
459 except asyncio.CancelledError:
460 return ""
461
462 return await call.debug_error_string()
463
464 async def wait_for_connection(self) -> None:
465 call = await self._interceptors_task
466 return await call.wait_for_connection()
467
468
469class _InterceptedUnaryResponseMixin:
470 def __await__(self):
471 call = yield from self._interceptors_task.__await__()
472 response = yield from call.__await__()
473 return response
474
475
476class _InterceptedStreamResponseMixin:
477 _response_aiter: Optional[AsyncIterable[ResponseType]]
478
479 def _init_stream_response_mixin(self) -> None:
480 # Is initialized later, otherwise if the iterator is not finally
481 # consumed a logging warning is emitted by Asyncio.
482 self._response_aiter = None
483
484 async def _wait_for_interceptor_task_response_iterator(
485 self,
486 ) -> ResponseType:
487 call = await self._interceptors_task
488 async for response in call:
489 yield response
490
491 def __aiter__(self) -> AsyncIterable[ResponseType]:
492 if self._response_aiter is None:
493 self._response_aiter = (
494 self._wait_for_interceptor_task_response_iterator()
495 )
496 return self._response_aiter
497
498 async def read(self) -> Union[EOFType, ResponseType]:
499 if self._response_aiter is None:
500 self._response_aiter = (
501 self._wait_for_interceptor_task_response_iterator()
502 )
503 try:
504 return await self._response_aiter.asend(None)
505 except StopAsyncIteration:
506 return cygrpc.EOF
507
508
509class _InterceptedStreamRequestMixin:
510 _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
511 _write_to_iterator_queue: Optional[asyncio.Queue]
512 _status_code_task: Optional[asyncio.Task]
513
514 _FINISH_ITERATOR_SENTINEL = object()
515
516 def _init_stream_request_mixin(
517 self, request_iterator: Optional[RequestIterableType]
518 ) -> RequestIterableType:
519 if request_iterator is None:
520 # We provide our own request iterator which is a proxy
521 # of the futures writes that will be done by the caller.
522 self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
523 self._write_to_iterator_async_gen = (
524 self._proxy_writes_as_request_iterator()
525 )
526 self._status_code_task = None
527 request_iterator = self._write_to_iterator_async_gen
528 else:
529 self._write_to_iterator_queue = None
530
531 return request_iterator
532
533 async def _proxy_writes_as_request_iterator(self):
534 await self._interceptors_task
535
536 while True:
537 value = await self._write_to_iterator_queue.get()
538 if (
539 value
540 is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL
541 ):
542 break
543 yield value
544
545 async def _write_to_iterator_queue_interruptible(
546 self, request: RequestType, call: InterceptedCall
547 ):
548 # Write the specified 'request' to the request iterator queue using the
549 # specified 'call' to allow for interruption of the write in the case
550 # of abrupt termination of the call.
551 if self._status_code_task is None:
552 self._status_code_task = self._loop.create_task(call.code())
553
554 await asyncio.wait(
555 (
556 self._loop.create_task(
557 self._write_to_iterator_queue.put(request)
558 ),
559 self._status_code_task,
560 ),
561 return_when=asyncio.FIRST_COMPLETED,
562 )
563
564 async def write(self, request: RequestType) -> None:
565 # If no queue was created it means that requests
566 # should be expected through an iterators provided
567 # by the caller.
568 if self._write_to_iterator_queue is None:
569 raise cygrpc.UsageError(_API_STYLE_ERROR)
570
571 try:
572 call = await self._interceptors_task
573 except (asyncio.CancelledError, AioRpcError):
574 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
575
576 if call.done():
577 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
578 elif call._done_writing_flag:
579 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
580
581 await self._write_to_iterator_queue_interruptible(request, call)
582
583 if call.done():
584 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
585
586 async def done_writing(self) -> None:
587 """Signal peer that client is done writing.
588
589 This method is idempotent.
590 """
591 # If no queue was created it means that requests
592 # should be expected through an iterators provided
593 # by the caller.
594 if self._write_to_iterator_queue is None:
595 raise cygrpc.UsageError(_API_STYLE_ERROR)
596
597 try:
598 call = await self._interceptors_task
599 except asyncio.CancelledError:
600 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
601
602 await self._write_to_iterator_queue_interruptible(
603 _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call
604 )
605
606
607class InterceptedUnaryUnaryCall(
608 _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall
609):
610 """Used for running a `UnaryUnaryCall` wrapped by interceptors.
611
612 For the `__await__` method is it is proxied to the intercepted call only when
613 the interceptor task is finished.
614 """
615
616 _loop: asyncio.AbstractEventLoop
617 _channel: cygrpc.AioChannel
618
619 # pylint: disable=too-many-arguments
620 def __init__(
621 self,
622 interceptors: Sequence[UnaryUnaryClientInterceptor],
623 request: RequestType,
624 timeout: Optional[float],
625 metadata: Metadata,
626 credentials: Optional[grpc.CallCredentials],
627 wait_for_ready: Optional[bool],
628 channel: cygrpc.AioChannel,
629 method: bytes,
630 request_serializer: SerializingFunction,
631 response_deserializer: DeserializingFunction,
632 loop: asyncio.AbstractEventLoop,
633 ) -> None:
634 self._loop = loop
635 self._channel = channel
636 interceptors_task = loop.create_task(
637 self._invoke(
638 interceptors,
639 method,
640 timeout,
641 metadata,
642 credentials,
643 wait_for_ready,
644 request,
645 request_serializer,
646 response_deserializer,
647 )
648 )
649 super().__init__(interceptors_task)
650
651 # pylint: disable=too-many-arguments
652 async def _invoke(
653 self,
654 interceptors: Sequence[UnaryUnaryClientInterceptor],
655 method: bytes,
656 timeout: Optional[float],
657 metadata: Optional[Metadata],
658 credentials: Optional[grpc.CallCredentials],
659 wait_for_ready: Optional[bool],
660 request: RequestType,
661 request_serializer: SerializingFunction,
662 response_deserializer: DeserializingFunction,
663 ) -> UnaryUnaryCall:
664 """Run the RPC call wrapped in interceptors"""
665
666 async def _run_interceptor(
667 interceptors: List[UnaryUnaryClientInterceptor],
668 client_call_details: ClientCallDetails,
669 request: RequestType,
670 ) -> _base_call.UnaryUnaryCall:
671 if interceptors:
672 continuation = functools.partial(
673 _run_interceptor, interceptors[1:]
674 )
675 call_or_response = await interceptors[0].intercept_unary_unary(
676 continuation, client_call_details, request
677 )
678
679 if isinstance(call_or_response, _base_call.UnaryUnaryCall):
680 return call_or_response
681 else:
682 return UnaryUnaryCallResponse(call_or_response)
683
684 else:
685 return UnaryUnaryCall(
686 request,
687 _timeout_to_deadline(client_call_details.timeout),
688 client_call_details.metadata,
689 client_call_details.credentials,
690 client_call_details.wait_for_ready,
691 self._channel,
692 client_call_details.method,
693 request_serializer,
694 response_deserializer,
695 self._loop,
696 )
697
698 client_call_details = ClientCallDetails(
699 method, timeout, metadata, credentials, wait_for_ready
700 )
701 return await _run_interceptor(
702 list(interceptors), client_call_details, request
703 )
704
705 def time_remaining(self) -> Optional[float]:
706 raise NotImplementedError()
707
708
709class InterceptedUnaryStreamCall(
710 _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall
711):
712 """Used for running a `UnaryStreamCall` wrapped by interceptors."""
713
714 _loop: asyncio.AbstractEventLoop
715 _channel: cygrpc.AioChannel
716 _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
717
718 # pylint: disable=too-many-arguments
719 def __init__(
720 self,
721 interceptors: Sequence[UnaryStreamClientInterceptor],
722 request: RequestType,
723 timeout: Optional[float],
724 metadata: Metadata,
725 credentials: Optional[grpc.CallCredentials],
726 wait_for_ready: Optional[bool],
727 channel: cygrpc.AioChannel,
728 method: bytes,
729 request_serializer: SerializingFunction,
730 response_deserializer: DeserializingFunction,
731 loop: asyncio.AbstractEventLoop,
732 ) -> None:
733 self._loop = loop
734 self._channel = channel
735 self._init_stream_response_mixin()
736 self._last_returned_call_from_interceptors = None
737 interceptors_task = loop.create_task(
738 self._invoke(
739 interceptors,
740 method,
741 timeout,
742 metadata,
743 credentials,
744 wait_for_ready,
745 request,
746 request_serializer,
747 response_deserializer,
748 )
749 )
750 super().__init__(interceptors_task)
751
752 # pylint: disable=too-many-arguments
753 async def _invoke(
754 self,
755 interceptors: Sequence[UnaryStreamClientInterceptor],
756 method: bytes,
757 timeout: Optional[float],
758 metadata: Optional[Metadata],
759 credentials: Optional[grpc.CallCredentials],
760 wait_for_ready: Optional[bool],
761 request: RequestType,
762 request_serializer: SerializingFunction,
763 response_deserializer: DeserializingFunction,
764 ) -> UnaryStreamCall:
765 """Run the RPC call wrapped in interceptors"""
766
767 async def _run_interceptor(
768 interceptors: List[UnaryStreamClientInterceptor],
769 client_call_details: ClientCallDetails,
770 request: RequestType,
771 ) -> _base_call.UnaryStreamCall:
772 if interceptors:
773 continuation = functools.partial(
774 _run_interceptor, interceptors[1:]
775 )
776
777 call_or_response_iterator = await interceptors[
778 0
779 ].intercept_unary_stream(
780 continuation, client_call_details, request
781 )
782
783 if isinstance(
784 call_or_response_iterator, _base_call.UnaryStreamCall
785 ):
786 self._last_returned_call_from_interceptors = (
787 call_or_response_iterator
788 )
789 else:
790 self._last_returned_call_from_interceptors = (
791 UnaryStreamCallResponseIterator(
792 self._last_returned_call_from_interceptors,
793 call_or_response_iterator,
794 )
795 )
796 return self._last_returned_call_from_interceptors
797 else:
798 self._last_returned_call_from_interceptors = UnaryStreamCall(
799 request,
800 _timeout_to_deadline(client_call_details.timeout),
801 client_call_details.metadata,
802 client_call_details.credentials,
803 client_call_details.wait_for_ready,
804 self._channel,
805 client_call_details.method,
806 request_serializer,
807 response_deserializer,
808 self._loop,
809 )
810
811 return self._last_returned_call_from_interceptors
812
813 client_call_details = ClientCallDetails(
814 method, timeout, metadata, credentials, wait_for_ready
815 )
816 return await _run_interceptor(
817 list(interceptors), client_call_details, request
818 )
819
820 def time_remaining(self) -> Optional[float]:
821 raise NotImplementedError()
822
823
824class InterceptedStreamUnaryCall(
825 _InterceptedUnaryResponseMixin,
826 _InterceptedStreamRequestMixin,
827 InterceptedCall,
828 _base_call.StreamUnaryCall,
829):
830 """Used for running a `StreamUnaryCall` wrapped by interceptors.
831
832 For the `__await__` method is it is proxied to the intercepted call only when
833 the interceptor task is finished.
834 """
835
836 _loop: asyncio.AbstractEventLoop
837 _channel: cygrpc.AioChannel
838
839 # pylint: disable=too-many-arguments
840 def __init__(
841 self,
842 interceptors: Sequence[StreamUnaryClientInterceptor],
843 request_iterator: Optional[RequestIterableType],
844 timeout: Optional[float],
845 metadata: Metadata,
846 credentials: Optional[grpc.CallCredentials],
847 wait_for_ready: Optional[bool],
848 channel: cygrpc.AioChannel,
849 method: bytes,
850 request_serializer: SerializingFunction,
851 response_deserializer: DeserializingFunction,
852 loop: asyncio.AbstractEventLoop,
853 ) -> None:
854 self._loop = loop
855 self._channel = channel
856 request_iterator = self._init_stream_request_mixin(request_iterator)
857 interceptors_task = loop.create_task(
858 self._invoke(
859 interceptors,
860 method,
861 timeout,
862 metadata,
863 credentials,
864 wait_for_ready,
865 request_iterator,
866 request_serializer,
867 response_deserializer,
868 )
869 )
870 super().__init__(interceptors_task)
871
872 # pylint: disable=too-many-arguments
873 async def _invoke(
874 self,
875 interceptors: Sequence[StreamUnaryClientInterceptor],
876 method: bytes,
877 timeout: Optional[float],
878 metadata: Optional[Metadata],
879 credentials: Optional[grpc.CallCredentials],
880 wait_for_ready: Optional[bool],
881 request_iterator: RequestIterableType,
882 request_serializer: SerializingFunction,
883 response_deserializer: DeserializingFunction,
884 ) -> StreamUnaryCall:
885 """Run the RPC call wrapped in interceptors"""
886
887 async def _run_interceptor(
888 interceptors: Iterator[StreamUnaryClientInterceptor],
889 client_call_details: ClientCallDetails,
890 request_iterator: RequestIterableType,
891 ) -> _base_call.StreamUnaryCall:
892 if interceptors:
893 continuation = functools.partial(
894 _run_interceptor, interceptors[1:]
895 )
896
897 return await interceptors[0].intercept_stream_unary(
898 continuation, client_call_details, request_iterator
899 )
900 else:
901 return StreamUnaryCall(
902 request_iterator,
903 _timeout_to_deadline(client_call_details.timeout),
904 client_call_details.metadata,
905 client_call_details.credentials,
906 client_call_details.wait_for_ready,
907 self._channel,
908 client_call_details.method,
909 request_serializer,
910 response_deserializer,
911 self._loop,
912 )
913
914 client_call_details = ClientCallDetails(
915 method, timeout, metadata, credentials, wait_for_ready
916 )
917 return await _run_interceptor(
918 list(interceptors), client_call_details, request_iterator
919 )
920
921 def time_remaining(self) -> Optional[float]:
922 raise NotImplementedError()
923
924
925class InterceptedStreamStreamCall(
926 _InterceptedStreamResponseMixin,
927 _InterceptedStreamRequestMixin,
928 InterceptedCall,
929 _base_call.StreamStreamCall,
930):
931 """Used for running a `StreamStreamCall` wrapped by interceptors."""
932
933 _loop: asyncio.AbstractEventLoop
934 _channel: cygrpc.AioChannel
935 _last_returned_call_from_interceptors = Optional[
936 _base_call.StreamStreamCall
937 ]
938
939 # pylint: disable=too-many-arguments
940 def __init__(
941 self,
942 interceptors: Sequence[StreamStreamClientInterceptor],
943 request_iterator: Optional[RequestIterableType],
944 timeout: Optional[float],
945 metadata: Metadata,
946 credentials: Optional[grpc.CallCredentials],
947 wait_for_ready: Optional[bool],
948 channel: cygrpc.AioChannel,
949 method: bytes,
950 request_serializer: SerializingFunction,
951 response_deserializer: DeserializingFunction,
952 loop: asyncio.AbstractEventLoop,
953 ) -> None:
954 self._loop = loop
955 self._channel = channel
956 self._init_stream_response_mixin()
957 request_iterator = self._init_stream_request_mixin(request_iterator)
958 self._last_returned_call_from_interceptors = None
959 interceptors_task = loop.create_task(
960 self._invoke(
961 interceptors,
962 method,
963 timeout,
964 metadata,
965 credentials,
966 wait_for_ready,
967 request_iterator,
968 request_serializer,
969 response_deserializer,
970 )
971 )
972 super().__init__(interceptors_task)
973
974 # pylint: disable=too-many-arguments
975 async def _invoke(
976 self,
977 interceptors: Sequence[StreamStreamClientInterceptor],
978 method: bytes,
979 timeout: Optional[float],
980 metadata: Optional[Metadata],
981 credentials: Optional[grpc.CallCredentials],
982 wait_for_ready: Optional[bool],
983 request_iterator: RequestIterableType,
984 request_serializer: SerializingFunction,
985 response_deserializer: DeserializingFunction,
986 ) -> StreamStreamCall:
987 """Run the RPC call wrapped in interceptors"""
988
989 async def _run_interceptor(
990 interceptors: List[StreamStreamClientInterceptor],
991 client_call_details: ClientCallDetails,
992 request_iterator: RequestIterableType,
993 ) -> _base_call.StreamStreamCall:
994 if interceptors:
995 continuation = functools.partial(
996 _run_interceptor, interceptors[1:]
997 )
998
999 call_or_response_iterator = await interceptors[
1000 0
1001 ].intercept_stream_stream(
1002 continuation, client_call_details, request_iterator
1003 )
1004
1005 if isinstance(
1006 call_or_response_iterator, _base_call.StreamStreamCall
1007 ):
1008 self._last_returned_call_from_interceptors = (
1009 call_or_response_iterator
1010 )
1011 else:
1012 self._last_returned_call_from_interceptors = (
1013 StreamStreamCallResponseIterator(
1014 self._last_returned_call_from_interceptors,
1015 call_or_response_iterator,
1016 )
1017 )
1018 return self._last_returned_call_from_interceptors
1019 else:
1020 self._last_returned_call_from_interceptors = StreamStreamCall(
1021 request_iterator,
1022 _timeout_to_deadline(client_call_details.timeout),
1023 client_call_details.metadata,
1024 client_call_details.credentials,
1025 client_call_details.wait_for_ready,
1026 self._channel,
1027 client_call_details.method,
1028 request_serializer,
1029 response_deserializer,
1030 self._loop,
1031 )
1032 return self._last_returned_call_from_interceptors
1033
1034 client_call_details = ClientCallDetails(
1035 method, timeout, metadata, credentials, wait_for_ready
1036 )
1037 return await _run_interceptor(
1038 list(interceptors), client_call_details, request_iterator
1039 )
1040
1041 def time_remaining(self) -> Optional[float]:
1042 raise NotImplementedError()
1043
1044
1045class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
1046 """Final UnaryUnaryCall class finished with a response."""
1047
1048 _response: ResponseType
1049
1050 def __init__(self, response: ResponseType) -> None:
1051 self._response = response
1052
1053 def cancel(self) -> bool:
1054 return False
1055
1056 def cancelled(self) -> bool:
1057 return False
1058
1059 def done(self) -> bool:
1060 return True
1061
1062 def add_done_callback(self, unused_callback) -> None:
1063 raise NotImplementedError()
1064
1065 def time_remaining(self) -> Optional[float]:
1066 raise NotImplementedError()
1067
1068 async def initial_metadata(self) -> Optional[Metadata]:
1069 return None
1070
1071 async def trailing_metadata(self) -> Optional[Metadata]:
1072 return None
1073
1074 async def code(self) -> grpc.StatusCode:
1075 return grpc.StatusCode.OK
1076
1077 async def details(self) -> str:
1078 return ""
1079
1080 async def debug_error_string(self) -> Optional[str]:
1081 return None
1082
1083 def __await__(self):
1084 if False: # pylint: disable=using-constant-test
1085 # This code path is never used, but a yield statement is needed
1086 # for telling the interpreter that __await__ is a generator.
1087 yield None
1088 return self._response
1089
1090 async def wait_for_connection(self) -> None:
1091 pass
1092
1093
1094class _StreamCallResponseIterator:
1095 _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
1096 _response_iterator: AsyncIterable[ResponseType]
1097
1098 def __init__(
1099 self,
1100 call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall],
1101 response_iterator: AsyncIterable[ResponseType],
1102 ) -> None:
1103 self._response_iterator = response_iterator
1104 self._call = call
1105
1106 def cancel(self) -> bool:
1107 return self._call.cancel()
1108
1109 def cancelled(self) -> bool:
1110 return self._call.cancelled()
1111
1112 def done(self) -> bool:
1113 return self._call.done()
1114
1115 def add_done_callback(self, callback) -> None:
1116 self._call.add_done_callback(callback)
1117
1118 def time_remaining(self) -> Optional[float]:
1119 return self._call.time_remaining()
1120
1121 async def initial_metadata(self) -> Optional[Metadata]:
1122 return await self._call.initial_metadata()
1123
1124 async def trailing_metadata(self) -> Optional[Metadata]:
1125 return await self._call.trailing_metadata()
1126
1127 async def code(self) -> grpc.StatusCode:
1128 return await self._call.code()
1129
1130 async def details(self) -> str:
1131 return await self._call.details()
1132
1133 async def debug_error_string(self) -> Optional[str]:
1134 return await self._call.debug_error_string()
1135
1136 def __aiter__(self):
1137 return self._response_iterator.__aiter__()
1138
1139 async def wait_for_connection(self) -> None:
1140 return await self._call.wait_for_connection()
1141
1142
1143class UnaryStreamCallResponseIterator(
1144 _StreamCallResponseIterator, _base_call.UnaryStreamCall
1145):
1146 """UnaryStreamCall class which uses an alternative response iterator."""
1147
1148 async def read(self) -> Union[EOFType, ResponseType]:
1149 # Behind the scenes everything goes through the
1150 # async iterator. So this path should not be reached.
1151 raise NotImplementedError()
1152
1153
1154class StreamStreamCallResponseIterator(
1155 _StreamCallResponseIterator, _base_call.StreamStreamCall
1156):
1157 """StreamStreamCall class which uses an alternative response iterator."""
1158
1159 async def read(self) -> Union[EOFType, ResponseType]:
1160 # Behind the scenes everything goes through the
1161 # async iterator. So this path should not be reached.
1162 raise NotImplementedError()
1163
1164 async def write(self, request: RequestType) -> None:
1165 # Behind the scenes everything goes through the
1166 # async iterator provided by the InterceptedStreamStreamCall.
1167 # So this path should not be reached.
1168 raise NotImplementedError()
1169
1170 async def done_writing(self) -> None:
1171 # Behind the scenes everything goes through the
1172 # async iterator provided by the InterceptedStreamStreamCall.
1173 # So this path should not be reached.
1174 raise NotImplementedError()
1175
1176 @property
1177 def _done_writing_flag(self) -> bool:
1178 return self._call._done_writing_flag