1import asyncio
2import socket
3import sys
4from logging import getLogger
5from typing import Callable, List, Optional, TypedDict, Union
6
7if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
8 from asyncio import timeout as async_timeout
9else:
10 from async_timeout import timeout as async_timeout
11
12from ..exceptions import ConnectionError, InvalidResponse, RedisError
13from ..typing import EncodableT
14from ..utils import HIREDIS_AVAILABLE
15from .base import (
16 AsyncBaseParser,
17 AsyncPushNotificationsParser,
18 BaseParser,
19 PushNotificationsParser,
20)
21from .socket import (
22 NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
23 NONBLOCKING_EXCEPTIONS,
24 SENTINEL,
25 SERVER_CLOSED_CONNECTION_ERROR,
26)
27
28# Used to signal that hiredis-py does not have enough data to parse.
29# Using `False` or `None` is not reliable, given that the parser can
30# return `False` or `None` for legitimate reasons from RESP payloads.
31NOT_ENOUGH_DATA = object()
32
33
34class _HiredisReaderArgs(TypedDict, total=False):
35 protocolError: Callable[[str], Exception]
36 replyError: Callable[[str], Exception]
37 encoding: Optional[str]
38 errors: Optional[str]
39
40
41class _HiredisParser(BaseParser, PushNotificationsParser):
42 "Parser class for connections using Hiredis"
43
44 def __init__(self, socket_read_size):
45 if not HIREDIS_AVAILABLE:
46 raise RedisError("Hiredis is not installed")
47 self.socket_read_size = socket_read_size
48 self._buffer = bytearray(socket_read_size)
49 self.pubsub_push_handler_func = self.handle_pubsub_push_response
50 self.invalidation_push_handler_func = None
51 self._hiredis_PushNotificationType = None
52
53 def __del__(self):
54 try:
55 self.on_disconnect()
56 except Exception:
57 pass
58
59 def handle_pubsub_push_response(self, response):
60 logger = getLogger("push_response")
61 logger.debug("Push response: " + str(response))
62 return response
63
64 def on_connect(self, connection, **kwargs):
65 import hiredis
66
67 self._sock = connection._sock
68 self._socket_timeout = connection.socket_timeout
69 kwargs = {
70 "protocolError": InvalidResponse,
71 "replyError": self.parse_error,
72 "errors": connection.encoder.encoding_errors,
73 "notEnoughData": NOT_ENOUGH_DATA,
74 }
75
76 if connection.encoder.decode_responses:
77 kwargs["encoding"] = connection.encoder.encoding
78 self._reader = hiredis.Reader(**kwargs)
79 self._next_response = NOT_ENOUGH_DATA
80
81 try:
82 self._hiredis_PushNotificationType = hiredis.PushNotification
83 except AttributeError:
84 # hiredis < 3.2
85 self._hiredis_PushNotificationType = None
86
87 def on_disconnect(self):
88 self._sock = None
89 self._reader = None
90 self._next_response = NOT_ENOUGH_DATA
91
92 def can_read(self, timeout):
93 if not self._reader:
94 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
95
96 if self._next_response is NOT_ENOUGH_DATA:
97 self._next_response = self._reader.gets()
98 if self._next_response is NOT_ENOUGH_DATA:
99 return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
100 return True
101
102 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
103 sock = self._sock
104 custom_timeout = timeout is not SENTINEL
105 try:
106 if custom_timeout:
107 sock.settimeout(timeout)
108 bufflen = self._sock.recv_into(self._buffer)
109 if bufflen == 0:
110 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
111 self._reader.feed(self._buffer, 0, bufflen)
112 # data was read from the socket and added to the buffer.
113 # return True to indicate that data was read.
114 return True
115 except socket.timeout:
116 if raise_on_timeout:
117 raise TimeoutError("Timeout reading from socket")
118 return False
119 except NONBLOCKING_EXCEPTIONS as ex:
120 # if we're in nonblocking mode and the recv raises a
121 # blocking error, simply return False indicating that
122 # there's no data to be read. otherwise raise the
123 # original exception.
124 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
125 if not raise_on_timeout and ex.errno == allowed:
126 return False
127 raise ConnectionError(f"Error while reading from socket: {ex.args}")
128 finally:
129 if custom_timeout:
130 sock.settimeout(self._socket_timeout)
131
132 def read_response(self, disable_decoding=False, push_request=False):
133 if not self._reader:
134 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
135
136 # _next_response might be cached from a can_read() call
137 if self._next_response is not NOT_ENOUGH_DATA:
138 response = self._next_response
139 self._next_response = NOT_ENOUGH_DATA
140 if self._hiredis_PushNotificationType is not None and isinstance(
141 response, self._hiredis_PushNotificationType
142 ):
143 response = self.handle_push_response(response)
144 if not push_request:
145 return self.read_response(
146 disable_decoding=disable_decoding, push_request=push_request
147 )
148 else:
149 return response
150 return response
151
152 if disable_decoding:
153 response = self._reader.gets(False)
154 else:
155 response = self._reader.gets()
156
157 while response is NOT_ENOUGH_DATA:
158 self.read_from_socket()
159 if disable_decoding:
160 response = self._reader.gets(False)
161 else:
162 response = self._reader.gets()
163 # if the response is a ConnectionError or the response is a list and
164 # the first item is a ConnectionError, raise it as something bad
165 # happened
166 if isinstance(response, ConnectionError):
167 raise response
168 elif self._hiredis_PushNotificationType is not None and isinstance(
169 response, self._hiredis_PushNotificationType
170 ):
171 response = self.handle_push_response(response)
172 if not push_request:
173 return self.read_response(
174 disable_decoding=disable_decoding, push_request=push_request
175 )
176 else:
177 return response
178 elif (
179 isinstance(response, list)
180 and response
181 and isinstance(response[0], ConnectionError)
182 ):
183 raise response[0]
184 return response
185
186
187class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
188 """Async implementation of parser class for connections using Hiredis"""
189
190 __slots__ = ("_reader",)
191
192 def __init__(self, socket_read_size: int):
193 if not HIREDIS_AVAILABLE:
194 raise RedisError("Hiredis is not available.")
195 super().__init__(socket_read_size=socket_read_size)
196 self._reader = None
197 self.pubsub_push_handler_func = self.handle_pubsub_push_response
198 self.invalidation_push_handler_func = None
199 self._hiredis_PushNotificationType = None
200
201 async def handle_pubsub_push_response(self, response):
202 logger = getLogger("push_response")
203 logger.debug("Push response: " + str(response))
204 return response
205
206 def on_connect(self, connection):
207 import hiredis
208
209 self._stream = connection._reader
210 kwargs: _HiredisReaderArgs = {
211 "protocolError": InvalidResponse,
212 "replyError": self.parse_error,
213 "notEnoughData": NOT_ENOUGH_DATA,
214 }
215 if connection.encoder.decode_responses:
216 kwargs["encoding"] = connection.encoder.encoding
217 kwargs["errors"] = connection.encoder.encoding_errors
218
219 self._reader = hiredis.Reader(**kwargs)
220 self._connected = True
221
222 try:
223 self._hiredis_PushNotificationType = getattr(
224 hiredis, "PushNotification", None
225 )
226 except AttributeError:
227 # hiredis < 3.2
228 self._hiredis_PushNotificationType = None
229
230 def on_disconnect(self):
231 self._connected = False
232
233 async def can_read_destructive(self):
234 if not self._connected:
235 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
236 if self._reader.gets() is not NOT_ENOUGH_DATA:
237 return True
238 try:
239 async with async_timeout(0):
240 return await self.read_from_socket()
241 except asyncio.TimeoutError:
242 return False
243
244 async def read_from_socket(self):
245 buffer = await self._stream.read(self._read_size)
246 if not buffer or not isinstance(buffer, bytes):
247 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
248 self._reader.feed(buffer)
249 # data was read from the socket and added to the buffer.
250 # return True to indicate that data was read.
251 return True
252
253 async def read_response(
254 self, disable_decoding: bool = False, push_request: bool = False
255 ) -> Union[EncodableT, List[EncodableT]]:
256 # If `on_disconnect()` has been called, prohibit any more reads
257 # even if they could happen because data might be present.
258 # We still allow reads in progress to finish
259 if not self._connected:
260 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
261
262 if disable_decoding:
263 response = self._reader.gets(False)
264 else:
265 response = self._reader.gets()
266
267 while response is NOT_ENOUGH_DATA:
268 await self.read_from_socket()
269 if disable_decoding:
270 response = self._reader.gets(False)
271 else:
272 response = self._reader.gets()
273
274 # if the response is a ConnectionError or the response is a list and
275 # the first item is a ConnectionError, raise it as something bad
276 # happened
277 if isinstance(response, ConnectionError):
278 raise response
279 elif self._hiredis_PushNotificationType is not None and isinstance(
280 response, self._hiredis_PushNotificationType
281 ):
282 response = await self.handle_push_response(response)
283 if not push_request:
284 return await self.read_response(
285 disable_decoding=disable_decoding, push_request=push_request
286 )
287 else:
288 return response
289 elif (
290 isinstance(response, list)
291 and response
292 and isinstance(response[0], ConnectionError)
293 ):
294 raise response[0]
295 return response