1import asyncio
2from abc import ABC
3from typing import (
4 TYPE_CHECKING,
5 AsyncIterable,
6 AsyncIterator,
7 Collection,
8 Iterable,
9 Mapping,
10 Optional,
11 Tuple,
12 Type,
13 Union,
14)
15
16import grpclib.const
17
18
19if TYPE_CHECKING:
20 from grpclib.client import Channel
21 from grpclib.metadata import Deadline
22
23 from .._types import (
24 ST,
25 IProtoMessage,
26 Message,
27 T,
28 )
29
30
31Value = Union[str, bytes]
32MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
33MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
34
35
36class ServiceStub(ABC):
37 """
38 Base class for async gRPC clients.
39 """
40
41 def __init__(
42 self,
43 channel: "Channel",
44 *,
45 timeout: Optional[float] = None,
46 deadline: Optional["Deadline"] = None,
47 metadata: Optional[MetadataLike] = None,
48 ) -> None:
49 self.channel = channel
50 self.timeout = timeout
51 self.deadline = deadline
52 self.metadata = metadata
53
54 def __resolve_request_kwargs(
55 self,
56 timeout: Optional[float],
57 deadline: Optional["Deadline"],
58 metadata: Optional[MetadataLike],
59 ):
60 return {
61 "timeout": self.timeout if timeout is None else timeout,
62 "deadline": self.deadline if deadline is None else deadline,
63 "metadata": self.metadata if metadata is None else metadata,
64 }
65
66 async def _unary_unary(
67 self,
68 route: str,
69 request: "IProtoMessage",
70 response_type: Type["T"],
71 *,
72 timeout: Optional[float] = None,
73 deadline: Optional["Deadline"] = None,
74 metadata: Optional[MetadataLike] = None,
75 ) -> "T":
76 """Make a unary request and return the response."""
77 async with self.channel.request(
78 route,
79 grpclib.const.Cardinality.UNARY_UNARY,
80 type(request),
81 response_type,
82 **self.__resolve_request_kwargs(timeout, deadline, metadata),
83 ) as stream:
84 await stream.send_message(request, end=True)
85 response = await stream.recv_message()
86 assert response is not None
87 return response
88
89 async def _unary_stream(
90 self,
91 route: str,
92 request: "IProtoMessage",
93 response_type: Type["T"],
94 *,
95 timeout: Optional[float] = None,
96 deadline: Optional["Deadline"] = None,
97 metadata: Optional[MetadataLike] = None,
98 ) -> AsyncIterator["T"]:
99 """Make a unary request and return the stream response iterator."""
100 async with self.channel.request(
101 route,
102 grpclib.const.Cardinality.UNARY_STREAM,
103 type(request),
104 response_type,
105 **self.__resolve_request_kwargs(timeout, deadline, metadata),
106 ) as stream:
107 await stream.send_message(request, end=True)
108 async for message in stream:
109 yield message
110
111 async def _stream_unary(
112 self,
113 route: str,
114 request_iterator: MessageSource,
115 request_type: Type["IProtoMessage"],
116 response_type: Type["T"],
117 *,
118 timeout: Optional[float] = None,
119 deadline: Optional["Deadline"] = None,
120 metadata: Optional[MetadataLike] = None,
121 ) -> "T":
122 """Make a stream request and return the response."""
123 async with self.channel.request(
124 route,
125 grpclib.const.Cardinality.STREAM_UNARY,
126 request_type,
127 response_type,
128 **self.__resolve_request_kwargs(timeout, deadline, metadata),
129 ) as stream:
130 await stream.send_request()
131 await self._send_messages(stream, request_iterator)
132 response = await stream.recv_message()
133 assert response is not None
134 return response
135
136 async def _stream_stream(
137 self,
138 route: str,
139 request_iterator: MessageSource,
140 request_type: Type["IProtoMessage"],
141 response_type: Type["T"],
142 *,
143 timeout: Optional[float] = None,
144 deadline: Optional["Deadline"] = None,
145 metadata: Optional[MetadataLike] = None,
146 ) -> AsyncIterator["T"]:
147 """
148 Make a stream request and return an AsyncIterator to iterate over response
149 messages.
150 """
151 async with self.channel.request(
152 route,
153 grpclib.const.Cardinality.STREAM_STREAM,
154 request_type,
155 response_type,
156 **self.__resolve_request_kwargs(timeout, deadline, metadata),
157 ) as stream:
158 await stream.send_request()
159 sending_task = asyncio.ensure_future(
160 self._send_messages(stream, request_iterator)
161 )
162 try:
163 async for response in stream:
164 yield response
165 except:
166 sending_task.cancel()
167 raise
168
169 @staticmethod
170 async def _send_messages(stream, messages: MessageSource):
171 if isinstance(messages, AsyncIterable):
172 async for message in messages:
173 await stream.send_message(message)
174 else:
175 for message in messages:
176 await stream.send_message(message)
177 await stream.end()