1import asyncio
2import socket
3import sys
4from logging import getLogger
5from typing import Callable, List, Optional, TypedDict, Union
6
7if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
8 from asyncio import timeout as async_timeout
9else:
10 from async_timeout import timeout as async_timeout
11
12from ..exceptions import ConnectionError, InvalidResponse, RedisError
13from ..typing import EncodableT
14from ..utils import HIREDIS_AVAILABLE
15from .base import (
16 AsyncBaseParser,
17 AsyncPushNotificationsParser,
18 BaseParser,
19 PushNotificationsParser,
20)
21from .socket import (
22 NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
23 NONBLOCKING_EXCEPTIONS,
24 SENTINEL,
25 SERVER_CLOSED_CONNECTION_ERROR,
26)
27
28# Used to signal that hiredis-py does not have enough data to parse.
29# Using `False` or `None` is not reliable, given that the parser can
30# return `False` or `None` for legitimate reasons from RESP payloads.
31NOT_ENOUGH_DATA = object()
32
33
34class _HiredisReaderArgs(TypedDict, total=False):
35 protocolError: Callable[[str], Exception]
36 replyError: Callable[[str], Exception]
37 encoding: Optional[str]
38 errors: Optional[str]
39
40
41class _HiredisParser(BaseParser, PushNotificationsParser):
42 "Parser class for connections using Hiredis"
43
44 def __init__(self, socket_read_size):
45 if not HIREDIS_AVAILABLE:
46 raise RedisError("Hiredis is not installed")
47 self.socket_read_size = socket_read_size
48 self._buffer = bytearray(socket_read_size)
49 self.pubsub_push_handler_func = self.handle_pubsub_push_response
50 self.node_moving_push_handler_func = None
51 self.maintenance_push_handler_func = None
52 self.invalidation_push_handler_func = None
53 self._hiredis_PushNotificationType = None
54
55 def __del__(self):
56 try:
57 self.on_disconnect()
58 except Exception:
59 pass
60
61 def handle_pubsub_push_response(self, response):
62 logger = getLogger("push_response")
63 logger.debug("Push response: " + str(response))
64 return response
65
66 def on_connect(self, connection, **kwargs):
67 import hiredis
68
69 self._sock = connection._sock
70 self._socket_timeout = connection.socket_timeout
71 kwargs = {
72 "protocolError": InvalidResponse,
73 "replyError": self.parse_error,
74 "errors": connection.encoder.encoding_errors,
75 "notEnoughData": NOT_ENOUGH_DATA,
76 }
77
78 if connection.encoder.decode_responses:
79 kwargs["encoding"] = connection.encoder.encoding
80 self._reader = hiredis.Reader(**kwargs)
81 self._next_response = NOT_ENOUGH_DATA
82
83 try:
84 self._hiredis_PushNotificationType = hiredis.PushNotification
85 except AttributeError:
86 # hiredis < 3.2
87 self._hiredis_PushNotificationType = None
88
89 def on_disconnect(self):
90 self._sock = None
91 self._reader = None
92 self._next_response = NOT_ENOUGH_DATA
93
94 def can_read(self, timeout):
95 if not self._reader:
96 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
97
98 if self._next_response is NOT_ENOUGH_DATA:
99 self._next_response = self._reader.gets()
100 if self._next_response is NOT_ENOUGH_DATA:
101 return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
102 return True
103
104 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
105 sock = self._sock
106 custom_timeout = timeout is not SENTINEL
107 try:
108 if custom_timeout:
109 sock.settimeout(timeout)
110 bufflen = self._sock.recv_into(self._buffer)
111 if bufflen == 0:
112 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
113 self._reader.feed(self._buffer, 0, bufflen)
114 # data was read from the socket and added to the buffer.
115 # return True to indicate that data was read.
116 return True
117 except socket.timeout:
118 if raise_on_timeout:
119 raise TimeoutError("Timeout reading from socket")
120 return False
121 except NONBLOCKING_EXCEPTIONS as ex:
122 # if we're in nonblocking mode and the recv raises a
123 # blocking error, simply return False indicating that
124 # there's no data to be read. otherwise raise the
125 # original exception.
126 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
127 if not raise_on_timeout and ex.errno == allowed:
128 return False
129 raise ConnectionError(f"Error while reading from socket: {ex.args}")
130 finally:
131 if custom_timeout:
132 sock.settimeout(self._socket_timeout)
133
134 def read_response(self, disable_decoding=False, push_request=False):
135 if not self._reader:
136 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
137
138 # _next_response might be cached from a can_read() call
139 if self._next_response is not NOT_ENOUGH_DATA:
140 response = self._next_response
141 self._next_response = NOT_ENOUGH_DATA
142 if self._hiredis_PushNotificationType is not None and isinstance(
143 response, self._hiredis_PushNotificationType
144 ):
145 response = self.handle_push_response(response)
146
147 # if this is a push request return the push response
148 if push_request:
149 return response
150
151 return self.read_response(
152 disable_decoding=disable_decoding,
153 push_request=push_request,
154 )
155 return response
156
157 if disable_decoding:
158 response = self._reader.gets(False)
159 else:
160 response = self._reader.gets()
161
162 while response is NOT_ENOUGH_DATA:
163 self.read_from_socket()
164 if disable_decoding:
165 response = self._reader.gets(False)
166 else:
167 response = self._reader.gets()
168 # if the response is a ConnectionError or the response is a list and
169 # the first item is a ConnectionError, raise it as something bad
170 # happened
171 if isinstance(response, ConnectionError):
172 raise response
173 elif self._hiredis_PushNotificationType is not None and isinstance(
174 response, self._hiredis_PushNotificationType
175 ):
176 response = self.handle_push_response(response)
177 if push_request:
178 return response
179 return self.read_response(
180 disable_decoding=disable_decoding,
181 push_request=push_request,
182 )
183
184 elif (
185 isinstance(response, list)
186 and response
187 and isinstance(response[0], ConnectionError)
188 ):
189 raise response[0]
190 return response
191
192
193class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
194 """Async implementation of parser class for connections using Hiredis"""
195
196 __slots__ = ("_reader",)
197
198 def __init__(self, socket_read_size: int):
199 if not HIREDIS_AVAILABLE:
200 raise RedisError("Hiredis is not available.")
201 super().__init__(socket_read_size=socket_read_size)
202 self._reader = None
203 self.pubsub_push_handler_func = self.handle_pubsub_push_response
204 self.invalidation_push_handler_func = None
205 self._hiredis_PushNotificationType = None
206
207 async def handle_pubsub_push_response(self, response):
208 logger = getLogger("push_response")
209 logger.debug("Push response: " + str(response))
210 return response
211
212 def on_connect(self, connection):
213 import hiredis
214
215 self._stream = connection._reader
216 kwargs: _HiredisReaderArgs = {
217 "protocolError": InvalidResponse,
218 "replyError": self.parse_error,
219 "notEnoughData": NOT_ENOUGH_DATA,
220 }
221 if connection.encoder.decode_responses:
222 kwargs["encoding"] = connection.encoder.encoding
223 kwargs["errors"] = connection.encoder.encoding_errors
224
225 self._reader = hiredis.Reader(**kwargs)
226 self._connected = True
227
228 try:
229 self._hiredis_PushNotificationType = getattr(
230 hiredis, "PushNotification", None
231 )
232 except AttributeError:
233 # hiredis < 3.2
234 self._hiredis_PushNotificationType = None
235
236 def on_disconnect(self):
237 self._connected = False
238
239 async def can_read_destructive(self):
240 if not self._connected:
241 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
242 if self._reader.gets() is not NOT_ENOUGH_DATA:
243 return True
244 try:
245 async with async_timeout(0):
246 return await self.read_from_socket()
247 except asyncio.TimeoutError:
248 return False
249
250 async def read_from_socket(self):
251 buffer = await self._stream.read(self._read_size)
252 if not buffer or not isinstance(buffer, bytes):
253 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
254 self._reader.feed(buffer)
255 # data was read from the socket and added to the buffer.
256 # return True to indicate that data was read.
257 return True
258
259 async def read_response(
260 self, disable_decoding: bool = False, push_request: bool = False
261 ) -> Union[EncodableT, List[EncodableT]]:
262 # If `on_disconnect()` has been called, prohibit any more reads
263 # even if they could happen because data might be present.
264 # We still allow reads in progress to finish
265 if not self._connected:
266 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
267
268 if disable_decoding:
269 response = self._reader.gets(False)
270 else:
271 response = self._reader.gets()
272
273 while response is NOT_ENOUGH_DATA:
274 await self.read_from_socket()
275 if disable_decoding:
276 response = self._reader.gets(False)
277 else:
278 response = self._reader.gets()
279
280 # if the response is a ConnectionError or the response is a list and
281 # the first item is a ConnectionError, raise it as something bad
282 # happened
283 if isinstance(response, ConnectionError):
284 raise response
285 elif self._hiredis_PushNotificationType is not None and isinstance(
286 response, self._hiredis_PushNotificationType
287 ):
288 response = await self.handle_push_response(response)
289 if not push_request:
290 return await self.read_response(
291 disable_decoding=disable_decoding, push_request=push_request
292 )
293 else:
294 return response
295 elif (
296 isinstance(response, list)
297 and response
298 and isinstance(response[0], ConnectionError)
299 ):
300 raise response[0]
301 return response