1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Interceptors implementation of gRPC Asyncio Python."""
15from __future__ import annotations
16
17from abc import ABCMeta
18from abc import abstractmethod
19import asyncio
20import collections
21import functools
22from typing import (
23 AsyncIterable,
24 Awaitable,
25 Callable,
26 List,
27 Optional,
28 Sequence,
29 Union,
30)
31
32import grpc
33from grpc._cython import cygrpc
34
35from . import _base_call
36from ._call import AioRpcError
37from ._call import StreamStreamCall
38from ._call import StreamUnaryCall
39from ._call import UnaryStreamCall
40from ._call import UnaryUnaryCall
41from ._call import _API_STYLE_ERROR
42from ._call import _RPC_ALREADY_FINISHED_DETAILS
43from ._call import _RPC_HALF_CLOSED_DETAILS
44from ._metadata import Metadata
45from ._typing import DeserializingFunction
46from ._typing import DoneCallbackType
47from ._typing import EOFType
48from ._typing import RequestIterableType
49from ._typing import RequestType
50from ._typing import ResponseIterableType
51from ._typing import ResponseType
52from ._typing import SerializingFunction
53from ._utils import _timeout_to_deadline
54
55_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
56
57
58class ServerInterceptor(metaclass=ABCMeta):
59 """Affords intercepting incoming RPCs on the service-side.
60
61 This is an EXPERIMENTAL API.
62 """
63
64 @abstractmethod
65 async def intercept_service(
66 self,
67 continuation: Callable[
68 [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]
69 ],
70 handler_call_details: grpc.HandlerCallDetails,
71 ) -> grpc.RpcMethodHandler:
72 """Intercepts incoming RPCs before handing them over to a handler.
73
74 State can be passed from an interceptor to downstream interceptors
75 via contextvars. The first interceptor is called from an empty
76 contextvars.Context, and the same Context is used for downstream
77 interceptors and for the final handler call. Note that there are no
78 guarantees that interceptors and handlers will be called from the
79 same thread.
80
81 Args:
82 continuation: A function that takes a HandlerCallDetails and
83 proceeds to invoke the next interceptor in the chain, if any,
84 or the RPC handler lookup logic, with the call details passed
85 as an argument, and returns an RpcMethodHandler instance if
86 the RPC is considered serviced, or None otherwise.
87 handler_call_details: A HandlerCallDetails describing the RPC.
88
89 Returns:
90 An RpcMethodHandler with which the RPC may be serviced if the
91 interceptor chooses to service this RPC, or None otherwise.
92 """
93
94
95class ClientCallDetails(
96 collections.namedtuple(
97 "ClientCallDetails",
98 ("method", "timeout", "metadata", "credentials", "wait_for_ready"),
99 ),
100 grpc.ClientCallDetails,
101):
102 """Describes an RPC to be invoked.
103
104 This is an EXPERIMENTAL API.
105
106 Args:
107 method: The method name of the RPC.
108 timeout: An optional duration of time in seconds to allow for the RPC.
109 metadata: Optional metadata to be transmitted to the service-side of
110 the RPC.
111 credentials: An optional CallCredentials for the RPC.
112 wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism.
113 """
114
115 method: bytes
116 timeout: Optional[float]
117 metadata: Optional[Metadata]
118 credentials: Optional[grpc.CallCredentials]
119 wait_for_ready: Optional[bool]
120
121
122class ClientInterceptor(metaclass=ABCMeta):
123 """Base class used for all Aio Client Interceptor classes"""
124
125
126class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
127 """Affords intercepting unary-unary invocations."""
128
129 @abstractmethod
130 async def intercept_unary_unary(
131 self,
132 continuation: Callable[
133 [ClientCallDetails, RequestType], UnaryUnaryCall
134 ],
135 client_call_details: ClientCallDetails,
136 request: RequestType,
137 ) -> Union[UnaryUnaryCall, ResponseType]:
138 """Intercepts a unary-unary invocation asynchronously.
139
140 Args:
141 continuation: A coroutine that proceeds with the invocation by
142 executing the next interceptor in the chain or invoking the
143 actual RPC on the underlying Channel. It is the interceptor's
144 responsibility to call it if it decides to move the RPC forward.
145 The interceptor can use
146 `call = await continuation(client_call_details, request)`
147 to continue with the RPC. `continuation` returns the call to the
148 RPC.
149 client_call_details: A ClientCallDetails object describing the
150 outgoing RPC.
151 request: The request value for the RPC.
152
153 Returns:
154 An object with the RPC response.
155
156 Raises:
157 AioRpcError: Indicating that the RPC terminated with non-OK status.
158 asyncio.CancelledError: Indicating that the RPC was canceled.
159 """
160
161
162class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
163 """Affords intercepting unary-stream invocations."""
164
165 @abstractmethod
166 async def intercept_unary_stream(
167 self,
168 continuation: Callable[
169 [ClientCallDetails, RequestType], UnaryStreamCall
170 ],
171 client_call_details: ClientCallDetails,
172 request: RequestType,
173 ) -> Union[ResponseIterableType, UnaryStreamCall]:
174 """Intercepts a unary-stream invocation asynchronously.
175
176 The function could return the call object or an asynchronous
177 iterator, in case of being an asyncrhonous iterator this will
178 become the source of the reads done by the caller.
179
180 Args:
181 continuation: A coroutine that proceeds with the invocation by
182 executing the next interceptor in the chain or invoking the
183 actual RPC on the underlying Channel. It is the interceptor's
184 responsibility to call it if it decides to move the RPC forward.
185 The interceptor can use
186 `call = await continuation(client_call_details, request)`
187 to continue with the RPC. `continuation` returns the call to the
188 RPC.
189 client_call_details: A ClientCallDetails object describing the
190 outgoing RPC.
191 request: The request value for the RPC.
192
193 Returns:
194 The RPC Call or an asynchronous iterator.
195
196 Raises:
197 AioRpcError: Indicating that the RPC terminated with non-OK status.
198 asyncio.CancelledError: Indicating that the RPC was canceled.
199 """
200
201
202class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
203 """Affords intercepting stream-unary invocations."""
204
205 @abstractmethod
206 async def intercept_stream_unary(
207 self,
208 continuation: Callable[
209 [ClientCallDetails, RequestType], StreamUnaryCall
210 ],
211 client_call_details: ClientCallDetails,
212 request_iterator: RequestIterableType,
213 ) -> StreamUnaryCall:
214 """Intercepts a stream-unary invocation asynchronously.
215
216 Within the interceptor the usage of the call methods like `write` or
217 even awaiting the call should be done carefully, since the caller
218 could be expecting an untouched call, for example for start writing
219 messages to it.
220
221 Args:
222 continuation: A coroutine that proceeds with the invocation by
223 executing the next interceptor in the chain or invoking the
224 actual RPC on the underlying Channel. It is the interceptor's
225 responsibility to call it if it decides to move the RPC forward.
226 The interceptor can use
227 `call = await continuation(client_call_details, request_iterator)`
228 to continue with the RPC. `continuation` returns the call to the
229 RPC.
230 client_call_details: A ClientCallDetails object describing the
231 outgoing RPC.
232 request_iterator: The request iterator that will produce requests
233 for the RPC.
234
235 Returns:
236 The RPC Call.
237
238 Raises:
239 AioRpcError: Indicating that the RPC terminated with non-OK status.
240 asyncio.CancelledError: Indicating that the RPC was canceled.
241 """
242
243
244class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
245 """Affords intercepting stream-stream invocations."""
246
247 @abstractmethod
248 async def intercept_stream_stream(
249 self,
250 continuation: Callable[
251 [ClientCallDetails, RequestType], StreamStreamCall
252 ],
253 client_call_details: ClientCallDetails,
254 request_iterator: RequestIterableType,
255 ) -> Union[ResponseIterableType, StreamStreamCall]:
256 """Intercepts a stream-stream invocation asynchronously.
257
258 Within the interceptor the usage of the call methods like `write` or
259 even awaiting the call should be done carefully, since the caller
260 could be expecting an untouched call, for example for start writing
261 messages to it.
262
263 The function could return the call object or an asynchronous
264 iterator, in case of being an asyncrhonous iterator this will
265 become the source of the reads done by the caller.
266
267 Args:
268 continuation: A coroutine that proceeds with the invocation by
269 executing the next interceptor in the chain or invoking the
270 actual RPC on the underlying Channel. It is the interceptor's
271 responsibility to call it if it decides to move the RPC forward.
272 The interceptor can use
273 `call = await continuation(client_call_details, request_iterator)`
274 to continue with the RPC. `continuation` returns the call to the
275 RPC.
276 client_call_details: A ClientCallDetails object describing the
277 outgoing RPC.
278 request_iterator: The request iterator that will produce requests
279 for the RPC.
280
281 Returns:
282 The RPC Call or an asynchronous iterator.
283
284 Raises:
285 AioRpcError: Indicating that the RPC terminated with non-OK status.
286 asyncio.CancelledError: Indicating that the RPC was canceled.
287 """
288
289
290class InterceptedCall:
291 """Base implementation for all intercepted call arities.
292
293 Interceptors might have some work to do before the RPC invocation with
294 the capacity of changing the invocation parameters, and some work to do
295 after the RPC invocation with the capacity for accessing to the wrapped
296 `UnaryUnaryCall`.
297
298 It handles also early and later cancellations, when the RPC has not even
299 started and the execution is still held by the interceptors or when the
300 RPC has finished but again the execution is still held by the interceptors.
301
302 Once the RPC is finally executed, all methods are finally done against the
303 intercepted call, being at the same time the same call returned to the
304 interceptors.
305
306 As a base class for all of the interceptors implements the logic around
307 final status, metadata and cancellation.
308 """
309
310 _interceptors_task: asyncio.Task
311 _pending_add_done_callbacks: Sequence[DoneCallbackType]
312
313 def __init__(self, interceptors_task: asyncio.Task) -> None:
314 self._interceptors_task = interceptors_task
315 self._pending_add_done_callbacks = []
316 self._interceptors_task.add_done_callback(
317 self._fire_or_add_pending_done_callbacks
318 )
319
320 def __del__(self):
321 self.cancel()
322
323 def _fire_or_add_pending_done_callbacks(
324 self, interceptors_task: asyncio.Task
325 ) -> None:
326 if not self._pending_add_done_callbacks:
327 return
328
329 call_completed = False
330
331 try:
332 call = interceptors_task.result()
333 if call.done():
334 call_completed = True
335 except (AioRpcError, asyncio.CancelledError):
336 call_completed = True
337
338 if call_completed:
339 for callback in self._pending_add_done_callbacks:
340 callback(self)
341 else:
342 for callback in self._pending_add_done_callbacks:
343 callback = functools.partial(
344 self._wrap_add_done_callback, callback
345 )
346 call.add_done_callback(callback)
347
348 self._pending_add_done_callbacks = []
349
350 def _wrap_add_done_callback(
351 self, callback: DoneCallbackType, unused_call: _base_call.Call
352 ) -> None:
353 callback(self)
354
355 def cancel(self) -> bool:
356 if not self._interceptors_task.done():
357 # There is no yet the intercepted call available,
358 # Trying to cancel it by using the generic Asyncio
359 # cancellation method.
360 return self._interceptors_task.cancel()
361
362 try:
363 call = self._interceptors_task.result()
364 except AioRpcError:
365 return False
366 except asyncio.CancelledError:
367 return False
368
369 return call.cancel()
370
371 def cancelled(self) -> bool:
372 if not self._interceptors_task.done():
373 return False
374
375 try:
376 call = self._interceptors_task.result()
377 except AioRpcError as err:
378 return err.code() == grpc.StatusCode.CANCELLED
379 except asyncio.CancelledError:
380 return True
381
382 return call.cancelled()
383
384 def done(self) -> bool:
385 if not self._interceptors_task.done():
386 return False
387
388 try:
389 call = self._interceptors_task.result()
390 except (AioRpcError, asyncio.CancelledError):
391 return True
392
393 return call.done()
394
395 def add_done_callback(self, callback: DoneCallbackType) -> None:
396 if not self._interceptors_task.done():
397 self._pending_add_done_callbacks.append(callback)
398 return
399
400 try:
401 call = self._interceptors_task.result()
402 except (AioRpcError, asyncio.CancelledError):
403 callback(self)
404 return
405
406 if call.done():
407 callback(self)
408 else:
409 callback = functools.partial(self._wrap_add_done_callback, callback)
410 call.add_done_callback(callback)
411
412 def time_remaining(self) -> Optional[float]:
413 raise NotImplementedError()
414
415 async def initial_metadata(self) -> Optional[Metadata]:
416 try:
417 call = await self._interceptors_task
418 except AioRpcError as err:
419 return err.initial_metadata()
420 except asyncio.CancelledError:
421 return None
422
423 return await call.initial_metadata()
424
425 async def trailing_metadata(self) -> Optional[Metadata]:
426 try:
427 call = await self._interceptors_task
428 except AioRpcError as err:
429 return err.trailing_metadata()
430 except asyncio.CancelledError:
431 return None
432
433 return await call.trailing_metadata()
434
435 async def code(self) -> grpc.StatusCode:
436 try:
437 call = await self._interceptors_task
438 except AioRpcError as err:
439 return err.code()
440 except asyncio.CancelledError:
441 return grpc.StatusCode.CANCELLED
442
443 return await call.code()
444
445 async def details(self) -> str:
446 try:
447 call = await self._interceptors_task
448 except AioRpcError as err:
449 return err.details()
450 except asyncio.CancelledError:
451 return _LOCAL_CANCELLATION_DETAILS
452
453 return await call.details()
454
455 async def debug_error_string(self) -> Optional[str]:
456 try:
457 call = await self._interceptors_task
458 except AioRpcError as err:
459 return err.debug_error_string()
460 except asyncio.CancelledError:
461 return ""
462
463 return await call.debug_error_string()
464
465 async def wait_for_connection(self) -> None:
466 call = await self._interceptors_task
467 return await call.wait_for_connection()
468
469
470class _InterceptedUnaryResponseMixin:
471 def __await__(self):
472 call = yield from self._interceptors_task.__await__()
473 response = yield from call.__await__()
474 return response
475
476
477class _InterceptedStreamResponseMixin:
478 _response_aiter: Optional[AsyncIterable[ResponseType]]
479
480 def _init_stream_response_mixin(self) -> None:
481 # Is initialized later, otherwise if the iterator is not finally
482 # consumed a logging warning is emitted by Asyncio.
483 self._response_aiter = None
484
485 async def _wait_for_interceptor_task_response_iterator(
486 self,
487 ) -> ResponseType:
488 call = await self._interceptors_task
489 async for response in call:
490 yield response
491
492 def __aiter__(self) -> AsyncIterable[ResponseType]:
493 if self._response_aiter is None:
494 self._response_aiter = (
495 self._wait_for_interceptor_task_response_iterator()
496 )
497 return self._response_aiter
498
499 async def read(self) -> Union[EOFType, ResponseType]:
500 if self._response_aiter is None:
501 self._response_aiter = (
502 self._wait_for_interceptor_task_response_iterator()
503 )
504 try:
505 return await self._response_aiter.asend(None)
506 except StopAsyncIteration:
507 return cygrpc.EOF
508
509
510class _InterceptedStreamRequestMixin:
511 _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
512 _write_to_iterator_queue: Optional[asyncio.Queue]
513 _status_code_task: Optional[asyncio.Task]
514
515 _FINISH_ITERATOR_SENTINEL = object()
516
517 def _init_stream_request_mixin(
518 self, request_iterator: Optional[RequestIterableType]
519 ) -> RequestIterableType:
520 if request_iterator is None:
521 # We provide our own request iterator which is a proxy
522 # of the futures writes that will be done by the caller.
523 self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
524 self._write_to_iterator_async_gen = (
525 self._proxy_writes_as_request_iterator()
526 )
527 self._status_code_task = None
528 request_iterator = self._write_to_iterator_async_gen
529 else:
530 self._write_to_iterator_queue = None
531
532 return request_iterator
533
534 async def _proxy_writes_as_request_iterator(self):
535 await self._interceptors_task
536
537 while True:
538 value = await self._write_to_iterator_queue.get()
539 if (
540 value
541 is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL
542 ):
543 break
544 yield value
545
546 async def _write_to_iterator_queue_interruptible(
547 self,
548 request: RequestType,
549 call: _base_call.Call,
550 ):
551 # Write the specified 'request' to the request iterator queue using the
552 # specified 'call' to allow for interruption of the write in the case
553 # of abrupt termination of the call.
554 if self._status_code_task is None:
555 self._status_code_task = self._loop.create_task(call.code())
556
557 await asyncio.wait(
558 (
559 self._loop.create_task(
560 self._write_to_iterator_queue.put(request)
561 ),
562 self._status_code_task,
563 ),
564 return_when=asyncio.FIRST_COMPLETED,
565 )
566
567 async def write(self, request: RequestType) -> None:
568 # If no queue was created it means that requests
569 # should be expected through an iterators provided
570 # by the caller.
571 if self._write_to_iterator_queue is None:
572 raise cygrpc.UsageError(_API_STYLE_ERROR)
573
574 try:
575 call = await self._interceptors_task
576 except (asyncio.CancelledError, AioRpcError):
577 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
578
579 if call.done():
580 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
581 elif call._done_writing_flag:
582 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
583
584 await self._write_to_iterator_queue_interruptible(request, call)
585
586 if call.done():
587 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
588
589 async def done_writing(self) -> None:
590 """Signal peer that client is done writing.
591
592 This method is idempotent.
593 """
594 # If no queue was created it means that requests
595 # should be expected through an iterators provided
596 # by the caller.
597 if self._write_to_iterator_queue is None:
598 raise cygrpc.UsageError(_API_STYLE_ERROR)
599
600 try:
601 call = await self._interceptors_task
602 except asyncio.CancelledError:
603 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
604
605 await self._write_to_iterator_queue_interruptible(
606 _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call
607 )
608
609
610class InterceptedUnaryUnaryCall(
611 _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall
612):
613 """Used for running a `UnaryUnaryCall` wrapped by interceptors.
614
615 For the `__await__` method is it is proxied to the intercepted call only when
616 the interceptor task is finished.
617 """
618
619 _loop: asyncio.AbstractEventLoop
620 _channel: cygrpc.AioChannel
621
622 # pylint: disable=too-many-arguments
623 def __init__(
624 self,
625 interceptors: Sequence[UnaryUnaryClientInterceptor],
626 request: RequestType,
627 timeout: Optional[float],
628 metadata: Metadata,
629 credentials: Optional[grpc.CallCredentials],
630 wait_for_ready: Optional[bool],
631 channel: cygrpc.AioChannel,
632 method: bytes,
633 request_serializer: Optional[SerializingFunction],
634 response_deserializer: Optional[DeserializingFunction],
635 loop: asyncio.AbstractEventLoop,
636 ) -> None:
637 self._loop = loop
638 self._channel = channel
639 interceptors_task = loop.create_task(
640 self._invoke(
641 interceptors,
642 method,
643 timeout,
644 metadata,
645 credentials,
646 wait_for_ready,
647 request,
648 request_serializer,
649 response_deserializer,
650 )
651 )
652 super().__init__(interceptors_task)
653
654 # pylint: disable=too-many-arguments
655 async def _invoke(
656 self,
657 interceptors: Sequence[UnaryUnaryClientInterceptor],
658 method: bytes,
659 timeout: Optional[float],
660 metadata: Optional[Metadata],
661 credentials: Optional[grpc.CallCredentials],
662 wait_for_ready: Optional[bool],
663 request: RequestType,
664 request_serializer: Optional[SerializingFunction],
665 response_deserializer: Optional[DeserializingFunction],
666 ) -> Union[UnaryUnaryCall, UnaryUnaryCallResponse]:
667 """Run the RPC call wrapped in interceptors"""
668
669 async def _run_interceptor(
670 interceptors: List[UnaryUnaryClientInterceptor],
671 client_call_details: ClientCallDetails,
672 request: RequestType,
673 ) -> Union[UnaryUnaryCall, UnaryUnaryCallResponse]:
674 if interceptors:
675 continuation = functools.partial(
676 _run_interceptor, interceptors[1:]
677 )
678 call_or_response = await interceptors[0].intercept_unary_unary(
679 continuation, client_call_details, request
680 )
681
682 if isinstance(call_or_response, _base_call.UnaryUnaryCall):
683 return call_or_response
684 else:
685 return UnaryUnaryCallResponse(call_or_response)
686
687 else:
688 return UnaryUnaryCall(
689 request,
690 _timeout_to_deadline(client_call_details.timeout),
691 client_call_details.metadata,
692 client_call_details.credentials,
693 client_call_details.wait_for_ready,
694 self._channel,
695 client_call_details.method,
696 request_serializer,
697 response_deserializer,
698 self._loop,
699 )
700
701 client_call_details = ClientCallDetails(
702 method, timeout, metadata, credentials, wait_for_ready
703 )
704 return await _run_interceptor(
705 list(interceptors), client_call_details, request
706 )
707
708 def time_remaining(self) -> Optional[float]:
709 raise NotImplementedError()
710
711
712class InterceptedUnaryStreamCall(
713 _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall
714):
715 """Used for running a `UnaryStreamCall` wrapped by interceptors."""
716
717 _loop: asyncio.AbstractEventLoop
718 _channel: cygrpc.AioChannel
719 _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
720
721 # pylint: disable=too-many-arguments
722 def __init__(
723 self,
724 interceptors: Sequence[UnaryStreamClientInterceptor],
725 request: RequestType,
726 timeout: Optional[float],
727 metadata: Metadata,
728 credentials: Optional[grpc.CallCredentials],
729 wait_for_ready: Optional[bool],
730 channel: cygrpc.AioChannel,
731 method: bytes,
732 request_serializer: Optional[SerializingFunction],
733 response_deserializer: Optional[DeserializingFunction],
734 loop: asyncio.AbstractEventLoop,
735 ) -> None:
736 self._loop = loop
737 self._channel = channel
738 self._init_stream_response_mixin()
739 self._last_returned_call_from_interceptors = None
740 interceptors_task = loop.create_task(
741 self._invoke(
742 interceptors,
743 method,
744 timeout,
745 metadata,
746 credentials,
747 wait_for_ready,
748 request,
749 request_serializer,
750 response_deserializer,
751 )
752 )
753 super().__init__(interceptors_task)
754
755 # pylint: disable=too-many-arguments
756 async def _invoke(
757 self,
758 interceptors: Sequence[UnaryStreamClientInterceptor],
759 method: bytes,
760 timeout: Optional[float],
761 metadata: Optional[Metadata],
762 credentials: Optional[grpc.CallCredentials],
763 wait_for_ready: Optional[bool],
764 request: RequestType,
765 request_serializer: Optional[SerializingFunction],
766 response_deserializer: Optional[DeserializingFunction],
767 ) -> Union[UnaryStreamCall, UnaryStreamCallResponseIterator]:
768 """Run the RPC call wrapped in interceptors"""
769
770 async def _run_interceptor(
771 interceptors: List[UnaryStreamClientInterceptor],
772 client_call_details: ClientCallDetails,
773 request: RequestType,
774 ) -> Union[UnaryStreamCall, UnaryStreamCallResponseIterator]:
775 if interceptors:
776 continuation = functools.partial(
777 _run_interceptor, interceptors[1:]
778 )
779
780 call_or_response_iterator = await interceptors[
781 0
782 ].intercept_unary_stream(
783 continuation, client_call_details, request
784 )
785
786 if isinstance(
787 call_or_response_iterator, _base_call.UnaryStreamCall
788 ):
789 self._last_returned_call_from_interceptors = (
790 call_or_response_iterator
791 )
792 else:
793 self._last_returned_call_from_interceptors = (
794 UnaryStreamCallResponseIterator(
795 self._last_returned_call_from_interceptors,
796 call_or_response_iterator,
797 )
798 )
799 return self._last_returned_call_from_interceptors
800 else:
801 self._last_returned_call_from_interceptors = UnaryStreamCall(
802 request,
803 _timeout_to_deadline(client_call_details.timeout),
804 client_call_details.metadata,
805 client_call_details.credentials,
806 client_call_details.wait_for_ready,
807 self._channel,
808 client_call_details.method,
809 request_serializer,
810 response_deserializer,
811 self._loop,
812 )
813
814 return self._last_returned_call_from_interceptors
815
816 client_call_details = ClientCallDetails(
817 method, timeout, metadata, credentials, wait_for_ready
818 )
819 return await _run_interceptor(
820 list(interceptors), client_call_details, request
821 )
822
823 def time_remaining(self) -> Optional[float]:
824 raise NotImplementedError()
825
826
827class InterceptedStreamUnaryCall(
828 _InterceptedUnaryResponseMixin,
829 _InterceptedStreamRequestMixin,
830 InterceptedCall,
831 _base_call.StreamUnaryCall,
832):
833 """Used for running a `StreamUnaryCall` wrapped by interceptors.
834
835 For the `__await__` method is it is proxied to the intercepted call only when
836 the interceptor task is finished.
837 """
838
839 _loop: asyncio.AbstractEventLoop
840 _channel: cygrpc.AioChannel
841
842 # pylint: disable=too-many-arguments
843 def __init__(
844 self,
845 interceptors: Sequence[StreamUnaryClientInterceptor],
846 request_iterator: Optional[RequestIterableType],
847 timeout: Optional[float],
848 metadata: Metadata,
849 credentials: Optional[grpc.CallCredentials],
850 wait_for_ready: Optional[bool],
851 channel: cygrpc.AioChannel,
852 method: bytes,
853 request_serializer: Optional[SerializingFunction],
854 response_deserializer: Optional[DeserializingFunction],
855 loop: asyncio.AbstractEventLoop,
856 ) -> None:
857 self._loop = loop
858 self._channel = channel
859 request_iterator = self._init_stream_request_mixin(request_iterator)
860 interceptors_task = loop.create_task(
861 self._invoke(
862 interceptors,
863 method,
864 timeout,
865 metadata,
866 credentials,
867 wait_for_ready,
868 request_iterator,
869 request_serializer,
870 response_deserializer,
871 )
872 )
873 super().__init__(interceptors_task)
874
875 # pylint: disable=too-many-arguments
876 async def _invoke(
877 self,
878 interceptors: Sequence[StreamUnaryClientInterceptor],
879 method: bytes,
880 timeout: Optional[float],
881 metadata: Optional[Metadata],
882 credentials: Optional[grpc.CallCredentials],
883 wait_for_ready: Optional[bool],
884 request_iterator: RequestIterableType,
885 request_serializer: Optional[SerializingFunction],
886 response_deserializer: Optional[DeserializingFunction],
887 ) -> StreamUnaryCall:
888 """Run the RPC call wrapped in interceptors"""
889
890 async def _run_interceptor(
891 interceptors: Sequence[StreamUnaryClientInterceptor],
892 client_call_details: ClientCallDetails,
893 request_iterator: RequestIterableType,
894 ) -> _base_call.StreamUnaryCall:
895 if interceptors:
896 continuation = functools.partial(
897 _run_interceptor, interceptors[1:]
898 )
899
900 return await interceptors[0].intercept_stream_unary(
901 continuation, client_call_details, request_iterator
902 )
903 else:
904 return StreamUnaryCall(
905 request_iterator,
906 _timeout_to_deadline(client_call_details.timeout),
907 client_call_details.metadata,
908 client_call_details.credentials,
909 client_call_details.wait_for_ready,
910 self._channel,
911 client_call_details.method,
912 request_serializer,
913 response_deserializer,
914 self._loop,
915 )
916
917 client_call_details = ClientCallDetails(
918 method, timeout, metadata, credentials, wait_for_ready
919 )
920 return await _run_interceptor(
921 list(interceptors), client_call_details, request_iterator
922 )
923
924 def time_remaining(self) -> Optional[float]:
925 raise NotImplementedError()
926
927
928class InterceptedStreamStreamCall(
929 _InterceptedStreamResponseMixin,
930 _InterceptedStreamRequestMixin,
931 InterceptedCall,
932 _base_call.StreamStreamCall,
933):
934 """Used for running a `StreamStreamCall` wrapped by interceptors."""
935
936 _loop: asyncio.AbstractEventLoop
937 _channel: cygrpc.AioChannel
938 _last_returned_call_from_interceptors = Optional[
939 _base_call.StreamStreamCall
940 ]
941
942 # pylint: disable=too-many-arguments
943 def __init__(
944 self,
945 interceptors: Sequence[StreamStreamClientInterceptor],
946 request_iterator: Optional[RequestIterableType],
947 timeout: Optional[float],
948 metadata: Metadata,
949 credentials: Optional[grpc.CallCredentials],
950 wait_for_ready: Optional[bool],
951 channel: cygrpc.AioChannel,
952 method: bytes,
953 request_serializer: Optional[SerializingFunction],
954 response_deserializer: Optional[DeserializingFunction],
955 loop: asyncio.AbstractEventLoop,
956 ) -> None:
957 self._loop = loop
958 self._channel = channel
959 self._init_stream_response_mixin()
960 request_iterator = self._init_stream_request_mixin(request_iterator)
961 self._last_returned_call_from_interceptors = None
962 interceptors_task = loop.create_task(
963 self._invoke(
964 interceptors,
965 method,
966 timeout,
967 metadata,
968 credentials,
969 wait_for_ready,
970 request_iterator,
971 request_serializer,
972 response_deserializer,
973 )
974 )
975 super().__init__(interceptors_task)
976
977 # pylint: disable=too-many-arguments
978 async def _invoke(
979 self,
980 interceptors: Sequence[StreamStreamClientInterceptor],
981 method: bytes,
982 timeout: Optional[float],
983 metadata: Optional[Metadata],
984 credentials: Optional[grpc.CallCredentials],
985 wait_for_ready: Optional[bool],
986 request_iterator: RequestIterableType,
987 request_serializer: Optional[SerializingFunction],
988 response_deserializer: Optional[DeserializingFunction],
989 ) -> Union[StreamStreamCall, StreamStreamCallResponseIterator]:
990 """Run the RPC call wrapped in interceptors"""
991
992 async def _run_interceptor(
993 interceptors: List[StreamStreamClientInterceptor],
994 client_call_details: ClientCallDetails,
995 request_iterator: RequestIterableType,
996 ) -> Union[StreamStreamCall, StreamStreamCallResponseIterator]:
997 if interceptors:
998 continuation = functools.partial(
999 _run_interceptor, interceptors[1:]
1000 )
1001
1002 call_or_response_iterator = await interceptors[
1003 0
1004 ].intercept_stream_stream(
1005 continuation, client_call_details, request_iterator
1006 )
1007
1008 if isinstance(
1009 call_or_response_iterator, _base_call.StreamStreamCall
1010 ):
1011 self._last_returned_call_from_interceptors = (
1012 call_or_response_iterator
1013 )
1014 else:
1015 self._last_returned_call_from_interceptors = (
1016 StreamStreamCallResponseIterator(
1017 self._last_returned_call_from_interceptors,
1018 call_or_response_iterator,
1019 )
1020 )
1021 return self._last_returned_call_from_interceptors
1022 else:
1023 self._last_returned_call_from_interceptors = StreamStreamCall(
1024 request_iterator,
1025 _timeout_to_deadline(client_call_details.timeout),
1026 client_call_details.metadata,
1027 client_call_details.credentials,
1028 client_call_details.wait_for_ready,
1029 self._channel,
1030 client_call_details.method,
1031 request_serializer,
1032 response_deserializer,
1033 self._loop,
1034 )
1035 return self._last_returned_call_from_interceptors
1036
1037 client_call_details = ClientCallDetails(
1038 method, timeout, metadata, credentials, wait_for_ready
1039 )
1040 return await _run_interceptor(
1041 list(interceptors), client_call_details, request_iterator
1042 )
1043
1044 def time_remaining(self) -> Optional[float]:
1045 raise NotImplementedError()
1046
1047
1048class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
1049 """Final UnaryUnaryCall class finished with a response."""
1050
1051 _response: ResponseType
1052
1053 def __init__(self, response: ResponseType) -> None:
1054 self._response = response
1055
1056 def cancel(self) -> bool:
1057 return False
1058
1059 def cancelled(self) -> bool:
1060 return False
1061
1062 def done(self) -> bool:
1063 return True
1064
1065 def add_done_callback(self, unused_callback) -> None:
1066 raise NotImplementedError()
1067
1068 def time_remaining(self) -> Optional[float]:
1069 raise NotImplementedError()
1070
1071 async def initial_metadata(self) -> Optional[Metadata]:
1072 return None
1073
1074 async def trailing_metadata(self) -> Optional[Metadata]:
1075 return None
1076
1077 async def code(self) -> grpc.StatusCode:
1078 return grpc.StatusCode.OK
1079
1080 async def details(self) -> str:
1081 return ""
1082
1083 async def debug_error_string(self) -> Optional[str]:
1084 return None
1085
1086 def __await__(self):
1087 if False: # pylint: disable=using-constant-test
1088 # This code path is never used, but a yield statement is needed
1089 # for telling the interpreter that __await__ is a generator.
1090 yield None
1091 return self._response
1092
1093 async def wait_for_connection(self) -> None:
1094 pass
1095
1096
1097class _StreamCallResponseIterator:
1098 _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
1099 _response_iterator: AsyncIterable[ResponseType]
1100
1101 def __init__(
1102 self,
1103 call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall],
1104 response_iterator: AsyncIterable[ResponseType],
1105 ) -> None:
1106 self._response_iterator = response_iterator
1107 self._call = call
1108
1109 def cancel(self) -> bool:
1110 return self._call.cancel()
1111
1112 def cancelled(self) -> bool:
1113 return self._call.cancelled()
1114
1115 def done(self) -> bool:
1116 return self._call.done()
1117
1118 def add_done_callback(self, callback) -> None:
1119 self._call.add_done_callback(callback)
1120
1121 def time_remaining(self) -> Optional[float]:
1122 return self._call.time_remaining()
1123
1124 async def initial_metadata(self) -> Optional[Metadata]:
1125 return await self._call.initial_metadata()
1126
1127 async def trailing_metadata(self) -> Optional[Metadata]:
1128 return await self._call.trailing_metadata()
1129
1130 async def code(self) -> grpc.StatusCode:
1131 return await self._call.code()
1132
1133 async def details(self) -> str:
1134 return await self._call.details()
1135
1136 async def debug_error_string(self) -> Optional[str]:
1137 return await self._call.debug_error_string()
1138
1139 def __aiter__(self):
1140 return self._response_iterator.__aiter__()
1141
1142 async def wait_for_connection(self) -> None:
1143 return await self._call.wait_for_connection()
1144
1145
1146class UnaryStreamCallResponseIterator(
1147 _StreamCallResponseIterator, _base_call.UnaryStreamCall
1148):
1149 """UnaryStreamCall class which uses an alternative response iterator."""
1150
1151 async def read(self) -> Union[EOFType, ResponseType]:
1152 # Behind the scenes everything goes through the
1153 # async iterator. So this path should not be reached.
1154 raise NotImplementedError()
1155
1156
1157class StreamStreamCallResponseIterator(
1158 _StreamCallResponseIterator, _base_call.StreamStreamCall
1159):
1160 """StreamStreamCall class which uses an alternative response iterator."""
1161
1162 async def read(self) -> Union[EOFType, ResponseType]:
1163 # Behind the scenes everything goes through the
1164 # async iterator. So this path should not be reached.
1165 raise NotImplementedError()
1166
1167 async def write(self, request: RequestType) -> None:
1168 # Behind the scenes everything goes through the
1169 # async iterator provided by the InterceptedStreamStreamCall.
1170 # So this path should not be reached.
1171 raise NotImplementedError()
1172
1173 async def done_writing(self) -> None:
1174 # Behind the scenes everything goes through the
1175 # async iterator provided by the InterceptedStreamStreamCall.
1176 # So this path should not be reached.
1177 raise NotImplementedError()
1178
1179 @property
1180 def _done_writing_flag(self) -> bool:
1181 return self._call._done_writing_flag