1import logging
2import sys
3from abc import ABC
4from asyncio import IncompleteReadError, StreamReader, TimeoutError
5from typing import Awaitable, Callable, List, Optional, Protocol, Union
6
7from redis.maint_notifications import (
8 MaintenanceNotification,
9 NodeFailedOverNotification,
10 NodeFailingOverNotification,
11 NodeMigratedNotification,
12 NodeMigratingNotification,
13 NodeMovingNotification,
14)
15
16if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
17 from asyncio import timeout as async_timeout
18else:
19 from async_timeout import timeout as async_timeout
20
21from ..exceptions import (
22 AskError,
23 AuthenticationError,
24 AuthenticationWrongNumberOfArgsError,
25 BusyLoadingError,
26 ClusterCrossSlotError,
27 ClusterDownError,
28 ConnectionError,
29 ExecAbortError,
30 MasterDownError,
31 ModuleError,
32 MovedError,
33 NoPermissionError,
34 NoScriptError,
35 OutOfMemoryError,
36 ReadOnlyError,
37 RedisError,
38 ResponseError,
39 TryAgainError,
40)
41from ..typing import EncodableT
42from .encoders import Encoder
43from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
44
45MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
46NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
47MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
48MODULE_EXPORTS_DATA_TYPES_ERROR = (
49 "Error unloading module: the module "
50 "exports one or more module-side data "
51 "types, can't unload"
52)
53# user send an AUTH cmd to a server without authorization configured
54NO_AUTH_SET_ERROR = {
55 # Redis >= 6.0
56 "AUTH <password> called without any password "
57 "configured for the default user. Are you sure "
58 "your configuration is correct?": AuthenticationError,
59 # Redis < 6.0
60 "Client sent AUTH, but no password is set": AuthenticationError,
61}
62
63logger = logging.getLogger(__name__)
64
65
66class BaseParser(ABC):
67 EXCEPTION_CLASSES = {
68 "ERR": {
69 "max number of clients reached": ConnectionError,
70 "invalid password": AuthenticationError,
71 # some Redis server versions report invalid command syntax
72 # in lowercase
73 "wrong number of arguments "
74 "for 'auth' command": AuthenticationWrongNumberOfArgsError,
75 # some Redis server versions report invalid command syntax
76 # in uppercase
77 "wrong number of arguments "
78 "for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
79 MODULE_LOAD_ERROR: ModuleError,
80 MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
81 NO_SUCH_MODULE_ERROR: ModuleError,
82 MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
83 **NO_AUTH_SET_ERROR,
84 },
85 "OOM": OutOfMemoryError,
86 "WRONGPASS": AuthenticationError,
87 "EXECABORT": ExecAbortError,
88 "LOADING": BusyLoadingError,
89 "NOSCRIPT": NoScriptError,
90 "READONLY": ReadOnlyError,
91 "NOAUTH": AuthenticationError,
92 "NOPERM": NoPermissionError,
93 "ASK": AskError,
94 "TRYAGAIN": TryAgainError,
95 "MOVED": MovedError,
96 "CLUSTERDOWN": ClusterDownError,
97 "CROSSSLOT": ClusterCrossSlotError,
98 "MASTERDOWN": MasterDownError,
99 }
100
101 @classmethod
102 def parse_error(cls, response):
103 "Parse an error response"
104 error_code = response.split(" ")[0]
105 if error_code in cls.EXCEPTION_CLASSES:
106 response = response[len(error_code) + 1 :]
107 exception_class = cls.EXCEPTION_CLASSES[error_code]
108 if isinstance(exception_class, dict):
109 exception_class = exception_class.get(response, ResponseError)
110 return exception_class(response)
111 return ResponseError(response)
112
113 def on_disconnect(self):
114 raise NotImplementedError()
115
116 def on_connect(self, connection):
117 raise NotImplementedError()
118
119
120class _RESPBase(BaseParser):
121 """Base class for sync-based resp parsing"""
122
123 def __init__(self, socket_read_size):
124 self.socket_read_size = socket_read_size
125 self.encoder = None
126 self._sock = None
127 self._buffer = None
128
129 def __del__(self):
130 try:
131 self.on_disconnect()
132 except Exception:
133 pass
134
135 def on_connect(self, connection):
136 "Called when the socket connects"
137 self._sock = connection._sock
138 self._buffer = SocketBuffer(
139 self._sock, self.socket_read_size, connection.socket_timeout
140 )
141 self.encoder = connection.encoder
142
143 def on_disconnect(self):
144 "Called when the socket disconnects"
145 self._sock = None
146 if self._buffer is not None:
147 self._buffer.close()
148 self._buffer = None
149 self.encoder = None
150
151 def can_read(self, timeout):
152 return self._buffer and self._buffer.can_read(timeout)
153
154
155class AsyncBaseParser(BaseParser):
156 """Base parsing class for the python-backed async parser"""
157
158 __slots__ = "_stream", "_read_size"
159
160 def __init__(self, socket_read_size: int):
161 self._stream: Optional[StreamReader] = None
162 self._read_size = socket_read_size
163
164 async def can_read_destructive(self) -> bool:
165 raise NotImplementedError()
166
167 async def read_response(
168 self, disable_decoding: bool = False
169 ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
170 raise NotImplementedError()
171
172
173class MaintenanceNotificationsParser:
174 """Protocol defining maintenance push notification parsing functionality"""
175
176 @staticmethod
177 def parse_maintenance_start_msg(response, notification_type):
178 # Expected message format is: <notification_type> <seq_number> <time>
179 id = response[1]
180 ttl = response[2]
181 return notification_type(id, ttl)
182
183 @staticmethod
184 def parse_maintenance_completed_msg(response, notification_type):
185 # Expected message format is: <notification_type> <seq_number>
186 id = response[1]
187 return notification_type(id)
188
189 @staticmethod
190 def parse_moving_msg(response):
191 # Expected message format is: MOVING <seq_number> <time> <endpoint>
192 id = response[1]
193 ttl = response[2]
194 if response[3] is None:
195 host, port = None, None
196 else:
197 value = response[3]
198 if isinstance(value, bytes):
199 value = value.decode()
200 host, port = value.split(":")
201 port = int(port) if port is not None else None
202
203 return NodeMovingNotification(id, host, port, ttl)
204
205
206_INVALIDATION_MESSAGE = "invalidate"
207_MOVING_MESSAGE = "MOVING"
208_MIGRATING_MESSAGE = "MIGRATING"
209_MIGRATED_MESSAGE = "MIGRATED"
210_FAILING_OVER_MESSAGE = "FAILING_OVER"
211_FAILED_OVER_MESSAGE = "FAILED_OVER"
212
213_MAINTENANCE_MESSAGES = (
214 _MIGRATING_MESSAGE,
215 _MIGRATED_MESSAGE,
216 _FAILING_OVER_MESSAGE,
217 _FAILED_OVER_MESSAGE,
218)
219
220MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
221 str, tuple[type[MaintenanceNotification], Callable]
222] = {
223 _MIGRATING_MESSAGE: (
224 NodeMigratingNotification,
225 MaintenanceNotificationsParser.parse_maintenance_start_msg,
226 ),
227 _MIGRATED_MESSAGE: (
228 NodeMigratedNotification,
229 MaintenanceNotificationsParser.parse_maintenance_completed_msg,
230 ),
231 _FAILING_OVER_MESSAGE: (
232 NodeFailingOverNotification,
233 MaintenanceNotificationsParser.parse_maintenance_start_msg,
234 ),
235 _FAILED_OVER_MESSAGE: (
236 NodeFailedOverNotification,
237 MaintenanceNotificationsParser.parse_maintenance_completed_msg,
238 ),
239 _MOVING_MESSAGE: (
240 NodeMovingNotification,
241 MaintenanceNotificationsParser.parse_moving_msg,
242 ),
243}
244
245
246class PushNotificationsParser(Protocol):
247 """Protocol defining RESP3-specific parsing functionality"""
248
249 pubsub_push_handler_func: Callable
250 invalidation_push_handler_func: Optional[Callable] = None
251 node_moving_push_handler_func: Optional[Callable] = None
252 maintenance_push_handler_func: Optional[Callable] = None
253
254 def handle_pubsub_push_response(self, response):
255 """Handle pubsub push responses"""
256 raise NotImplementedError()
257
258 def handle_push_response(self, response, **kwargs):
259 msg_type = response[0]
260 if isinstance(msg_type, bytes):
261 msg_type = msg_type.decode()
262
263 if msg_type not in (
264 _INVALIDATION_MESSAGE,
265 *_MAINTENANCE_MESSAGES,
266 _MOVING_MESSAGE,
267 ):
268 return self.pubsub_push_handler_func(response)
269
270 try:
271 if (
272 msg_type == _INVALIDATION_MESSAGE
273 and self.invalidation_push_handler_func
274 ):
275 return self.invalidation_push_handler_func(response)
276
277 if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
278 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
279 msg_type
280 ][1]
281
282 notification = parser_function(response)
283 return self.node_moving_push_handler_func(notification)
284
285 if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
286 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
287 msg_type
288 ][1]
289 notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
290 msg_type
291 ][0]
292 notification = parser_function(response, notification_type)
293
294 if notification is not None:
295 return self.maintenance_push_handler_func(notification)
296 except Exception as e:
297 logger.error(
298 "Error handling {} message ({}): {}".format(msg_type, response, e)
299 )
300
301 return None
302
303 def set_pubsub_push_handler(self, pubsub_push_handler_func):
304 self.pubsub_push_handler_func = pubsub_push_handler_func
305
306 def set_invalidation_push_handler(self, invalidation_push_handler_func):
307 self.invalidation_push_handler_func = invalidation_push_handler_func
308
309 def set_node_moving_push_handler(self, node_moving_push_handler_func):
310 self.node_moving_push_handler_func = node_moving_push_handler_func
311
312 def set_maintenance_push_handler(self, maintenance_push_handler_func):
313 self.maintenance_push_handler_func = maintenance_push_handler_func
314
315
316class AsyncPushNotificationsParser(Protocol):
317 """Protocol defining async RESP3-specific parsing functionality"""
318
319 pubsub_push_handler_func: Callable
320 invalidation_push_handler_func: Optional[Callable] = None
321 node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
322 maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
323
324 async def handle_pubsub_push_response(self, response):
325 """Handle pubsub push responses asynchronously"""
326 raise NotImplementedError()
327
328 async def handle_push_response(self, response, **kwargs):
329 """Handle push responses asynchronously"""
330
331 msg_type = response[0]
332 if isinstance(msg_type, bytes):
333 msg_type = msg_type.decode()
334
335 if msg_type not in (
336 _INVALIDATION_MESSAGE,
337 *_MAINTENANCE_MESSAGES,
338 _MOVING_MESSAGE,
339 ):
340 return await self.pubsub_push_handler_func(response)
341
342 try:
343 if (
344 msg_type == _INVALIDATION_MESSAGE
345 and self.invalidation_push_handler_func
346 ):
347 return await self.invalidation_push_handler_func(response)
348
349 if isinstance(msg_type, bytes):
350 msg_type = msg_type.decode()
351
352 if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
353 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
354 msg_type
355 ][1]
356 notification = parser_function(response)
357 return await self.node_moving_push_handler_func(notification)
358
359 if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
360 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
361 msg_type
362 ][1]
363 notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
364 msg_type
365 ][0]
366 notification = parser_function(response, notification_type)
367
368 if notification is not None:
369 return await self.maintenance_push_handler_func(notification)
370 except Exception as e:
371 logger.error(
372 "Error handling {} message ({}): {}".format(msg_type, response, e)
373 )
374
375 return None
376
377 def set_pubsub_push_handler(self, pubsub_push_handler_func):
378 """Set the pubsub push handler function"""
379 self.pubsub_push_handler_func = pubsub_push_handler_func
380
381 def set_invalidation_push_handler(self, invalidation_push_handler_func):
382 """Set the invalidation push handler function"""
383 self.invalidation_push_handler_func = invalidation_push_handler_func
384
385 def set_node_moving_push_handler(self, node_moving_push_handler_func):
386 self.node_moving_push_handler_func = node_moving_push_handler_func
387
388 def set_maintenance_push_handler(self, maintenance_push_handler_func):
389 self.maintenance_push_handler_func = maintenance_push_handler_func
390
391
392class _AsyncRESPBase(AsyncBaseParser):
393 """Base class for async resp parsing"""
394
395 __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
396
397 def __init__(self, socket_read_size: int):
398 super().__init__(socket_read_size)
399 self.encoder: Optional[Encoder] = None
400 self._buffer = b""
401 self._chunks = []
402 self._pos = 0
403
404 def _clear(self):
405 self._buffer = b""
406 self._chunks.clear()
407
408 def on_connect(self, connection):
409 """Called when the stream connects"""
410 self._stream = connection._reader
411 if self._stream is None:
412 raise RedisError("Buffer is closed.")
413 self.encoder = connection.encoder
414 self._clear()
415 self._connected = True
416
417 def on_disconnect(self):
418 """Called when the stream disconnects"""
419 self._connected = False
420
421 async def can_read_destructive(self) -> bool:
422 if not self._connected:
423 raise RedisError("Buffer is closed.")
424 if self._buffer:
425 return True
426 try:
427 async with async_timeout(0):
428 return self._stream.at_eof()
429 except TimeoutError:
430 return False
431
432 async def _read(self, length: int) -> bytes:
433 """
434 Read `length` bytes of data. These are assumed to be followed
435 by a '\r\n' terminator which is subsequently discarded.
436 """
437 want = length + 2
438 end = self._pos + want
439 if len(self._buffer) >= end:
440 result = self._buffer[self._pos : end - 2]
441 else:
442 tail = self._buffer[self._pos :]
443 try:
444 data = await self._stream.readexactly(want - len(tail))
445 except IncompleteReadError as error:
446 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
447 result = (tail + data)[:-2]
448 self._chunks.append(data)
449 self._pos += want
450 return result
451
452 async def _readline(self) -> bytes:
453 """
454 read an unknown number of bytes up to the next '\r\n'
455 line separator, which is discarded.
456 """
457 found = self._buffer.find(b"\r\n", self._pos)
458 if found >= 0:
459 result = self._buffer[self._pos : found]
460 else:
461 tail = self._buffer[self._pos :]
462 data = await self._stream.readline()
463 if not data.endswith(b"\r\n"):
464 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
465 result = (tail + data)[:-2]
466 self._chunks.append(data)
467 self._pos += len(result) + 2
468 return result