1import asyncio
2import socket
3import sys
4from logging import getLogger
5from typing import Callable, List, Optional, TypedDict, Union
6
7if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
8 from asyncio import timeout as async_timeout
9else:
10 from async_timeout import timeout as async_timeout
11
12from ..exceptions import ConnectionError, InvalidResponse, RedisError
13from ..typing import EncodableT
14from ..utils import HIREDIS_AVAILABLE
15from .base import (
16 AsyncBaseParser,
17 AsyncPushNotificationsParser,
18 BaseParser,
19 PushNotificationsParser,
20)
21from .socket import (
22 NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
23 NONBLOCKING_EXCEPTIONS,
24 SENTINEL,
25 SERVER_CLOSED_CONNECTION_ERROR,
26)
27
28# Used to signal that hiredis-py does not have enough data to parse.
29# Using `False` or `None` is not reliable, given that the parser can
30# return `False` or `None` for legitimate reasons from RESP payloads.
31NOT_ENOUGH_DATA = object()
32
33
34class _HiredisReaderArgs(TypedDict, total=False):
35 protocolError: Callable[[str], Exception]
36 replyError: Callable[[str], Exception]
37 encoding: Optional[str]
38 errors: Optional[str]
39
40
41class _HiredisParser(BaseParser, PushNotificationsParser):
42 "Parser class for connections using Hiredis"
43
44 def __init__(self, socket_read_size):
45 if not HIREDIS_AVAILABLE:
46 raise RedisError("Hiredis is not installed")
47 self.socket_read_size = socket_read_size
48 self._buffer = bytearray(socket_read_size)
49 self.pubsub_push_handler_func = self.handle_pubsub_push_response
50 self.node_moving_push_handler_func = None
51 self.maintenance_push_handler_func = None
52 self.oss_cluster_maint_push_handler_func = None
53 self.invalidation_push_handler_func = None
54 self._hiredis_PushNotificationType = None
55
56 def __del__(self):
57 try:
58 self.on_disconnect()
59 except Exception:
60 pass
61
62 def handle_pubsub_push_response(self, response):
63 logger = getLogger("push_response")
64 logger.debug("Push response: " + str(response))
65 return response
66
67 def on_connect(self, connection, **kwargs):
68 import hiredis
69
70 self._sock = connection._sock
71 self._socket_timeout = connection.socket_timeout
72 kwargs = {
73 "protocolError": InvalidResponse,
74 "replyError": self.parse_error,
75 "errors": connection.encoder.encoding_errors,
76 "notEnoughData": NOT_ENOUGH_DATA,
77 }
78
79 if connection.encoder.decode_responses:
80 kwargs["encoding"] = connection.encoder.encoding
81 self._reader = hiredis.Reader(**kwargs)
82 self._next_response = NOT_ENOUGH_DATA
83
84 try:
85 self._hiredis_PushNotificationType = hiredis.PushNotification
86 except AttributeError:
87 # hiredis < 3.2
88 self._hiredis_PushNotificationType = None
89
90 def on_disconnect(self):
91 self._sock = None
92 self._reader = None
93 self._next_response = NOT_ENOUGH_DATA
94
95 def can_read(self, timeout):
96 if not self._reader:
97 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
98
99 if self._next_response is NOT_ENOUGH_DATA:
100 self._next_response = self._reader.gets()
101 if self._next_response is NOT_ENOUGH_DATA:
102 return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
103 return True
104
105 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
106 sock = self._sock
107 custom_timeout = timeout is not SENTINEL
108 try:
109 if custom_timeout:
110 sock.settimeout(timeout)
111 bufflen = self._sock.recv_into(self._buffer)
112 if bufflen == 0:
113 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
114 self._reader.feed(self._buffer, 0, bufflen)
115 # data was read from the socket and added to the buffer.
116 # return True to indicate that data was read.
117 return True
118 except socket.timeout:
119 if raise_on_timeout:
120 raise TimeoutError("Timeout reading from socket")
121 return False
122 except NONBLOCKING_EXCEPTIONS as ex:
123 # if we're in nonblocking mode and the recv raises a
124 # blocking error, simply return False indicating that
125 # there's no data to be read. otherwise raise the
126 # original exception.
127 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
128 if not raise_on_timeout and ex.errno == allowed:
129 return False
130 raise ConnectionError(f"Error while reading from socket: {ex.args}")
131 finally:
132 if custom_timeout:
133 sock.settimeout(self._socket_timeout)
134
135 def read_response(
136 self,
137 disable_decoding=False,
138 push_request=False,
139 timeout: Union[float, object] = SENTINEL,
140 ):
141 if not self._reader:
142 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
143
144 # _next_response might be cached from a can_read() call
145 if self._next_response is not NOT_ENOUGH_DATA:
146 response = self._next_response
147 self._next_response = NOT_ENOUGH_DATA
148 if self._hiredis_PushNotificationType is not None and isinstance(
149 response, self._hiredis_PushNotificationType
150 ):
151 response = self.handle_push_response(response)
152
153 # if this is a push request return the push response
154 if push_request:
155 return response
156
157 return self.read_response(
158 disable_decoding=disable_decoding,
159 push_request=push_request,
160 timeout=timeout,
161 )
162 return response
163
164 if disable_decoding:
165 response = self._reader.gets(False)
166 else:
167 response = self._reader.gets()
168
169 while response is NOT_ENOUGH_DATA:
170 self.read_from_socket(timeout=timeout)
171 if disable_decoding:
172 response = self._reader.gets(False)
173 else:
174 response = self._reader.gets()
175 # if the response is a ConnectionError or the response is a list and
176 # the first item is a ConnectionError, raise it as something bad
177 # happened
178 if isinstance(response, ConnectionError):
179 raise response
180 elif self._hiredis_PushNotificationType is not None and isinstance(
181 response, self._hiredis_PushNotificationType
182 ):
183 response = self.handle_push_response(response)
184 if push_request:
185 return response
186 return self.read_response(
187 disable_decoding=disable_decoding,
188 push_request=push_request,
189 )
190
191 elif (
192 isinstance(response, list)
193 and response
194 and isinstance(response[0], ConnectionError)
195 ):
196 raise response[0]
197 return response
198
199
200class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
201 """Async implementation of parser class for connections using Hiredis"""
202
203 __slots__ = ("_reader",)
204
205 def __init__(self, socket_read_size: int):
206 if not HIREDIS_AVAILABLE:
207 raise RedisError("Hiredis is not available.")
208 super().__init__(socket_read_size=socket_read_size)
209 self._reader = None
210 self.pubsub_push_handler_func = self.handle_pubsub_push_response
211 self.invalidation_push_handler_func = None
212 self._hiredis_PushNotificationType = None
213
214 async def handle_pubsub_push_response(self, response):
215 logger = getLogger("push_response")
216 logger.debug("Push response: " + str(response))
217 return response
218
219 def on_connect(self, connection):
220 import hiredis
221
222 self._stream = connection._reader
223 kwargs: _HiredisReaderArgs = {
224 "protocolError": InvalidResponse,
225 "replyError": self.parse_error,
226 "notEnoughData": NOT_ENOUGH_DATA,
227 }
228 if connection.encoder.decode_responses:
229 kwargs["encoding"] = connection.encoder.encoding
230 kwargs["errors"] = connection.encoder.encoding_errors
231
232 self._reader = hiredis.Reader(**kwargs)
233 self._connected = True
234
235 try:
236 self._hiredis_PushNotificationType = getattr(
237 hiredis, "PushNotification", None
238 )
239 except AttributeError:
240 # hiredis < 3.2
241 self._hiredis_PushNotificationType = None
242
243 def on_disconnect(self):
244 self._connected = False
245
246 async def can_read_destructive(self):
247 if not self._connected:
248 raise OSError("Buffer is closed.")
249 if self._reader.gets() is not NOT_ENOUGH_DATA:
250 return True
251 try:
252 async with async_timeout(0):
253 return await self.read_from_socket()
254 except asyncio.TimeoutError:
255 return False
256
257 async def read_from_socket(self):
258 buffer = await self._stream.read(self._read_size)
259 if not buffer or not isinstance(buffer, bytes):
260 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
261 self._reader.feed(buffer)
262 # data was read from the socket and added to the buffer.
263 # return True to indicate that data was read.
264 return True
265
266 async def read_response(
267 self, disable_decoding: bool = False, push_request: bool = False
268 ) -> Union[EncodableT, List[EncodableT]]:
269 # If `on_disconnect()` has been called, prohibit any more reads
270 # even if they could happen because data might be present.
271 # We still allow reads in progress to finish
272 if not self._connected:
273 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
274
275 if disable_decoding:
276 response = self._reader.gets(False)
277 else:
278 response = self._reader.gets()
279
280 while response is NOT_ENOUGH_DATA:
281 await self.read_from_socket()
282 if disable_decoding:
283 response = self._reader.gets(False)
284 else:
285 response = self._reader.gets()
286
287 # if the response is a ConnectionError or the response is a list and
288 # the first item is a ConnectionError, raise it as something bad
289 # happened
290 if isinstance(response, ConnectionError):
291 raise response
292 elif self._hiredis_PushNotificationType is not None and isinstance(
293 response, self._hiredis_PushNotificationType
294 ):
295 response = await self.handle_push_response(response)
296 if not push_request:
297 return await self.read_response(
298 disable_decoding=disable_decoding, push_request=push_request
299 )
300 else:
301 return response
302 elif (
303 isinstance(response, list)
304 and response
305 and isinstance(response[0], ConnectionError)
306 ):
307 raise response[0]
308 return response