1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementation of gRPC Python interceptors."""
15
16import collections
17import sys
18import types
19from typing import Any, Callable, Optional, Sequence, Tuple, Union
20
21import grpc
22
23from ._typing import DeserializingFunction
24from ._typing import DoneCallbackType
25from ._typing import MetadataType
26from ._typing import RequestIterableType
27from ._typing import SerializingFunction
28
29
30class _ServicePipeline(object):
31 interceptors: Tuple[grpc.ServerInterceptor]
32
33 def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
34 self.interceptors = tuple(interceptors)
35
36 def _continuation(self, thunk: Callable, index: int) -> Callable:
37 return lambda context: self._intercept_at(thunk, index, context)
38
39 def _intercept_at(
40 self, thunk: Callable, index: int, context: grpc.HandlerCallDetails
41 ) -> grpc.RpcMethodHandler:
42 if index < len(self.interceptors):
43 interceptor = self.interceptors[index]
44 thunk = self._continuation(thunk, index + 1)
45 return interceptor.intercept_service(thunk, context)
46 else:
47 return thunk(context)
48
49 def execute(
50 self, thunk: Callable, context: grpc.HandlerCallDetails
51 ) -> grpc.RpcMethodHandler:
52 return self._intercept_at(thunk, 0, context)
53
54
55def service_pipeline(
56 interceptors: Optional[Sequence[grpc.ServerInterceptor]],
57) -> Optional[_ServicePipeline]:
58 return _ServicePipeline(interceptors) if interceptors else None
59
60
61class _ClientCallDetails(
62 collections.namedtuple(
63 "_ClientCallDetails",
64 (
65 "method",
66 "timeout",
67 "metadata",
68 "credentials",
69 "wait_for_ready",
70 "compression",
71 ),
72 ),
73 grpc.ClientCallDetails,
74):
75 pass
76
77
78def _unwrap_client_call_details(
79 call_details: grpc.ClientCallDetails,
80 default_details: grpc.ClientCallDetails,
81) -> Tuple[
82 str, float, MetadataType, grpc.CallCredentials, bool, grpc.Compression
83]:
84 try:
85 method = call_details.method # pytype: disable=attribute-error
86 except AttributeError:
87 method = default_details.method # pytype: disable=attribute-error
88
89 try:
90 timeout = call_details.timeout # pytype: disable=attribute-error
91 except AttributeError:
92 timeout = default_details.timeout # pytype: disable=attribute-error
93
94 try:
95 metadata = call_details.metadata # pytype: disable=attribute-error
96 except AttributeError:
97 metadata = default_details.metadata # pytype: disable=attribute-error
98
99 try:
100 credentials = (
101 call_details.credentials
102 ) # pytype: disable=attribute-error
103 except AttributeError:
104 credentials = (
105 default_details.credentials
106 ) # pytype: disable=attribute-error
107
108 try:
109 wait_for_ready = (
110 call_details.wait_for_ready
111 ) # pytype: disable=attribute-error
112 except AttributeError:
113 wait_for_ready = (
114 default_details.wait_for_ready
115 ) # pytype: disable=attribute-error
116
117 try:
118 compression = (
119 call_details.compression
120 ) # pytype: disable=attribute-error
121 except AttributeError:
122 compression = (
123 default_details.compression
124 ) # pytype: disable=attribute-error
125
126 return method, timeout, metadata, credentials, wait_for_ready, compression
127
128
129class _FailureOutcome(
130 grpc.RpcError, grpc.Future, grpc.Call
131): # pylint: disable=too-many-ancestors
132 _exception: Exception
133 _traceback: types.TracebackType
134
135 def __init__(self, exception: Exception, traceback: types.TracebackType):
136 super(_FailureOutcome, self).__init__()
137 self._exception = exception
138 self._traceback = traceback
139
140 def initial_metadata(self) -> Optional[MetadataType]:
141 return None
142
143 def trailing_metadata(self) -> Optional[MetadataType]:
144 return None
145
146 def code(self) -> Optional[grpc.StatusCode]:
147 return grpc.StatusCode.INTERNAL
148
149 def details(self) -> Optional[str]:
150 return "Exception raised while intercepting the RPC"
151
152 def cancel(self) -> bool:
153 return False
154
155 def cancelled(self) -> bool:
156 return False
157
158 def is_active(self) -> bool:
159 return False
160
161 def time_remaining(self) -> Optional[float]:
162 return None
163
164 def running(self) -> bool:
165 return False
166
167 def done(self) -> bool:
168 return True
169
170 def result(self, ignored_timeout: Optional[float] = None):
171 raise self._exception
172
173 def exception(
174 self, ignored_timeout: Optional[float] = None
175 ) -> Optional[Exception]:
176 return self._exception
177
178 def traceback(
179 self, ignored_timeout: Optional[float] = None
180 ) -> Optional[types.TracebackType]:
181 return self._traceback
182
183 def add_callback(self, unused_callback) -> bool:
184 return False
185
186 def add_done_callback(self, fn: DoneCallbackType) -> None:
187 fn(self)
188
189 def __iter__(self):
190 return self
191
192 def __next__(self):
193 raise self._exception
194
195 def next(self):
196 return self.__next__()
197
198
199class _UnaryOutcome(grpc.Call, grpc.Future):
200 _response: Any
201 _call: grpc.Call
202
203 def __init__(self, response: Any, call: grpc.Call):
204 self._response = response
205 self._call = call
206
207 def initial_metadata(self) -> Optional[MetadataType]:
208 return self._call.initial_metadata()
209
210 def trailing_metadata(self) -> Optional[MetadataType]:
211 return self._call.trailing_metadata()
212
213 def code(self) -> Optional[grpc.StatusCode]:
214 return self._call.code()
215
216 def details(self) -> Optional[str]:
217 return self._call.details()
218
219 def is_active(self) -> bool:
220 return self._call.is_active()
221
222 def time_remaining(self) -> Optional[float]:
223 return self._call.time_remaining()
224
225 def cancel(self) -> bool:
226 return self._call.cancel()
227
228 def add_callback(self, callback) -> bool:
229 return self._call.add_callback(callback)
230
231 def cancelled(self) -> bool:
232 return False
233
234 def running(self) -> bool:
235 return False
236
237 def done(self) -> bool:
238 return True
239
240 def result(self, ignored_timeout: Optional[float] = None):
241 return self._response
242
243 def exception(self, ignored_timeout: Optional[float] = None):
244 return None
245
246 def traceback(self, ignored_timeout: Optional[float] = None):
247 return None
248
249 def add_done_callback(self, fn: DoneCallbackType) -> None:
250 fn(self)
251
252
253class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
254 _thunk: Callable
255 _method: str
256 _interceptor: grpc.UnaryUnaryClientInterceptor
257
258 def __init__(
259 self,
260 thunk: Callable,
261 method: str,
262 interceptor: grpc.UnaryUnaryClientInterceptor,
263 ):
264 self._thunk = thunk
265 self._method = method
266 self._interceptor = interceptor
267
268 def __call__(
269 self,
270 request: Any,
271 timeout: Optional[float] = None,
272 metadata: Optional[MetadataType] = None,
273 credentials: Optional[grpc.CallCredentials] = None,
274 wait_for_ready: Optional[bool] = None,
275 compression: Optional[grpc.Compression] = None,
276 ) -> Any:
277 response, ignored_call = self._with_call(
278 request,
279 timeout=timeout,
280 metadata=metadata,
281 credentials=credentials,
282 wait_for_ready=wait_for_ready,
283 compression=compression,
284 )
285 return response
286
287 def _with_call(
288 self,
289 request: Any,
290 timeout: Optional[float] = None,
291 metadata: Optional[MetadataType] = None,
292 credentials: Optional[grpc.CallCredentials] = None,
293 wait_for_ready: Optional[bool] = None,
294 compression: Optional[grpc.Compression] = None,
295 ) -> Tuple[Any, grpc.Call]:
296 client_call_details = _ClientCallDetails(
297 self._method,
298 timeout,
299 metadata,
300 credentials,
301 wait_for_ready,
302 compression,
303 )
304
305 def continuation(new_details, request):
306 (
307 new_method,
308 new_timeout,
309 new_metadata,
310 new_credentials,
311 new_wait_for_ready,
312 new_compression,
313 ) = _unwrap_client_call_details(new_details, client_call_details)
314 try:
315 response, call = self._thunk(new_method).with_call(
316 request,
317 timeout=new_timeout,
318 metadata=new_metadata,
319 credentials=new_credentials,
320 wait_for_ready=new_wait_for_ready,
321 compression=new_compression,
322 )
323 return _UnaryOutcome(response, call)
324 except grpc.RpcError as rpc_error:
325 return rpc_error
326 except Exception as exception: # pylint:disable=broad-except
327 return _FailureOutcome(exception, sys.exc_info()[2])
328
329 call = self._interceptor.intercept_unary_unary(
330 continuation, client_call_details, request
331 )
332 return call.result(), call
333
334 def with_call(
335 self,
336 request: Any,
337 timeout: Optional[float] = None,
338 metadata: Optional[MetadataType] = None,
339 credentials: Optional[grpc.CallCredentials] = None,
340 wait_for_ready: Optional[bool] = None,
341 compression: Optional[grpc.Compression] = None,
342 ) -> Tuple[Any, grpc.Call]:
343 return self._with_call(
344 request,
345 timeout=timeout,
346 metadata=metadata,
347 credentials=credentials,
348 wait_for_ready=wait_for_ready,
349 compression=compression,
350 )
351
352 def future(
353 self,
354 request: Any,
355 timeout: Optional[float] = None,
356 metadata: Optional[MetadataType] = None,
357 credentials: Optional[grpc.CallCredentials] = None,
358 wait_for_ready: Optional[bool] = None,
359 compression: Optional[grpc.Compression] = None,
360 ) -> Any:
361 client_call_details = _ClientCallDetails(
362 self._method,
363 timeout,
364 metadata,
365 credentials,
366 wait_for_ready,
367 compression,
368 )
369
370 def continuation(new_details, request):
371 (
372 new_method,
373 new_timeout,
374 new_metadata,
375 new_credentials,
376 new_wait_for_ready,
377 new_compression,
378 ) = _unwrap_client_call_details(new_details, client_call_details)
379 return self._thunk(new_method).future(
380 request,
381 timeout=new_timeout,
382 metadata=new_metadata,
383 credentials=new_credentials,
384 wait_for_ready=new_wait_for_ready,
385 compression=new_compression,
386 )
387
388 try:
389 return self._interceptor.intercept_unary_unary(
390 continuation, client_call_details, request
391 )
392 except Exception as exception: # pylint:disable=broad-except
393 return _FailureOutcome(exception, sys.exc_info()[2])
394
395
396class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
397 _thunk: Callable
398 _method: str
399 _interceptor: grpc.UnaryStreamClientInterceptor
400
401 def __init__(
402 self,
403 thunk: Callable,
404 method: str,
405 interceptor: grpc.UnaryStreamClientInterceptor,
406 ):
407 self._thunk = thunk
408 self._method = method
409 self._interceptor = interceptor
410
411 def __call__(
412 self,
413 request: Any,
414 timeout: Optional[float] = None,
415 metadata: Optional[MetadataType] = None,
416 credentials: Optional[grpc.CallCredentials] = None,
417 wait_for_ready: Optional[bool] = None,
418 compression: Optional[grpc.Compression] = None,
419 ):
420 client_call_details = _ClientCallDetails(
421 self._method,
422 timeout,
423 metadata,
424 credentials,
425 wait_for_ready,
426 compression,
427 )
428
429 def continuation(new_details, request):
430 (
431 new_method,
432 new_timeout,
433 new_metadata,
434 new_credentials,
435 new_wait_for_ready,
436 new_compression,
437 ) = _unwrap_client_call_details(new_details, client_call_details)
438 return self._thunk(new_method)(
439 request,
440 timeout=new_timeout,
441 metadata=new_metadata,
442 credentials=new_credentials,
443 wait_for_ready=new_wait_for_ready,
444 compression=new_compression,
445 )
446
447 try:
448 return self._interceptor.intercept_unary_stream(
449 continuation, client_call_details, request
450 )
451 except Exception as exception: # pylint:disable=broad-except
452 return _FailureOutcome(exception, sys.exc_info()[2])
453
454
455class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
456 _thunk: Callable
457 _method: str
458 _interceptor: grpc.StreamUnaryClientInterceptor
459
460 def __init__(
461 self,
462 thunk: Callable,
463 method: str,
464 interceptor: grpc.StreamUnaryClientInterceptor,
465 ):
466 self._thunk = thunk
467 self._method = method
468 self._interceptor = interceptor
469
470 def __call__(
471 self,
472 request_iterator: RequestIterableType,
473 timeout: Optional[float] = None,
474 metadata: Optional[MetadataType] = None,
475 credentials: Optional[grpc.CallCredentials] = None,
476 wait_for_ready: Optional[bool] = None,
477 compression: Optional[grpc.Compression] = None,
478 ) -> Any:
479 response, ignored_call = self._with_call(
480 request_iterator,
481 timeout=timeout,
482 metadata=metadata,
483 credentials=credentials,
484 wait_for_ready=wait_for_ready,
485 compression=compression,
486 )
487 return response
488
489 def _with_call(
490 self,
491 request_iterator: RequestIterableType,
492 timeout: Optional[float] = None,
493 metadata: Optional[MetadataType] = None,
494 credentials: Optional[grpc.CallCredentials] = None,
495 wait_for_ready: Optional[bool] = None,
496 compression: Optional[grpc.Compression] = None,
497 ) -> Tuple[Any, grpc.Call]:
498 client_call_details = _ClientCallDetails(
499 self._method,
500 timeout,
501 metadata,
502 credentials,
503 wait_for_ready,
504 compression,
505 )
506
507 def continuation(new_details, request_iterator):
508 (
509 new_method,
510 new_timeout,
511 new_metadata,
512 new_credentials,
513 new_wait_for_ready,
514 new_compression,
515 ) = _unwrap_client_call_details(new_details, client_call_details)
516 try:
517 response, call = self._thunk(new_method).with_call(
518 request_iterator,
519 timeout=new_timeout,
520 metadata=new_metadata,
521 credentials=new_credentials,
522 wait_for_ready=new_wait_for_ready,
523 compression=new_compression,
524 )
525 return _UnaryOutcome(response, call)
526 except grpc.RpcError as rpc_error:
527 return rpc_error
528 except Exception as exception: # pylint:disable=broad-except
529 return _FailureOutcome(exception, sys.exc_info()[2])
530
531 call = self._interceptor.intercept_stream_unary(
532 continuation, client_call_details, request_iterator
533 )
534 return call.result(), call
535
536 def with_call(
537 self,
538 request_iterator: RequestIterableType,
539 timeout: Optional[float] = None,
540 metadata: Optional[MetadataType] = None,
541 credentials: Optional[grpc.CallCredentials] = None,
542 wait_for_ready: Optional[bool] = None,
543 compression: Optional[grpc.Compression] = None,
544 ) -> Tuple[Any, grpc.Call]:
545 return self._with_call(
546 request_iterator,
547 timeout=timeout,
548 metadata=metadata,
549 credentials=credentials,
550 wait_for_ready=wait_for_ready,
551 compression=compression,
552 )
553
554 def future(
555 self,
556 request_iterator: RequestIterableType,
557 timeout: Optional[float] = None,
558 metadata: Optional[MetadataType] = None,
559 credentials: Optional[grpc.CallCredentials] = None,
560 wait_for_ready: Optional[bool] = None,
561 compression: Optional[grpc.Compression] = None,
562 ) -> Any:
563 client_call_details = _ClientCallDetails(
564 self._method,
565 timeout,
566 metadata,
567 credentials,
568 wait_for_ready,
569 compression,
570 )
571
572 def continuation(new_details, request_iterator):
573 (
574 new_method,
575 new_timeout,
576 new_metadata,
577 new_credentials,
578 new_wait_for_ready,
579 new_compression,
580 ) = _unwrap_client_call_details(new_details, client_call_details)
581 return self._thunk(new_method).future(
582 request_iterator,
583 timeout=new_timeout,
584 metadata=new_metadata,
585 credentials=new_credentials,
586 wait_for_ready=new_wait_for_ready,
587 compression=new_compression,
588 )
589
590 try:
591 return self._interceptor.intercept_stream_unary(
592 continuation, client_call_details, request_iterator
593 )
594 except Exception as exception: # pylint:disable=broad-except
595 return _FailureOutcome(exception, sys.exc_info()[2])
596
597
598class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
599 _thunk: Callable
600 _method: str
601 _interceptor: grpc.StreamStreamClientInterceptor
602
603 def __init__(
604 self,
605 thunk: Callable,
606 method: str,
607 interceptor: grpc.StreamStreamClientInterceptor,
608 ):
609 self._thunk = thunk
610 self._method = method
611 self._interceptor = interceptor
612
613 def __call__(
614 self,
615 request_iterator: RequestIterableType,
616 timeout: Optional[float] = None,
617 metadata: Optional[MetadataType] = None,
618 credentials: Optional[grpc.CallCredentials] = None,
619 wait_for_ready: Optional[bool] = None,
620 compression: Optional[grpc.Compression] = None,
621 ):
622 client_call_details = _ClientCallDetails(
623 self._method,
624 timeout,
625 metadata,
626 credentials,
627 wait_for_ready,
628 compression,
629 )
630
631 def continuation(new_details, request_iterator):
632 (
633 new_method,
634 new_timeout,
635 new_metadata,
636 new_credentials,
637 new_wait_for_ready,
638 new_compression,
639 ) = _unwrap_client_call_details(new_details, client_call_details)
640 return self._thunk(new_method)(
641 request_iterator,
642 timeout=new_timeout,
643 metadata=new_metadata,
644 credentials=new_credentials,
645 wait_for_ready=new_wait_for_ready,
646 compression=new_compression,
647 )
648
649 try:
650 return self._interceptor.intercept_stream_stream(
651 continuation, client_call_details, request_iterator
652 )
653 except Exception as exception: # pylint:disable=broad-except
654 return _FailureOutcome(exception, sys.exc_info()[2])
655
656
657class _Channel(grpc.Channel):
658 _channel: grpc.Channel
659 _interceptor: Union[
660 grpc.UnaryUnaryClientInterceptor,
661 grpc.UnaryStreamClientInterceptor,
662 grpc.StreamStreamClientInterceptor,
663 grpc.StreamUnaryClientInterceptor,
664 ]
665
666 def __init__(
667 self,
668 channel: grpc.Channel,
669 interceptor: Union[
670 grpc.UnaryUnaryClientInterceptor,
671 grpc.UnaryStreamClientInterceptor,
672 grpc.StreamStreamClientInterceptor,
673 grpc.StreamUnaryClientInterceptor,
674 ],
675 ):
676 self._channel = channel
677 self._interceptor = interceptor
678
679 def subscribe(
680 self, callback: Callable, try_to_connect: Optional[bool] = False
681 ):
682 self._channel.subscribe(callback, try_to_connect=try_to_connect)
683
684 def unsubscribe(self, callback: Callable):
685 self._channel.unsubscribe(callback)
686
687 # pylint: disable=arguments-differ
688 def unary_unary(
689 self,
690 method: str,
691 request_serializer: Optional[SerializingFunction] = None,
692 response_deserializer: Optional[DeserializingFunction] = None,
693 _registered_method: Optional[bool] = False,
694 ) -> grpc.UnaryUnaryMultiCallable:
695 # pytype: disable=wrong-arg-count
696 thunk = lambda m: self._channel.unary_unary(
697 m,
698 request_serializer,
699 response_deserializer,
700 _registered_method,
701 )
702 # pytype: enable=wrong-arg-count
703 if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
704 return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
705 else:
706 return thunk(method)
707
708 # pylint: disable=arguments-differ
709 def unary_stream(
710 self,
711 method: str,
712 request_serializer: Optional[SerializingFunction] = None,
713 response_deserializer: Optional[DeserializingFunction] = None,
714 _registered_method: Optional[bool] = False,
715 ) -> grpc.UnaryStreamMultiCallable:
716 # pytype: disable=wrong-arg-count
717 thunk = lambda m: self._channel.unary_stream(
718 m,
719 request_serializer,
720 response_deserializer,
721 _registered_method,
722 )
723 # pytype: enable=wrong-arg-count
724 if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
725 return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
726 else:
727 return thunk(method)
728
729 # pylint: disable=arguments-differ
730 def stream_unary(
731 self,
732 method: str,
733 request_serializer: Optional[SerializingFunction] = None,
734 response_deserializer: Optional[DeserializingFunction] = None,
735 _registered_method: Optional[bool] = False,
736 ) -> grpc.StreamUnaryMultiCallable:
737 # pytype: disable=wrong-arg-count
738 thunk = lambda m: self._channel.stream_unary(
739 m,
740 request_serializer,
741 response_deserializer,
742 _registered_method,
743 )
744 # pytype: enable=wrong-arg-count
745 if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
746 return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
747 else:
748 return thunk(method)
749
750 # pylint: disable=arguments-differ
751 def stream_stream(
752 self,
753 method: str,
754 request_serializer: Optional[SerializingFunction] = None,
755 response_deserializer: Optional[DeserializingFunction] = None,
756 _registered_method: Optional[bool] = False,
757 ) -> grpc.StreamStreamMultiCallable:
758 # pytype: disable=wrong-arg-count
759 thunk = lambda m: self._channel.stream_stream(
760 m,
761 request_serializer,
762 response_deserializer,
763 _registered_method,
764 )
765 # pytype: enable=wrong-arg-count
766 if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
767 return _StreamStreamMultiCallable(thunk, method, self._interceptor)
768 else:
769 return thunk(method)
770
771 def _close(self):
772 self._channel.close()
773
774 def __enter__(self):
775 return self
776
777 def __exit__(self, exc_type, exc_val, exc_tb):
778 self._close()
779 return False
780
781 def close(self):
782 self._channel.close()
783
784
785def intercept_channel(
786 channel: grpc.Channel,
787 *interceptors: Optional[
788 Sequence[
789 Union[
790 grpc.UnaryUnaryClientInterceptor,
791 grpc.UnaryStreamClientInterceptor,
792 grpc.StreamStreamClientInterceptor,
793 grpc.StreamUnaryClientInterceptor,
794 ]
795 ]
796 ],
797) -> grpc.Channel:
798 for interceptor in reversed(list(interceptors)):
799 if (
800 not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor)
801 and not isinstance(interceptor, grpc.UnaryStreamClientInterceptor)
802 and not isinstance(interceptor, grpc.StreamUnaryClientInterceptor)
803 and not isinstance(interceptor, grpc.StreamStreamClientInterceptor)
804 ):
805 raise TypeError(
806 "interceptor must be "
807 "grpc.UnaryUnaryClientInterceptor or "
808 "grpc.UnaryStreamClientInterceptor or "
809 "grpc.StreamUnaryClientInterceptor or "
810 "grpc.StreamStreamClientInterceptor or "
811 )
812 channel = _Channel(channel, interceptor)
813 return channel