1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Interceptors implementation of gRPC Asyncio Python."""
15from __future__ import annotations
16
17from abc import ABCMeta
18from abc import abstractmethod
19import asyncio
20import collections
21import functools
22from typing import (
23 AsyncIterable,
24 AsyncIterator,
25 Awaitable,
26 Callable,
27 List,
28 Optional,
29 Sequence,
30 Union,
31)
32
33import grpc
34from grpc._cython import cygrpc
35
36from . import _base_call
37from ._call import AioRpcError
38from ._call import StreamStreamCall
39from ._call import StreamUnaryCall
40from ._call import UnaryStreamCall
41from ._call import UnaryUnaryCall
42from ._call import _API_STYLE_ERROR
43from ._call import _RPC_ALREADY_FINISHED_DETAILS
44from ._call import _RPC_HALF_CLOSED_DETAILS
45from ._metadata import Metadata
46from ._typing import DeserializingFunction
47from ._typing import DoneCallbackType
48from ._typing import EOFType
49from ._typing import RequestIterableType
50from ._typing import RequestType
51from ._typing import ResponseIterableType
52from ._typing import ResponseType
53from ._typing import SerializingFunction
54from ._utils import _timeout_to_deadline
55
56_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
57
58
59class ServerInterceptor(metaclass=ABCMeta):
60 """Affords intercepting incoming RPCs on the service-side.
61
62 This is an EXPERIMENTAL API.
63 """
64
65 @abstractmethod
66 async def intercept_service(
67 self,
68 continuation: Callable[
69 [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]
70 ],
71 handler_call_details: grpc.HandlerCallDetails,
72 ) -> grpc.RpcMethodHandler:
73 """Intercepts incoming RPCs before handing them over to a handler.
74
75 State can be passed from an interceptor to downstream interceptors
76 via contextvars. The first interceptor is called from an empty
77 contextvars.Context, and the same Context is used for downstream
78 interceptors and for the final handler call. Note that there are no
79 guarantees that interceptors and handlers will be called from the
80 same thread.
81
82 Args:
83 continuation: A function that takes a HandlerCallDetails and
84 proceeds to invoke the next interceptor in the chain, if any,
85 or the RPC handler lookup logic, with the call details passed
86 as an argument, and returns an RpcMethodHandler instance if
87 the RPC is considered serviced, or None otherwise.
88 handler_call_details: A HandlerCallDetails describing the RPC.
89
90 Returns:
91 An RpcMethodHandler with which the RPC may be serviced if the
92 interceptor chooses to service this RPC, or None otherwise.
93 """
94
95
96class ClientCallDetails(
97 collections.namedtuple(
98 "ClientCallDetails",
99 ("method", "timeout", "metadata", "credentials", "wait_for_ready"),
100 ),
101 grpc.ClientCallDetails,
102):
103 """Describes an RPC to be invoked.
104
105 This is an EXPERIMENTAL API.
106
107 Args:
108 method: The method name of the RPC.
109 timeout: An optional duration of time in seconds to allow for the RPC.
110 metadata: Optional metadata to be transmitted to the service-side of
111 the RPC.
112 credentials: An optional CallCredentials for the RPC.
113 wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism.
114 """
115
116 method: bytes
117 timeout: Optional[float]
118 metadata: Optional[Metadata]
119 credentials: Optional[grpc.CallCredentials]
120 wait_for_ready: Optional[bool]
121
122
123class ClientInterceptor(metaclass=ABCMeta):
124 """Base class used for all Aio Client Interceptor classes"""
125
126
127class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
128 """Affords intercepting unary-unary invocations."""
129
130 @abstractmethod
131 async def intercept_unary_unary(
132 self,
133 continuation: Callable[
134 [ClientCallDetails, RequestType], UnaryUnaryCall
135 ],
136 client_call_details: ClientCallDetails,
137 request: RequestType,
138 ) -> Union[UnaryUnaryCall, ResponseType]:
139 """Intercepts a unary-unary invocation asynchronously.
140
141 Args:
142 continuation: A coroutine that proceeds with the invocation by
143 executing the next interceptor in the chain or invoking the
144 actual RPC on the underlying Channel. It is the interceptor's
145 responsibility to call it if it decides to move the RPC forward.
146 The interceptor can use
147 `call = await continuation(client_call_details, request)`
148 to continue with the RPC. `continuation` returns the call to the
149 RPC.
150 client_call_details: A ClientCallDetails object describing the
151 outgoing RPC.
152 request: The request value for the RPC.
153
154 Returns:
155 An object with the RPC response.
156
157 Raises:
158 AioRpcError: Indicating that the RPC terminated with non-OK status.
159 asyncio.CancelledError: Indicating that the RPC was canceled.
160 """
161
162
163class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
164 """Affords intercepting unary-stream invocations."""
165
166 @abstractmethod
167 async def intercept_unary_stream(
168 self,
169 continuation: Callable[
170 [ClientCallDetails, RequestType], UnaryStreamCall
171 ],
172 client_call_details: ClientCallDetails,
173 request: RequestType,
174 ) -> Union[ResponseIterableType, UnaryStreamCall]:
175 """Intercepts a unary-stream invocation asynchronously.
176
177 The function could return the call object or an asynchronous
178 iterator, in case of being an asyncrhonous iterator this will
179 become the source of the reads done by the caller.
180
181 Args:
182 continuation: A coroutine that proceeds with the invocation by
183 executing the next interceptor in the chain or invoking the
184 actual RPC on the underlying Channel. It is the interceptor's
185 responsibility to call it if it decides to move the RPC forward.
186 The interceptor can use
187 `call = await continuation(client_call_details, request)`
188 to continue with the RPC. `continuation` returns the call to the
189 RPC.
190 client_call_details: A ClientCallDetails object describing the
191 outgoing RPC.
192 request: The request value for the RPC.
193
194 Returns:
195 The RPC Call or an asynchronous iterator.
196
197 Raises:
198 AioRpcError: Indicating that the RPC terminated with non-OK status.
199 asyncio.CancelledError: Indicating that the RPC was canceled.
200 """
201
202
203class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
204 """Affords intercepting stream-unary invocations."""
205
206 @abstractmethod
207 async def intercept_stream_unary(
208 self,
209 continuation: Callable[
210 [ClientCallDetails, RequestType], StreamUnaryCall
211 ],
212 client_call_details: ClientCallDetails,
213 request_iterator: RequestIterableType,
214 ) -> StreamUnaryCall:
215 """Intercepts a stream-unary invocation asynchronously.
216
217 Within the interceptor the usage of the call methods like `write` or
218 even awaiting the call should be done carefully, since the caller
219 could be expecting an untouched call, for example for start writing
220 messages to it.
221
222 Args:
223 continuation: A coroutine that proceeds with the invocation by
224 executing the next interceptor in the chain or invoking the
225 actual RPC on the underlying Channel. It is the interceptor's
226 responsibility to call it if it decides to move the RPC forward.
227 The interceptor can use
228 `call = await continuation(client_call_details, request_iterator)`
229 to continue with the RPC. `continuation` returns the call to the
230 RPC.
231 client_call_details: A ClientCallDetails object describing the
232 outgoing RPC.
233 request_iterator: The request iterator that will produce requests
234 for the RPC.
235
236 Returns:
237 The RPC Call.
238
239 Raises:
240 AioRpcError: Indicating that the RPC terminated with non-OK status.
241 asyncio.CancelledError: Indicating that the RPC was canceled.
242 """
243
244
245class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
246 """Affords intercepting stream-stream invocations."""
247
248 @abstractmethod
249 async def intercept_stream_stream(
250 self,
251 continuation: Callable[
252 [ClientCallDetails, RequestType], StreamStreamCall
253 ],
254 client_call_details: ClientCallDetails,
255 request_iterator: RequestIterableType,
256 ) -> Union[ResponseIterableType, StreamStreamCall]:
257 """Intercepts a stream-stream invocation asynchronously.
258
259 Within the interceptor the usage of the call methods like `write` or
260 even awaiting the call should be done carefully, since the caller
261 could be expecting an untouched call, for example for start writing
262 messages to it.
263
264 The function could return the call object or an asynchronous
265 iterator, in case of being an asyncrhonous iterator this will
266 become the source of the reads done by the caller.
267
268 Args:
269 continuation: A coroutine that proceeds with the invocation by
270 executing the next interceptor in the chain or invoking the
271 actual RPC on the underlying Channel. It is the interceptor's
272 responsibility to call it if it decides to move the RPC forward.
273 The interceptor can use
274 `call = await continuation(client_call_details, request_iterator)`
275 to continue with the RPC. `continuation` returns the call to the
276 RPC.
277 client_call_details: A ClientCallDetails object describing the
278 outgoing RPC.
279 request_iterator: The request iterator that will produce requests
280 for the RPC.
281
282 Returns:
283 The RPC Call or an asynchronous iterator.
284
285 Raises:
286 AioRpcError: Indicating that the RPC terminated with non-OK status.
287 asyncio.CancelledError: Indicating that the RPC was canceled.
288 """
289
290
291class InterceptedCall:
292 """Base implementation for all intercepted call arities.
293
294 Interceptors might have some work to do before the RPC invocation with
295 the capacity of changing the invocation parameters, and some work to do
296 after the RPC invocation with the capacity for accessing to the wrapped
297 `UnaryUnaryCall`.
298
299 It handles also early and later cancellations, when the RPC has not even
300 started and the execution is still held by the interceptors or when the
301 RPC has finished but again the execution is still held by the interceptors.
302
303 Once the RPC is finally executed, all methods are finally done against the
304 intercepted call, being at the same time the same call returned to the
305 interceptors.
306
307 As a base class for all of the interceptors implements the logic around
308 final status, metadata and cancellation.
309 """
310
311 _interceptors_task: asyncio.Task
312 _pending_add_done_callbacks: Sequence[DoneCallbackType]
313
314 def __init__(self, interceptors_task: asyncio.Task) -> None:
315 self._interceptors_task = interceptors_task
316 self._pending_add_done_callbacks = []
317 self._interceptors_task.add_done_callback(
318 self._fire_or_add_pending_done_callbacks
319 )
320
321 def __del__(self):
322 self.cancel()
323
324 def _fire_or_add_pending_done_callbacks(
325 self, interceptors_task: asyncio.Task
326 ) -> None:
327 if not self._pending_add_done_callbacks:
328 return
329
330 call_completed = False
331
332 try:
333 call = interceptors_task.result()
334 if call.done():
335 call_completed = True
336 except (AioRpcError, asyncio.CancelledError):
337 call_completed = True
338
339 if call_completed:
340 for callback in self._pending_add_done_callbacks:
341 callback(self)
342 else:
343 for callback in self._pending_add_done_callbacks:
344 callback = functools.partial(
345 self._wrap_add_done_callback, callback
346 )
347 call.add_done_callback(callback)
348
349 self._pending_add_done_callbacks = []
350
351 def _wrap_add_done_callback(
352 self, callback: DoneCallbackType, unused_call: _base_call.Call
353 ) -> None:
354 callback(self)
355
356 def cancel(self) -> bool:
357 if not self._interceptors_task.done():
358 # There is no yet the intercepted call available,
359 # Trying to cancel it by using the generic Asyncio
360 # cancellation method.
361 return self._interceptors_task.cancel()
362
363 try:
364 call = self._interceptors_task.result()
365 except AioRpcError:
366 return False
367 except asyncio.CancelledError:
368 return False
369
370 return call.cancel()
371
372 def cancelled(self) -> bool:
373 if not self._interceptors_task.done():
374 return False
375
376 try:
377 call = self._interceptors_task.result()
378 except AioRpcError as err:
379 return err.code() == grpc.StatusCode.CANCELLED
380 except asyncio.CancelledError:
381 return True
382
383 return call.cancelled()
384
385 def done(self) -> bool:
386 if not self._interceptors_task.done():
387 return False
388
389 try:
390 call = self._interceptors_task.result()
391 except (AioRpcError, asyncio.CancelledError):
392 return True
393
394 return call.done()
395
396 def add_done_callback(self, callback: DoneCallbackType) -> None:
397 if not self._interceptors_task.done():
398 self._pending_add_done_callbacks.append(callback)
399 return
400
401 try:
402 call = self._interceptors_task.result()
403 except (AioRpcError, asyncio.CancelledError):
404 callback(self)
405 return
406
407 if call.done():
408 callback(self)
409 else:
410 callback = functools.partial(self._wrap_add_done_callback, callback)
411 call.add_done_callback(callback)
412
413 def time_remaining(self) -> Optional[float]:
414 raise NotImplementedError()
415
416 async def initial_metadata(self) -> Optional[Metadata]:
417 try:
418 call = await self._interceptors_task
419 except AioRpcError as err:
420 return err.initial_metadata()
421 except asyncio.CancelledError:
422 return None
423
424 return await call.initial_metadata()
425
426 async def trailing_metadata(self) -> Optional[Metadata]:
427 try:
428 call = await self._interceptors_task
429 except AioRpcError as err:
430 return err.trailing_metadata()
431 except asyncio.CancelledError:
432 return None
433
434 return await call.trailing_metadata()
435
436 async def code(self) -> grpc.StatusCode:
437 try:
438 call = await self._interceptors_task
439 except AioRpcError as err:
440 return err.code()
441 except asyncio.CancelledError:
442 return grpc.StatusCode.CANCELLED
443
444 return await call.code()
445
446 async def details(self) -> str:
447 try:
448 call = await self._interceptors_task
449 except AioRpcError as err:
450 return err.details()
451 except asyncio.CancelledError:
452 return _LOCAL_CANCELLATION_DETAILS
453
454 return await call.details()
455
456 async def debug_error_string(self) -> Optional[str]:
457 try:
458 call = await self._interceptors_task
459 except AioRpcError as err:
460 return err.debug_error_string()
461 except asyncio.CancelledError:
462 return ""
463
464 return await call.debug_error_string()
465
466 async def wait_for_connection(self) -> None:
467 call = await self._interceptors_task
468 return await call.wait_for_connection()
469
470
471class _InterceptedUnaryResponseMixin:
472 def __await__(self):
473 call = yield from self._interceptors_task.__await__()
474 response = yield from call.__await__()
475 return response
476
477
478class _InterceptedStreamResponseMixin:
479 _response_aiter: Optional[AsyncIterable[ResponseType]]
480
481 def _init_stream_response_mixin(self) -> None:
482 # Is initialized later, otherwise if the iterator is not finally
483 # consumed a logging warning is emitted by Asyncio.
484 self._response_aiter = None
485
486 async def _wait_for_interceptor_task_response_iterator(
487 self,
488 ) -> ResponseType:
489 call = await self._interceptors_task
490 async for response in call:
491 yield response
492
493 def __aiter__(self) -> AsyncIterator[ResponseType]:
494 if self._response_aiter is None:
495 self._response_aiter = (
496 self._wait_for_interceptor_task_response_iterator()
497 )
498 return self._response_aiter
499
500 async def read(self) -> Union[EOFType, ResponseType]:
501 if self._response_aiter is None:
502 self._response_aiter = (
503 self._wait_for_interceptor_task_response_iterator()
504 )
505 try:
506 return await self._response_aiter.asend(None)
507 except StopAsyncIteration:
508 return cygrpc.EOF
509
510
511class _InterceptedStreamRequestMixin:
512 _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
513 _write_to_iterator_queue: Optional[asyncio.Queue]
514 _status_code_task: Optional[asyncio.Task]
515
516 _FINISH_ITERATOR_SENTINEL = object()
517
518 def _init_stream_request_mixin(
519 self, request_iterator: Optional[RequestIterableType]
520 ) -> RequestIterableType:
521 if request_iterator is None:
522 # We provide our own request iterator which is a proxy
523 # of the futures writes that will be done by the caller.
524 self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
525 self._write_to_iterator_async_gen = (
526 self._proxy_writes_as_request_iterator()
527 )
528 self._status_code_task = None
529 request_iterator = self._write_to_iterator_async_gen
530 else:
531 self._write_to_iterator_queue = None
532
533 return request_iterator
534
535 async def _proxy_writes_as_request_iterator(self):
536 await self._interceptors_task
537
538 while True:
539 value = await self._write_to_iterator_queue.get()
540 if (
541 value
542 is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL
543 ):
544 break
545 yield value
546
547 async def _write_to_iterator_queue_interruptible(
548 self,
549 request: RequestType,
550 call: _base_call.Call,
551 ):
552 # Write the specified 'request' to the request iterator queue using the
553 # specified 'call' to allow for interruption of the write in the case
554 # of abrupt termination of the call.
555 if self._status_code_task is None:
556 self._status_code_task = self._loop.create_task(call.code())
557
558 await asyncio.wait(
559 (
560 self._loop.create_task(
561 self._write_to_iterator_queue.put(request)
562 ),
563 self._status_code_task,
564 ),
565 return_when=asyncio.FIRST_COMPLETED,
566 )
567
568 async def write(self, request: RequestType) -> None:
569 # If no queue was created it means that requests
570 # should be expected through an iterators provided
571 # by the caller.
572 if self._write_to_iterator_queue is None:
573 raise cygrpc.UsageError(_API_STYLE_ERROR)
574
575 try:
576 call = await self._interceptors_task
577 except (asyncio.CancelledError, AioRpcError):
578 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
579
580 if call.done():
581 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
582 elif call._done_writing_flag:
583 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
584
585 await self._write_to_iterator_queue_interruptible(request, call)
586
587 if call.done():
588 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
589
590 async def done_writing(self) -> None:
591 """Signal peer that client is done writing.
592
593 This method is idempotent.
594 """
595 # If no queue was created it means that requests
596 # should be expected through an iterators provided
597 # by the caller.
598 if self._write_to_iterator_queue is None:
599 raise cygrpc.UsageError(_API_STYLE_ERROR)
600
601 try:
602 call = await self._interceptors_task
603 except asyncio.CancelledError:
604 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
605
606 await self._write_to_iterator_queue_interruptible(
607 _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call
608 )
609
610
611class InterceptedUnaryUnaryCall(
612 _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall
613):
614 """Used for running a `UnaryUnaryCall` wrapped by interceptors.
615
616 For the `__await__` method is it is proxied to the intercepted call only when
617 the interceptor task is finished.
618 """
619
620 _loop: asyncio.AbstractEventLoop
621 _channel: cygrpc.AioChannel
622
623 # pylint: disable=too-many-arguments
624 def __init__(
625 self,
626 interceptors: Sequence[UnaryUnaryClientInterceptor],
627 request: RequestType,
628 timeout: Optional[float],
629 metadata: Metadata,
630 credentials: Optional[grpc.CallCredentials],
631 wait_for_ready: Optional[bool],
632 channel: cygrpc.AioChannel,
633 method: bytes,
634 request_serializer: Optional[SerializingFunction],
635 response_deserializer: Optional[DeserializingFunction],
636 loop: asyncio.AbstractEventLoop,
637 ) -> None:
638 self._loop = loop
639 self._channel = channel
640 interceptors_task = loop.create_task(
641 self._invoke(
642 interceptors,
643 method,
644 timeout,
645 metadata,
646 credentials,
647 wait_for_ready,
648 request,
649 request_serializer,
650 response_deserializer,
651 )
652 )
653 super().__init__(interceptors_task)
654
655 # pylint: disable=too-many-arguments
656 async def _invoke(
657 self,
658 interceptors: Sequence[UnaryUnaryClientInterceptor],
659 method: bytes,
660 timeout: Optional[float],
661 metadata: Optional[Metadata],
662 credentials: Optional[grpc.CallCredentials],
663 wait_for_ready: Optional[bool],
664 request: RequestType,
665 request_serializer: Optional[SerializingFunction],
666 response_deserializer: Optional[DeserializingFunction],
667 ) -> Union[UnaryUnaryCall, UnaryUnaryCallResponse]:
668 """Run the RPC call wrapped in interceptors"""
669
670 async def _run_interceptor(
671 interceptors: List[UnaryUnaryClientInterceptor],
672 client_call_details: ClientCallDetails,
673 request: RequestType,
674 ) -> Union[UnaryUnaryCall, UnaryUnaryCallResponse]:
675 if interceptors:
676 continuation = functools.partial(
677 _run_interceptor, interceptors[1:]
678 )
679 call_or_response = await interceptors[0].intercept_unary_unary(
680 continuation, client_call_details, request
681 )
682
683 if isinstance(call_or_response, _base_call.UnaryUnaryCall):
684 return call_or_response
685 return UnaryUnaryCallResponse(call_or_response)
686
687 return UnaryUnaryCall(
688 request,
689 _timeout_to_deadline(client_call_details.timeout),
690 client_call_details.metadata,
691 client_call_details.credentials,
692 client_call_details.wait_for_ready,
693 self._channel,
694 client_call_details.method,
695 request_serializer,
696 response_deserializer,
697 self._loop,
698 )
699
700 client_call_details = ClientCallDetails(
701 method, timeout, metadata, credentials, wait_for_ready
702 )
703 return await _run_interceptor(
704 list(interceptors), client_call_details, request
705 )
706
707 def time_remaining(self) -> Optional[float]:
708 raise NotImplementedError()
709
710
711class InterceptedUnaryStreamCall(
712 _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall
713):
714 """Used for running a `UnaryStreamCall` wrapped by interceptors."""
715
716 _loop: asyncio.AbstractEventLoop
717 _channel: cygrpc.AioChannel
718 _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
719
720 # pylint: disable=too-many-arguments
721 def __init__(
722 self,
723 interceptors: Sequence[UnaryStreamClientInterceptor],
724 request: RequestType,
725 timeout: Optional[float],
726 metadata: Metadata,
727 credentials: Optional[grpc.CallCredentials],
728 wait_for_ready: Optional[bool],
729 channel: cygrpc.AioChannel,
730 method: bytes,
731 request_serializer: Optional[SerializingFunction],
732 response_deserializer: Optional[DeserializingFunction],
733 loop: asyncio.AbstractEventLoop,
734 ) -> None:
735 self._loop = loop
736 self._channel = channel
737 self._init_stream_response_mixin()
738 self._last_returned_call_from_interceptors = None
739 interceptors_task = loop.create_task(
740 self._invoke(
741 interceptors,
742 method,
743 timeout,
744 metadata,
745 credentials,
746 wait_for_ready,
747 request,
748 request_serializer,
749 response_deserializer,
750 )
751 )
752 super().__init__(interceptors_task)
753
754 # pylint: disable=too-many-arguments
755 async def _invoke(
756 self,
757 interceptors: Sequence[UnaryStreamClientInterceptor],
758 method: bytes,
759 timeout: Optional[float],
760 metadata: Optional[Metadata],
761 credentials: Optional[grpc.CallCredentials],
762 wait_for_ready: Optional[bool],
763 request: RequestType,
764 request_serializer: Optional[SerializingFunction],
765 response_deserializer: Optional[DeserializingFunction],
766 ) -> Union[UnaryStreamCall, UnaryStreamCallResponseIterator]:
767 """Run the RPC call wrapped in interceptors"""
768
769 async def _run_interceptor(
770 interceptors: List[UnaryStreamClientInterceptor],
771 client_call_details: ClientCallDetails,
772 request: RequestType,
773 ) -> Union[UnaryStreamCall, UnaryStreamCallResponseIterator]:
774 if interceptors:
775 continuation = functools.partial(
776 _run_interceptor, interceptors[1:]
777 )
778
779 call_or_response_iterator = await interceptors[
780 0
781 ].intercept_unary_stream(
782 continuation, client_call_details, request
783 )
784
785 if isinstance(
786 call_or_response_iterator, _base_call.UnaryStreamCall
787 ):
788 self._last_returned_call_from_interceptors = (
789 call_or_response_iterator
790 )
791 else:
792 self._last_returned_call_from_interceptors = (
793 UnaryStreamCallResponseIterator(
794 self._last_returned_call_from_interceptors,
795 call_or_response_iterator,
796 )
797 )
798 return self._last_returned_call_from_interceptors
799 self._last_returned_call_from_interceptors = UnaryStreamCall(
800 request,
801 _timeout_to_deadline(client_call_details.timeout),
802 client_call_details.metadata,
803 client_call_details.credentials,
804 client_call_details.wait_for_ready,
805 self._channel,
806 client_call_details.method,
807 request_serializer,
808 response_deserializer,
809 self._loop,
810 )
811
812 return self._last_returned_call_from_interceptors
813
814 client_call_details = ClientCallDetails(
815 method, timeout, metadata, credentials, wait_for_ready
816 )
817 return await _run_interceptor(
818 list(interceptors), client_call_details, request
819 )
820
821 def time_remaining(self) -> Optional[float]:
822 raise NotImplementedError()
823
824
825class InterceptedStreamUnaryCall(
826 _InterceptedUnaryResponseMixin,
827 _InterceptedStreamRequestMixin,
828 InterceptedCall,
829 _base_call.StreamUnaryCall,
830):
831 """Used for running a `StreamUnaryCall` wrapped by interceptors.
832
833 For the `__await__` method is it is proxied to the intercepted call only when
834 the interceptor task is finished.
835 """
836
837 _loop: asyncio.AbstractEventLoop
838 _channel: cygrpc.AioChannel
839
840 # pylint: disable=too-many-arguments
841 def __init__(
842 self,
843 interceptors: Sequence[StreamUnaryClientInterceptor],
844 request_iterator: Optional[RequestIterableType],
845 timeout: Optional[float],
846 metadata: Metadata,
847 credentials: Optional[grpc.CallCredentials],
848 wait_for_ready: Optional[bool],
849 channel: cygrpc.AioChannel,
850 method: bytes,
851 request_serializer: Optional[SerializingFunction],
852 response_deserializer: Optional[DeserializingFunction],
853 loop: asyncio.AbstractEventLoop,
854 ) -> None:
855 self._loop = loop
856 self._channel = channel
857 request_iterator = self._init_stream_request_mixin(request_iterator)
858 interceptors_task = loop.create_task(
859 self._invoke(
860 interceptors,
861 method,
862 timeout,
863 metadata,
864 credentials,
865 wait_for_ready,
866 request_iterator,
867 request_serializer,
868 response_deserializer,
869 )
870 )
871 super().__init__(interceptors_task)
872
873 # pylint: disable=too-many-arguments
874 async def _invoke(
875 self,
876 interceptors: Sequence[StreamUnaryClientInterceptor],
877 method: bytes,
878 timeout: Optional[float],
879 metadata: Optional[Metadata],
880 credentials: Optional[grpc.CallCredentials],
881 wait_for_ready: Optional[bool],
882 request_iterator: RequestIterableType,
883 request_serializer: Optional[SerializingFunction],
884 response_deserializer: Optional[DeserializingFunction],
885 ) -> StreamUnaryCall:
886 """Run the RPC call wrapped in interceptors"""
887
888 async def _run_interceptor(
889 interceptors: Sequence[StreamUnaryClientInterceptor],
890 client_call_details: ClientCallDetails,
891 request_iterator: RequestIterableType,
892 ) -> _base_call.StreamUnaryCall:
893 if interceptors:
894 continuation = functools.partial(
895 _run_interceptor, interceptors[1:]
896 )
897
898 return await interceptors[0].intercept_stream_unary(
899 continuation, client_call_details, request_iterator
900 )
901 return StreamUnaryCall(
902 request_iterator,
903 _timeout_to_deadline(client_call_details.timeout),
904 client_call_details.metadata,
905 client_call_details.credentials,
906 client_call_details.wait_for_ready,
907 self._channel,
908 client_call_details.method,
909 request_serializer,
910 response_deserializer,
911 self._loop,
912 )
913
914 client_call_details = ClientCallDetails(
915 method, timeout, metadata, credentials, wait_for_ready
916 )
917 return await _run_interceptor(
918 list(interceptors), client_call_details, request_iterator
919 )
920
921 def time_remaining(self) -> Optional[float]:
922 raise NotImplementedError()
923
924
925class InterceptedStreamStreamCall(
926 _InterceptedStreamResponseMixin,
927 _InterceptedStreamRequestMixin,
928 InterceptedCall,
929 _base_call.StreamStreamCall,
930):
931 """Used for running a `StreamStreamCall` wrapped by interceptors."""
932
933 _loop: asyncio.AbstractEventLoop
934 _channel: cygrpc.AioChannel
935 _last_returned_call_from_interceptors = Optional[
936 _base_call.StreamStreamCall
937 ]
938
939 # pylint: disable=too-many-arguments
940 def __init__(
941 self,
942 interceptors: Sequence[StreamStreamClientInterceptor],
943 request_iterator: Optional[RequestIterableType],
944 timeout: Optional[float],
945 metadata: Metadata,
946 credentials: Optional[grpc.CallCredentials],
947 wait_for_ready: Optional[bool],
948 channel: cygrpc.AioChannel,
949 method: bytes,
950 request_serializer: Optional[SerializingFunction],
951 response_deserializer: Optional[DeserializingFunction],
952 loop: asyncio.AbstractEventLoop,
953 ) -> None:
954 self._loop = loop
955 self._channel = channel
956 self._init_stream_response_mixin()
957 request_iterator = self._init_stream_request_mixin(request_iterator)
958 self._last_returned_call_from_interceptors = None
959 interceptors_task = loop.create_task(
960 self._invoke(
961 interceptors,
962 method,
963 timeout,
964 metadata,
965 credentials,
966 wait_for_ready,
967 request_iterator,
968 request_serializer,
969 response_deserializer,
970 )
971 )
972 super().__init__(interceptors_task)
973
974 # pylint: disable=too-many-arguments
975 async def _invoke(
976 self,
977 interceptors: Sequence[StreamStreamClientInterceptor],
978 method: bytes,
979 timeout: Optional[float],
980 metadata: Optional[Metadata],
981 credentials: Optional[grpc.CallCredentials],
982 wait_for_ready: Optional[bool],
983 request_iterator: RequestIterableType,
984 request_serializer: Optional[SerializingFunction],
985 response_deserializer: Optional[DeserializingFunction],
986 ) -> Union[StreamStreamCall, StreamStreamCallResponseIterator]:
987 """Run the RPC call wrapped in interceptors"""
988
989 async def _run_interceptor(
990 interceptors: List[StreamStreamClientInterceptor],
991 client_call_details: ClientCallDetails,
992 request_iterator: RequestIterableType,
993 ) -> Union[StreamStreamCall, StreamStreamCallResponseIterator]:
994 if interceptors:
995 continuation = functools.partial(
996 _run_interceptor, interceptors[1:]
997 )
998
999 call_or_response_iterator = await interceptors[
1000 0
1001 ].intercept_stream_stream(
1002 continuation, client_call_details, request_iterator
1003 )
1004
1005 if isinstance(
1006 call_or_response_iterator, _base_call.StreamStreamCall
1007 ):
1008 self._last_returned_call_from_interceptors = (
1009 call_or_response_iterator
1010 )
1011 else:
1012 self._last_returned_call_from_interceptors = (
1013 StreamStreamCallResponseIterator(
1014 self._last_returned_call_from_interceptors,
1015 call_or_response_iterator,
1016 )
1017 )
1018 return self._last_returned_call_from_interceptors
1019 self._last_returned_call_from_interceptors = StreamStreamCall(
1020 request_iterator,
1021 _timeout_to_deadline(client_call_details.timeout),
1022 client_call_details.metadata,
1023 client_call_details.credentials,
1024 client_call_details.wait_for_ready,
1025 self._channel,
1026 client_call_details.method,
1027 request_serializer,
1028 response_deserializer,
1029 self._loop,
1030 )
1031 return self._last_returned_call_from_interceptors
1032
1033 client_call_details = ClientCallDetails(
1034 method, timeout, metadata, credentials, wait_for_ready
1035 )
1036 return await _run_interceptor(
1037 list(interceptors), client_call_details, request_iterator
1038 )
1039
1040 def time_remaining(self) -> Optional[float]:
1041 raise NotImplementedError()
1042
1043
1044class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
1045 """Final UnaryUnaryCall class finished with a response."""
1046
1047 _response: ResponseType
1048
1049 def __init__(self, response: ResponseType) -> None:
1050 self._response = response
1051
1052 def cancel(self) -> bool:
1053 return False
1054
1055 def cancelled(self) -> bool:
1056 return False
1057
1058 def done(self) -> bool:
1059 return True
1060
1061 def add_done_callback(self, unused_callback) -> None:
1062 raise NotImplementedError()
1063
1064 def time_remaining(self) -> Optional[float]:
1065 raise NotImplementedError()
1066
1067 async def initial_metadata(self) -> Optional[Metadata]:
1068 return None
1069
1070 async def trailing_metadata(self) -> Optional[Metadata]:
1071 return None
1072
1073 async def code(self) -> grpc.StatusCode:
1074 return grpc.StatusCode.OK
1075
1076 async def details(self) -> str:
1077 return ""
1078
1079 async def debug_error_string(self) -> Optional[str]:
1080 return None
1081
1082 def __await__(self):
1083 if False: # pylint: disable=using-constant-test
1084 # This code path is never used, but a yield statement is needed
1085 # for telling the interpreter that __await__ is a generator.
1086 yield None
1087 return self._response
1088
1089 async def wait_for_connection(self) -> None:
1090 pass
1091
1092
1093class _StreamCallResponseIterator:
1094 _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
1095 _response_iterator: AsyncIterable[ResponseType]
1096
1097 def __init__(
1098 self,
1099 call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall],
1100 response_iterator: AsyncIterable[ResponseType],
1101 ) -> None:
1102 self._response_iterator = response_iterator
1103 self._call = call
1104
1105 def cancel(self) -> bool:
1106 return self._call.cancel()
1107
1108 def cancelled(self) -> bool:
1109 return self._call.cancelled()
1110
1111 def done(self) -> bool:
1112 return self._call.done()
1113
1114 def add_done_callback(self, callback) -> None:
1115 self._call.add_done_callback(callback)
1116
1117 def time_remaining(self) -> Optional[float]:
1118 return self._call.time_remaining()
1119
1120 async def initial_metadata(self) -> Optional[Metadata]:
1121 return await self._call.initial_metadata()
1122
1123 async def trailing_metadata(self) -> Optional[Metadata]:
1124 return await self._call.trailing_metadata()
1125
1126 async def code(self) -> grpc.StatusCode:
1127 return await self._call.code()
1128
1129 async def details(self) -> str:
1130 return await self._call.details()
1131
1132 async def debug_error_string(self) -> Optional[str]:
1133 return await self._call.debug_error_string()
1134
1135 def __aiter__(self):
1136 return self._response_iterator.__aiter__()
1137
1138 async def wait_for_connection(self) -> None:
1139 return await self._call.wait_for_connection()
1140
1141
1142class UnaryStreamCallResponseIterator(
1143 _StreamCallResponseIterator, _base_call.UnaryStreamCall
1144):
1145 """UnaryStreamCall class which uses an alternative response iterator."""
1146
1147 async def read(self) -> Union[EOFType, ResponseType]:
1148 # Behind the scenes everything goes through the
1149 # async iterator. So this path should not be reached.
1150 raise NotImplementedError()
1151
1152
1153class StreamStreamCallResponseIterator(
1154 _StreamCallResponseIterator, _base_call.StreamStreamCall
1155):
1156 """StreamStreamCall class which uses an alternative response iterator."""
1157
1158 async def read(self) -> Union[EOFType, ResponseType]:
1159 # Behind the scenes everything goes through the
1160 # async iterator. So this path should not be reached.
1161 raise NotImplementedError()
1162
1163 async def write(self, request: RequestType) -> None:
1164 # Behind the scenes everything goes through the
1165 # async iterator provided by the InterceptedStreamStreamCall.
1166 # So this path should not be reached.
1167 raise NotImplementedError()
1168
1169 async def done_writing(self) -> None:
1170 # Behind the scenes everything goes through the
1171 # async iterator provided by the InterceptedStreamStreamCall.
1172 # So this path should not be reached.
1173 raise NotImplementedError()
1174
1175 @property
1176 def _done_writing_flag(self) -> bool:
1177 return self._call._done_writing_flag