Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/grpc/_server.py: 37%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Service-side implementation of gRPC Python."""
16from __future__ import annotations
18import abc
19import collections
20from concurrent import futures
21import contextvars
22import enum
23import logging
24import threading
25import time
26import traceback
27from typing import (
28 Any,
29 Callable,
30 Dict,
31 Iterable,
32 Iterator,
33 List,
34 Mapping,
35 Optional,
36 Sequence,
37 Set,
38 Tuple,
39 Union,
40)
42import grpc # pytype: disable=pyi-error
43from grpc import _common # pytype: disable=pyi-error
44from grpc import _compression # pytype: disable=pyi-error
45from grpc import _interceptor # pytype: disable=pyi-error
46from grpc import _observability # pytype: disable=pyi-error
47from grpc._cython import cygrpc
48from grpc._typing import ArityAgnosticMethodHandler
49from grpc._typing import ChannelArgumentType
50from grpc._typing import DeserializingFunction
51from grpc._typing import MetadataType
52from grpc._typing import NullaryCallbackType
53from grpc._typing import ResponseType
54from grpc._typing import SerializingFunction
55from grpc._typing import ServerCallbackTag
56from grpc._typing import ServerTagCallbackType
57from typing_extensions import override
59_LOGGER = logging.getLogger(__name__)
61_SHUTDOWN_TAG = "shutdown"
62_REQUEST_CALL_TAG = "request_call"
64_RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server"
65_SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata"
66_RECEIVE_MESSAGE_TOKEN = "receive_message"
67_SEND_MESSAGE_TOKEN = "send_message"
68_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
69 "send_initial_metadata * send_message"
70)
71_SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server"
72_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
73 "send_initial_metadata * send_status_from_server"
74)
76_OPEN = "open"
77_CLOSED = "closed"
78_CANCELLED = "cancelled"
80_EMPTY_FLAGS = 0
82_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
83_INF_TIMEOUT = 1e9
86def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes:
87 return request_event.batch_operations[0].message()
90def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode:
91 cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
92 return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
95def _completion_code(state: _RPCState) -> cygrpc.StatusCode:
96 if state.code is None:
97 return cygrpc.StatusCode.ok
98 else:
99 return _application_code(state.code)
102def _abortion_code(
103 state: _RPCState, code: cygrpc.StatusCode
104) -> cygrpc.StatusCode:
105 if state.code is None:
106 return code
107 else:
108 return _application_code(state.code)
111def _details(state: _RPCState) -> bytes:
112 return b"" if state.details is None else state.details
115class _HandlerCallDetails(
116 collections.namedtuple(
117 "_HandlerCallDetails",
118 (
119 "method",
120 "invocation_metadata",
121 ),
122 ),
123 grpc.HandlerCallDetails,
124):
125 pass
128class _Method(abc.ABC):
129 @abc.abstractmethod
130 def name(self) -> Optional[str]:
131 raise NotImplementedError()
133 @abc.abstractmethod
134 def handler(
135 self, handler_call_details: _HandlerCallDetails
136 ) -> Optional[grpc.RpcMethodHandler]:
137 raise NotImplementedError()
140class _RegisteredMethod(_Method):
141 def __init__(
142 self,
143 name: str,
144 registered_handler: Optional[grpc.RpcMethodHandler],
145 ):
146 self._name = name
147 self._registered_handler = registered_handler
149 @override
150 def name(self) -> Optional[str]:
151 return self._name
153 @override
154 def handler(
155 self, handler_call_details: _HandlerCallDetails
156 ) -> Optional[grpc.RpcMethodHandler]:
157 return self._registered_handler
160class _GenericMethod(_Method):
161 def __init__(
162 self,
163 generic_handlers: List[grpc.GenericRpcHandler],
164 ):
165 self._generic_handlers = generic_handlers
167 @override
168 def name(self) -> Optional[str]:
169 return None
171 @override
172 def handler(
173 self, handler_call_details: _HandlerCallDetails
174 ) -> Optional[grpc.RpcMethodHandler]:
175 # If the same method have both generic and registered handler,
176 # registered handler will take precedence.
177 for generic_handler in self._generic_handlers:
178 method_handler = generic_handler.service(handler_call_details)
179 if method_handler is not None:
180 return method_handler
181 return None
184class _RPCState(object):
185 context: contextvars.Context
186 condition: threading.Condition
187 due = Set[str]
188 request: Any
189 client: str
190 initial_metadata_allowed: bool
191 compression_algorithm: Optional[grpc.Compression]
192 disable_next_compression: bool
193 trailing_metadata: Optional[MetadataType]
194 code: Optional[grpc.StatusCode]
195 details: Optional[bytes]
196 statused: bool
197 rpc_errors: List[Exception]
198 callbacks: Optional[List[NullaryCallbackType]]
199 aborted: bool
201 def __init__(self):
202 self.context = contextvars.Context()
203 self.condition = threading.Condition()
204 self.due = set()
205 self.request = None
206 self.client = _OPEN
207 self.initial_metadata_allowed = True
208 self.compression_algorithm = None
209 self.disable_next_compression = False
210 self.trailing_metadata = None
211 self.code = None
212 self.details = None
213 self.statused = False
214 self.rpc_errors = []
215 self.callbacks = []
216 self.aborted = False
219def _raise_rpc_error(state: _RPCState) -> None:
220 rpc_error = grpc.RpcError()
221 state.rpc_errors.append(rpc_error)
222 raise rpc_error
225def _possibly_finish_call(
226 state: _RPCState, token: str
227) -> ServerTagCallbackType:
228 state.due.remove(token)
229 if not _is_rpc_state_active(state) and not state.due:
230 callbacks = state.callbacks
231 state.callbacks = None
232 return state, callbacks
233 else:
234 return None, ()
237def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag:
238 def send_status_from_server(unused_send_status_from_server_event):
239 with state.condition:
240 return _possibly_finish_call(state, token)
242 return send_status_from_server
245def _get_initial_metadata(
246 state: _RPCState, metadata: Optional[MetadataType]
247) -> Optional[MetadataType]:
248 with state.condition:
249 if state.compression_algorithm:
250 compression_metadata = (
251 _compression.compression_algorithm_to_metadata(
252 state.compression_algorithm
253 ),
254 )
255 if metadata is None:
256 return compression_metadata
257 else:
258 return compression_metadata + tuple(metadata)
259 else:
260 return metadata
263def _get_initial_metadata_operation(
264 state: _RPCState, metadata: Optional[MetadataType]
265) -> cygrpc.Operation:
266 operation = cygrpc.SendInitialMetadataOperation(
267 _get_initial_metadata(state, metadata), _EMPTY_FLAGS
268 )
269 return operation
272def _abort(
273 state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes
274) -> None:
275 if state.client is not _CANCELLED:
276 effective_code = _abortion_code(state, code)
277 effective_details = details if state.details is None else state.details
278 if state.initial_metadata_allowed:
279 operations = (
280 _get_initial_metadata_operation(state, None),
281 cygrpc.SendStatusFromServerOperation(
282 state.trailing_metadata,
283 effective_code,
284 effective_details,
285 _EMPTY_FLAGS,
286 ),
287 )
288 token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
289 else:
290 operations = (
291 cygrpc.SendStatusFromServerOperation(
292 state.trailing_metadata,
293 effective_code,
294 effective_details,
295 _EMPTY_FLAGS,
296 ),
297 )
298 token = _SEND_STATUS_FROM_SERVER_TOKEN
299 call.start_server_batch(
300 operations, _send_status_from_server(state, token)
301 )
302 state.statused = True
303 state.due.add(token)
306def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag:
307 def receive_close_on_server(receive_close_on_server_event):
308 with state.condition:
309 if receive_close_on_server_event.batch_operations[0].cancelled():
310 state.client = _CANCELLED
311 elif state.client is _OPEN:
312 state.client = _CLOSED
313 state.condition.notify_all()
314 return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
316 return receive_close_on_server
319def _receive_message(
320 state: _RPCState,
321 call: cygrpc.Call,
322 request_deserializer: Optional[DeserializingFunction],
323) -> ServerCallbackTag:
324 def receive_message(receive_message_event):
325 serialized_request = _serialized_request(receive_message_event)
326 if serialized_request is None:
327 with state.condition:
328 if state.client is _OPEN:
329 state.client = _CLOSED
330 state.condition.notify_all()
331 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
332 else:
333 request = _common.deserialize(
334 serialized_request, request_deserializer
335 )
336 with state.condition:
337 if request is None:
338 _abort(
339 state,
340 call,
341 cygrpc.StatusCode.internal,
342 b"Exception deserializing request!",
343 )
344 else:
345 state.request = request
346 state.condition.notify_all()
347 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
349 return receive_message
352def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag:
353 def send_initial_metadata(unused_send_initial_metadata_event):
354 with state.condition:
355 return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
357 return send_initial_metadata
360def _send_message(state: _RPCState, token: str) -> ServerCallbackTag:
361 def send_message(unused_send_message_event):
362 with state.condition:
363 state.condition.notify_all()
364 return _possibly_finish_call(state, token)
366 return send_message
369class _Context(grpc.ServicerContext):
370 _rpc_event: cygrpc.BaseEvent
371 _state: _RPCState
372 request_deserializer: Optional[DeserializingFunction]
374 def __init__(
375 self,
376 rpc_event: cygrpc.BaseEvent,
377 state: _RPCState,
378 request_deserializer: Optional[DeserializingFunction],
379 ):
380 self._rpc_event = rpc_event
381 self._state = state
382 self._request_deserializer = request_deserializer
384 def is_active(self) -> bool:
385 with self._state.condition:
386 return _is_rpc_state_active(self._state)
388 def time_remaining(self) -> float:
389 return max(self._rpc_event.call_details.deadline - time.time(), 0)
391 def cancel(self) -> None:
392 self._rpc_event.call.cancel()
394 def add_callback(self, callback: NullaryCallbackType) -> bool:
395 with self._state.condition:
396 if self._state.callbacks is None:
397 return False
398 else:
399 self._state.callbacks.append(callback)
400 return True
402 def disable_next_message_compression(self) -> None:
403 with self._state.condition:
404 self._state.disable_next_compression = True
406 def invocation_metadata(self) -> Optional[MetadataType]:
407 return self._rpc_event.invocation_metadata
409 def peer(self) -> str:
410 return _common.decode(self._rpc_event.call.peer())
412 def peer_identities(self) -> Optional[Sequence[bytes]]:
413 return cygrpc.peer_identities(self._rpc_event.call)
415 def peer_identity_key(self) -> Optional[str]:
416 id_key = cygrpc.peer_identity_key(self._rpc_event.call)
417 return id_key if id_key is None else _common.decode(id_key)
419 def auth_context(self) -> Mapping[str, Sequence[bytes]]:
420 auth_context = cygrpc.auth_context(self._rpc_event.call)
421 auth_context_dict = {} if auth_context is None else auth_context
422 return {
423 _common.decode(key): value
424 for key, value in auth_context_dict.items()
425 }
427 def set_compression(self, compression: grpc.Compression) -> None:
428 with self._state.condition:
429 self._state.compression_algorithm = compression
431 def send_initial_metadata(self, initial_metadata: MetadataType) -> None:
432 with self._state.condition:
433 if self._state.client is _CANCELLED:
434 _raise_rpc_error(self._state)
435 else:
436 if self._state.initial_metadata_allowed:
437 operation = _get_initial_metadata_operation(
438 self._state, initial_metadata
439 )
440 self._rpc_event.call.start_server_batch(
441 (operation,), _send_initial_metadata(self._state)
442 )
443 self._state.initial_metadata_allowed = False
444 self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
445 else:
446 error_msg = "Initial metadata no longer allowed!"
447 raise ValueError(error_msg)
449 def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None:
450 with self._state.condition:
451 self._state.trailing_metadata = trailing_metadata
453 def trailing_metadata(self) -> Optional[MetadataType]:
454 return self._state.trailing_metadata
456 def abort(self, code: grpc.StatusCode, details: str) -> None:
457 # treat OK like other invalid arguments: fail the RPC
458 if code == grpc.StatusCode.OK:
459 _LOGGER.error(
460 "abort() called with StatusCode.OK; returning UNKNOWN"
461 )
462 code = grpc.StatusCode.UNKNOWN
463 details = ""
464 with self._state.condition:
465 self._state.code = code
466 self._state.details = _common.encode(details)
467 self._state.aborted = True
468 raise Exception()
470 def abort_with_status(self, status: grpc.Status) -> None:
471 self._state.trailing_metadata = status.trailing_metadata
472 self.abort(status.code, status.details)
474 def set_code(self, code: grpc.StatusCode) -> None:
475 with self._state.condition:
476 self._state.code = code
478 def code(self) -> grpc.StatusCode:
479 return self._state.code
481 def set_details(self, details: str) -> None:
482 with self._state.condition:
483 self._state.details = _common.encode(details)
485 def details(self) -> bytes:
486 return self._state.details
488 def _finalize_state(self) -> None:
489 pass
492class _RequestIterator(object):
493 _state: _RPCState
494 _call: cygrpc.Call
495 _request_deserializer: Optional[DeserializingFunction]
497 def __init__(
498 self,
499 state: _RPCState,
500 call: cygrpc.Call,
501 request_deserializer: Optional[DeserializingFunction],
502 ):
503 self._state = state
504 self._call = call
505 self._request_deserializer = request_deserializer
507 def _raise_or_start_receive_message(self) -> None:
508 if self._state.client is _CANCELLED:
509 _raise_rpc_error(self._state)
510 elif not _is_rpc_state_active(self._state):
511 raise StopIteration()
512 else:
513 self._call.start_server_batch(
514 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
515 _receive_message(
516 self._state, self._call, self._request_deserializer
517 ),
518 )
519 self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
521 def _look_for_request(self) -> Any:
522 if self._state.client is _CANCELLED:
523 _raise_rpc_error(self._state)
524 elif (
525 self._state.request is None
526 and _RECEIVE_MESSAGE_TOKEN not in self._state.due
527 ):
528 raise StopIteration()
529 else:
530 request = self._state.request
531 self._state.request = None
532 return request
534 raise AssertionError() # should never run
536 def _next(self) -> Any:
537 with self._state.condition:
538 self._raise_or_start_receive_message()
539 while True:
540 self._state.condition.wait()
541 request = self._look_for_request()
542 if request is not None:
543 return request
545 def __iter__(self) -> _RequestIterator:
546 return self
548 def __next__(self) -> Any:
549 return self._next()
551 def next(self) -> Any:
552 return self._next()
555def _unary_request(
556 rpc_event: cygrpc.BaseEvent,
557 state: _RPCState,
558 request_deserializer: Optional[DeserializingFunction],
559) -> Callable[[], Any]:
560 def unary_request():
561 with state.condition:
562 if not _is_rpc_state_active(state):
563 return None
564 else:
565 rpc_event.call.start_server_batch(
566 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
567 _receive_message(
568 state, rpc_event.call, request_deserializer
569 ),
570 )
571 state.due.add(_RECEIVE_MESSAGE_TOKEN)
572 while True:
573 state.condition.wait()
574 if state.request is None:
575 if state.client is _CLOSED:
576 details = '"{}" requires exactly one request message.'.format(
577 rpc_event.call_details.method
578 )
579 _abort(
580 state,
581 rpc_event.call,
582 cygrpc.StatusCode.unimplemented,
583 _common.encode(details),
584 )
585 return None
586 elif state.client is _CANCELLED:
587 return None
588 else:
589 request = state.request
590 state.request = None
591 return request
593 return unary_request
596def _call_behavior(
597 rpc_event: cygrpc.BaseEvent,
598 state: _RPCState,
599 behavior: ArityAgnosticMethodHandler,
600 argument: Any,
601 request_deserializer: Optional[DeserializingFunction],
602 send_response_callback: Optional[Callable[[ResponseType], None]] = None,
603) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]:
604 from grpc import _create_servicer_context # pytype: disable=pyi-error
606 with _create_servicer_context(
607 rpc_event, state, request_deserializer
608 ) as context:
609 try:
610 response_or_iterator = None
611 if send_response_callback is not None:
612 response_or_iterator = behavior(
613 argument, context, send_response_callback
614 )
615 else:
616 response_or_iterator = behavior(argument, context)
617 return response_or_iterator, True
618 except Exception as exception: # pylint: disable=broad-except
619 with state.condition:
620 if state.aborted:
621 _abort(
622 state,
623 rpc_event.call,
624 cygrpc.StatusCode.unknown,
625 b"RPC Aborted",
626 )
627 elif exception not in state.rpc_errors:
628 try:
629 details = "Exception calling application: {}".format(
630 exception
631 )
632 except Exception: # pylint: disable=broad-except
633 details = (
634 "Calling application raised unprintable Exception!"
635 )
636 _LOGGER.exception(
637 traceback.format_exception(
638 type(exception),
639 exception,
640 exception.__traceback__,
641 )
642 )
643 traceback.print_exc()
644 _LOGGER.exception(details)
645 _abort(
646 state,
647 rpc_event.call,
648 cygrpc.StatusCode.unknown,
649 _common.encode(details),
650 )
651 return None, False
654def _take_response_from_response_iterator(
655 rpc_event: cygrpc.BaseEvent,
656 state: _RPCState,
657 response_iterator: Iterator[ResponseType],
658) -> Tuple[ResponseType, bool]:
659 try:
660 return next(response_iterator), True
661 except StopIteration:
662 return None, True
663 except Exception as exception: # pylint: disable=broad-except
664 with state.condition:
665 if state.aborted:
666 _abort(
667 state,
668 rpc_event.call,
669 cygrpc.StatusCode.unknown,
670 b"RPC Aborted",
671 )
672 elif exception not in state.rpc_errors:
673 details = "Exception iterating responses: {}".format(exception)
674 _LOGGER.exception(details)
675 _abort(
676 state,
677 rpc_event.call,
678 cygrpc.StatusCode.unknown,
679 _common.encode(details),
680 )
681 return None, False
684def _serialize_response(
685 rpc_event: cygrpc.BaseEvent,
686 state: _RPCState,
687 response: Any,
688 response_serializer: Optional[SerializingFunction],
689) -> Optional[bytes]:
690 serialized_response = _common.serialize(response, response_serializer)
691 if serialized_response is None:
692 with state.condition:
693 _abort(
694 state,
695 rpc_event.call,
696 cygrpc.StatusCode.internal,
697 b"Failed to serialize response!",
698 )
699 return None
700 else:
701 return serialized_response
704def _get_send_message_op_flags_from_state(
705 state: _RPCState,
706) -> Union[int, cygrpc.WriteFlag]:
707 if state.disable_next_compression:
708 return cygrpc.WriteFlag.no_compress
709 else:
710 return _EMPTY_FLAGS
713def _reset_per_message_state(state: _RPCState) -> None:
714 with state.condition:
715 state.disable_next_compression = False
718def _send_response(
719 rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes
720) -> bool:
721 with state.condition:
722 if not _is_rpc_state_active(state):
723 return False
724 else:
725 if state.initial_metadata_allowed:
726 operations = (
727 _get_initial_metadata_operation(state, None),
728 cygrpc.SendMessageOperation(
729 serialized_response,
730 _get_send_message_op_flags_from_state(state),
731 ),
732 )
733 state.initial_metadata_allowed = False
734 token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
735 else:
736 operations = (
737 cygrpc.SendMessageOperation(
738 serialized_response,
739 _get_send_message_op_flags_from_state(state),
740 ),
741 )
742 token = _SEND_MESSAGE_TOKEN
743 rpc_event.call.start_server_batch(
744 operations, _send_message(state, token)
745 )
746 state.due.add(token)
747 _reset_per_message_state(state)
748 while True:
749 state.condition.wait()
750 if token not in state.due:
751 return _is_rpc_state_active(state)
754def _status(
755 rpc_event: cygrpc.BaseEvent,
756 state: _RPCState,
757 serialized_response: Optional[bytes],
758) -> None:
759 with state.condition:
760 if state.client is not _CANCELLED:
761 code = _completion_code(state)
762 details = _details(state)
763 operations = [
764 cygrpc.SendStatusFromServerOperation(
765 state.trailing_metadata, code, details, _EMPTY_FLAGS
766 ),
767 ]
768 if state.initial_metadata_allowed:
769 operations.append(_get_initial_metadata_operation(state, None))
770 if serialized_response is not None:
771 operations.append(
772 cygrpc.SendMessageOperation(
773 serialized_response,
774 _get_send_message_op_flags_from_state(state),
775 )
776 )
777 rpc_event.call.start_server_batch(
778 operations,
779 _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN),
780 )
781 state.statused = True
782 _reset_per_message_state(state)
783 state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
786def _unary_response_in_pool(
787 rpc_event: cygrpc.BaseEvent,
788 state: _RPCState,
789 behavior: ArityAgnosticMethodHandler,
790 argument_thunk: Callable[[], Any],
791 request_deserializer: Optional[SerializingFunction],
792 response_serializer: Optional[SerializingFunction],
793) -> None:
794 cygrpc.install_context_from_request_call_event(rpc_event)
796 try:
797 argument = argument_thunk()
798 if argument is not None:
799 response, proceed = _call_behavior(
800 rpc_event, state, behavior, argument, request_deserializer
801 )
802 if proceed:
803 serialized_response = _serialize_response(
804 rpc_event, state, response, response_serializer
805 )
806 if serialized_response is not None:
807 _status(rpc_event, state, serialized_response)
808 except Exception: # pylint: disable=broad-except
809 traceback.print_exc()
810 finally:
811 cygrpc.uninstall_context()
814def _stream_response_in_pool(
815 rpc_event: cygrpc.BaseEvent,
816 state: _RPCState,
817 behavior: ArityAgnosticMethodHandler,
818 argument_thunk: Callable[[], Any],
819 request_deserializer: Optional[DeserializingFunction],
820 response_serializer: Optional[SerializingFunction],
821) -> None:
822 cygrpc.install_context_from_request_call_event(rpc_event)
824 def send_response(response: Any) -> None:
825 if response is None:
826 _status(rpc_event, state, None)
827 else:
828 serialized_response = _serialize_response(
829 rpc_event, state, response, response_serializer
830 )
831 if serialized_response is not None:
832 _send_response(rpc_event, state, serialized_response)
834 try:
835 argument = argument_thunk()
836 if argument is not None:
837 if (
838 hasattr(behavior, "experimental_non_blocking")
839 and behavior.experimental_non_blocking
840 ):
841 _call_behavior(
842 rpc_event,
843 state,
844 behavior,
845 argument,
846 request_deserializer,
847 send_response_callback=send_response,
848 )
849 else:
850 response_iterator, proceed = _call_behavior(
851 rpc_event, state, behavior, argument, request_deserializer
852 )
853 if proceed:
854 _send_message_callback_to_blocking_iterator_adapter(
855 rpc_event, state, send_response, response_iterator
856 )
857 except Exception: # pylint: disable=broad-except
858 traceback.print_exc()
859 finally:
860 cygrpc.uninstall_context()
863def _is_rpc_state_active(state: _RPCState) -> bool:
864 return state.client is not _CANCELLED and not state.statused
867def _send_message_callback_to_blocking_iterator_adapter(
868 rpc_event: cygrpc.BaseEvent,
869 state: _RPCState,
870 send_response_callback: Callable[[ResponseType], None],
871 response_iterator: Iterator[ResponseType],
872) -> None:
873 while True:
874 response, proceed = _take_response_from_response_iterator(
875 rpc_event, state, response_iterator
876 )
877 if proceed:
878 send_response_callback(response)
879 if not _is_rpc_state_active(state):
880 break
881 else:
882 break
885def _select_thread_pool_for_behavior(
886 behavior: ArityAgnosticMethodHandler,
887 default_thread_pool: futures.ThreadPoolExecutor,
888) -> futures.ThreadPoolExecutor:
889 if hasattr(behavior, "experimental_thread_pool") and isinstance(
890 behavior.experimental_thread_pool, futures.ThreadPoolExecutor
891 ):
892 return behavior.experimental_thread_pool
893 else:
894 return default_thread_pool
897def _handle_unary_unary(
898 rpc_event: cygrpc.BaseEvent,
899 state: _RPCState,
900 method_handler: grpc.RpcMethodHandler,
901 default_thread_pool: futures.ThreadPoolExecutor,
902) -> futures.Future:
903 unary_request = _unary_request(
904 rpc_event, state, method_handler.request_deserializer
905 )
906 thread_pool = _select_thread_pool_for_behavior(
907 method_handler.unary_unary, default_thread_pool
908 )
909 return thread_pool.submit(
910 state.context.run,
911 _unary_response_in_pool,
912 rpc_event,
913 state,
914 method_handler.unary_unary,
915 unary_request,
916 method_handler.request_deserializer,
917 method_handler.response_serializer,
918 )
921def _handle_unary_stream(
922 rpc_event: cygrpc.BaseEvent,
923 state: _RPCState,
924 method_handler: grpc.RpcMethodHandler,
925 default_thread_pool: futures.ThreadPoolExecutor,
926) -> futures.Future:
927 unary_request = _unary_request(
928 rpc_event, state, method_handler.request_deserializer
929 )
930 thread_pool = _select_thread_pool_for_behavior(
931 method_handler.unary_stream, default_thread_pool
932 )
933 return thread_pool.submit(
934 state.context.run,
935 _stream_response_in_pool,
936 rpc_event,
937 state,
938 method_handler.unary_stream,
939 unary_request,
940 method_handler.request_deserializer,
941 method_handler.response_serializer,
942 )
945def _handle_stream_unary(
946 rpc_event: cygrpc.BaseEvent,
947 state: _RPCState,
948 method_handler: grpc.RpcMethodHandler,
949 default_thread_pool: futures.ThreadPoolExecutor,
950) -> futures.Future:
951 request_iterator = _RequestIterator(
952 state, rpc_event.call, method_handler.request_deserializer
953 )
954 thread_pool = _select_thread_pool_for_behavior(
955 method_handler.stream_unary, default_thread_pool
956 )
957 return thread_pool.submit(
958 state.context.run,
959 _unary_response_in_pool,
960 rpc_event,
961 state,
962 method_handler.stream_unary,
963 lambda: request_iterator,
964 method_handler.request_deserializer,
965 method_handler.response_serializer,
966 )
969def _handle_stream_stream(
970 rpc_event: cygrpc.BaseEvent,
971 state: _RPCState,
972 method_handler: grpc.RpcMethodHandler,
973 default_thread_pool: futures.ThreadPoolExecutor,
974) -> futures.Future:
975 request_iterator = _RequestIterator(
976 state, rpc_event.call, method_handler.request_deserializer
977 )
978 thread_pool = _select_thread_pool_for_behavior(
979 method_handler.stream_stream, default_thread_pool
980 )
981 return thread_pool.submit(
982 state.context.run,
983 _stream_response_in_pool,
984 rpc_event,
985 state,
986 method_handler.stream_stream,
987 lambda: request_iterator,
988 method_handler.request_deserializer,
989 method_handler.response_serializer,
990 )
993def _find_method_handler(
994 rpc_event: cygrpc.BaseEvent,
995 state: _RPCState,
996 method_with_handler: _Method,
997 interceptor_pipeline: Optional[_interceptor._ServicePipeline],
998) -> Optional[grpc.RpcMethodHandler]:
999 def query_handlers(
1000 handler_call_details: _HandlerCallDetails,
1001 ) -> Optional[grpc.RpcMethodHandler]:
1002 return method_with_handler.handler(handler_call_details)
1004 method_name = method_with_handler.name()
1005 if not method_name:
1006 method_name = _common.decode(rpc_event.call_details.method)
1008 handler_call_details = _HandlerCallDetails(
1009 method_name,
1010 rpc_event.invocation_metadata,
1011 )
1013 if interceptor_pipeline is not None:
1014 return state.context.run(
1015 interceptor_pipeline.execute, query_handlers, handler_call_details
1016 )
1017 else:
1018 return state.context.run(query_handlers, handler_call_details)
1021def _reject_rpc(
1022 rpc_event: cygrpc.BaseEvent,
1023 rpc_state: _RPCState,
1024 status: cygrpc.StatusCode,
1025 details: bytes,
1026):
1027 operations = (
1028 _get_initial_metadata_operation(rpc_state, None),
1029 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
1030 cygrpc.SendStatusFromServerOperation(
1031 None, status, details, _EMPTY_FLAGS
1032 ),
1033 )
1034 rpc_event.call.start_server_batch(
1035 operations,
1036 lambda _ignored_event: (
1037 rpc_state,
1038 (),
1039 ),
1040 )
1043def _handle_with_method_handler(
1044 rpc_event: cygrpc.BaseEvent,
1045 state: _RPCState,
1046 method_handler: grpc.RpcMethodHandler,
1047 thread_pool: futures.ThreadPoolExecutor,
1048) -> futures.Future:
1049 with state.condition:
1050 rpc_event.call.start_server_batch(
1051 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
1052 _receive_close_on_server(state),
1053 )
1054 state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
1055 if method_handler.request_streaming:
1056 if method_handler.response_streaming:
1057 return _handle_stream_stream(
1058 rpc_event, state, method_handler, thread_pool
1059 )
1060 else:
1061 return _handle_stream_unary(
1062 rpc_event, state, method_handler, thread_pool
1063 )
1064 else:
1065 if method_handler.response_streaming:
1066 return _handle_unary_stream(
1067 rpc_event, state, method_handler, thread_pool
1068 )
1069 else:
1070 return _handle_unary_unary(
1071 rpc_event, state, method_handler, thread_pool
1072 )
1075def _handle_call(
1076 rpc_event: cygrpc.BaseEvent,
1077 method_with_handler: _Method,
1078 interceptor_pipeline: Optional[_interceptor._ServicePipeline],
1079 thread_pool: futures.ThreadPoolExecutor,
1080 concurrency_exceeded: bool,
1081) -> Tuple[Optional[_RPCState], Optional[futures.Future]]:
1082 """Handles RPC based on provided handlers.
1084 When receiving a call event from Core, registered method will have its
1085 name as tag, we pass the tag as registered_method_name to this method,
1086 then we can find the handler in registered_method_handlers based on
1087 the method name.
1089 For call event with unregistered method, the method name will be included
1090 in rpc_event.call_details.method and we need to query the generics handlers
1091 to find the actual handler.
1092 """
1093 if not rpc_event.success:
1094 return None, None
1095 if rpc_event.call_details.method or method_with_handler.name():
1096 rpc_state = _RPCState()
1097 try:
1098 method_handler = _find_method_handler(
1099 rpc_event,
1100 rpc_state,
1101 method_with_handler,
1102 interceptor_pipeline,
1103 )
1104 except Exception as exception: # pylint: disable=broad-except
1105 details = "Exception servicing handler: {}".format(exception)
1106 _LOGGER.exception(details)
1107 _reject_rpc(
1108 rpc_event,
1109 rpc_state,
1110 cygrpc.StatusCode.unknown,
1111 b"Error in service handler!",
1112 )
1113 return rpc_state, None
1114 if method_handler is None:
1115 _reject_rpc(
1116 rpc_event,
1117 rpc_state,
1118 cygrpc.StatusCode.unimplemented,
1119 b"Method not found!",
1120 )
1121 return rpc_state, None
1122 elif concurrency_exceeded:
1123 _reject_rpc(
1124 rpc_event,
1125 rpc_state,
1126 cygrpc.StatusCode.resource_exhausted,
1127 b"Concurrent RPC limit exceeded!",
1128 )
1129 return rpc_state, None
1130 else:
1131 return (
1132 rpc_state,
1133 _handle_with_method_handler(
1134 rpc_event, rpc_state, method_handler, thread_pool
1135 ),
1136 )
1137 else:
1138 return None, None
1141@enum.unique
1142class _ServerStage(enum.Enum):
1143 STOPPED = "stopped"
1144 STARTED = "started"
1145 GRACE = "grace"
1148class _ServerState(object):
1149 lock: threading.RLock
1150 completion_queue: cygrpc.CompletionQueue
1151 server: cygrpc.Server
1152 generic_handlers: List[grpc.GenericRpcHandler]
1153 registered_method_handlers: Dict[str, grpc.RpcMethodHandler]
1154 interceptor_pipeline: Optional[_interceptor._ServicePipeline]
1155 thread_pool: futures.ThreadPoolExecutor
1156 stage: _ServerStage
1157 termination_event: threading.Event
1158 shutdown_events: List[threading.Event]
1159 maximum_concurrent_rpcs: Optional[int]
1160 active_rpc_count: int
1161 rpc_states: Set[_RPCState]
1162 due: Set[str]
1163 server_deallocated: bool
1165 # pylint: disable=too-many-arguments
1166 def __init__(
1167 self,
1168 completion_queue: cygrpc.CompletionQueue,
1169 server: cygrpc.Server,
1170 generic_handlers: Sequence[grpc.GenericRpcHandler],
1171 interceptor_pipeline: Optional[_interceptor._ServicePipeline],
1172 thread_pool: futures.ThreadPoolExecutor,
1173 maximum_concurrent_rpcs: Optional[int],
1174 ):
1175 self.lock = threading.RLock()
1176 self.completion_queue = completion_queue
1177 self.server = server
1178 self.generic_handlers = list(generic_handlers)
1179 self.interceptor_pipeline = interceptor_pipeline
1180 self.thread_pool = thread_pool
1181 self.stage = _ServerStage.STOPPED
1182 self.termination_event = threading.Event()
1183 self.shutdown_events = [self.termination_event]
1184 self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
1185 self.active_rpc_count = 0
1186 self.registered_method_handlers = {}
1188 # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
1189 self.rpc_states = set()
1190 self.due = set()
1192 # A "volatile" flag to interrupt the daemon serving thread
1193 self.server_deallocated = False
1196def _add_generic_handlers(
1197 state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler]
1198) -> None:
1199 with state.lock:
1200 state.generic_handlers.extend(generic_handlers)
1203def _add_registered_method_handlers(
1204 state: _ServerState, method_handlers: Dict[str, grpc.RpcMethodHandler]
1205) -> None:
1206 with state.lock:
1207 state.registered_method_handlers.update(method_handlers)
1210def _add_insecure_port(state: _ServerState, address: bytes) -> int:
1211 with state.lock:
1212 return state.server.add_http2_port(address)
1215def _add_secure_port(
1216 state: _ServerState,
1217 address: bytes,
1218 server_credentials: grpc.ServerCredentials,
1219) -> int:
1220 with state.lock:
1221 return state.server.add_http2_port(
1222 address, server_credentials._credentials
1223 )
1226def _request_call(state: _ServerState) -> None:
1227 state.server.request_call(
1228 state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG
1229 )
1230 state.due.add(_REQUEST_CALL_TAG)
1233def _request_registered_call(state: _ServerState, method: str) -> None:
1234 registered_call_tag = method
1235 state.server.request_registered_call(
1236 state.completion_queue,
1237 state.completion_queue,
1238 method,
1239 registered_call_tag,
1240 )
1241 state.due.add(registered_call_tag)
1244# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
1245def _stop_serving(state: _ServerState) -> bool:
1246 if not state.rpc_states and not state.due:
1247 state.server.destroy()
1248 for shutdown_event in state.shutdown_events:
1249 shutdown_event.set()
1250 state.stage = _ServerStage.STOPPED
1251 return True
1252 else:
1253 return False
1256def _on_call_completed(state: _ServerState) -> None:
1257 with state.lock:
1258 state.active_rpc_count -= 1
1261# pylint: disable=too-many-branches
1262def _process_event_and_continue(
1263 state: _ServerState, event: cygrpc.BaseEvent
1264) -> bool:
1265 should_continue = True
1266 if event.tag is _SHUTDOWN_TAG:
1267 with state.lock:
1268 state.due.remove(_SHUTDOWN_TAG)
1269 if _stop_serving(state):
1270 should_continue = False
1271 elif (
1272 event.tag is _REQUEST_CALL_TAG
1273 or event.tag in state.registered_method_handlers.keys()
1274 ):
1275 registered_method_name = None
1276 if event.tag in state.registered_method_handlers.keys():
1277 registered_method_name = event.tag
1278 method_with_handler = _RegisteredMethod(
1279 registered_method_name,
1280 state.registered_method_handlers.get(
1281 registered_method_name, None
1282 ),
1283 )
1284 else:
1285 method_with_handler = _GenericMethod(
1286 state.generic_handlers,
1287 )
1288 with state.lock:
1289 state.due.remove(event.tag)
1290 concurrency_exceeded = (
1291 state.maximum_concurrent_rpcs is not None
1292 and state.active_rpc_count >= state.maximum_concurrent_rpcs
1293 )
1294 rpc_state, rpc_future = _handle_call(
1295 event,
1296 method_with_handler,
1297 state.interceptor_pipeline,
1298 state.thread_pool,
1299 concurrency_exceeded,
1300 )
1301 if rpc_state is not None:
1302 state.rpc_states.add(rpc_state)
1303 if rpc_future is not None:
1304 state.active_rpc_count += 1
1305 rpc_future.add_done_callback(
1306 lambda _unused_future: _on_call_completed(state)
1307 )
1308 if state.stage is _ServerStage.STARTED:
1309 if (
1310 registered_method_name
1311 in state.registered_method_handlers.keys()
1312 ):
1313 _request_registered_call(state, registered_method_name)
1314 else:
1315 _request_call(state)
1316 elif _stop_serving(state):
1317 should_continue = False
1318 else:
1319 rpc_state, callbacks = event.tag(event)
1320 for callback in callbacks:
1321 try:
1322 callback()
1323 except Exception: # pylint: disable=broad-except
1324 _LOGGER.exception("Exception calling callback!")
1325 if rpc_state is not None:
1326 with state.lock:
1327 state.rpc_states.remove(rpc_state)
1328 if _stop_serving(state):
1329 should_continue = False
1330 return should_continue
1333def _serve(state: _ServerState) -> None:
1334 while True:
1335 timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
1336 event = state.completion_queue.poll(timeout)
1337 if state.server_deallocated:
1338 _begin_shutdown_once(state)
1339 if event.completion_type != cygrpc.CompletionType.queue_timeout:
1340 if not _process_event_and_continue(state, event):
1341 return
1342 # We want to force the deletion of the previous event
1343 # ~before~ we poll again; if the event has a reference
1344 # to a shutdown Call object, this can induce spinlock.
1345 event = None
1348def _begin_shutdown_once(state: _ServerState) -> None:
1349 with state.lock:
1350 if state.stage is _ServerStage.STARTED:
1351 state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
1352 state.stage = _ServerStage.GRACE
1353 state.due.add(_SHUTDOWN_TAG)
1356def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event:
1357 with state.lock:
1358 if state.stage is _ServerStage.STOPPED:
1359 shutdown_event = threading.Event()
1360 shutdown_event.set()
1361 return shutdown_event
1362 else:
1363 _begin_shutdown_once(state)
1364 shutdown_event = threading.Event()
1365 state.shutdown_events.append(shutdown_event)
1366 if grace is None:
1367 state.server.cancel_all_calls()
1368 else:
1370 def cancel_all_calls_after_grace():
1371 shutdown_event.wait(timeout=grace)
1372 with state.lock:
1373 state.server.cancel_all_calls()
1375 thread = threading.Thread(target=cancel_all_calls_after_grace)
1376 thread.start()
1377 return shutdown_event
1378 shutdown_event.wait()
1379 return shutdown_event
1382def _start(state: _ServerState) -> None:
1383 with state.lock:
1384 if state.stage is not _ServerStage.STOPPED:
1385 error_msg = "Cannot start already-started server!"
1386 raise ValueError(error_msg)
1387 state.server.start()
1388 state.stage = _ServerStage.STARTED
1389 # Request a call for each registered method so we can handle any of them.
1390 for method in state.registered_method_handlers.keys():
1391 _request_registered_call(state, method)
1392 # Also request a call for non-registered method.
1393 _request_call(state)
1394 thread = threading.Thread(target=_serve, args=(state,))
1395 thread.daemon = True
1396 thread.start()
1399def _validate_generic_rpc_handlers(
1400 generic_rpc_handlers: Iterable[grpc.GenericRpcHandler],
1401) -> None:
1402 for generic_rpc_handler in generic_rpc_handlers:
1403 service_attribute = getattr(generic_rpc_handler, "service", None)
1404 if service_attribute is None:
1405 error_msg = (
1406 f'"{generic_rpc_handler}" must conform to'
1407 'grpc.GenericRpcHandler type but does not have "service" method!'
1408 )
1409 raise AttributeError(error_msg)
1412def _augment_options(
1413 base_options: Sequence[ChannelArgumentType],
1414 compression: Optional[grpc.Compression],
1415 xds: bool,
1416) -> Sequence[ChannelArgumentType]:
1417 compression_option = _compression.create_channel_option(compression)
1418 maybe_server_call_tracer_factory_option = (
1419 _observability.create_server_call_tracer_factory_option(xds)
1420 )
1421 return (
1422 tuple(base_options)
1423 + compression_option
1424 + maybe_server_call_tracer_factory_option
1425 )
1428class _Server(grpc.Server):
1429 _state: _ServerState
1431 # pylint: disable=too-many-arguments
1432 def __init__(
1433 self,
1434 thread_pool: futures.ThreadPoolExecutor,
1435 generic_handlers: Sequence[grpc.GenericRpcHandler],
1436 interceptors: Sequence[grpc.ServerInterceptor],
1437 options: Sequence[ChannelArgumentType],
1438 maximum_concurrent_rpcs: Optional[int],
1439 compression: Optional[grpc.Compression],
1440 xds: bool,
1441 ):
1442 completion_queue = cygrpc.CompletionQueue()
1443 server = cygrpc.Server(_augment_options(options, compression, xds), xds)
1444 server.register_completion_queue(completion_queue)
1445 self._state = _ServerState(
1446 completion_queue,
1447 server,
1448 generic_handlers,
1449 _interceptor.service_pipeline(interceptors),
1450 thread_pool,
1451 maximum_concurrent_rpcs,
1452 )
1453 self._cy_server = server
1455 def add_generic_rpc_handlers(
1456 self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]
1457 ) -> None:
1458 _validate_generic_rpc_handlers(generic_rpc_handlers)
1459 _add_generic_handlers(self._state, generic_rpc_handlers)
1461 def add_registered_method_handlers(
1462 self,
1463 service_name: str,
1464 method_handlers: Dict[str, grpc.RpcMethodHandler],
1465 ) -> None:
1466 # Can't register method once server started.
1467 with self._state.lock:
1468 if self._state.stage is _ServerStage.STARTED:
1469 return
1471 # TODO(xuanwn): We should validate method_handlers first.
1472 method_to_handlers = {
1473 _common.fully_qualified_method(service_name, method): method_handler
1474 for method, method_handler in method_handlers.items()
1475 }
1476 for fully_qualified_method in method_to_handlers.keys():
1477 self._cy_server.register_method(fully_qualified_method)
1478 _add_registered_method_handlers(self._state, method_to_handlers)
1480 def add_insecure_port(self, address: str) -> int:
1481 return _common.validate_port_binding_result(
1482 address, _add_insecure_port(self._state, _common.encode(address))
1483 )
1485 def add_secure_port(
1486 self, address: str, server_credentials: grpc.ServerCredentials
1487 ) -> int:
1488 return _common.validate_port_binding_result(
1489 address,
1490 _add_secure_port(
1491 self._state, _common.encode(address), server_credentials
1492 ),
1493 )
1495 def start(self) -> None:
1496 _start(self._state)
1498 def wait_for_termination(self, timeout: Optional[float] = None) -> bool:
1499 # NOTE(https://bugs.python.org/issue35935)
1500 # Remove this workaround once threading.Event.wait() is working with
1501 # CTRL+C across platforms.
1502 return _common.wait(
1503 self._state.termination_event.wait,
1504 self._state.termination_event.is_set,
1505 timeout=timeout,
1506 )
1508 def stop(self, grace: Optional[float]) -> threading.Event:
1509 return _stop(self._state, grace)
1511 def __del__(self):
1512 if hasattr(self, "_state"):
1513 # We can not grab a lock in __del__(), so set a flag to signal the
1514 # serving daemon thread (if it exists) to initiate shutdown.
1515 self._state.server_deallocated = True
1518def create_server(
1519 thread_pool: futures.ThreadPoolExecutor,
1520 generic_rpc_handlers: Sequence[grpc.GenericRpcHandler],
1521 interceptors: Sequence[grpc.ServerInterceptor],
1522 options: Sequence[ChannelArgumentType],
1523 maximum_concurrent_rpcs: Optional[int],
1524 compression: Optional[grpc.Compression],
1525 xds: bool,
1526) -> _Server:
1527 _validate_generic_rpc_handlers(generic_rpc_handlers)
1528 return _Server(
1529 thread_pool,
1530 generic_rpc_handlers,
1531 interceptors,
1532 options,
1533 maximum_concurrent_rpcs,
1534 compression,
1535 xds,
1536 )