Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/grpc/_server.py: 36%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Service-side implementation of gRPC Python."""
16from __future__ import annotations
18import abc
19import collections
20from concurrent import futures
21import contextvars
22import enum
23import logging
24import threading
25import time
26import traceback
27from typing import (
28 Any,
29 Callable,
30 Dict,
31 Iterable,
32 Iterator,
33 List,
34 Mapping,
35 Optional,
36 Sequence,
37 Set,
38 Tuple,
39 Union,
40)
42import grpc # pytype: disable=pyi-error
43from grpc import _common # pytype: disable=pyi-error
44from grpc import _compression # pytype: disable=pyi-error
45from grpc import _interceptor # pytype: disable=pyi-error
46from grpc import _observability # pytype: disable=pyi-error
47from grpc._cython import cygrpc
48from grpc._typing import ArityAgnosticMethodHandler
49from grpc._typing import ChannelArgumentType
50from grpc._typing import DeserializingFunction
51from grpc._typing import MetadataType
52from grpc._typing import NullaryCallbackType
53from grpc._typing import ResponseType
54from grpc._typing import SerializingFunction
55from grpc._typing import ServerCallbackTag
56from grpc._typing import ServerTagCallbackType
58_LOGGER = logging.getLogger(__name__)
60_SHUTDOWN_TAG = "shutdown"
61_REQUEST_CALL_TAG = "request_call"
63_RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server"
64_SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata"
65_RECEIVE_MESSAGE_TOKEN = "receive_message"
66_SEND_MESSAGE_TOKEN = "send_message"
67_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
68 "send_initial_metadata * send_message"
69)
70_SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server"
71_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
72 "send_initial_metadata * send_status_from_server"
73)
75_OPEN = "open"
76_CLOSED = "closed"
77_CANCELLED = "cancelled"
79_EMPTY_FLAGS = 0
81_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
82_INF_TIMEOUT = 1e9
85def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes:
86 return request_event.batch_operations[0].message()
89def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode:
90 cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
91 return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
94def _completion_code(state: _RPCState) -> cygrpc.StatusCode:
95 if state.code is None:
96 return cygrpc.StatusCode.ok
97 else:
98 return _application_code(state.code)
101def _abortion_code(
102 state: _RPCState, code: cygrpc.StatusCode
103) -> cygrpc.StatusCode:
104 if state.code is None:
105 return code
106 else:
107 return _application_code(state.code)
110def _details(state: _RPCState) -> bytes:
111 return b"" if state.details is None else state.details
114class _HandlerCallDetails(
115 collections.namedtuple(
116 "_HandlerCallDetails",
117 (
118 "method",
119 "invocation_metadata",
120 ),
121 ),
122 grpc.HandlerCallDetails,
123):
124 pass
127class _Method(abc.ABC):
128 @abc.abstractmethod
129 def name(self) -> Optional[str]:
130 raise NotImplementedError()
132 @abc.abstractmethod
133 def handler(
134 self, handler_call_details: _HandlerCallDetails
135 ) -> Optional[grpc.RpcMethodHandler]:
136 raise NotImplementedError()
139class _RegisteredMethod(_Method):
140 def __init__(
141 self,
142 name: str,
143 registered_handler: Optional[grpc.RpcMethodHandler],
144 ):
145 self._name = name
146 self._registered_handler = registered_handler
148 def name(self) -> Optional[str]:
149 return self._name
151 def handler(
152 self, handler_call_details: _HandlerCallDetails
153 ) -> Optional[grpc.RpcMethodHandler]:
154 return self._registered_handler
157class _GenericMethod(_Method):
158 def __init__(
159 self,
160 generic_handlers: List[grpc.GenericRpcHandler],
161 ):
162 self._generic_handlers = generic_handlers
164 def name(self) -> Optional[str]:
165 return None
167 def handler(
168 self, handler_call_details: _HandlerCallDetails
169 ) -> Optional[grpc.RpcMethodHandler]:
170 # If the same method have both generic and registered handler,
171 # registered handler will take precedence.
172 for generic_handler in self._generic_handlers:
173 method_handler = generic_handler.service(handler_call_details)
174 if method_handler is not None:
175 return method_handler
176 return None
179class _RPCState(object):
180 context: contextvars.Context
181 condition: threading.Condition
182 due = Set[str]
183 request: Any
184 client: str
185 initial_metadata_allowed: bool
186 compression_algorithm: Optional[grpc.Compression]
187 disable_next_compression: bool
188 trailing_metadata: Optional[MetadataType]
189 code: Optional[grpc.StatusCode]
190 details: Optional[bytes]
191 statused: bool
192 rpc_errors: List[Exception]
193 callbacks: Optional[List[NullaryCallbackType]]
194 aborted: bool
196 def __init__(self):
197 self.context = contextvars.Context()
198 self.condition = threading.Condition()
199 self.due = set()
200 self.request = None
201 self.client = _OPEN
202 self.initial_metadata_allowed = True
203 self.compression_algorithm = None
204 self.disable_next_compression = False
205 self.trailing_metadata = None
206 self.code = None
207 self.details = None
208 self.statused = False
209 self.rpc_errors = []
210 self.callbacks = []
211 self.aborted = False
214def _raise_rpc_error(state: _RPCState) -> None:
215 rpc_error = grpc.RpcError()
216 state.rpc_errors.append(rpc_error)
217 raise rpc_error
220def _possibly_finish_call(
221 state: _RPCState, token: str
222) -> ServerTagCallbackType:
223 state.due.remove(token)
224 if not _is_rpc_state_active(state) and not state.due:
225 callbacks = state.callbacks
226 state.callbacks = None
227 return state, callbacks
228 else:
229 return None, ()
232def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag:
233 def send_status_from_server(unused_send_status_from_server_event):
234 with state.condition:
235 return _possibly_finish_call(state, token)
237 return send_status_from_server
240def _get_initial_metadata(
241 state: _RPCState, metadata: Optional[MetadataType]
242) -> Optional[MetadataType]:
243 with state.condition:
244 if state.compression_algorithm:
245 compression_metadata = (
246 _compression.compression_algorithm_to_metadata(
247 state.compression_algorithm
248 ),
249 )
250 if metadata is None:
251 return compression_metadata
252 else:
253 return compression_metadata + tuple(metadata)
254 else:
255 return metadata
258def _get_initial_metadata_operation(
259 state: _RPCState, metadata: Optional[MetadataType]
260) -> cygrpc.Operation:
261 operation = cygrpc.SendInitialMetadataOperation(
262 _get_initial_metadata(state, metadata), _EMPTY_FLAGS
263 )
264 return operation
267def _abort(
268 state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes
269) -> None:
270 if state.client is not _CANCELLED:
271 effective_code = _abortion_code(state, code)
272 effective_details = details if state.details is None else state.details
273 if state.initial_metadata_allowed:
274 operations = (
275 _get_initial_metadata_operation(state, None),
276 cygrpc.SendStatusFromServerOperation(
277 state.trailing_metadata,
278 effective_code,
279 effective_details,
280 _EMPTY_FLAGS,
281 ),
282 )
283 token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
284 else:
285 operations = (
286 cygrpc.SendStatusFromServerOperation(
287 state.trailing_metadata,
288 effective_code,
289 effective_details,
290 _EMPTY_FLAGS,
291 ),
292 )
293 token = _SEND_STATUS_FROM_SERVER_TOKEN
294 call.start_server_batch(
295 operations, _send_status_from_server(state, token)
296 )
297 state.statused = True
298 state.due.add(token)
301def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag:
302 def receive_close_on_server(receive_close_on_server_event):
303 with state.condition:
304 if receive_close_on_server_event.batch_operations[0].cancelled():
305 state.client = _CANCELLED
306 elif state.client is _OPEN:
307 state.client = _CLOSED
308 state.condition.notify_all()
309 return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
311 return receive_close_on_server
314def _receive_message(
315 state: _RPCState,
316 call: cygrpc.Call,
317 request_deserializer: Optional[DeserializingFunction],
318) -> ServerCallbackTag:
319 def receive_message(receive_message_event):
320 serialized_request = _serialized_request(receive_message_event)
321 if serialized_request is None:
322 with state.condition:
323 if state.client is _OPEN:
324 state.client = _CLOSED
325 state.condition.notify_all()
326 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
327 else:
328 request = _common.deserialize(
329 serialized_request, request_deserializer
330 )
331 with state.condition:
332 if request is None:
333 _abort(
334 state,
335 call,
336 cygrpc.StatusCode.internal,
337 b"Exception deserializing request!",
338 )
339 else:
340 state.request = request
341 state.condition.notify_all()
342 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
344 return receive_message
347def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag:
348 def send_initial_metadata(unused_send_initial_metadata_event):
349 with state.condition:
350 return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
352 return send_initial_metadata
355def _send_message(state: _RPCState, token: str) -> ServerCallbackTag:
356 def send_message(unused_send_message_event):
357 with state.condition:
358 state.condition.notify_all()
359 return _possibly_finish_call(state, token)
361 return send_message
364class _Context(grpc.ServicerContext):
365 _rpc_event: cygrpc.BaseEvent
366 _state: _RPCState
367 request_deserializer: Optional[DeserializingFunction]
369 def __init__(
370 self,
371 rpc_event: cygrpc.BaseEvent,
372 state: _RPCState,
373 request_deserializer: Optional[DeserializingFunction],
374 ):
375 self._rpc_event = rpc_event
376 self._state = state
377 self._request_deserializer = request_deserializer
379 def is_active(self) -> bool:
380 with self._state.condition:
381 return _is_rpc_state_active(self._state)
383 def time_remaining(self) -> float:
384 return max(self._rpc_event.call_details.deadline - time.time(), 0)
386 def cancel(self) -> None:
387 self._rpc_event.call.cancel()
389 def add_callback(self, callback: NullaryCallbackType) -> bool:
390 with self._state.condition:
391 if self._state.callbacks is None:
392 return False
393 else:
394 self._state.callbacks.append(callback)
395 return True
397 def disable_next_message_compression(self) -> None:
398 with self._state.condition:
399 self._state.disable_next_compression = True
401 def invocation_metadata(self) -> Optional[MetadataType]:
402 return self._rpc_event.invocation_metadata
404 def peer(self) -> str:
405 return _common.decode(self._rpc_event.call.peer())
407 def peer_identities(self) -> Optional[Sequence[bytes]]:
408 return cygrpc.peer_identities(self._rpc_event.call)
410 def peer_identity_key(self) -> Optional[str]:
411 id_key = cygrpc.peer_identity_key(self._rpc_event.call)
412 return id_key if id_key is None else _common.decode(id_key)
414 def auth_context(self) -> Mapping[str, Sequence[bytes]]:
415 auth_context = cygrpc.auth_context(self._rpc_event.call)
416 auth_context_dict = {} if auth_context is None else auth_context
417 return {
418 _common.decode(key): value
419 for key, value in auth_context_dict.items()
420 }
422 def set_compression(self, compression: grpc.Compression) -> None:
423 with self._state.condition:
424 self._state.compression_algorithm = compression
426 def send_initial_metadata(self, initial_metadata: MetadataType) -> None:
427 with self._state.condition:
428 if self._state.client is _CANCELLED:
429 _raise_rpc_error(self._state)
430 else:
431 if self._state.initial_metadata_allowed:
432 operation = _get_initial_metadata_operation(
433 self._state, initial_metadata
434 )
435 self._rpc_event.call.start_server_batch(
436 (operation,), _send_initial_metadata(self._state)
437 )
438 self._state.initial_metadata_allowed = False
439 self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
440 else:
441 raise ValueError("Initial metadata no longer allowed!")
443 def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None:
444 with self._state.condition:
445 self._state.trailing_metadata = trailing_metadata
447 def trailing_metadata(self) -> Optional[MetadataType]:
448 return self._state.trailing_metadata
450 def abort(self, code: grpc.StatusCode, details: str) -> None:
451 # treat OK like other invalid arguments: fail the RPC
452 if code == grpc.StatusCode.OK:
453 _LOGGER.error(
454 "abort() called with StatusCode.OK; returning UNKNOWN"
455 )
456 code = grpc.StatusCode.UNKNOWN
457 details = ""
458 with self._state.condition:
459 self._state.code = code
460 self._state.details = _common.encode(details)
461 self._state.aborted = True
462 raise Exception()
464 def abort_with_status(self, status: grpc.Status) -> None:
465 self._state.trailing_metadata = status.trailing_metadata
466 self.abort(status.code, status.details)
468 def set_code(self, code: grpc.StatusCode) -> None:
469 with self._state.condition:
470 self._state.code = code
472 def code(self) -> grpc.StatusCode:
473 return self._state.code
475 def set_details(self, details: str) -> None:
476 with self._state.condition:
477 self._state.details = _common.encode(details)
479 def details(self) -> bytes:
480 return self._state.details
482 def _finalize_state(self) -> None:
483 pass
486class _RequestIterator(object):
487 _state: _RPCState
488 _call: cygrpc.Call
489 _request_deserializer: Optional[DeserializingFunction]
491 def __init__(
492 self,
493 state: _RPCState,
494 call: cygrpc.Call,
495 request_deserializer: Optional[DeserializingFunction],
496 ):
497 self._state = state
498 self._call = call
499 self._request_deserializer = request_deserializer
501 def _raise_or_start_receive_message(self) -> None:
502 if self._state.client is _CANCELLED:
503 _raise_rpc_error(self._state)
504 elif not _is_rpc_state_active(self._state):
505 raise StopIteration()
506 else:
507 self._call.start_server_batch(
508 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
509 _receive_message(
510 self._state, self._call, self._request_deserializer
511 ),
512 )
513 self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
515 def _look_for_request(self) -> Any:
516 if self._state.client is _CANCELLED:
517 _raise_rpc_error(self._state)
518 elif (
519 self._state.request is None
520 and _RECEIVE_MESSAGE_TOKEN not in self._state.due
521 ):
522 raise StopIteration()
523 else:
524 request = self._state.request
525 self._state.request = None
526 return request
528 raise AssertionError() # should never run
530 def _next(self) -> Any:
531 with self._state.condition:
532 self._raise_or_start_receive_message()
533 while True:
534 self._state.condition.wait()
535 request = self._look_for_request()
536 if request is not None:
537 return request
539 def __iter__(self) -> _RequestIterator:
540 return self
542 def __next__(self) -> Any:
543 return self._next()
545 def next(self) -> Any:
546 return self._next()
549def _unary_request(
550 rpc_event: cygrpc.BaseEvent,
551 state: _RPCState,
552 request_deserializer: Optional[DeserializingFunction],
553) -> Callable[[], Any]:
554 def unary_request():
555 with state.condition:
556 if not _is_rpc_state_active(state):
557 return None
558 else:
559 rpc_event.call.start_server_batch(
560 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
561 _receive_message(
562 state, rpc_event.call, request_deserializer
563 ),
564 )
565 state.due.add(_RECEIVE_MESSAGE_TOKEN)
566 while True:
567 state.condition.wait()
568 if state.request is None:
569 if state.client is _CLOSED:
570 details = '"{}" requires exactly one request message.'.format(
571 rpc_event.call_details.method
572 )
573 _abort(
574 state,
575 rpc_event.call,
576 cygrpc.StatusCode.unimplemented,
577 _common.encode(details),
578 )
579 return None
580 elif state.client is _CANCELLED:
581 return None
582 else:
583 request = state.request
584 state.request = None
585 return request
587 return unary_request
590def _call_behavior(
591 rpc_event: cygrpc.BaseEvent,
592 state: _RPCState,
593 behavior: ArityAgnosticMethodHandler,
594 argument: Any,
595 request_deserializer: Optional[DeserializingFunction],
596 send_response_callback: Optional[Callable[[ResponseType], None]] = None,
597) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]:
598 from grpc import _create_servicer_context # pytype: disable=pyi-error
600 with _create_servicer_context(
601 rpc_event, state, request_deserializer
602 ) as context:
603 try:
604 response_or_iterator = None
605 if send_response_callback is not None:
606 response_or_iterator = behavior(
607 argument, context, send_response_callback
608 )
609 else:
610 response_or_iterator = behavior(argument, context)
611 return response_or_iterator, True
612 except Exception as exception: # pylint: disable=broad-except
613 with state.condition:
614 if state.aborted:
615 _abort(
616 state,
617 rpc_event.call,
618 cygrpc.StatusCode.unknown,
619 b"RPC Aborted",
620 )
621 elif exception not in state.rpc_errors:
622 try:
623 details = "Exception calling application: {}".format(
624 exception
625 )
626 except Exception: # pylint: disable=broad-except
627 details = (
628 "Calling application raised unprintable Exception!"
629 )
630 _LOGGER.exception(
631 traceback.format_exception(
632 type(exception),
633 exception,
634 exception.__traceback__,
635 )
636 )
637 traceback.print_exc()
638 _LOGGER.exception(details)
639 _abort(
640 state,
641 rpc_event.call,
642 cygrpc.StatusCode.unknown,
643 _common.encode(details),
644 )
645 return None, False
648def _take_response_from_response_iterator(
649 rpc_event: cygrpc.BaseEvent,
650 state: _RPCState,
651 response_iterator: Iterator[ResponseType],
652) -> Tuple[ResponseType, bool]:
653 try:
654 return next(response_iterator), True
655 except StopIteration:
656 return None, True
657 except Exception as exception: # pylint: disable=broad-except
658 with state.condition:
659 if state.aborted:
660 _abort(
661 state,
662 rpc_event.call,
663 cygrpc.StatusCode.unknown,
664 b"RPC Aborted",
665 )
666 elif exception not in state.rpc_errors:
667 details = "Exception iterating responses: {}".format(exception)
668 _LOGGER.exception(details)
669 _abort(
670 state,
671 rpc_event.call,
672 cygrpc.StatusCode.unknown,
673 _common.encode(details),
674 )
675 return None, False
678def _serialize_response(
679 rpc_event: cygrpc.BaseEvent,
680 state: _RPCState,
681 response: Any,
682 response_serializer: Optional[SerializingFunction],
683) -> Optional[bytes]:
684 serialized_response = _common.serialize(response, response_serializer)
685 if serialized_response is None:
686 with state.condition:
687 _abort(
688 state,
689 rpc_event.call,
690 cygrpc.StatusCode.internal,
691 b"Failed to serialize response!",
692 )
693 return None
694 else:
695 return serialized_response
698def _get_send_message_op_flags_from_state(
699 state: _RPCState,
700) -> Union[int, cygrpc.WriteFlag]:
701 if state.disable_next_compression:
702 return cygrpc.WriteFlag.no_compress
703 else:
704 return _EMPTY_FLAGS
707def _reset_per_message_state(state: _RPCState) -> None:
708 with state.condition:
709 state.disable_next_compression = False
712def _send_response(
713 rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes
714) -> bool:
715 with state.condition:
716 if not _is_rpc_state_active(state):
717 return False
718 else:
719 if state.initial_metadata_allowed:
720 operations = (
721 _get_initial_metadata_operation(state, None),
722 cygrpc.SendMessageOperation(
723 serialized_response,
724 _get_send_message_op_flags_from_state(state),
725 ),
726 )
727 state.initial_metadata_allowed = False
728 token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
729 else:
730 operations = (
731 cygrpc.SendMessageOperation(
732 serialized_response,
733 _get_send_message_op_flags_from_state(state),
734 ),
735 )
736 token = _SEND_MESSAGE_TOKEN
737 rpc_event.call.start_server_batch(
738 operations, _send_message(state, token)
739 )
740 state.due.add(token)
741 _reset_per_message_state(state)
742 while True:
743 state.condition.wait()
744 if token not in state.due:
745 return _is_rpc_state_active(state)
748def _status(
749 rpc_event: cygrpc.BaseEvent,
750 state: _RPCState,
751 serialized_response: Optional[bytes],
752) -> None:
753 with state.condition:
754 if state.client is not _CANCELLED:
755 code = _completion_code(state)
756 details = _details(state)
757 operations = [
758 cygrpc.SendStatusFromServerOperation(
759 state.trailing_metadata, code, details, _EMPTY_FLAGS
760 ),
761 ]
762 if state.initial_metadata_allowed:
763 operations.append(_get_initial_metadata_operation(state, None))
764 if serialized_response is not None:
765 operations.append(
766 cygrpc.SendMessageOperation(
767 serialized_response,
768 _get_send_message_op_flags_from_state(state),
769 )
770 )
771 rpc_event.call.start_server_batch(
772 operations,
773 _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN),
774 )
775 state.statused = True
776 _reset_per_message_state(state)
777 state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
780def _unary_response_in_pool(
781 rpc_event: cygrpc.BaseEvent,
782 state: _RPCState,
783 behavior: ArityAgnosticMethodHandler,
784 argument_thunk: Callable[[], Any],
785 request_deserializer: Optional[SerializingFunction],
786 response_serializer: Optional[SerializingFunction],
787) -> None:
788 cygrpc.install_context_from_request_call_event(rpc_event)
790 try:
791 argument = argument_thunk()
792 if argument is not None:
793 response, proceed = _call_behavior(
794 rpc_event, state, behavior, argument, request_deserializer
795 )
796 if proceed:
797 serialized_response = _serialize_response(
798 rpc_event, state, response, response_serializer
799 )
800 if serialized_response is not None:
801 _status(rpc_event, state, serialized_response)
802 except Exception: # pylint: disable=broad-except
803 traceback.print_exc()
804 finally:
805 cygrpc.uninstall_context()
808def _stream_response_in_pool(
809 rpc_event: cygrpc.BaseEvent,
810 state: _RPCState,
811 behavior: ArityAgnosticMethodHandler,
812 argument_thunk: Callable[[], Any],
813 request_deserializer: Optional[DeserializingFunction],
814 response_serializer: Optional[SerializingFunction],
815) -> None:
816 cygrpc.install_context_from_request_call_event(rpc_event)
818 def send_response(response: Any) -> None:
819 if response is None:
820 _status(rpc_event, state, None)
821 else:
822 serialized_response = _serialize_response(
823 rpc_event, state, response, response_serializer
824 )
825 if serialized_response is not None:
826 _send_response(rpc_event, state, serialized_response)
828 try:
829 argument = argument_thunk()
830 if argument is not None:
831 if (
832 hasattr(behavior, "experimental_non_blocking")
833 and behavior.experimental_non_blocking
834 ):
835 _call_behavior(
836 rpc_event,
837 state,
838 behavior,
839 argument,
840 request_deserializer,
841 send_response_callback=send_response,
842 )
843 else:
844 response_iterator, proceed = _call_behavior(
845 rpc_event, state, behavior, argument, request_deserializer
846 )
847 if proceed:
848 _send_message_callback_to_blocking_iterator_adapter(
849 rpc_event, state, send_response, response_iterator
850 )
851 except Exception: # pylint: disable=broad-except
852 traceback.print_exc()
853 finally:
854 cygrpc.uninstall_context()
857def _is_rpc_state_active(state: _RPCState) -> bool:
858 return state.client is not _CANCELLED and not state.statused
861def _send_message_callback_to_blocking_iterator_adapter(
862 rpc_event: cygrpc.BaseEvent,
863 state: _RPCState,
864 send_response_callback: Callable[[ResponseType], None],
865 response_iterator: Iterator[ResponseType],
866) -> None:
867 while True:
868 response, proceed = _take_response_from_response_iterator(
869 rpc_event, state, response_iterator
870 )
871 if proceed:
872 send_response_callback(response)
873 if not _is_rpc_state_active(state):
874 break
875 else:
876 break
879def _select_thread_pool_for_behavior(
880 behavior: ArityAgnosticMethodHandler,
881 default_thread_pool: futures.ThreadPoolExecutor,
882) -> futures.ThreadPoolExecutor:
883 if hasattr(behavior, "experimental_thread_pool") and isinstance(
884 behavior.experimental_thread_pool, futures.ThreadPoolExecutor
885 ):
886 return behavior.experimental_thread_pool
887 else:
888 return default_thread_pool
891def _handle_unary_unary(
892 rpc_event: cygrpc.BaseEvent,
893 state: _RPCState,
894 method_handler: grpc.RpcMethodHandler,
895 default_thread_pool: futures.ThreadPoolExecutor,
896) -> futures.Future:
897 unary_request = _unary_request(
898 rpc_event, state, method_handler.request_deserializer
899 )
900 thread_pool = _select_thread_pool_for_behavior(
901 method_handler.unary_unary, default_thread_pool
902 )
903 return thread_pool.submit(
904 state.context.run,
905 _unary_response_in_pool,
906 rpc_event,
907 state,
908 method_handler.unary_unary,
909 unary_request,
910 method_handler.request_deserializer,
911 method_handler.response_serializer,
912 )
915def _handle_unary_stream(
916 rpc_event: cygrpc.BaseEvent,
917 state: _RPCState,
918 method_handler: grpc.RpcMethodHandler,
919 default_thread_pool: futures.ThreadPoolExecutor,
920) -> futures.Future:
921 unary_request = _unary_request(
922 rpc_event, state, method_handler.request_deserializer
923 )
924 thread_pool = _select_thread_pool_for_behavior(
925 method_handler.unary_stream, default_thread_pool
926 )
927 return thread_pool.submit(
928 state.context.run,
929 _stream_response_in_pool,
930 rpc_event,
931 state,
932 method_handler.unary_stream,
933 unary_request,
934 method_handler.request_deserializer,
935 method_handler.response_serializer,
936 )
939def _handle_stream_unary(
940 rpc_event: cygrpc.BaseEvent,
941 state: _RPCState,
942 method_handler: grpc.RpcMethodHandler,
943 default_thread_pool: futures.ThreadPoolExecutor,
944) -> futures.Future:
945 request_iterator = _RequestIterator(
946 state, rpc_event.call, method_handler.request_deserializer
947 )
948 thread_pool = _select_thread_pool_for_behavior(
949 method_handler.stream_unary, default_thread_pool
950 )
951 return thread_pool.submit(
952 state.context.run,
953 _unary_response_in_pool,
954 rpc_event,
955 state,
956 method_handler.stream_unary,
957 lambda: request_iterator,
958 method_handler.request_deserializer,
959 method_handler.response_serializer,
960 )
963def _handle_stream_stream(
964 rpc_event: cygrpc.BaseEvent,
965 state: _RPCState,
966 method_handler: grpc.RpcMethodHandler,
967 default_thread_pool: futures.ThreadPoolExecutor,
968) -> futures.Future:
969 request_iterator = _RequestIterator(
970 state, rpc_event.call, method_handler.request_deserializer
971 )
972 thread_pool = _select_thread_pool_for_behavior(
973 method_handler.stream_stream, default_thread_pool
974 )
975 return thread_pool.submit(
976 state.context.run,
977 _stream_response_in_pool,
978 rpc_event,
979 state,
980 method_handler.stream_stream,
981 lambda: request_iterator,
982 method_handler.request_deserializer,
983 method_handler.response_serializer,
984 )
987def _find_method_handler(
988 rpc_event: cygrpc.BaseEvent,
989 state: _RPCState,
990 method_with_handler: _Method,
991 interceptor_pipeline: Optional[_interceptor._ServicePipeline],
992) -> Optional[grpc.RpcMethodHandler]:
993 def query_handlers(
994 handler_call_details: _HandlerCallDetails,
995 ) -> Optional[grpc.RpcMethodHandler]:
996 return method_with_handler.handler(handler_call_details)
998 method_name = method_with_handler.name()
999 if not method_name:
1000 method_name = _common.decode(rpc_event.call_details.method)
1002 handler_call_details = _HandlerCallDetails(
1003 method_name,
1004 rpc_event.invocation_metadata,
1005 )
1007 if interceptor_pipeline is not None:
1008 return state.context.run(
1009 interceptor_pipeline.execute, query_handlers, handler_call_details
1010 )
1011 else:
1012 return state.context.run(query_handlers, handler_call_details)
1015def _reject_rpc(
1016 rpc_event: cygrpc.BaseEvent,
1017 rpc_state: _RPCState,
1018 status: cygrpc.StatusCode,
1019 details: bytes,
1020):
1021 operations = (
1022 _get_initial_metadata_operation(rpc_state, None),
1023 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
1024 cygrpc.SendStatusFromServerOperation(
1025 None, status, details, _EMPTY_FLAGS
1026 ),
1027 )
1028 rpc_event.call.start_server_batch(
1029 operations,
1030 lambda ignored_event: (
1031 rpc_state,
1032 (),
1033 ),
1034 )
1037def _handle_with_method_handler(
1038 rpc_event: cygrpc.BaseEvent,
1039 state: _RPCState,
1040 method_handler: grpc.RpcMethodHandler,
1041 thread_pool: futures.ThreadPoolExecutor,
1042) -> futures.Future:
1043 with state.condition:
1044 rpc_event.call.start_server_batch(
1045 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
1046 _receive_close_on_server(state),
1047 )
1048 state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
1049 if method_handler.request_streaming:
1050 if method_handler.response_streaming:
1051 return _handle_stream_stream(
1052 rpc_event, state, method_handler, thread_pool
1053 )
1054 else:
1055 return _handle_stream_unary(
1056 rpc_event, state, method_handler, thread_pool
1057 )
1058 else:
1059 if method_handler.response_streaming:
1060 return _handle_unary_stream(
1061 rpc_event, state, method_handler, thread_pool
1062 )
1063 else:
1064 return _handle_unary_unary(
1065 rpc_event, state, method_handler, thread_pool
1066 )
1069def _handle_call(
1070 rpc_event: cygrpc.BaseEvent,
1071 method_with_handler: _Method,
1072 interceptor_pipeline: Optional[_interceptor._ServicePipeline],
1073 thread_pool: futures.ThreadPoolExecutor,
1074 concurrency_exceeded: bool,
1075) -> Tuple[Optional[_RPCState], Optional[futures.Future]]:
1076 """Handles RPC based on provided handlers.
1078 When receiving a call event from Core, registered method will have its
1079 name as tag, we pass the tag as registered_method_name to this method,
1080 then we can find the handler in registered_method_handlers based on
1081 the method name.
1083 For call event with unregistered method, the method name will be included
1084 in rpc_event.call_details.method and we need to query the generics handlers
1085 to find the actual handler.
1086 """
1087 if not rpc_event.success:
1088 return None, None
1089 if rpc_event.call_details.method or method_with_handler.name():
1090 rpc_state = _RPCState()
1091 try:
1092 method_handler = _find_method_handler(
1093 rpc_event,
1094 rpc_state,
1095 method_with_handler,
1096 interceptor_pipeline,
1097 )
1098 except Exception as exception: # pylint: disable=broad-except
1099 details = "Exception servicing handler: {}".format(exception)
1100 _LOGGER.exception(details)
1101 _reject_rpc(
1102 rpc_event,
1103 rpc_state,
1104 cygrpc.StatusCode.unknown,
1105 b"Error in service handler!",
1106 )
1107 return rpc_state, None
1108 if method_handler is None:
1109 _reject_rpc(
1110 rpc_event,
1111 rpc_state,
1112 cygrpc.StatusCode.unimplemented,
1113 b"Method not found!",
1114 )
1115 return rpc_state, None
1116 elif concurrency_exceeded:
1117 _reject_rpc(
1118 rpc_event,
1119 rpc_state,
1120 cygrpc.StatusCode.resource_exhausted,
1121 b"Concurrent RPC limit exceeded!",
1122 )
1123 return rpc_state, None
1124 else:
1125 return (
1126 rpc_state,
1127 _handle_with_method_handler(
1128 rpc_event, rpc_state, method_handler, thread_pool
1129 ),
1130 )
1131 else:
1132 return None, None
1135@enum.unique
1136class _ServerStage(enum.Enum):
1137 STOPPED = "stopped"
1138 STARTED = "started"
1139 GRACE = "grace"
1142class _ServerState(object):
1143 lock: threading.RLock
1144 completion_queue: cygrpc.CompletionQueue
1145 server: cygrpc.Server
1146 generic_handlers: List[grpc.GenericRpcHandler]
1147 registered_method_handlers: Dict[str, grpc.RpcMethodHandler]
1148 interceptor_pipeline: Optional[_interceptor._ServicePipeline]
1149 thread_pool: futures.ThreadPoolExecutor
1150 stage: _ServerStage
1151 termination_event: threading.Event
1152 shutdown_events: List[threading.Event]
1153 maximum_concurrent_rpcs: Optional[int]
1154 active_rpc_count: int
1155 rpc_states: Set[_RPCState]
1156 due: Set[str]
1157 server_deallocated: bool
1159 # pylint: disable=too-many-arguments
1160 def __init__(
1161 self,
1162 completion_queue: cygrpc.CompletionQueue,
1163 server: cygrpc.Server,
1164 generic_handlers: Sequence[grpc.GenericRpcHandler],
1165 interceptor_pipeline: Optional[_interceptor._ServicePipeline],
1166 thread_pool: futures.ThreadPoolExecutor,
1167 maximum_concurrent_rpcs: Optional[int],
1168 ):
1169 self.lock = threading.RLock()
1170 self.completion_queue = completion_queue
1171 self.server = server
1172 self.generic_handlers = list(generic_handlers)
1173 self.interceptor_pipeline = interceptor_pipeline
1174 self.thread_pool = thread_pool
1175 self.stage = _ServerStage.STOPPED
1176 self.termination_event = threading.Event()
1177 self.shutdown_events = [self.termination_event]
1178 self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
1179 self.active_rpc_count = 0
1180 self.registered_method_handlers = {}
1182 # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
1183 self.rpc_states = set()
1184 self.due = set()
1186 # A "volatile" flag to interrupt the daemon serving thread
1187 self.server_deallocated = False
1190def _add_generic_handlers(
1191 state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler]
1192) -> None:
1193 with state.lock:
1194 state.generic_handlers.extend(generic_handlers)
1197def _add_registered_method_handlers(
1198 state: _ServerState, method_handlers: Dict[str, grpc.RpcMethodHandler]
1199) -> None:
1200 with state.lock:
1201 state.registered_method_handlers.update(method_handlers)
1204def _add_insecure_port(state: _ServerState, address: bytes) -> int:
1205 with state.lock:
1206 return state.server.add_http2_port(address)
1209def _add_secure_port(
1210 state: _ServerState,
1211 address: bytes,
1212 server_credentials: grpc.ServerCredentials,
1213) -> int:
1214 with state.lock:
1215 return state.server.add_http2_port(
1216 address, server_credentials._credentials
1217 )
1220def _request_call(state: _ServerState) -> None:
1221 state.server.request_call(
1222 state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG
1223 )
1224 state.due.add(_REQUEST_CALL_TAG)
1227def _request_registered_call(state: _ServerState, method: str) -> None:
1228 registered_call_tag = method
1229 state.server.request_registered_call(
1230 state.completion_queue,
1231 state.completion_queue,
1232 method,
1233 registered_call_tag,
1234 )
1235 state.due.add(registered_call_tag)
1238# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
1239def _stop_serving(state: _ServerState) -> bool:
1240 if not state.rpc_states and not state.due:
1241 state.server.destroy()
1242 for shutdown_event in state.shutdown_events:
1243 shutdown_event.set()
1244 state.stage = _ServerStage.STOPPED
1245 return True
1246 else:
1247 return False
1250def _on_call_completed(state: _ServerState) -> None:
1251 with state.lock:
1252 state.active_rpc_count -= 1
1255# pylint: disable=too-many-branches
1256def _process_event_and_continue(
1257 state: _ServerState, event: cygrpc.BaseEvent
1258) -> bool:
1259 should_continue = True
1260 if event.tag is _SHUTDOWN_TAG:
1261 with state.lock:
1262 state.due.remove(_SHUTDOWN_TAG)
1263 if _stop_serving(state):
1264 should_continue = False
1265 elif (
1266 event.tag is _REQUEST_CALL_TAG
1267 or event.tag in state.registered_method_handlers.keys()
1268 ):
1269 registered_method_name = None
1270 if event.tag in state.registered_method_handlers.keys():
1271 registered_method_name = event.tag
1272 method_with_handler = _RegisteredMethod(
1273 registered_method_name,
1274 state.registered_method_handlers.get(
1275 registered_method_name, None
1276 ),
1277 )
1278 else:
1279 method_with_handler = _GenericMethod(
1280 state.generic_handlers,
1281 )
1282 with state.lock:
1283 state.due.remove(event.tag)
1284 concurrency_exceeded = (
1285 state.maximum_concurrent_rpcs is not None
1286 and state.active_rpc_count >= state.maximum_concurrent_rpcs
1287 )
1288 rpc_state, rpc_future = _handle_call(
1289 event,
1290 method_with_handler,
1291 state.interceptor_pipeline,
1292 state.thread_pool,
1293 concurrency_exceeded,
1294 )
1295 if rpc_state is not None:
1296 state.rpc_states.add(rpc_state)
1297 if rpc_future is not None:
1298 state.active_rpc_count += 1
1299 rpc_future.add_done_callback(
1300 lambda unused_future: _on_call_completed(state)
1301 )
1302 if state.stage is _ServerStage.STARTED:
1303 if (
1304 registered_method_name
1305 in state.registered_method_handlers.keys()
1306 ):
1307 _request_registered_call(state, registered_method_name)
1308 else:
1309 _request_call(state)
1310 elif _stop_serving(state):
1311 should_continue = False
1312 else:
1313 rpc_state, callbacks = event.tag(event)
1314 for callback in callbacks:
1315 try:
1316 callback()
1317 except Exception: # pylint: disable=broad-except
1318 _LOGGER.exception("Exception calling callback!")
1319 if rpc_state is not None:
1320 with state.lock:
1321 state.rpc_states.remove(rpc_state)
1322 if _stop_serving(state):
1323 should_continue = False
1324 return should_continue
1327def _serve(state: _ServerState) -> None:
1328 while True:
1329 timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
1330 event = state.completion_queue.poll(timeout)
1331 if state.server_deallocated:
1332 _begin_shutdown_once(state)
1333 if event.completion_type != cygrpc.CompletionType.queue_timeout:
1334 if not _process_event_and_continue(state, event):
1335 return
1336 # We want to force the deletion of the previous event
1337 # ~before~ we poll again; if the event has a reference
1338 # to a shutdown Call object, this can induce spinlock.
1339 event = None
1342def _begin_shutdown_once(state: _ServerState) -> None:
1343 with state.lock:
1344 if state.stage is _ServerStage.STARTED:
1345 state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
1346 state.stage = _ServerStage.GRACE
1347 state.due.add(_SHUTDOWN_TAG)
1350def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event:
1351 with state.lock:
1352 if state.stage is _ServerStage.STOPPED:
1353 shutdown_event = threading.Event()
1354 shutdown_event.set()
1355 return shutdown_event
1356 else:
1357 _begin_shutdown_once(state)
1358 shutdown_event = threading.Event()
1359 state.shutdown_events.append(shutdown_event)
1360 if grace is None:
1361 state.server.cancel_all_calls()
1362 else:
1364 def cancel_all_calls_after_grace():
1365 shutdown_event.wait(timeout=grace)
1366 with state.lock:
1367 state.server.cancel_all_calls()
1369 thread = threading.Thread(target=cancel_all_calls_after_grace)
1370 thread.start()
1371 return shutdown_event
1372 shutdown_event.wait()
1373 return shutdown_event
1376def _start(state: _ServerState) -> None:
1377 with state.lock:
1378 if state.stage is not _ServerStage.STOPPED:
1379 raise ValueError("Cannot start already-started server!")
1380 state.server.start()
1381 state.stage = _ServerStage.STARTED
1382 # Request a call for each registered method so we can handle any of them.
1383 for method in state.registered_method_handlers.keys():
1384 _request_registered_call(state, method)
1385 # Also request a call for non-registered method.
1386 _request_call(state)
1387 thread = threading.Thread(target=_serve, args=(state,))
1388 thread.daemon = True
1389 thread.start()
1392def _validate_generic_rpc_handlers(
1393 generic_rpc_handlers: Iterable[grpc.GenericRpcHandler],
1394) -> None:
1395 for generic_rpc_handler in generic_rpc_handlers:
1396 service_attribute = getattr(generic_rpc_handler, "service", None)
1397 if service_attribute is None:
1398 raise AttributeError(
1399 '"{}" must conform to grpc.GenericRpcHandler type but does '
1400 'not have "service" method!'.format(generic_rpc_handler)
1401 )
1404def _augment_options(
1405 base_options: Sequence[ChannelArgumentType],
1406 compression: Optional[grpc.Compression],
1407 xds: bool,
1408) -> Sequence[ChannelArgumentType]:
1409 compression_option = _compression.create_channel_option(compression)
1410 maybe_server_call_tracer_factory_option = (
1411 _observability.create_server_call_tracer_factory_option(xds)
1412 )
1413 return (
1414 tuple(base_options)
1415 + compression_option
1416 + maybe_server_call_tracer_factory_option
1417 )
1420class _Server(grpc.Server):
1421 _state: _ServerState
1423 # pylint: disable=too-many-arguments
1424 def __init__(
1425 self,
1426 thread_pool: futures.ThreadPoolExecutor,
1427 generic_handlers: Sequence[grpc.GenericRpcHandler],
1428 interceptors: Sequence[grpc.ServerInterceptor],
1429 options: Sequence[ChannelArgumentType],
1430 maximum_concurrent_rpcs: Optional[int],
1431 compression: Optional[grpc.Compression],
1432 xds: bool,
1433 ):
1434 completion_queue = cygrpc.CompletionQueue()
1435 server = cygrpc.Server(_augment_options(options, compression, xds), xds)
1436 server.register_completion_queue(completion_queue)
1437 self._state = _ServerState(
1438 completion_queue,
1439 server,
1440 generic_handlers,
1441 _interceptor.service_pipeline(interceptors),
1442 thread_pool,
1443 maximum_concurrent_rpcs,
1444 )
1445 self._cy_server = server
1447 def add_generic_rpc_handlers(
1448 self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]
1449 ) -> None:
1450 _validate_generic_rpc_handlers(generic_rpc_handlers)
1451 _add_generic_handlers(self._state, generic_rpc_handlers)
1453 def add_registered_method_handlers(
1454 self,
1455 service_name: str,
1456 method_handlers: Dict[str, grpc.RpcMethodHandler],
1457 ) -> None:
1458 # Can't register method once server started.
1459 with self._state.lock:
1460 if self._state.stage is _ServerStage.STARTED:
1461 return
1463 # TODO(xuanwn): We should validate method_handlers first.
1464 method_to_handlers = {
1465 _common.fully_qualified_method(service_name, method): method_handler
1466 for method, method_handler in method_handlers.items()
1467 }
1468 for fully_qualified_method in method_to_handlers.keys():
1469 self._cy_server.register_method(fully_qualified_method)
1470 _add_registered_method_handlers(self._state, method_to_handlers)
1472 def add_insecure_port(self, address: str) -> int:
1473 return _common.validate_port_binding_result(
1474 address, _add_insecure_port(self._state, _common.encode(address))
1475 )
1477 def add_secure_port(
1478 self, address: str, server_credentials: grpc.ServerCredentials
1479 ) -> int:
1480 return _common.validate_port_binding_result(
1481 address,
1482 _add_secure_port(
1483 self._state, _common.encode(address), server_credentials
1484 ),
1485 )
1487 def start(self) -> None:
1488 _start(self._state)
1490 def wait_for_termination(self, timeout: Optional[float] = None) -> bool:
1491 # NOTE(https://bugs.python.org/issue35935)
1492 # Remove this workaround once threading.Event.wait() is working with
1493 # CTRL+C across platforms.
1494 return _common.wait(
1495 self._state.termination_event.wait,
1496 self._state.termination_event.is_set,
1497 timeout=timeout,
1498 )
1500 def stop(self, grace: Optional[float]) -> threading.Event:
1501 return _stop(self._state, grace)
1503 def __del__(self):
1504 if hasattr(self, "_state"):
1505 # We can not grab a lock in __del__(), so set a flag to signal the
1506 # serving daemon thread (if it exists) to initiate shutdown.
1507 self._state.server_deallocated = True
1510def create_server(
1511 thread_pool: futures.ThreadPoolExecutor,
1512 generic_rpc_handlers: Sequence[grpc.GenericRpcHandler],
1513 interceptors: Sequence[grpc.ServerInterceptor],
1514 options: Sequence[ChannelArgumentType],
1515 maximum_concurrent_rpcs: Optional[int],
1516 compression: Optional[grpc.Compression],
1517 xds: bool,
1518) -> _Server:
1519 _validate_generic_rpc_handlers(generic_rpc_handlers)
1520 return _Server(
1521 thread_pool,
1522 generic_rpc_handlers,
1523 interceptors,
1524 options,
1525 maximum_concurrent_rpcs,
1526 compression,
1527 xds,
1528 )