1import select
2import socket
3from logging import getLogger
4from typing import Callable, List, Optional, TypedDict, Union
5
6from ..exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError
7from ..typing import EncodableT
8from ..utils import HIREDIS_AVAILABLE, SENTINEL, deprecated_function
9from .base import (
10 AsyncBaseParser,
11 AsyncPushNotificationsParser,
12 BaseParser,
13 PushNotificationsParser,
14)
15from .socket import (
16 NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
17 NONBLOCKING_EXCEPTIONS,
18 SERVER_CLOSED_CONNECTION_ERROR,
19)
20
21# Used to signal that hiredis-py does not have enough data to parse.
22# Using `False` or `None` is not reliable, given that the parser can
23# return `False` or `None` for legitimate reasons from RESP payloads.
24NOT_ENOUGH_DATA = object()
25
26
27def _socket_can_read(sock, timeout: float) -> bool:
28 # SSL sockets can have decrypted bytes buffered above the OS socket layer.
29 if hasattr(sock, "pending") and sock.pending():
30 return True
31 return bool(select.select([sock], [], [], timeout)[0])
32
33
34class _HiredisReaderArgs(TypedDict, total=False):
35 protocolError: Callable[[str], Exception]
36 replyError: Callable[[str], Exception]
37 encoding: Optional[str]
38 errors: Optional[str]
39
40
41class _HiredisParser(BaseParser, PushNotificationsParser):
42 "Parser class for connections using Hiredis"
43
44 def __init__(self, socket_read_size):
45 if not HIREDIS_AVAILABLE:
46 raise RedisError("Hiredis is not installed")
47 self.socket_read_size = socket_read_size
48 self._buffer = bytearray(socket_read_size)
49 self.pubsub_push_handler_func = self.handle_pubsub_push_response
50 self.node_moving_push_handler_func = None
51 self.maintenance_push_handler_func = None
52 self.oss_cluster_maint_push_handler_func = None
53 self.invalidation_push_handler_func = None
54 self._hiredis_PushNotificationType = None
55
56 def __del__(self):
57 try:
58 self.on_disconnect()
59 except Exception:
60 pass
61
62 def handle_pubsub_push_response(self, response):
63 logger = getLogger("push_response")
64 logger.debug("Push response: " + str(response))
65 return response
66
67 def on_connect(self, connection, **kwargs):
68 import hiredis
69
70 self._sock = connection._sock
71 self._socket_timeout = connection.socket_timeout
72 kwargs = {
73 "protocolError": InvalidResponse,
74 "replyError": self.parse_error,
75 "errors": connection.encoder.encoding_errors,
76 "notEnoughData": NOT_ENOUGH_DATA,
77 }
78
79 if connection.encoder.decode_responses:
80 kwargs["encoding"] = connection.encoder.encoding
81 self._reader = hiredis.Reader(**kwargs)
82
83 try:
84 self._hiredis_PushNotificationType = hiredis.PushNotification
85 except AttributeError:
86 # hiredis < 3.2
87 self._hiredis_PushNotificationType = None
88
89 def on_disconnect(self):
90 self._sock = None
91 self._reader = None
92
93 def can_read(self, timeout: float = 0) -> bool:
94 # TODO: Rename this API; it detects pending data or dirty/closed
95 # connection state, not only whether application data can be read.
96 if not self._reader:
97 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
98
99 if self._reader.has_data():
100 return True
101 return _socket_can_read(self._sock, timeout)
102
103 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
104 sock = self._sock
105 custom_timeout = timeout is not SENTINEL
106 try:
107 if custom_timeout:
108 sock.settimeout(timeout)
109 bufflen = self._sock.recv_into(self._buffer)
110 if bufflen == 0:
111 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
112 self._reader.feed(self._buffer, 0, bufflen)
113 # data was read from the socket and added to the buffer.
114 # return True to indicate that data was read.
115 return True
116 except socket.timeout:
117 if raise_on_timeout:
118 raise TimeoutError("Timeout reading from socket")
119 return False
120 except NONBLOCKING_EXCEPTIONS as ex:
121 # if we're in nonblocking mode and the recv raises a
122 # blocking error, simply return False indicating that
123 # there's no data to be read. otherwise raise the
124 # original exception.
125 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
126 if ex.errno == allowed:
127 if not raise_on_timeout:
128 return False
129 if timeout == 0:
130 raise TimeoutError("Timeout reading from socket")
131 raise ConnectionError(f"Error while reading from socket: {ex.args}")
132 finally:
133 if custom_timeout:
134 sock.settimeout(self._socket_timeout)
135
136 def read_response(
137 self,
138 disable_decoding=False,
139 push_request=False,
140 timeout: Union[float, object] = SENTINEL,
141 ):
142 if not self._reader:
143 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
144
145 if disable_decoding:
146 response = self._reader.gets(False)
147 else:
148 response = self._reader.gets()
149
150 while response is NOT_ENOUGH_DATA:
151 self.read_from_socket(timeout=timeout)
152 if disable_decoding:
153 response = self._reader.gets(False)
154 else:
155 response = self._reader.gets()
156 # if the response is a ConnectionError or the response is a list and
157 # the first item is a ConnectionError, raise it as something bad
158 # happened
159 if isinstance(response, ConnectionError):
160 raise response
161 elif self._hiredis_PushNotificationType is not None and isinstance(
162 response, self._hiredis_PushNotificationType
163 ):
164 response = self.handle_push_response(response)
165 if push_request:
166 return response
167 return self.read_response(
168 disable_decoding=disable_decoding,
169 push_request=push_request,
170 timeout=timeout,
171 )
172
173 elif (
174 isinstance(response, list)
175 and response
176 and isinstance(response[0], ConnectionError)
177 ):
178 raise response[0]
179 return response
180
181
182class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
183 """Async implementation of parser class for connections using Hiredis"""
184
185 __slots__ = ("_reader",)
186
187 def __init__(self, socket_read_size: int):
188 if not HIREDIS_AVAILABLE:
189 raise RedisError("Hiredis is not available.")
190 super().__init__(socket_read_size=socket_read_size)
191 self._reader = None
192 self.pubsub_push_handler_func = self.handle_pubsub_push_response
193 self.invalidation_push_handler_func = None
194 self._hiredis_PushNotificationType = None
195
196 async def handle_pubsub_push_response(self, response):
197 logger = getLogger("push_response")
198 logger.debug("Push response: " + str(response))
199 return response
200
201 def on_connect(self, connection):
202 import hiredis
203
204 self._stream = connection._reader
205 kwargs: _HiredisReaderArgs = {
206 "protocolError": InvalidResponse,
207 "replyError": self.parse_error,
208 "notEnoughData": NOT_ENOUGH_DATA,
209 }
210 if connection.encoder.decode_responses:
211 kwargs["encoding"] = connection.encoder.encoding
212 kwargs["errors"] = connection.encoder.encoding_errors
213
214 self._reader = hiredis.Reader(**kwargs)
215 self._connected = True
216
217 try:
218 self._hiredis_PushNotificationType = getattr(
219 hiredis, "PushNotification", None
220 )
221 except AttributeError:
222 # hiredis < 3.2
223 self._hiredis_PushNotificationType = None
224
225 def on_disconnect(self):
226 self._connected = False
227
228 @deprecated_function(
229 version="8.0.0", reason="Use can_read() instead", name="can_read_destructive"
230 )
231 async def can_read_destructive(self) -> bool:
232 return await self.can_read()
233
234 async def can_read(self) -> bool:
235 # TODO: Rename this API; it detects pending data or dirty/closed
236 # connection state, not only whether application data can be read.
237 if not self._connected:
238 raise OSError("Buffer is closed.")
239 # EOF means the connection is closed and not safe to reuse.
240 if self._reader.has_data() or self._stream.at_eof():
241 return True
242 # asyncio.StreamReader has no public non-destructive API for checking
243 # buffered bytes. Preserve dirty-connection detection for hiredis; tests
244 # with a real StreamReader guard this private buffer API in CI.
245 return bool(self._stream._buffer)
246
247 async def read_from_socket(self):
248 buffer = await self._stream.read(self._read_size)
249 if not buffer or not isinstance(buffer, bytes):
250 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
251 self._reader.feed(buffer)
252 # data was read from the socket and added to the buffer.
253 # return True to indicate that data was read.
254 return True
255
256 async def read_response(
257 self, disable_decoding: bool = False, push_request: bool = False
258 ) -> Union[EncodableT, List[EncodableT]]:
259 # If `on_disconnect()` has been called, prohibit any more reads
260 # even if they could happen because data might be present.
261 # We still allow reads in progress to finish
262 if not self._connected:
263 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
264
265 if disable_decoding:
266 response = self._reader.gets(False)
267 else:
268 response = self._reader.gets()
269
270 while response is NOT_ENOUGH_DATA:
271 await self.read_from_socket()
272 if disable_decoding:
273 response = self._reader.gets(False)
274 else:
275 response = self._reader.gets()
276
277 # if the response is a ConnectionError or the response is a list and
278 # the first item is a ConnectionError, raise it as something bad
279 # happened
280 if isinstance(response, ConnectionError):
281 raise response
282 elif self._hiredis_PushNotificationType is not None and isinstance(
283 response, self._hiredis_PushNotificationType
284 ):
285 response = await self.handle_push_response(response)
286 if not push_request:
287 return await self.read_response(
288 disable_decoding=disable_decoding, push_request=push_request
289 )
290 else:
291 return response
292 elif (
293 isinstance(response, list)
294 and response
295 and isinstance(response[0], ConnectionError)
296 ):
297 raise response[0]
298 return response