1import asyncio
2from abc import ABC
3from typing import (
4 TYPE_CHECKING,
5 AsyncIterable,
6 AsyncIterator,
7 Collection,
8 Iterable,
9 Mapping,
10 Optional,
11 Tuple,
12 Type,
13 Union,
14)
15
16import grpclib.const
17
18
19if TYPE_CHECKING:
20 from grpclib.client import Channel
21 from grpclib.metadata import Deadline
22
23 from .._types import (
24 ST,
25 IProtoMessage,
26 Message,
27 T,
28 )
29
30
31Value = Union[str, bytes]
32MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
33MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
34
35
36class ServiceStub(ABC):
37 """
38 Base class for async gRPC clients.
39 """
40
41 def __init__(
42 self,
43 channel: "Channel",
44 *,
45 timeout: Optional[float] = None,
46 deadline: Optional["Deadline"] = None,
47 metadata: Optional[MetadataLike] = None,
48 ) -> None:
49 self.channel = channel
50 self.timeout = timeout
51 self.deadline = deadline
52 self.metadata = metadata
53
54 def __resolve_request_kwargs(
55 self,
56 timeout: Optional[float],
57 deadline: Optional["Deadline"],
58 metadata: Optional[MetadataLike],
59 ):
60 return {
61 "timeout": self.timeout if timeout is None else timeout,
62 "deadline": self.deadline if deadline is None else deadline,
63 "metadata": self.metadata if metadata is None else metadata,
64 }
65
66 async def _unary_unary(
67 self,
68 route: str,
69 request: "IProtoMessage",
70 response_type: Type["T"],
71 *,
72 timeout: Optional[float] = None,
73 deadline: Optional["Deadline"] = None,
74 metadata: Optional[MetadataLike] = None,
75 ) -> "T":
76 """Make a unary request and return the response."""
77 async with self.channel.request(
78 route,
79 grpclib.const.Cardinality.UNARY_UNARY,
80 type(request),
81 response_type,
82 **self.__resolve_request_kwargs(timeout, deadline, metadata),
83 ) as stream:
84 await stream.send_message(request, end=True)
85 response = await stream.recv_message()
86 assert response is not None
87 return response
88
89 async def _unary_stream(
90 self,
91 route: str,
92 request: "IProtoMessage",
93 response_type: Type["T"],
94 *,
95 timeout: Optional[float] = None,
96 deadline: Optional["Deadline"] = None,
97 metadata: Optional[MetadataLike] = None,
98 ) -> AsyncIterator["T"]:
99 """Make a unary request and return the stream response iterator."""
100 async with self.channel.request(
101 route,
102 grpclib.const.Cardinality.UNARY_STREAM,
103 type(request),
104 response_type,
105 **self.__resolve_request_kwargs(timeout, deadline, metadata),
106 ) as stream:
107 await stream.send_message(request, end=True)
108 async for message in stream:
109 yield message
110
111 async def _stream_unary(
112 self,
113 route: str,
114 request_iterator: MessageSource,
115 request_type: Type["IProtoMessage"],
116 response_type: Type["T"],
117 *,
118 timeout: Optional[float] = None,
119 deadline: Optional["Deadline"] = None,
120 metadata: Optional[MetadataLike] = None,
121 ) -> "T":
122 """Make a stream request and return the response."""
123 async with self.channel.request(
124 route,
125 grpclib.const.Cardinality.STREAM_UNARY,
126 request_type,
127 response_type,
128 **self.__resolve_request_kwargs(timeout, deadline, metadata),
129 ) as stream:
130 await self._send_messages(stream, request_iterator)
131 response = await stream.recv_message()
132 assert response is not None
133 return response
134
135 async def _stream_stream(
136 self,
137 route: str,
138 request_iterator: MessageSource,
139 request_type: Type["IProtoMessage"],
140 response_type: Type["T"],
141 *,
142 timeout: Optional[float] = None,
143 deadline: Optional["Deadline"] = None,
144 metadata: Optional[MetadataLike] = None,
145 ) -> AsyncIterator["T"]:
146 """
147 Make a stream request and return an AsyncIterator to iterate over response
148 messages.
149 """
150 async with self.channel.request(
151 route,
152 grpclib.const.Cardinality.STREAM_STREAM,
153 request_type,
154 response_type,
155 **self.__resolve_request_kwargs(timeout, deadline, metadata),
156 ) as stream:
157 await stream.send_request()
158 sending_task = asyncio.ensure_future(
159 self._send_messages(stream, request_iterator)
160 )
161 try:
162 async for response in stream:
163 yield response
164 except:
165 sending_task.cancel()
166 raise
167
168 @staticmethod
169 async def _send_messages(stream, messages: MessageSource):
170 if isinstance(messages, AsyncIterable):
171 async for message in messages:
172 await stream.send_message(message)
173 else:
174 for message in messages:
175 await stream.send_message(message)
176 await stream.end()