1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Interceptors implementation of gRPC Asyncio Python."""
15from abc import ABCMeta
16from abc import abstractmethod
17import asyncio
18import collections
19import functools
20from typing import (
21 AsyncIterable,
22 Awaitable,
23 Callable,
24 Iterator,
25 List,
26 Optional,
27 Sequence,
28 Union,
29)
30
31import grpc
32from grpc._cython import cygrpc
33
34from . import _base_call
35from ._call import AioRpcError
36from ._call import StreamStreamCall
37from ._call import StreamUnaryCall
38from ._call import UnaryStreamCall
39from ._call import UnaryUnaryCall
40from ._call import _API_STYLE_ERROR
41from ._call import _RPC_ALREADY_FINISHED_DETAILS
42from ._call import _RPC_HALF_CLOSED_DETAILS
43from ._metadata import Metadata
44from ._typing import DeserializingFunction
45from ._typing import DoneCallbackType
46from ._typing import RequestIterableType
47from ._typing import RequestType
48from ._typing import ResponseIterableType
49from ._typing import ResponseType
50from ._typing import SerializingFunction
51from ._utils import _timeout_to_deadline
52
53_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
54
55
56class ServerInterceptor(metaclass=ABCMeta):
57 """Affords intercepting incoming RPCs on the service-side.
58
59 This is an EXPERIMENTAL API.
60 """
61
62 @abstractmethod
63 async def intercept_service(
64 self,
65 continuation: Callable[
66 [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]
67 ],
68 handler_call_details: grpc.HandlerCallDetails,
69 ) -> grpc.RpcMethodHandler:
70 """Intercepts incoming RPCs before handing them over to a handler.
71
72 State can be passed from an interceptor to downstream interceptors
73 via contextvars. The first interceptor is called from an empty
74 contextvars.Context, and the same Context is used for downstream
75 interceptors and for the final handler call. Note that there are no
76 guarantees that interceptors and handlers will be called from the
77 same thread.
78
79 Args:
80 continuation: A function that takes a HandlerCallDetails and
81 proceeds to invoke the next interceptor in the chain, if any,
82 or the RPC handler lookup logic, with the call details passed
83 as an argument, and returns an RpcMethodHandler instance if
84 the RPC is considered serviced, or None otherwise.
85 handler_call_details: A HandlerCallDetails describing the RPC.
86
87 Returns:
88 An RpcMethodHandler with which the RPC may be serviced if the
89 interceptor chooses to service this RPC, or None otherwise.
90 """
91
92
93class ClientCallDetails(
94 collections.namedtuple(
95 "ClientCallDetails",
96 ("method", "timeout", "metadata", "credentials", "wait_for_ready"),
97 ),
98 grpc.ClientCallDetails,
99):
100 """Describes an RPC to be invoked.
101
102 This is an EXPERIMENTAL API.
103
104 Args:
105 method: The method name of the RPC.
106 timeout: An optional duration of time in seconds to allow for the RPC.
107 metadata: Optional metadata to be transmitted to the service-side of
108 the RPC.
109 credentials: An optional CallCredentials for the RPC.
110 wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism.
111 """
112
113 method: str
114 timeout: Optional[float]
115 metadata: Optional[Metadata]
116 credentials: Optional[grpc.CallCredentials]
117 wait_for_ready: Optional[bool]
118
119
120class ClientInterceptor(metaclass=ABCMeta):
121 """Base class used for all Aio Client Interceptor classes"""
122
123
124class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
125 """Affords intercepting unary-unary invocations."""
126
127 @abstractmethod
128 async def intercept_unary_unary(
129 self,
130 continuation: Callable[
131 [ClientCallDetails, RequestType], UnaryUnaryCall
132 ],
133 client_call_details: ClientCallDetails,
134 request: RequestType,
135 ) -> Union[UnaryUnaryCall, ResponseType]:
136 """Intercepts a unary-unary invocation asynchronously.
137
138 Args:
139 continuation: A coroutine that proceeds with the invocation by
140 executing the next interceptor in the chain or invoking the
141 actual RPC on the underlying Channel. It is the interceptor's
142 responsibility to call it if it decides to move the RPC forward.
143 The interceptor can use
144 `call = await continuation(client_call_details, request)`
145 to continue with the RPC. `continuation` returns the call to the
146 RPC.
147 client_call_details: A ClientCallDetails object describing the
148 outgoing RPC.
149 request: The request value for the RPC.
150
151 Returns:
152 An object with the RPC response.
153
154 Raises:
155 AioRpcError: Indicating that the RPC terminated with non-OK status.
156 asyncio.CancelledError: Indicating that the RPC was canceled.
157 """
158
159
160class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
161 """Affords intercepting unary-stream invocations."""
162
163 @abstractmethod
164 async def intercept_unary_stream(
165 self,
166 continuation: Callable[
167 [ClientCallDetails, RequestType], UnaryStreamCall
168 ],
169 client_call_details: ClientCallDetails,
170 request: RequestType,
171 ) -> Union[ResponseIterableType, UnaryStreamCall]:
172 """Intercepts a unary-stream invocation asynchronously.
173
174 The function could return the call object or an asynchronous
175 iterator, in case of being an asyncrhonous iterator this will
176 become the source of the reads done by the caller.
177
178 Args:
179 continuation: A coroutine that proceeds with the invocation by
180 executing the next interceptor in the chain or invoking the
181 actual RPC on the underlying Channel. It is the interceptor's
182 responsibility to call it if it decides to move the RPC forward.
183 The interceptor can use
184 `call = await continuation(client_call_details, request)`
185 to continue with the RPC. `continuation` returns the call to the
186 RPC.
187 client_call_details: A ClientCallDetails object describing the
188 outgoing RPC.
189 request: The request value for the RPC.
190
191 Returns:
192 The RPC Call or an asynchronous iterator.
193
194 Raises:
195 AioRpcError: Indicating that the RPC terminated with non-OK status.
196 asyncio.CancelledError: Indicating that the RPC was canceled.
197 """
198
199
200class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
201 """Affords intercepting stream-unary invocations."""
202
203 @abstractmethod
204 async def intercept_stream_unary(
205 self,
206 continuation: Callable[
207 [ClientCallDetails, RequestType], StreamUnaryCall
208 ],
209 client_call_details: ClientCallDetails,
210 request_iterator: RequestIterableType,
211 ) -> StreamUnaryCall:
212 """Intercepts a stream-unary invocation asynchronously.
213
214 Within the interceptor the usage of the call methods like `write` or
215 even awaiting the call should be done carefully, since the caller
216 could be expecting an untouched call, for example for start writing
217 messages to it.
218
219 Args:
220 continuation: A coroutine that proceeds with the invocation by
221 executing the next interceptor in the chain or invoking the
222 actual RPC on the underlying Channel. It is the interceptor's
223 responsibility to call it if it decides to move the RPC forward.
224 The interceptor can use
225 `call = await continuation(client_call_details, request_iterator)`
226 to continue with the RPC. `continuation` returns the call to the
227 RPC.
228 client_call_details: A ClientCallDetails object describing the
229 outgoing RPC.
230 request_iterator: The request iterator that will produce requests
231 for the RPC.
232
233 Returns:
234 The RPC Call.
235
236 Raises:
237 AioRpcError: Indicating that the RPC terminated with non-OK status.
238 asyncio.CancelledError: Indicating that the RPC was canceled.
239 """
240
241
242class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
243 """Affords intercepting stream-stream invocations."""
244
245 @abstractmethod
246 async def intercept_stream_stream(
247 self,
248 continuation: Callable[
249 [ClientCallDetails, RequestType], StreamStreamCall
250 ],
251 client_call_details: ClientCallDetails,
252 request_iterator: RequestIterableType,
253 ) -> Union[ResponseIterableType, StreamStreamCall]:
254 """Intercepts a stream-stream invocation asynchronously.
255
256 Within the interceptor the usage of the call methods like `write` or
257 even awaiting the call should be done carefully, since the caller
258 could be expecting an untouched call, for example for start writing
259 messages to it.
260
261 The function could return the call object or an asynchronous
262 iterator, in case of being an asyncrhonous iterator this will
263 become the source of the reads done by the caller.
264
265 Args:
266 continuation: A coroutine that proceeds with the invocation by
267 executing the next interceptor in the chain or invoking the
268 actual RPC on the underlying Channel. It is the interceptor's
269 responsibility to call it if it decides to move the RPC forward.
270 The interceptor can use
271 `call = await continuation(client_call_details, request_iterator)`
272 to continue with the RPC. `continuation` returns the call to the
273 RPC.
274 client_call_details: A ClientCallDetails object describing the
275 outgoing RPC.
276 request_iterator: The request iterator that will produce requests
277 for the RPC.
278
279 Returns:
280 The RPC Call or an asynchronous iterator.
281
282 Raises:
283 AioRpcError: Indicating that the RPC terminated with non-OK status.
284 asyncio.CancelledError: Indicating that the RPC was canceled.
285 """
286
287
288class InterceptedCall:
289 """Base implementation for all intercepted call arities.
290
291 Interceptors might have some work to do before the RPC invocation with
292 the capacity of changing the invocation parameters, and some work to do
293 after the RPC invocation with the capacity for accessing to the wrapped
294 `UnaryUnaryCall`.
295
296 It handles also early and later cancellations, when the RPC has not even
297 started and the execution is still held by the interceptors or when the
298 RPC has finished but again the execution is still held by the interceptors.
299
300 Once the RPC is finally executed, all methods are finally done against the
301 intercepted call, being at the same time the same call returned to the
302 interceptors.
303
304 As a base class for all of the interceptors implements the logic around
305 final status, metadata and cancellation.
306 """
307
308 _interceptors_task: asyncio.Task
309 _pending_add_done_callbacks: Sequence[DoneCallbackType]
310
311 def __init__(self, interceptors_task: asyncio.Task) -> None:
312 self._interceptors_task = interceptors_task
313 self._pending_add_done_callbacks = []
314 self._interceptors_task.add_done_callback(
315 self._fire_or_add_pending_done_callbacks
316 )
317
318 def __del__(self):
319 self.cancel()
320
321 def _fire_or_add_pending_done_callbacks(
322 self, interceptors_task: asyncio.Task
323 ) -> None:
324 if not self._pending_add_done_callbacks:
325 return
326
327 call_completed = False
328
329 try:
330 call = interceptors_task.result()
331 if call.done():
332 call_completed = True
333 except (AioRpcError, asyncio.CancelledError):
334 call_completed = True
335
336 if call_completed:
337 for callback in self._pending_add_done_callbacks:
338 callback(self)
339 else:
340 for callback in self._pending_add_done_callbacks:
341 callback = functools.partial(
342 self._wrap_add_done_callback, callback
343 )
344 call.add_done_callback(callback)
345
346 self._pending_add_done_callbacks = []
347
348 def _wrap_add_done_callback(
349 self, callback: DoneCallbackType, unused_call: _base_call.Call
350 ) -> None:
351 callback(self)
352
353 def cancel(self) -> bool:
354 if not self._interceptors_task.done():
355 # There is no yet the intercepted call available,
356 # Trying to cancel it by using the generic Asyncio
357 # cancellation method.
358 return self._interceptors_task.cancel()
359
360 try:
361 call = self._interceptors_task.result()
362 except AioRpcError:
363 return False
364 except asyncio.CancelledError:
365 return False
366
367 return call.cancel()
368
369 def cancelled(self) -> bool:
370 if not self._interceptors_task.done():
371 return False
372
373 try:
374 call = self._interceptors_task.result()
375 except AioRpcError as err:
376 return err.code() == grpc.StatusCode.CANCELLED
377 except asyncio.CancelledError:
378 return True
379
380 return call.cancelled()
381
382 def done(self) -> bool:
383 if not self._interceptors_task.done():
384 return False
385
386 try:
387 call = self._interceptors_task.result()
388 except (AioRpcError, asyncio.CancelledError):
389 return True
390
391 return call.done()
392
393 def add_done_callback(self, callback: DoneCallbackType) -> None:
394 if not self._interceptors_task.done():
395 self._pending_add_done_callbacks.append(callback)
396 return
397
398 try:
399 call = self._interceptors_task.result()
400 except (AioRpcError, asyncio.CancelledError):
401 callback(self)
402 return
403
404 if call.done():
405 callback(self)
406 else:
407 callback = functools.partial(self._wrap_add_done_callback, callback)
408 call.add_done_callback(callback)
409
410 def time_remaining(self) -> Optional[float]:
411 raise NotImplementedError()
412
413 async def initial_metadata(self) -> Optional[Metadata]:
414 try:
415 call = await self._interceptors_task
416 except AioRpcError as err:
417 return err.initial_metadata()
418 except asyncio.CancelledError:
419 return None
420
421 return await call.initial_metadata()
422
423 async def trailing_metadata(self) -> Optional[Metadata]:
424 try:
425 call = await self._interceptors_task
426 except AioRpcError as err:
427 return err.trailing_metadata()
428 except asyncio.CancelledError:
429 return None
430
431 return await call.trailing_metadata()
432
433 async def code(self) -> grpc.StatusCode:
434 try:
435 call = await self._interceptors_task
436 except AioRpcError as err:
437 return err.code()
438 except asyncio.CancelledError:
439 return grpc.StatusCode.CANCELLED
440
441 return await call.code()
442
443 async def details(self) -> str:
444 try:
445 call = await self._interceptors_task
446 except AioRpcError as err:
447 return err.details()
448 except asyncio.CancelledError:
449 return _LOCAL_CANCELLATION_DETAILS
450
451 return await call.details()
452
453 async def debug_error_string(self) -> Optional[str]:
454 try:
455 call = await self._interceptors_task
456 except AioRpcError as err:
457 return err.debug_error_string()
458 except asyncio.CancelledError:
459 return ""
460
461 return await call.debug_error_string()
462
463 async def wait_for_connection(self) -> None:
464 call = await self._interceptors_task
465 return await call.wait_for_connection()
466
467
468class _InterceptedUnaryResponseMixin:
469 def __await__(self):
470 call = yield from self._interceptors_task.__await__()
471 response = yield from call.__await__()
472 return response
473
474
475class _InterceptedStreamResponseMixin:
476 _response_aiter: Optional[AsyncIterable[ResponseType]]
477
478 def _init_stream_response_mixin(self) -> None:
479 # Is initalized later, otherwise if the iterator is not finally
480 # consumed a logging warning is emmited by Asyncio.
481 self._response_aiter = None
482
483 async def _wait_for_interceptor_task_response_iterator(
484 self,
485 ) -> ResponseType:
486 call = await self._interceptors_task
487 async for response in call:
488 yield response
489
490 def __aiter__(self) -> AsyncIterable[ResponseType]:
491 if self._response_aiter is None:
492 self._response_aiter = (
493 self._wait_for_interceptor_task_response_iterator()
494 )
495 return self._response_aiter
496
497 async def read(self) -> ResponseType:
498 if self._response_aiter is None:
499 self._response_aiter = (
500 self._wait_for_interceptor_task_response_iterator()
501 )
502 return await self._response_aiter.asend(None)
503
504
505class _InterceptedStreamRequestMixin:
506 _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
507 _write_to_iterator_queue: Optional[asyncio.Queue]
508 _status_code_task: Optional[asyncio.Task]
509
510 _FINISH_ITERATOR_SENTINEL = object()
511
512 def _init_stream_request_mixin(
513 self, request_iterator: Optional[RequestIterableType]
514 ) -> RequestIterableType:
515 if request_iterator is None:
516 # We provide our own request iterator which is a proxy
517 # of the futures writes that will be done by the caller.
518 self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
519 self._write_to_iterator_async_gen = (
520 self._proxy_writes_as_request_iterator()
521 )
522 self._status_code_task = None
523 request_iterator = self._write_to_iterator_async_gen
524 else:
525 self._write_to_iterator_queue = None
526
527 return request_iterator
528
529 async def _proxy_writes_as_request_iterator(self):
530 await self._interceptors_task
531
532 while True:
533 value = await self._write_to_iterator_queue.get()
534 if (
535 value
536 is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL
537 ):
538 break
539 yield value
540
541 async def _write_to_iterator_queue_interruptible(
542 self, request: RequestType, call: InterceptedCall
543 ):
544 # Write the specified 'request' to the request iterator queue using the
545 # specified 'call' to allow for interruption of the write in the case
546 # of abrupt termination of the call.
547 if self._status_code_task is None:
548 self._status_code_task = self._loop.create_task(call.code())
549
550 await asyncio.wait(
551 (
552 self._loop.create_task(
553 self._write_to_iterator_queue.put(request)
554 ),
555 self._status_code_task,
556 ),
557 return_when=asyncio.FIRST_COMPLETED,
558 )
559
560 async def write(self, request: RequestType) -> None:
561 # If no queue was created it means that requests
562 # should be expected through an iterators provided
563 # by the caller.
564 if self._write_to_iterator_queue is None:
565 raise cygrpc.UsageError(_API_STYLE_ERROR)
566
567 try:
568 call = await self._interceptors_task
569 except (asyncio.CancelledError, AioRpcError):
570 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
571
572 if call.done():
573 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
574 elif call._done_writing_flag:
575 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
576
577 await self._write_to_iterator_queue_interruptible(request, call)
578
579 if call.done():
580 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
581
582 async def done_writing(self) -> None:
583 """Signal peer that client is done writing.
584
585 This method is idempotent.
586 """
587 # If no queue was created it means that requests
588 # should be expected through an iterators provided
589 # by the caller.
590 if self._write_to_iterator_queue is None:
591 raise cygrpc.UsageError(_API_STYLE_ERROR)
592
593 try:
594 call = await self._interceptors_task
595 except asyncio.CancelledError:
596 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
597
598 await self._write_to_iterator_queue_interruptible(
599 _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call
600 )
601
602
603class InterceptedUnaryUnaryCall(
604 _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall
605):
606 """Used for running a `UnaryUnaryCall` wrapped by interceptors.
607
608 For the `__await__` method is it is proxied to the intercepted call only when
609 the interceptor task is finished.
610 """
611
612 _loop: asyncio.AbstractEventLoop
613 _channel: cygrpc.AioChannel
614
615 # pylint: disable=too-many-arguments
616 def __init__(
617 self,
618 interceptors: Sequence[UnaryUnaryClientInterceptor],
619 request: RequestType,
620 timeout: Optional[float],
621 metadata: Metadata,
622 credentials: Optional[grpc.CallCredentials],
623 wait_for_ready: Optional[bool],
624 channel: cygrpc.AioChannel,
625 method: bytes,
626 request_serializer: SerializingFunction,
627 response_deserializer: DeserializingFunction,
628 loop: asyncio.AbstractEventLoop,
629 ) -> None:
630 self._loop = loop
631 self._channel = channel
632 interceptors_task = loop.create_task(
633 self._invoke(
634 interceptors,
635 method,
636 timeout,
637 metadata,
638 credentials,
639 wait_for_ready,
640 request,
641 request_serializer,
642 response_deserializer,
643 )
644 )
645 super().__init__(interceptors_task)
646
647 # pylint: disable=too-many-arguments
648 async def _invoke(
649 self,
650 interceptors: Sequence[UnaryUnaryClientInterceptor],
651 method: bytes,
652 timeout: Optional[float],
653 metadata: Optional[Metadata],
654 credentials: Optional[grpc.CallCredentials],
655 wait_for_ready: Optional[bool],
656 request: RequestType,
657 request_serializer: SerializingFunction,
658 response_deserializer: DeserializingFunction,
659 ) -> UnaryUnaryCall:
660 """Run the RPC call wrapped in interceptors"""
661
662 async def _run_interceptor(
663 interceptors: List[UnaryUnaryClientInterceptor],
664 client_call_details: ClientCallDetails,
665 request: RequestType,
666 ) -> _base_call.UnaryUnaryCall:
667 if interceptors:
668 continuation = functools.partial(
669 _run_interceptor, interceptors[1:]
670 )
671 call_or_response = await interceptors[0].intercept_unary_unary(
672 continuation, client_call_details, request
673 )
674
675 if isinstance(call_or_response, _base_call.UnaryUnaryCall):
676 return call_or_response
677 else:
678 return UnaryUnaryCallResponse(call_or_response)
679
680 else:
681 return UnaryUnaryCall(
682 request,
683 _timeout_to_deadline(client_call_details.timeout),
684 client_call_details.metadata,
685 client_call_details.credentials,
686 client_call_details.wait_for_ready,
687 self._channel,
688 client_call_details.method,
689 request_serializer,
690 response_deserializer,
691 self._loop,
692 )
693
694 client_call_details = ClientCallDetails(
695 method, timeout, metadata, credentials, wait_for_ready
696 )
697 return await _run_interceptor(
698 list(interceptors), client_call_details, request
699 )
700
701 def time_remaining(self) -> Optional[float]:
702 raise NotImplementedError()
703
704
705class InterceptedUnaryStreamCall(
706 _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall
707):
708 """Used for running a `UnaryStreamCall` wrapped by interceptors."""
709
710 _loop: asyncio.AbstractEventLoop
711 _channel: cygrpc.AioChannel
712 _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
713
714 # pylint: disable=too-many-arguments
715 def __init__(
716 self,
717 interceptors: Sequence[UnaryStreamClientInterceptor],
718 request: RequestType,
719 timeout: Optional[float],
720 metadata: Metadata,
721 credentials: Optional[grpc.CallCredentials],
722 wait_for_ready: Optional[bool],
723 channel: cygrpc.AioChannel,
724 method: bytes,
725 request_serializer: SerializingFunction,
726 response_deserializer: DeserializingFunction,
727 loop: asyncio.AbstractEventLoop,
728 ) -> None:
729 self._loop = loop
730 self._channel = channel
731 self._init_stream_response_mixin()
732 self._last_returned_call_from_interceptors = None
733 interceptors_task = loop.create_task(
734 self._invoke(
735 interceptors,
736 method,
737 timeout,
738 metadata,
739 credentials,
740 wait_for_ready,
741 request,
742 request_serializer,
743 response_deserializer,
744 )
745 )
746 super().__init__(interceptors_task)
747
748 # pylint: disable=too-many-arguments
749 async def _invoke(
750 self,
751 interceptors: Sequence[UnaryStreamClientInterceptor],
752 method: bytes,
753 timeout: Optional[float],
754 metadata: Optional[Metadata],
755 credentials: Optional[grpc.CallCredentials],
756 wait_for_ready: Optional[bool],
757 request: RequestType,
758 request_serializer: SerializingFunction,
759 response_deserializer: DeserializingFunction,
760 ) -> UnaryStreamCall:
761 """Run the RPC call wrapped in interceptors"""
762
763 async def _run_interceptor(
764 interceptors: List[UnaryStreamClientInterceptor],
765 client_call_details: ClientCallDetails,
766 request: RequestType,
767 ) -> _base_call.UnaryStreamCall:
768 if interceptors:
769 continuation = functools.partial(
770 _run_interceptor, interceptors[1:]
771 )
772
773 call_or_response_iterator = await interceptors[
774 0
775 ].intercept_unary_stream(
776 continuation, client_call_details, request
777 )
778
779 if isinstance(
780 call_or_response_iterator, _base_call.UnaryStreamCall
781 ):
782 self._last_returned_call_from_interceptors = (
783 call_or_response_iterator
784 )
785 else:
786 self._last_returned_call_from_interceptors = (
787 UnaryStreamCallResponseIterator(
788 self._last_returned_call_from_interceptors,
789 call_or_response_iterator,
790 )
791 )
792 return self._last_returned_call_from_interceptors
793 else:
794 self._last_returned_call_from_interceptors = UnaryStreamCall(
795 request,
796 _timeout_to_deadline(client_call_details.timeout),
797 client_call_details.metadata,
798 client_call_details.credentials,
799 client_call_details.wait_for_ready,
800 self._channel,
801 client_call_details.method,
802 request_serializer,
803 response_deserializer,
804 self._loop,
805 )
806
807 return self._last_returned_call_from_interceptors
808
809 client_call_details = ClientCallDetails(
810 method, timeout, metadata, credentials, wait_for_ready
811 )
812 return await _run_interceptor(
813 list(interceptors), client_call_details, request
814 )
815
816 def time_remaining(self) -> Optional[float]:
817 raise NotImplementedError()
818
819
820class InterceptedStreamUnaryCall(
821 _InterceptedUnaryResponseMixin,
822 _InterceptedStreamRequestMixin,
823 InterceptedCall,
824 _base_call.StreamUnaryCall,
825):
826 """Used for running a `StreamUnaryCall` wrapped by interceptors.
827
828 For the `__await__` method is it is proxied to the intercepted call only when
829 the interceptor task is finished.
830 """
831
832 _loop: asyncio.AbstractEventLoop
833 _channel: cygrpc.AioChannel
834
835 # pylint: disable=too-many-arguments
836 def __init__(
837 self,
838 interceptors: Sequence[StreamUnaryClientInterceptor],
839 request_iterator: Optional[RequestIterableType],
840 timeout: Optional[float],
841 metadata: Metadata,
842 credentials: Optional[grpc.CallCredentials],
843 wait_for_ready: Optional[bool],
844 channel: cygrpc.AioChannel,
845 method: bytes,
846 request_serializer: SerializingFunction,
847 response_deserializer: DeserializingFunction,
848 loop: asyncio.AbstractEventLoop,
849 ) -> None:
850 self._loop = loop
851 self._channel = channel
852 request_iterator = self._init_stream_request_mixin(request_iterator)
853 interceptors_task = loop.create_task(
854 self._invoke(
855 interceptors,
856 method,
857 timeout,
858 metadata,
859 credentials,
860 wait_for_ready,
861 request_iterator,
862 request_serializer,
863 response_deserializer,
864 )
865 )
866 super().__init__(interceptors_task)
867
868 # pylint: disable=too-many-arguments
869 async def _invoke(
870 self,
871 interceptors: Sequence[StreamUnaryClientInterceptor],
872 method: bytes,
873 timeout: Optional[float],
874 metadata: Optional[Metadata],
875 credentials: Optional[grpc.CallCredentials],
876 wait_for_ready: Optional[bool],
877 request_iterator: RequestIterableType,
878 request_serializer: SerializingFunction,
879 response_deserializer: DeserializingFunction,
880 ) -> StreamUnaryCall:
881 """Run the RPC call wrapped in interceptors"""
882
883 async def _run_interceptor(
884 interceptors: Iterator[StreamUnaryClientInterceptor],
885 client_call_details: ClientCallDetails,
886 request_iterator: RequestIterableType,
887 ) -> _base_call.StreamUnaryCall:
888 if interceptors:
889 continuation = functools.partial(
890 _run_interceptor, interceptors[1:]
891 )
892
893 return await interceptors[0].intercept_stream_unary(
894 continuation, client_call_details, request_iterator
895 )
896 else:
897 return StreamUnaryCall(
898 request_iterator,
899 _timeout_to_deadline(client_call_details.timeout),
900 client_call_details.metadata,
901 client_call_details.credentials,
902 client_call_details.wait_for_ready,
903 self._channel,
904 client_call_details.method,
905 request_serializer,
906 response_deserializer,
907 self._loop,
908 )
909
910 client_call_details = ClientCallDetails(
911 method, timeout, metadata, credentials, wait_for_ready
912 )
913 return await _run_interceptor(
914 list(interceptors), client_call_details, request_iterator
915 )
916
917 def time_remaining(self) -> Optional[float]:
918 raise NotImplementedError()
919
920
921class InterceptedStreamStreamCall(
922 _InterceptedStreamResponseMixin,
923 _InterceptedStreamRequestMixin,
924 InterceptedCall,
925 _base_call.StreamStreamCall,
926):
927 """Used for running a `StreamStreamCall` wrapped by interceptors."""
928
929 _loop: asyncio.AbstractEventLoop
930 _channel: cygrpc.AioChannel
931 _last_returned_call_from_interceptors = Optional[
932 _base_call.StreamStreamCall
933 ]
934
935 # pylint: disable=too-many-arguments
936 def __init__(
937 self,
938 interceptors: Sequence[StreamStreamClientInterceptor],
939 request_iterator: Optional[RequestIterableType],
940 timeout: Optional[float],
941 metadata: Metadata,
942 credentials: Optional[grpc.CallCredentials],
943 wait_for_ready: Optional[bool],
944 channel: cygrpc.AioChannel,
945 method: bytes,
946 request_serializer: SerializingFunction,
947 response_deserializer: DeserializingFunction,
948 loop: asyncio.AbstractEventLoop,
949 ) -> None:
950 self._loop = loop
951 self._channel = channel
952 self._init_stream_response_mixin()
953 request_iterator = self._init_stream_request_mixin(request_iterator)
954 self._last_returned_call_from_interceptors = None
955 interceptors_task = loop.create_task(
956 self._invoke(
957 interceptors,
958 method,
959 timeout,
960 metadata,
961 credentials,
962 wait_for_ready,
963 request_iterator,
964 request_serializer,
965 response_deserializer,
966 )
967 )
968 super().__init__(interceptors_task)
969
970 # pylint: disable=too-many-arguments
971 async def _invoke(
972 self,
973 interceptors: Sequence[StreamStreamClientInterceptor],
974 method: bytes,
975 timeout: Optional[float],
976 metadata: Optional[Metadata],
977 credentials: Optional[grpc.CallCredentials],
978 wait_for_ready: Optional[bool],
979 request_iterator: RequestIterableType,
980 request_serializer: SerializingFunction,
981 response_deserializer: DeserializingFunction,
982 ) -> StreamStreamCall:
983 """Run the RPC call wrapped in interceptors"""
984
985 async def _run_interceptor(
986 interceptors: List[StreamStreamClientInterceptor],
987 client_call_details: ClientCallDetails,
988 request_iterator: RequestIterableType,
989 ) -> _base_call.StreamStreamCall:
990 if interceptors:
991 continuation = functools.partial(
992 _run_interceptor, interceptors[1:]
993 )
994
995 call_or_response_iterator = await interceptors[
996 0
997 ].intercept_stream_stream(
998 continuation, client_call_details, request_iterator
999 )
1000
1001 if isinstance(
1002 call_or_response_iterator, _base_call.StreamStreamCall
1003 ):
1004 self._last_returned_call_from_interceptors = (
1005 call_or_response_iterator
1006 )
1007 else:
1008 self._last_returned_call_from_interceptors = (
1009 StreamStreamCallResponseIterator(
1010 self._last_returned_call_from_interceptors,
1011 call_or_response_iterator,
1012 )
1013 )
1014 return self._last_returned_call_from_interceptors
1015 else:
1016 self._last_returned_call_from_interceptors = StreamStreamCall(
1017 request_iterator,
1018 _timeout_to_deadline(client_call_details.timeout),
1019 client_call_details.metadata,
1020 client_call_details.credentials,
1021 client_call_details.wait_for_ready,
1022 self._channel,
1023 client_call_details.method,
1024 request_serializer,
1025 response_deserializer,
1026 self._loop,
1027 )
1028 return self._last_returned_call_from_interceptors
1029
1030 client_call_details = ClientCallDetails(
1031 method, timeout, metadata, credentials, wait_for_ready
1032 )
1033 return await _run_interceptor(
1034 list(interceptors), client_call_details, request_iterator
1035 )
1036
1037 def time_remaining(self) -> Optional[float]:
1038 raise NotImplementedError()
1039
1040
1041class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
1042 """Final UnaryUnaryCall class finished with a response."""
1043
1044 _response: ResponseType
1045
1046 def __init__(self, response: ResponseType) -> None:
1047 self._response = response
1048
1049 def cancel(self) -> bool:
1050 return False
1051
1052 def cancelled(self) -> bool:
1053 return False
1054
1055 def done(self) -> bool:
1056 return True
1057
1058 def add_done_callback(self, unused_callback) -> None:
1059 raise NotImplementedError()
1060
1061 def time_remaining(self) -> Optional[float]:
1062 raise NotImplementedError()
1063
1064 async def initial_metadata(self) -> Optional[Metadata]:
1065 return None
1066
1067 async def trailing_metadata(self) -> Optional[Metadata]:
1068 return None
1069
1070 async def code(self) -> grpc.StatusCode:
1071 return grpc.StatusCode.OK
1072
1073 async def details(self) -> str:
1074 return ""
1075
1076 async def debug_error_string(self) -> Optional[str]:
1077 return None
1078
1079 def __await__(self):
1080 if False: # pylint: disable=using-constant-test
1081 # This code path is never used, but a yield statement is needed
1082 # for telling the interpreter that __await__ is a generator.
1083 yield None
1084 return self._response
1085
1086 async def wait_for_connection(self) -> None:
1087 pass
1088
1089
1090class _StreamCallResponseIterator:
1091 _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
1092 _response_iterator: AsyncIterable[ResponseType]
1093
1094 def __init__(
1095 self,
1096 call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall],
1097 response_iterator: AsyncIterable[ResponseType],
1098 ) -> None:
1099 self._response_iterator = response_iterator
1100 self._call = call
1101
1102 def cancel(self) -> bool:
1103 return self._call.cancel()
1104
1105 def cancelled(self) -> bool:
1106 return self._call.cancelled()
1107
1108 def done(self) -> bool:
1109 return self._call.done()
1110
1111 def add_done_callback(self, callback) -> None:
1112 self._call.add_done_callback(callback)
1113
1114 def time_remaining(self) -> Optional[float]:
1115 return self._call.time_remaining()
1116
1117 async def initial_metadata(self) -> Optional[Metadata]:
1118 return await self._call.initial_metadata()
1119
1120 async def trailing_metadata(self) -> Optional[Metadata]:
1121 return await self._call.trailing_metadata()
1122
1123 async def code(self) -> grpc.StatusCode:
1124 return await self._call.code()
1125
1126 async def details(self) -> str:
1127 return await self._call.details()
1128
1129 async def debug_error_string(self) -> Optional[str]:
1130 return await self._call.debug_error_string()
1131
1132 def __aiter__(self):
1133 return self._response_iterator.__aiter__()
1134
1135 async def wait_for_connection(self) -> None:
1136 return await self._call.wait_for_connection()
1137
1138
1139class UnaryStreamCallResponseIterator(
1140 _StreamCallResponseIterator, _base_call.UnaryStreamCall
1141):
1142 """UnaryStreamCall class wich uses an alternative response iterator."""
1143
1144 async def read(self) -> ResponseType:
1145 # Behind the scenes everyting goes through the
1146 # async iterator. So this path should not be reached.
1147 raise NotImplementedError()
1148
1149
1150class StreamStreamCallResponseIterator(
1151 _StreamCallResponseIterator, _base_call.StreamStreamCall
1152):
1153 """StreamStreamCall class wich uses an alternative response iterator."""
1154
1155 async def read(self) -> ResponseType:
1156 # Behind the scenes everyting goes through the
1157 # async iterator. So this path should not be reached.
1158 raise NotImplementedError()
1159
1160 async def write(self, request: RequestType) -> None:
1161 # Behind the scenes everyting goes through the
1162 # async iterator provided by the InterceptedStreamStreamCall.
1163 # So this path should not be reached.
1164 raise NotImplementedError()
1165
1166 async def done_writing(self) -> None:
1167 # Behind the scenes everyting goes through the
1168 # async iterator provided by the InterceptedStreamStreamCall.
1169 # So this path should not be reached.
1170 raise NotImplementedError()
1171
1172 @property
1173 def _done_writing_flag(self) -> bool:
1174 return self._call._done_writing_flag