1from logging import getLogger
2from typing import Any, Union
3
4from ..exceptions import ConnectionError, InvalidResponse, ResponseError
5from ..typing import EncodableT
6from .base import (
7 AsyncPushNotificationsParser,
8 PushNotificationsParser,
9 _AsyncRESPBase,
10 _RESPBase,
11)
12from .socket import SERVER_CLOSED_CONNECTION_ERROR
13
14
15class _RESP3Parser(_RESPBase, PushNotificationsParser):
16 """RESP3 protocol implementation"""
17
18 def __init__(self, socket_read_size):
19 super().__init__(socket_read_size)
20 self.pubsub_push_handler_func = self.handle_pubsub_push_response
21 self.node_moving_push_handler_func = None
22 self.maintenance_push_handler_func = None
23 self.invalidation_push_handler_func = None
24
25 def handle_pubsub_push_response(self, response):
26 logger = getLogger("push_response")
27 logger.debug("Push response: " + str(response))
28 return response
29
30 def read_response(self, disable_decoding=False, push_request=False):
31 pos = self._buffer.get_pos() if self._buffer else None
32 try:
33 result = self._read_response(
34 disable_decoding=disable_decoding, push_request=push_request
35 )
36 except BaseException:
37 if self._buffer:
38 self._buffer.rewind(pos)
39 raise
40 else:
41 self._buffer.purge()
42 return result
43
44 def _read_response(self, disable_decoding=False, push_request=False):
45 raw = self._buffer.readline()
46 if not raw:
47 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
48
49 byte, response = raw[:1], raw[1:]
50
51 # server returned an error
52 if byte in (b"-", b"!"):
53 if byte == b"!":
54 response = self._buffer.read(int(response))
55 response = response.decode("utf-8", errors="replace")
56 error = self.parse_error(response)
57 # if the error is a ConnectionError, raise immediately so the user
58 # is notified
59 if isinstance(error, ConnectionError):
60 raise error
61 # otherwise, we're dealing with a ResponseError that might belong
62 # inside a pipeline response. the connection's read_response()
63 # and/or the pipeline's execute() will raise this error if
64 # necessary, so just return the exception instance here.
65 return error
66 # single value
67 elif byte == b"+":
68 pass
69 # null value
70 elif byte == b"_":
71 return None
72 # int and big int values
73 elif byte in (b":", b"("):
74 return int(response)
75 # double value
76 elif byte == b",":
77 return float(response)
78 # bool value
79 elif byte == b"#":
80 return response == b"t"
81 # bulk response
82 elif byte == b"$":
83 response = self._buffer.read(int(response))
84 # verbatim string response
85 elif byte == b"=":
86 response = self._buffer.read(int(response))[4:]
87 # array response
88 elif byte == b"*":
89 response = [
90 self._read_response(disable_decoding=disable_decoding)
91 for _ in range(int(response))
92 ]
93 # set response
94 elif byte == b"~":
95 # redis can return unhashable types (like dict) in a set,
96 # so we return sets as list, all the time, for predictability
97 response = [
98 self._read_response(disable_decoding=disable_decoding)
99 for _ in range(int(response))
100 ]
101 # map response
102 elif byte == b"%":
103 # We cannot use a dict-comprehension to parse stream.
104 # Evaluation order of key:val expression in dict comprehension only
105 # became defined to be left-right in version 3.8
106 resp_dict = {}
107 for _ in range(int(response)):
108 key = self._read_response(disable_decoding=disable_decoding)
109 resp_dict[key] = self._read_response(
110 disable_decoding=disable_decoding, push_request=push_request
111 )
112 response = resp_dict
113 # push response
114 elif byte == b">":
115 response = [
116 self._read_response(
117 disable_decoding=disable_decoding, push_request=push_request
118 )
119 for _ in range(int(response))
120 ]
121 response = self.handle_push_response(response)
122
123 # if this is a push request return the push response
124 if push_request:
125 return response
126
127 return self._read_response(
128 disable_decoding=disable_decoding,
129 push_request=push_request,
130 )
131 else:
132 raise InvalidResponse(f"Protocol Error: {raw!r}")
133
134 if isinstance(response, bytes) and disable_decoding is False:
135 response = self.encoder.decode(response)
136
137 return response
138
139
140class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
141 def __init__(self, socket_read_size):
142 super().__init__(socket_read_size)
143 self.pubsub_push_handler_func = self.handle_pubsub_push_response
144 self.invalidation_push_handler_func = None
145
146 async def handle_pubsub_push_response(self, response):
147 logger = getLogger("push_response")
148 logger.debug("Push response: " + str(response))
149 return response
150
151 async def read_response(
152 self, disable_decoding: bool = False, push_request: bool = False
153 ):
154 if self._chunks:
155 # augment parsing buffer with previously read data
156 self._buffer += b"".join(self._chunks)
157 self._chunks.clear()
158 self._pos = 0
159 response = await self._read_response(
160 disable_decoding=disable_decoding, push_request=push_request
161 )
162 # Successfully parsing a response allows us to clear our parsing buffer
163 self._clear()
164 return response
165
166 async def _read_response(
167 self, disable_decoding: bool = False, push_request: bool = False
168 ) -> Union[EncodableT, ResponseError, None]:
169 if not self._stream or not self.encoder:
170 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
171 raw = await self._readline()
172 response: Any
173 byte, response = raw[:1], raw[1:]
174
175 # if byte not in (b"-", b"+", b":", b"$", b"*"):
176 # raise InvalidResponse(f"Protocol Error: {raw!r}")
177
178 # server returned an error
179 if byte in (b"-", b"!"):
180 if byte == b"!":
181 response = await self._read(int(response))
182 response = response.decode("utf-8", errors="replace")
183 error = self.parse_error(response)
184 # if the error is a ConnectionError, raise immediately so the user
185 # is notified
186 if isinstance(error, ConnectionError):
187 self._clear() # Successful parse
188 raise error
189 # otherwise, we're dealing with a ResponseError that might belong
190 # inside a pipeline response. the connection's read_response()
191 # and/or the pipeline's execute() will raise this error if
192 # necessary, so just return the exception instance here.
193 return error
194 # single value
195 elif byte == b"+":
196 pass
197 # null value
198 elif byte == b"_":
199 return None
200 # int and big int values
201 elif byte in (b":", b"("):
202 return int(response)
203 # double value
204 elif byte == b",":
205 return float(response)
206 # bool value
207 elif byte == b"#":
208 return response == b"t"
209 # bulk response
210 elif byte == b"$":
211 response = await self._read(int(response))
212 # verbatim string response
213 elif byte == b"=":
214 response = (await self._read(int(response)))[4:]
215 # array response
216 elif byte == b"*":
217 response = [
218 (await self._read_response(disable_decoding=disable_decoding))
219 for _ in range(int(response))
220 ]
221 # set response
222 elif byte == b"~":
223 # redis can return unhashable types (like dict) in a set,
224 # so we always convert to a list, to have predictable return types
225 response = [
226 (await self._read_response(disable_decoding=disable_decoding))
227 for _ in range(int(response))
228 ]
229 # map response
230 elif byte == b"%":
231 # We cannot use a dict-comprehension to parse stream.
232 # Evaluation order of key:val expression in dict comprehension only
233 # became defined to be left-right in version 3.8
234 resp_dict = {}
235 for _ in range(int(response)):
236 key = await self._read_response(disable_decoding=disable_decoding)
237 resp_dict[key] = await self._read_response(
238 disable_decoding=disable_decoding, push_request=push_request
239 )
240 response = resp_dict
241 # push response
242 elif byte == b">":
243 response = [
244 (
245 await self._read_response(
246 disable_decoding=disable_decoding, push_request=push_request
247 )
248 )
249 for _ in range(int(response))
250 ]
251 response = await self.handle_push_response(response)
252 if not push_request:
253 return await self._read_response(
254 disable_decoding=disable_decoding, push_request=push_request
255 )
256 else:
257 return response
258 else:
259 raise InvalidResponse(f"Protocol Error: {raw!r}")
260
261 if isinstance(response, bytes) and disable_decoding is False:
262 response = self.encoder.decode(response)
263 return response