1from logging import getLogger
2from typing import Any, Union
3
4from ..exceptions import ConnectionError, InvalidResponse, ResponseError
5from ..typing import EncodableT
6from .base import (
7 AsyncPushNotificationsParser,
8 PushNotificationsParser,
9 _AsyncRESPBase,
10 _RESPBase,
11)
12from .socket import SERVER_CLOSED_CONNECTION_ERROR
13
14
15class _RESP3Parser(_RESPBase, PushNotificationsParser):
16 """RESP3 protocol implementation"""
17
18 def __init__(self, socket_read_size):
19 super().__init__(socket_read_size)
20 self.pubsub_push_handler_func = self.handle_pubsub_push_response
21 self.invalidation_push_handler_func = None
22
23 def handle_pubsub_push_response(self, response):
24 logger = getLogger("push_response")
25 logger.debug("Push response: " + str(response))
26 return response
27
28 def read_response(self, disable_decoding=False, push_request=False):
29 pos = self._buffer.get_pos() if self._buffer else None
30 try:
31 result = self._read_response(
32 disable_decoding=disable_decoding, push_request=push_request
33 )
34 except BaseException:
35 if self._buffer:
36 self._buffer.rewind(pos)
37 raise
38 else:
39 self._buffer.purge()
40 return result
41
42 def _read_response(self, disable_decoding=False, push_request=False):
43 raw = self._buffer.readline()
44 if not raw:
45 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
46
47 byte, response = raw[:1], raw[1:]
48
49 # server returned an error
50 if byte in (b"-", b"!"):
51 if byte == b"!":
52 response = self._buffer.read(int(response))
53 response = response.decode("utf-8", errors="replace")
54 error = self.parse_error(response)
55 # if the error is a ConnectionError, raise immediately so the user
56 # is notified
57 if isinstance(error, ConnectionError):
58 raise error
59 # otherwise, we're dealing with a ResponseError that might belong
60 # inside a pipeline response. the connection's read_response()
61 # and/or the pipeline's execute() will raise this error if
62 # necessary, so just return the exception instance here.
63 return error
64 # single value
65 elif byte == b"+":
66 pass
67 # null value
68 elif byte == b"_":
69 return None
70 # int and big int values
71 elif byte in (b":", b"("):
72 return int(response)
73 # double value
74 elif byte == b",":
75 return float(response)
76 # bool value
77 elif byte == b"#":
78 return response == b"t"
79 # bulk response
80 elif byte == b"$":
81 response = self._buffer.read(int(response))
82 # verbatim string response
83 elif byte == b"=":
84 response = self._buffer.read(int(response))[4:]
85 # array response
86 elif byte == b"*":
87 response = [
88 self._read_response(disable_decoding=disable_decoding)
89 for _ in range(int(response))
90 ]
91 # set response
92 elif byte == b"~":
93 # redis can return unhashable types (like dict) in a set,
94 # so we return sets as list, all the time, for predictability
95 response = [
96 self._read_response(disable_decoding=disable_decoding)
97 for _ in range(int(response))
98 ]
99 # map response
100 elif byte == b"%":
101 # We cannot use a dict-comprehension to parse stream.
102 # Evaluation order of key:val expression in dict comprehension only
103 # became defined to be left-right in version 3.8
104 resp_dict = {}
105 for _ in range(int(response)):
106 key = self._read_response(disable_decoding=disable_decoding)
107 resp_dict[key] = self._read_response(
108 disable_decoding=disable_decoding, push_request=push_request
109 )
110 response = resp_dict
111 # push response
112 elif byte == b">":
113 response = [
114 self._read_response(
115 disable_decoding=disable_decoding, push_request=push_request
116 )
117 for _ in range(int(response))
118 ]
119 response = self.handle_push_response(response)
120 if not push_request:
121 return self._read_response(
122 disable_decoding=disable_decoding, push_request=push_request
123 )
124 else:
125 return response
126 else:
127 raise InvalidResponse(f"Protocol Error: {raw!r}")
128
129 if isinstance(response, bytes) and disable_decoding is False:
130 response = self.encoder.decode(response)
131 return response
132
133
134class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
135 def __init__(self, socket_read_size):
136 super().__init__(socket_read_size)
137 self.pubsub_push_handler_func = self.handle_pubsub_push_response
138 self.invalidation_push_handler_func = None
139
140 async def handle_pubsub_push_response(self, response):
141 logger = getLogger("push_response")
142 logger.debug("Push response: " + str(response))
143 return response
144
145 async def read_response(
146 self, disable_decoding: bool = False, push_request: bool = False
147 ):
148 if self._chunks:
149 # augment parsing buffer with previously read data
150 self._buffer += b"".join(self._chunks)
151 self._chunks.clear()
152 self._pos = 0
153 response = await self._read_response(
154 disable_decoding=disable_decoding, push_request=push_request
155 )
156 # Successfully parsing a response allows us to clear our parsing buffer
157 self._clear()
158 return response
159
160 async def _read_response(
161 self, disable_decoding: bool = False, push_request: bool = False
162 ) -> Union[EncodableT, ResponseError, None]:
163 if not self._stream or not self.encoder:
164 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
165 raw = await self._readline()
166 response: Any
167 byte, response = raw[:1], raw[1:]
168
169 # if byte not in (b"-", b"+", b":", b"$", b"*"):
170 # raise InvalidResponse(f"Protocol Error: {raw!r}")
171
172 # server returned an error
173 if byte in (b"-", b"!"):
174 if byte == b"!":
175 response = await self._read(int(response))
176 response = response.decode("utf-8", errors="replace")
177 error = self.parse_error(response)
178 # if the error is a ConnectionError, raise immediately so the user
179 # is notified
180 if isinstance(error, ConnectionError):
181 self._clear() # Successful parse
182 raise error
183 # otherwise, we're dealing with a ResponseError that might belong
184 # inside a pipeline response. the connection's read_response()
185 # and/or the pipeline's execute() will raise this error if
186 # necessary, so just return the exception instance here.
187 return error
188 # single value
189 elif byte == b"+":
190 pass
191 # null value
192 elif byte == b"_":
193 return None
194 # int and big int values
195 elif byte in (b":", b"("):
196 return int(response)
197 # double value
198 elif byte == b",":
199 return float(response)
200 # bool value
201 elif byte == b"#":
202 return response == b"t"
203 # bulk response
204 elif byte == b"$":
205 response = await self._read(int(response))
206 # verbatim string response
207 elif byte == b"=":
208 response = (await self._read(int(response)))[4:]
209 # array response
210 elif byte == b"*":
211 response = [
212 (await self._read_response(disable_decoding=disable_decoding))
213 for _ in range(int(response))
214 ]
215 # set response
216 elif byte == b"~":
217 # redis can return unhashable types (like dict) in a set,
218 # so we always convert to a list, to have predictable return types
219 response = [
220 (await self._read_response(disable_decoding=disable_decoding))
221 for _ in range(int(response))
222 ]
223 # map response
224 elif byte == b"%":
225 # We cannot use a dict-comprehension to parse stream.
226 # Evaluation order of key:val expression in dict comprehension only
227 # became defined to be left-right in version 3.8
228 resp_dict = {}
229 for _ in range(int(response)):
230 key = await self._read_response(disable_decoding=disable_decoding)
231 resp_dict[key] = await self._read_response(
232 disable_decoding=disable_decoding, push_request=push_request
233 )
234 response = resp_dict
235 # push response
236 elif byte == b">":
237 response = [
238 (
239 await self._read_response(
240 disable_decoding=disable_decoding, push_request=push_request
241 )
242 )
243 for _ in range(int(response))
244 ]
245 response = await self.handle_push_response(response)
246 if not push_request:
247 return await self._read_response(
248 disable_decoding=disable_decoding, push_request=push_request
249 )
250 else:
251 return response
252 else:
253 raise InvalidResponse(f"Protocol Error: {raw!r}")
254
255 if isinstance(response, bytes) and disable_decoding is False:
256 response = self.encoder.decode(response)
257 return response