1import asyncio
2import socket
3import sys
4from logging import getLogger
5from typing import Callable, List, Optional, TypedDict, Union
6
7if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
8 from asyncio import timeout as async_timeout
9else:
10 from async_timeout import timeout as async_timeout
11
12from ..exceptions import ConnectionError, InvalidResponse, RedisError
13from ..typing import EncodableT
14from ..utils import HIREDIS_AVAILABLE
15from .base import (
16 AsyncBaseParser,
17 AsyncPushNotificationsParser,
18 BaseParser,
19 PushNotificationsParser,
20)
21from .socket import (
22 NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
23 NONBLOCKING_EXCEPTIONS,
24 SENTINEL,
25 SERVER_CLOSED_CONNECTION_ERROR,
26)
27
28# Used to signal that hiredis-py does not have enough data to parse.
29# Using `False` or `None` is not reliable, given that the parser can
30# return `False` or `None` for legitimate reasons from RESP payloads.
31NOT_ENOUGH_DATA = object()
32
33
34class _HiredisReaderArgs(TypedDict, total=False):
35 protocolError: Callable[[str], Exception]
36 replyError: Callable[[str], Exception]
37 encoding: Optional[str]
38 errors: Optional[str]
39
40
41class _HiredisParser(BaseParser, PushNotificationsParser):
42 "Parser class for connections using Hiredis"
43
44 def __init__(self, socket_read_size):
45 if not HIREDIS_AVAILABLE:
46 raise RedisError("Hiredis is not installed")
47 self.socket_read_size = socket_read_size
48 self._buffer = bytearray(socket_read_size)
49 self.pubsub_push_handler_func = self.handle_pubsub_push_response
50 self.node_moving_push_handler_func = None
51 self.maintenance_push_handler_func = None
52 self.oss_cluster_maint_push_handler_func = None
53 self.invalidation_push_handler_func = None
54 self._hiredis_PushNotificationType = None
55
56 def __del__(self):
57 try:
58 self.on_disconnect()
59 except Exception:
60 pass
61
62 def handle_pubsub_push_response(self, response):
63 logger = getLogger("push_response")
64 logger.debug("Push response: " + str(response))
65 return response
66
67 def on_connect(self, connection, **kwargs):
68 import hiredis
69
70 self._sock = connection._sock
71 self._socket_timeout = connection.socket_timeout
72 kwargs = {
73 "protocolError": InvalidResponse,
74 "replyError": self.parse_error,
75 "errors": connection.encoder.encoding_errors,
76 "notEnoughData": NOT_ENOUGH_DATA,
77 }
78
79 if connection.encoder.decode_responses:
80 kwargs["encoding"] = connection.encoder.encoding
81 self._reader = hiredis.Reader(**kwargs)
82 self._next_response = NOT_ENOUGH_DATA
83
84 try:
85 self._hiredis_PushNotificationType = hiredis.PushNotification
86 except AttributeError:
87 # hiredis < 3.2
88 self._hiredis_PushNotificationType = None
89
90 def on_disconnect(self):
91 self._sock = None
92 self._reader = None
93 self._next_response = NOT_ENOUGH_DATA
94
95 def can_read(self, timeout):
96 if not self._reader:
97 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
98
99 if self._next_response is NOT_ENOUGH_DATA:
100 self._next_response = self._reader.gets()
101 if self._next_response is NOT_ENOUGH_DATA:
102 return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
103 return True
104
105 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
106 sock = self._sock
107 custom_timeout = timeout is not SENTINEL
108 try:
109 if custom_timeout:
110 sock.settimeout(timeout)
111 bufflen = self._sock.recv_into(self._buffer)
112 if bufflen == 0:
113 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
114 self._reader.feed(self._buffer, 0, bufflen)
115 # data was read from the socket and added to the buffer.
116 # return True to indicate that data was read.
117 return True
118 except socket.timeout:
119 if raise_on_timeout:
120 raise TimeoutError("Timeout reading from socket")
121 return False
122 except NONBLOCKING_EXCEPTIONS as ex:
123 # if we're in nonblocking mode and the recv raises a
124 # blocking error, simply return False indicating that
125 # there's no data to be read. otherwise raise the
126 # original exception.
127 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
128 if not raise_on_timeout and ex.errno == allowed:
129 return False
130 raise ConnectionError(f"Error while reading from socket: {ex.args}")
131 finally:
132 if custom_timeout:
133 sock.settimeout(self._socket_timeout)
134
135 def read_response(self, disable_decoding=False, push_request=False):
136 if not self._reader:
137 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
138
139 # _next_response might be cached from a can_read() call
140 if self._next_response is not NOT_ENOUGH_DATA:
141 response = self._next_response
142 self._next_response = NOT_ENOUGH_DATA
143 if self._hiredis_PushNotificationType is not None and isinstance(
144 response, self._hiredis_PushNotificationType
145 ):
146 response = self.handle_push_response(response)
147
148 # if this is a push request return the push response
149 if push_request:
150 return response
151
152 return self.read_response(
153 disable_decoding=disable_decoding,
154 push_request=push_request,
155 )
156 return response
157
158 if disable_decoding:
159 response = self._reader.gets(False)
160 else:
161 response = self._reader.gets()
162
163 while response is NOT_ENOUGH_DATA:
164 self.read_from_socket()
165 if disable_decoding:
166 response = self._reader.gets(False)
167 else:
168 response = self._reader.gets()
169 # if the response is a ConnectionError or the response is a list and
170 # the first item is a ConnectionError, raise it as something bad
171 # happened
172 if isinstance(response, ConnectionError):
173 raise response
174 elif self._hiredis_PushNotificationType is not None and isinstance(
175 response, self._hiredis_PushNotificationType
176 ):
177 response = self.handle_push_response(response)
178 if push_request:
179 return response
180 return self.read_response(
181 disable_decoding=disable_decoding,
182 push_request=push_request,
183 )
184
185 elif (
186 isinstance(response, list)
187 and response
188 and isinstance(response[0], ConnectionError)
189 ):
190 raise response[0]
191 return response
192
193
194class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
195 """Async implementation of parser class for connections using Hiredis"""
196
197 __slots__ = ("_reader",)
198
199 def __init__(self, socket_read_size: int):
200 if not HIREDIS_AVAILABLE:
201 raise RedisError("Hiredis is not available.")
202 super().__init__(socket_read_size=socket_read_size)
203 self._reader = None
204 self.pubsub_push_handler_func = self.handle_pubsub_push_response
205 self.invalidation_push_handler_func = None
206 self._hiredis_PushNotificationType = None
207
208 async def handle_pubsub_push_response(self, response):
209 logger = getLogger("push_response")
210 logger.debug("Push response: " + str(response))
211 return response
212
213 def on_connect(self, connection):
214 import hiredis
215
216 self._stream = connection._reader
217 kwargs: _HiredisReaderArgs = {
218 "protocolError": InvalidResponse,
219 "replyError": self.parse_error,
220 "notEnoughData": NOT_ENOUGH_DATA,
221 }
222 if connection.encoder.decode_responses:
223 kwargs["encoding"] = connection.encoder.encoding
224 kwargs["errors"] = connection.encoder.encoding_errors
225
226 self._reader = hiredis.Reader(**kwargs)
227 self._connected = True
228
229 try:
230 self._hiredis_PushNotificationType = getattr(
231 hiredis, "PushNotification", None
232 )
233 except AttributeError:
234 # hiredis < 3.2
235 self._hiredis_PushNotificationType = None
236
237 def on_disconnect(self):
238 self._connected = False
239
240 async def can_read_destructive(self):
241 if not self._connected:
242 raise OSError("Buffer is closed.")
243 if self._reader.gets() is not NOT_ENOUGH_DATA:
244 return True
245 try:
246 async with async_timeout(0):
247 return await self.read_from_socket()
248 except asyncio.TimeoutError:
249 return False
250
251 async def read_from_socket(self):
252 buffer = await self._stream.read(self._read_size)
253 if not buffer or not isinstance(buffer, bytes):
254 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
255 self._reader.feed(buffer)
256 # data was read from the socket and added to the buffer.
257 # return True to indicate that data was read.
258 return True
259
260 async def read_response(
261 self, disable_decoding: bool = False, push_request: bool = False
262 ) -> Union[EncodableT, List[EncodableT]]:
263 # If `on_disconnect()` has been called, prohibit any more reads
264 # even if they could happen because data might be present.
265 # We still allow reads in progress to finish
266 if not self._connected:
267 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
268
269 if disable_decoding:
270 response = self._reader.gets(False)
271 else:
272 response = self._reader.gets()
273
274 while response is NOT_ENOUGH_DATA:
275 await self.read_from_socket()
276 if disable_decoding:
277 response = self._reader.gets(False)
278 else:
279 response = self._reader.gets()
280
281 # if the response is a ConnectionError or the response is a list and
282 # the first item is a ConnectionError, raise it as something bad
283 # happened
284 if isinstance(response, ConnectionError):
285 raise response
286 elif self._hiredis_PushNotificationType is not None and isinstance(
287 response, self._hiredis_PushNotificationType
288 ):
289 response = await self.handle_push_response(response)
290 if not push_request:
291 return await self.read_response(
292 disable_decoding=disable_decoding, push_request=push_request
293 )
294 else:
295 return response
296 elif (
297 isinstance(response, list)
298 and response
299 and isinstance(response[0], ConnectionError)
300 ):
301 raise response[0]
302 return response