1import logging
2import sys
3from abc import ABC
4from asyncio import IncompleteReadError, StreamReader, TimeoutError
5from typing import Awaitable, Callable, List, Optional, Protocol, Union
6
7from redis.maintenance_events import (
8 MaintenanceEvent,
9 NodeFailedOverEvent,
10 NodeFailingOverEvent,
11 NodeMigratedEvent,
12 NodeMigratingEvent,
13 NodeMovingEvent,
14)
15
16if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
17 from asyncio import timeout as async_timeout
18else:
19 from async_timeout import timeout as async_timeout
20
21from ..exceptions import (
22 AskError,
23 AuthenticationError,
24 AuthenticationWrongNumberOfArgsError,
25 BusyLoadingError,
26 ClusterCrossSlotError,
27 ClusterDownError,
28 ConnectionError,
29 ExecAbortError,
30 MasterDownError,
31 ModuleError,
32 MovedError,
33 NoPermissionError,
34 NoScriptError,
35 OutOfMemoryError,
36 ReadOnlyError,
37 RedisError,
38 ResponseError,
39 TryAgainError,
40)
41from ..typing import EncodableT
42from .encoders import Encoder
43from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
44
45MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
46NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
47MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
48MODULE_EXPORTS_DATA_TYPES_ERROR = (
49 "Error unloading module: the module "
50 "exports one or more module-side data "
51 "types, can't unload"
52)
53# user send an AUTH cmd to a server without authorization configured
54NO_AUTH_SET_ERROR = {
55 # Redis >= 6.0
56 "AUTH <password> called without any password "
57 "configured for the default user. Are you sure "
58 "your configuration is correct?": AuthenticationError,
59 # Redis < 6.0
60 "Client sent AUTH, but no password is set": AuthenticationError,
61}
62
63logger = logging.getLogger(__name__)
64
65
66class BaseParser(ABC):
67 EXCEPTION_CLASSES = {
68 "ERR": {
69 "max number of clients reached": ConnectionError,
70 "invalid password": AuthenticationError,
71 # some Redis server versions report invalid command syntax
72 # in lowercase
73 "wrong number of arguments "
74 "for 'auth' command": AuthenticationWrongNumberOfArgsError,
75 # some Redis server versions report invalid command syntax
76 # in uppercase
77 "wrong number of arguments "
78 "for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
79 MODULE_LOAD_ERROR: ModuleError,
80 MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
81 NO_SUCH_MODULE_ERROR: ModuleError,
82 MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
83 **NO_AUTH_SET_ERROR,
84 },
85 "OOM": OutOfMemoryError,
86 "WRONGPASS": AuthenticationError,
87 "EXECABORT": ExecAbortError,
88 "LOADING": BusyLoadingError,
89 "NOSCRIPT": NoScriptError,
90 "READONLY": ReadOnlyError,
91 "NOAUTH": AuthenticationError,
92 "NOPERM": NoPermissionError,
93 "ASK": AskError,
94 "TRYAGAIN": TryAgainError,
95 "MOVED": MovedError,
96 "CLUSTERDOWN": ClusterDownError,
97 "CROSSSLOT": ClusterCrossSlotError,
98 "MASTERDOWN": MasterDownError,
99 }
100
101 @classmethod
102 def parse_error(cls, response):
103 "Parse an error response"
104 error_code = response.split(" ")[0]
105 if error_code in cls.EXCEPTION_CLASSES:
106 response = response[len(error_code) + 1 :]
107 exception_class = cls.EXCEPTION_CLASSES[error_code]
108 if isinstance(exception_class, dict):
109 exception_class = exception_class.get(response, ResponseError)
110 return exception_class(response)
111 return ResponseError(response)
112
113 def on_disconnect(self):
114 raise NotImplementedError()
115
116 def on_connect(self, connection):
117 raise NotImplementedError()
118
119
120class _RESPBase(BaseParser):
121 """Base class for sync-based resp parsing"""
122
123 def __init__(self, socket_read_size):
124 self.socket_read_size = socket_read_size
125 self.encoder = None
126 self._sock = None
127 self._buffer = None
128
129 def __del__(self):
130 try:
131 self.on_disconnect()
132 except Exception:
133 pass
134
135 def on_connect(self, connection):
136 "Called when the socket connects"
137 self._sock = connection._sock
138 self._buffer = SocketBuffer(
139 self._sock, self.socket_read_size, connection.socket_timeout
140 )
141 self.encoder = connection.encoder
142
143 def on_disconnect(self):
144 "Called when the socket disconnects"
145 self._sock = None
146 if self._buffer is not None:
147 self._buffer.close()
148 self._buffer = None
149 self.encoder = None
150
151 def can_read(self, timeout):
152 return self._buffer and self._buffer.can_read(timeout)
153
154
155class AsyncBaseParser(BaseParser):
156 """Base parsing class for the python-backed async parser"""
157
158 __slots__ = "_stream", "_read_size"
159
160 def __init__(self, socket_read_size: int):
161 self._stream: Optional[StreamReader] = None
162 self._read_size = socket_read_size
163
164 async def can_read_destructive(self) -> bool:
165 raise NotImplementedError()
166
167 async def read_response(
168 self, disable_decoding: bool = False
169 ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
170 raise NotImplementedError()
171
172
173class MaintenanceNotificationsParser:
174 """Protocol defining maintenance push notification parsing functionality"""
175
176 @staticmethod
177 def parse_maintenance_start_msg(response, notification_type):
178 # Expected message format is: <event_type> <seq_number> <time>
179 id = response[1]
180 ttl = response[2]
181 return notification_type(id, ttl)
182
183 @staticmethod
184 def parse_maintenance_completed_msg(response, notification_type):
185 # Expected message format is: <event_type> <seq_number>
186 id = response[1]
187 return notification_type(id)
188
189 @staticmethod
190 def parse_moving_msg(response):
191 # Expected message format is: MOVING <seq_number> <time> <endpoint>
192 id = response[1]
193 ttl = response[2]
194 if response[3] in [b"null", "null"]:
195 host, port = None, None
196 else:
197 value = response[3]
198 if isinstance(value, bytes):
199 value = value.decode()
200 host, port = value.split(":")
201 port = int(port) if port is not None else None
202
203 return NodeMovingEvent(id, host, port, ttl)
204
205
206_INVALIDATION_MESSAGE = "invalidate"
207_MOVING_MESSAGE = "MOVING"
208_MIGRATING_MESSAGE = "MIGRATING"
209_MIGRATED_MESSAGE = "MIGRATED"
210_FAILING_OVER_MESSAGE = "FAILING_OVER"
211_FAILED_OVER_MESSAGE = "FAILED_OVER"
212
213_MAINTENANCE_MESSAGES = (
214 _MIGRATING_MESSAGE,
215 _MIGRATED_MESSAGE,
216 _FAILING_OVER_MESSAGE,
217 _FAILED_OVER_MESSAGE,
218)
219
220MSG_TYPE_TO_EVENT_PARSER_MAPPING: dict[str, tuple[type[MaintenanceEvent], Callable]] = {
221 _MIGRATING_MESSAGE: (
222 NodeMigratingEvent,
223 MaintenanceNotificationsParser.parse_maintenance_start_msg,
224 ),
225 _MIGRATED_MESSAGE: (
226 NodeMigratedEvent,
227 MaintenanceNotificationsParser.parse_maintenance_completed_msg,
228 ),
229 _FAILING_OVER_MESSAGE: (
230 NodeFailingOverEvent,
231 MaintenanceNotificationsParser.parse_maintenance_start_msg,
232 ),
233 _FAILED_OVER_MESSAGE: (
234 NodeFailedOverEvent,
235 MaintenanceNotificationsParser.parse_maintenance_completed_msg,
236 ),
237 _MOVING_MESSAGE: (
238 NodeMovingEvent,
239 MaintenanceNotificationsParser.parse_moving_msg,
240 ),
241}
242
243
244class PushNotificationsParser(Protocol):
245 """Protocol defining RESP3-specific parsing functionality"""
246
247 pubsub_push_handler_func: Callable
248 invalidation_push_handler_func: Optional[Callable] = None
249 node_moving_push_handler_func: Optional[Callable] = None
250 maintenance_push_handler_func: Optional[Callable] = None
251
252 def handle_pubsub_push_response(self, response):
253 """Handle pubsub push responses"""
254 raise NotImplementedError()
255
256 def handle_push_response(self, response, **kwargs):
257 msg_type = response[0]
258 if isinstance(msg_type, bytes):
259 msg_type = msg_type.decode()
260
261 if msg_type not in (
262 _INVALIDATION_MESSAGE,
263 *_MAINTENANCE_MESSAGES,
264 _MOVING_MESSAGE,
265 ):
266 return self.pubsub_push_handler_func(response)
267
268 try:
269 if (
270 msg_type == _INVALIDATION_MESSAGE
271 and self.invalidation_push_handler_func
272 ):
273 return self.invalidation_push_handler_func(response)
274
275 if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
276 parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]
277
278 notification = parser_function(response)
279 return self.node_moving_push_handler_func(notification)
280
281 if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
282 parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]
283 notification_type = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][0]
284 notification = parser_function(response, notification_type)
285
286 if notification is not None:
287 return self.maintenance_push_handler_func(notification)
288 except Exception as e:
289 logger.error(
290 "Error handling {} message ({}): {}".format(msg_type, response, e)
291 )
292
293 return None
294
295 def set_pubsub_push_handler(self, pubsub_push_handler_func):
296 self.pubsub_push_handler_func = pubsub_push_handler_func
297
298 def set_invalidation_push_handler(self, invalidation_push_handler_func):
299 self.invalidation_push_handler_func = invalidation_push_handler_func
300
301 def set_node_moving_push_handler(self, node_moving_push_handler_func):
302 self.node_moving_push_handler_func = node_moving_push_handler_func
303
304 def set_maintenance_push_handler(self, maintenance_push_handler_func):
305 self.maintenance_push_handler_func = maintenance_push_handler_func
306
307
308class AsyncPushNotificationsParser(Protocol):
309 """Protocol defining async RESP3-specific parsing functionality"""
310
311 pubsub_push_handler_func: Callable
312 invalidation_push_handler_func: Optional[Callable] = None
313 node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
314 maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
315
316 async def handle_pubsub_push_response(self, response):
317 """Handle pubsub push responses asynchronously"""
318 raise NotImplementedError()
319
320 async def handle_push_response(self, response, **kwargs):
321 """Handle push responses asynchronously"""
322
323 msg_type = response[0]
324 if isinstance(msg_type, bytes):
325 msg_type = msg_type.decode()
326
327 if msg_type not in (
328 _INVALIDATION_MESSAGE,
329 *_MAINTENANCE_MESSAGES,
330 _MOVING_MESSAGE,
331 ):
332 return await self.pubsub_push_handler_func(response)
333
334 try:
335 if (
336 msg_type == _INVALIDATION_MESSAGE
337 and self.invalidation_push_handler_func
338 ):
339 return await self.invalidation_push_handler_func(response)
340
341 if isinstance(msg_type, bytes):
342 msg_type = msg_type.decode()
343
344 if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
345 parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]
346 notification = parser_function(response)
347 return await self.node_moving_push_handler_func(notification)
348
349 if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
350 parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]
351 notification_type = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][0]
352 notification = parser_function(response, notification_type)
353
354 if notification is not None:
355 return await self.maintenance_push_handler_func(notification)
356 except Exception as e:
357 logger.error(
358 "Error handling {} message ({}): {}".format(msg_type, response, e)
359 )
360
361 return None
362
363 def set_pubsub_push_handler(self, pubsub_push_handler_func):
364 """Set the pubsub push handler function"""
365 self.pubsub_push_handler_func = pubsub_push_handler_func
366
367 def set_invalidation_push_handler(self, invalidation_push_handler_func):
368 """Set the invalidation push handler function"""
369 self.invalidation_push_handler_func = invalidation_push_handler_func
370
371 def set_node_moving_push_handler(self, node_moving_push_handler_func):
372 self.node_moving_push_handler_func = node_moving_push_handler_func
373
374 def set_maintenance_push_handler(self, maintenance_push_handler_func):
375 self.maintenance_push_handler_func = maintenance_push_handler_func
376
377
378class _AsyncRESPBase(AsyncBaseParser):
379 """Base class for async resp parsing"""
380
381 __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
382
383 def __init__(self, socket_read_size: int):
384 super().__init__(socket_read_size)
385 self.encoder: Optional[Encoder] = None
386 self._buffer = b""
387 self._chunks = []
388 self._pos = 0
389
390 def _clear(self):
391 self._buffer = b""
392 self._chunks.clear()
393
394 def on_connect(self, connection):
395 """Called when the stream connects"""
396 self._stream = connection._reader
397 if self._stream is None:
398 raise RedisError("Buffer is closed.")
399 self.encoder = connection.encoder
400 self._clear()
401 self._connected = True
402
403 def on_disconnect(self):
404 """Called when the stream disconnects"""
405 self._connected = False
406
407 async def can_read_destructive(self) -> bool:
408 if not self._connected:
409 raise RedisError("Buffer is closed.")
410 if self._buffer:
411 return True
412 try:
413 async with async_timeout(0):
414 return self._stream.at_eof()
415 except TimeoutError:
416 return False
417
418 async def _read(self, length: int) -> bytes:
419 """
420 Read `length` bytes of data. These are assumed to be followed
421 by a '\r\n' terminator which is subsequently discarded.
422 """
423 want = length + 2
424 end = self._pos + want
425 if len(self._buffer) >= end:
426 result = self._buffer[self._pos : end - 2]
427 else:
428 tail = self._buffer[self._pos :]
429 try:
430 data = await self._stream.readexactly(want - len(tail))
431 except IncompleteReadError as error:
432 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
433 result = (tail + data)[:-2]
434 self._chunks.append(data)
435 self._pos += want
436 return result
437
438 async def _readline(self) -> bytes:
439 """
440 read an unknown number of bytes up to the next '\r\n'
441 line separator, which is discarded.
442 """
443 found = self._buffer.find(b"\r\n", self._pos)
444 if found >= 0:
445 result = self._buffer[self._pos : found]
446 else:
447 tail = self._buffer[self._pos :]
448 data = await self._stream.readline()
449 if not data.endswith(b"\r\n"):
450 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
451 result = (tail + data)[:-2]
452 self._chunks.append(data)
453 self._pos += len(result) + 2
454 return result