1import sys
2from abc import ABC
3from asyncio import IncompleteReadError, StreamReader, TimeoutError
4from typing import Callable, List, Optional, Protocol, Union
5
6if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
7 from asyncio import timeout as async_timeout
8else:
9 from async_timeout import timeout as async_timeout
10
11from ..exceptions import (
12 AskError,
13 AuthenticationError,
14 AuthenticationWrongNumberOfArgsError,
15 BusyLoadingError,
16 ClusterCrossSlotError,
17 ClusterDownError,
18 ConnectionError,
19 ExecAbortError,
20 MasterDownError,
21 ModuleError,
22 MovedError,
23 NoPermissionError,
24 NoScriptError,
25 OutOfMemoryError,
26 ReadOnlyError,
27 RedisError,
28 ResponseError,
29 TryAgainError,
30)
31from ..typing import EncodableT
32from .encoders import Encoder
33from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
34
35MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
36NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
37MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
38MODULE_EXPORTS_DATA_TYPES_ERROR = (
39 "Error unloading module: the module "
40 "exports one or more module-side data "
41 "types, can't unload"
42)
43# user send an AUTH cmd to a server without authorization configured
44NO_AUTH_SET_ERROR = {
45 # Redis >= 6.0
46 "AUTH <password> called without any password "
47 "configured for the default user. Are you sure "
48 "your configuration is correct?": AuthenticationError,
49 # Redis < 6.0
50 "Client sent AUTH, but no password is set": AuthenticationError,
51}
52
53
54class BaseParser(ABC):
55 EXCEPTION_CLASSES = {
56 "ERR": {
57 "max number of clients reached": ConnectionError,
58 "invalid password": AuthenticationError,
59 # some Redis server versions report invalid command syntax
60 # in lowercase
61 "wrong number of arguments "
62 "for 'auth' command": AuthenticationWrongNumberOfArgsError,
63 # some Redis server versions report invalid command syntax
64 # in uppercase
65 "wrong number of arguments "
66 "for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
67 MODULE_LOAD_ERROR: ModuleError,
68 MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
69 NO_SUCH_MODULE_ERROR: ModuleError,
70 MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
71 **NO_AUTH_SET_ERROR,
72 },
73 "OOM": OutOfMemoryError,
74 "WRONGPASS": AuthenticationError,
75 "EXECABORT": ExecAbortError,
76 "LOADING": BusyLoadingError,
77 "NOSCRIPT": NoScriptError,
78 "READONLY": ReadOnlyError,
79 "NOAUTH": AuthenticationError,
80 "NOPERM": NoPermissionError,
81 "ASK": AskError,
82 "TRYAGAIN": TryAgainError,
83 "MOVED": MovedError,
84 "CLUSTERDOWN": ClusterDownError,
85 "CROSSSLOT": ClusterCrossSlotError,
86 "MASTERDOWN": MasterDownError,
87 }
88
89 @classmethod
90 def parse_error(cls, response):
91 "Parse an error response"
92 error_code = response.split(" ")[0]
93 if error_code in cls.EXCEPTION_CLASSES:
94 response = response[len(error_code) + 1 :]
95 exception_class = cls.EXCEPTION_CLASSES[error_code]
96 if isinstance(exception_class, dict):
97 exception_class = exception_class.get(response, ResponseError)
98 return exception_class(response)
99 return ResponseError(response)
100
101 def on_disconnect(self):
102 raise NotImplementedError()
103
104 def on_connect(self, connection):
105 raise NotImplementedError()
106
107
108class _RESPBase(BaseParser):
109 """Base class for sync-based resp parsing"""
110
111 def __init__(self, socket_read_size):
112 self.socket_read_size = socket_read_size
113 self.encoder = None
114 self._sock = None
115 self._buffer = None
116
117 def __del__(self):
118 try:
119 self.on_disconnect()
120 except Exception:
121 pass
122
123 def on_connect(self, connection):
124 "Called when the socket connects"
125 self._sock = connection._sock
126 self._buffer = SocketBuffer(
127 self._sock, self.socket_read_size, connection.socket_timeout
128 )
129 self.encoder = connection.encoder
130
131 def on_disconnect(self):
132 "Called when the socket disconnects"
133 self._sock = None
134 if self._buffer is not None:
135 self._buffer.close()
136 self._buffer = None
137 self.encoder = None
138
139 def can_read(self, timeout):
140 return self._buffer and self._buffer.can_read(timeout)
141
142
143class AsyncBaseParser(BaseParser):
144 """Base parsing class for the python-backed async parser"""
145
146 __slots__ = "_stream", "_read_size"
147
148 def __init__(self, socket_read_size: int):
149 self._stream: Optional[StreamReader] = None
150 self._read_size = socket_read_size
151
152 async def can_read_destructive(self) -> bool:
153 raise NotImplementedError()
154
155 async def read_response(
156 self, disable_decoding: bool = False
157 ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
158 raise NotImplementedError()
159
160
161_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
162
163
164class PushNotificationsParser(Protocol):
165 """Protocol defining RESP3-specific parsing functionality"""
166
167 pubsub_push_handler_func: Callable
168 invalidation_push_handler_func: Optional[Callable] = None
169
170 def handle_pubsub_push_response(self, response):
171 """Handle pubsub push responses"""
172 raise NotImplementedError()
173
174 def handle_push_response(self, response, **kwargs):
175 if response[0] not in _INVALIDATION_MESSAGE:
176 return self.pubsub_push_handler_func(response)
177 if self.invalidation_push_handler_func:
178 return self.invalidation_push_handler_func(response)
179
180 def set_pubsub_push_handler(self, pubsub_push_handler_func):
181 self.pubsub_push_handler_func = pubsub_push_handler_func
182
183 def set_invalidation_push_handler(self, invalidation_push_handler_func):
184 self.invalidation_push_handler_func = invalidation_push_handler_func
185
186
187class AsyncPushNotificationsParser(Protocol):
188 """Protocol defining async RESP3-specific parsing functionality"""
189
190 pubsub_push_handler_func: Callable
191 invalidation_push_handler_func: Optional[Callable] = None
192
193 async def handle_pubsub_push_response(self, response):
194 """Handle pubsub push responses asynchronously"""
195 raise NotImplementedError()
196
197 async def handle_push_response(self, response, **kwargs):
198 """Handle push responses asynchronously"""
199 if response[0] not in _INVALIDATION_MESSAGE:
200 return await self.pubsub_push_handler_func(response)
201 if self.invalidation_push_handler_func:
202 return await self.invalidation_push_handler_func(response)
203
204 def set_pubsub_push_handler(self, pubsub_push_handler_func):
205 """Set the pubsub push handler function"""
206 self.pubsub_push_handler_func = pubsub_push_handler_func
207
208 def set_invalidation_push_handler(self, invalidation_push_handler_func):
209 """Set the invalidation push handler function"""
210 self.invalidation_push_handler_func = invalidation_push_handler_func
211
212
213class _AsyncRESPBase(AsyncBaseParser):
214 """Base class for async resp parsing"""
215
216 __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
217
218 def __init__(self, socket_read_size: int):
219 super().__init__(socket_read_size)
220 self.encoder: Optional[Encoder] = None
221 self._buffer = b""
222 self._chunks = []
223 self._pos = 0
224
225 def _clear(self):
226 self._buffer = b""
227 self._chunks.clear()
228
229 def on_connect(self, connection):
230 """Called when the stream connects"""
231 self._stream = connection._reader
232 if self._stream is None:
233 raise RedisError("Buffer is closed.")
234 self.encoder = connection.encoder
235 self._clear()
236 self._connected = True
237
238 def on_disconnect(self):
239 """Called when the stream disconnects"""
240 self._connected = False
241
242 async def can_read_destructive(self) -> bool:
243 if not self._connected:
244 raise RedisError("Buffer is closed.")
245 if self._buffer:
246 return True
247 try:
248 async with async_timeout(0):
249 return self._stream.at_eof()
250 except TimeoutError:
251 return False
252
253 async def _read(self, length: int) -> bytes:
254 """
255 Read `length` bytes of data. These are assumed to be followed
256 by a '\r\n' terminator which is subsequently discarded.
257 """
258 want = length + 2
259 end = self._pos + want
260 if len(self._buffer) >= end:
261 result = self._buffer[self._pos : end - 2]
262 else:
263 tail = self._buffer[self._pos :]
264 try:
265 data = await self._stream.readexactly(want - len(tail))
266 except IncompleteReadError as error:
267 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
268 result = (tail + data)[:-2]
269 self._chunks.append(data)
270 self._pos += want
271 return result
272
273 async def _readline(self) -> bytes:
274 """
275 read an unknown number of bytes up to the next '\r\n'
276 line separator, which is discarded.
277 """
278 found = self._buffer.find(b"\r\n", self._pos)
279 if found >= 0:
280 result = self._buffer[self._pos : found]
281 else:
282 tail = self._buffer[self._pos :]
283 data = await self._stream.readline()
284 if not data.endswith(b"\r\n"):
285 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
286 result = (tail + data)[:-2]
287 self._chunks.append(data)
288 self._pos += len(result) + 2
289 return result