1import enum
2import ipaddress
3import logging
4import re
5import threading
6import time
7from abc import ABC, abstractmethod
8from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
9
10from redis.typing import Number
11
12if TYPE_CHECKING:
13 from redis.cluster import MaintNotificationsAbstractRedisCluster
14
15logger = logging.getLogger(__name__)
16
17
18class MaintenanceState(enum.Enum):
19 NONE = "none"
20 MOVING = "moving"
21 MAINTENANCE = "maintenance"
22
23
24class EndpointType(enum.Enum):
25 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
26
27 INTERNAL_IP = "internal-ip"
28 INTERNAL_FQDN = "internal-fqdn"
29 EXTERNAL_IP = "external-ip"
30 EXTERNAL_FQDN = "external-fqdn"
31 NONE = "none"
32
33 def __str__(self):
34 """Return the string value of the enum."""
35 return self.value
36
37
38if TYPE_CHECKING:
39 from redis.connection import (
40 MaintNotificationsAbstractConnection,
41 MaintNotificationsAbstractConnectionPool,
42 )
43
44
45class MaintenanceNotification(ABC):
46 """
47 Base class for maintenance notifications sent through push messages by Redis server.
48
49 This class provides common functionality for all maintenance notifications including
50 unique identification and TTL (Time-To-Live) functionality.
51
52 Attributes:
53 id (int): Unique identifier for this notification
54 ttl (int): Time-to-live in seconds for this notification
55 creation_time (float): Timestamp when the notification was created/read
56 """
57
58 def __init__(self, id: int, ttl: int):
59 """
60 Initialize a new MaintenanceNotification with unique ID and TTL functionality.
61
62 Args:
63 id (int): Unique identifier for this notification
64 ttl (int): Time-to-live in seconds for this notification
65 """
66 self.id = id
67 self.ttl = ttl
68 self.creation_time = time.monotonic()
69 self.expire_at = self.creation_time + self.ttl
70
71 def is_expired(self) -> bool:
72 """
73 Check if this notification has expired based on its TTL
74 and creation time.
75
76 Returns:
77 bool: True if the notification has expired, False otherwise
78 """
79 return time.monotonic() > (self.creation_time + self.ttl)
80
81 @abstractmethod
82 def __repr__(self) -> str:
83 """
84 Return a string representation of the maintenance notification.
85
86 This method must be implemented by all concrete subclasses.
87
88 Returns:
89 str: String representation of the notification
90 """
91 pass
92
93 @abstractmethod
94 def __eq__(self, other) -> bool:
95 """
96 Compare two maintenance notifications for equality.
97
98 This method must be implemented by all concrete subclasses.
99 Notifications are typically considered equal if they have the same id
100 and are of the same type.
101
102 Args:
103 other: The other object to compare with
104
105 Returns:
106 bool: True if the notifications are equal, False otherwise
107 """
108 pass
109
110 @abstractmethod
111 def __hash__(self) -> int:
112 """
113 Return a hash value for the maintenance notification.
114
115 This method must be implemented by all concrete subclasses to allow
116 instances to be used in sets and as dictionary keys.
117
118 Returns:
119 int: Hash value for the notification
120 """
121 pass
122
123
124class NodeMovingNotification(MaintenanceNotification):
125 """
126 This notification is received when a node is replaced with a new node
127 during cluster rebalancing or maintenance operations.
128 """
129
130 def __init__(
131 self,
132 id: int,
133 new_node_host: Optional[str],
134 new_node_port: Optional[int],
135 ttl: int,
136 ):
137 """
138 Initialize a new NodeMovingNotification.
139
140 Args:
141 id (int): Unique identifier for this notification
142 new_node_host (str): Hostname or IP address of the new replacement node
143 new_node_port (int): Port number of the new replacement node
144 ttl (int): Time-to-live in seconds for this notification
145 """
146 super().__init__(id, ttl)
147 self.new_node_host = new_node_host
148 self.new_node_port = new_node_port
149
150 def __repr__(self) -> str:
151 expiry_time = self.expire_at
152 remaining = max(0, expiry_time - time.monotonic())
153
154 return (
155 f"{self.__class__.__name__}("
156 f"id={self.id}, "
157 f"new_node_host='{self.new_node_host}', "
158 f"new_node_port={self.new_node_port}, "
159 f"ttl={self.ttl}, "
160 f"creation_time={self.creation_time}, "
161 f"expires_at={expiry_time}, "
162 f"remaining={remaining:.1f}s, "
163 f"expired={self.is_expired()}"
164 f")"
165 )
166
167 def __eq__(self, other) -> bool:
168 """
169 Two NodeMovingNotification notifications are considered equal if they have the same
170 id, new_node_host, and new_node_port.
171 """
172 if not isinstance(other, NodeMovingNotification):
173 return False
174 return (
175 self.id == other.id
176 and self.new_node_host == other.new_node_host
177 and self.new_node_port == other.new_node_port
178 )
179
180 def __hash__(self) -> int:
181 """
182 Return a hash value for the notification to allow
183 instances to be used in sets and as dictionary keys.
184
185 Returns:
186 int: Hash value based on notification type class name, id,
187 new_node_host and new_node_port
188 """
189 try:
190 node_port = int(self.new_node_port) if self.new_node_port else None
191 except ValueError:
192 node_port = 0
193
194 return hash(
195 (
196 self.__class__.__name__,
197 int(self.id),
198 str(self.new_node_host),
199 node_port,
200 )
201 )
202
203
204class NodeMigratingNotification(MaintenanceNotification):
205 """
206 Notification for when a Redis cluster node is in the process of migrating slots.
207
208 This notification is received when a node starts migrating its slots to another node
209 during cluster rebalancing or maintenance operations.
210
211 Args:
212 id (int): Unique identifier for this notification
213 ttl (int): Time-to-live in seconds for this notification
214 """
215
216 def __init__(self, id: int, ttl: int):
217 super().__init__(id, ttl)
218
219 def __repr__(self) -> str:
220 expiry_time = self.creation_time + self.ttl
221 remaining = max(0, expiry_time - time.monotonic())
222 return (
223 f"{self.__class__.__name__}("
224 f"id={self.id}, "
225 f"ttl={self.ttl}, "
226 f"creation_time={self.creation_time}, "
227 f"expires_at={expiry_time}, "
228 f"remaining={remaining:.1f}s, "
229 f"expired={self.is_expired()}"
230 f")"
231 )
232
233 def __eq__(self, other) -> bool:
234 """
235 Two NodeMigratingNotification notifications are considered equal if they have the same
236 id and are of the same type.
237 """
238 if not isinstance(other, NodeMigratingNotification):
239 return False
240 return self.id == other.id and type(self) is type(other)
241
242 def __hash__(self) -> int:
243 """
244 Return a hash value for the notification to allow
245 instances to be used in sets and as dictionary keys.
246
247 Returns:
248 int: Hash value based on notification type and id
249 """
250 return hash((self.__class__.__name__, int(self.id)))
251
252
253class NodeMigratedNotification(MaintenanceNotification):
254 """
255 Notification for when a Redis cluster node has completed migrating slots.
256
257 This notification is received when a node has finished migrating all its slots
258 to other nodes during cluster rebalancing or maintenance operations.
259
260 Args:
261 id (int): Unique identifier for this notification
262 """
263
264 DEFAULT_TTL = 5
265
266 def __init__(self, id: int):
267 super().__init__(id, NodeMigratedNotification.DEFAULT_TTL)
268
269 def __repr__(self) -> str:
270 expiry_time = self.creation_time + self.ttl
271 remaining = max(0, expiry_time - time.monotonic())
272 return (
273 f"{self.__class__.__name__}("
274 f"id={self.id}, "
275 f"ttl={self.ttl}, "
276 f"creation_time={self.creation_time}, "
277 f"expires_at={expiry_time}, "
278 f"remaining={remaining:.1f}s, "
279 f"expired={self.is_expired()}"
280 f")"
281 )
282
283 def __eq__(self, other) -> bool:
284 """
285 Two NodeMigratedNotification notifications are considered equal if they have the same
286 id and are of the same type.
287 """
288 if not isinstance(other, NodeMigratedNotification):
289 return False
290 return self.id == other.id and type(self) is type(other)
291
292 def __hash__(self) -> int:
293 """
294 Return a hash value for the notification to allow
295 instances to be used in sets and as dictionary keys.
296
297 Returns:
298 int: Hash value based on notification type and id
299 """
300 return hash((self.__class__.__name__, int(self.id)))
301
302
303class NodeFailingOverNotification(MaintenanceNotification):
304 """
305 Notification for when a Redis cluster node is in the process of failing over.
306
307 This notification is received when a node starts a failover process during
308 cluster maintenance operations or when handling node failures.
309
310 Args:
311 id (int): Unique identifier for this notification
312 ttl (int): Time-to-live in seconds for this notification
313 """
314
315 def __init__(self, id: int, ttl: int):
316 super().__init__(id, ttl)
317
318 def __repr__(self) -> str:
319 expiry_time = self.creation_time + self.ttl
320 remaining = max(0, expiry_time - time.monotonic())
321 return (
322 f"{self.__class__.__name__}("
323 f"id={self.id}, "
324 f"ttl={self.ttl}, "
325 f"creation_time={self.creation_time}, "
326 f"expires_at={expiry_time}, "
327 f"remaining={remaining:.1f}s, "
328 f"expired={self.is_expired()}"
329 f")"
330 )
331
332 def __eq__(self, other) -> bool:
333 """
334 Two NodeFailingOverNotification notifications are considered equal if they have the same
335 id and are of the same type.
336 """
337 if not isinstance(other, NodeFailingOverNotification):
338 return False
339 return self.id == other.id and type(self) is type(other)
340
341 def __hash__(self) -> int:
342 """
343 Return a hash value for the notification to allow
344 instances to be used in sets and as dictionary keys.
345
346 Returns:
347 int: Hash value based on notification type and id
348 """
349 return hash((self.__class__.__name__, int(self.id)))
350
351
352class NodeFailedOverNotification(MaintenanceNotification):
353 """
354 Notification for when a Redis cluster node has completed a failover.
355
356 This notification is received when a node has finished the failover process
357 during cluster maintenance operations or after handling node failures.
358
359 Args:
360 id (int): Unique identifier for this notification
361 """
362
363 DEFAULT_TTL = 5
364
365 def __init__(self, id: int):
366 super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL)
367
368 def __repr__(self) -> str:
369 expiry_time = self.creation_time + self.ttl
370 remaining = max(0, expiry_time - time.monotonic())
371 return (
372 f"{self.__class__.__name__}("
373 f"id={self.id}, "
374 f"ttl={self.ttl}, "
375 f"creation_time={self.creation_time}, "
376 f"expires_at={expiry_time}, "
377 f"remaining={remaining:.1f}s, "
378 f"expired={self.is_expired()}"
379 f")"
380 )
381
382 def __eq__(self, other) -> bool:
383 """
384 Two NodeFailedOverNotification notifications are considered equal if they have the same
385 id and are of the same type.
386 """
387 if not isinstance(other, NodeFailedOverNotification):
388 return False
389 return self.id == other.id and type(self) is type(other)
390
391 def __hash__(self) -> int:
392 """
393 Return a hash value for the notification to allow
394 instances to be used in sets and as dictionary keys.
395
396 Returns:
397 int: Hash value based on notification type and id
398 """
399 return hash((self.__class__.__name__, int(self.id)))
400
401
402class OSSNodeMigratingNotification(MaintenanceNotification):
403 """
404 Notification for when a Redis OSS API client is used and a node is in the process of migrating slots.
405
406 This notification is received when a node starts migrating its slots to another node
407 during cluster rebalancing or maintenance operations.
408
409 Args:
410 id (int): Unique identifier for this notification
411 slots (Optional[List[int]]): List of slots being migrated
412 """
413
414 DEFAULT_TTL = 30
415
416 def __init__(
417 self,
418 id: int,
419 slots: Optional[str] = None,
420 ):
421 super().__init__(id, OSSNodeMigratingNotification.DEFAULT_TTL)
422 self.slots = slots
423
424 def __repr__(self) -> str:
425 expiry_time = self.creation_time + self.ttl
426 remaining = max(0, expiry_time - time.monotonic())
427 return (
428 f"{self.__class__.__name__}("
429 f"id={self.id}, "
430 f"slots={self.slots}, "
431 f"ttl={self.ttl}, "
432 f"creation_time={self.creation_time}, "
433 f"expires_at={expiry_time}, "
434 f"remaining={remaining:.1f}s, "
435 f"expired={self.is_expired()}"
436 f")"
437 )
438
439 def __eq__(self, other) -> bool:
440 """
441 Two OSSNodeMigratingNotification notifications are considered equal if they have the same
442 id and are of the same type.
443 """
444 if not isinstance(other, OSSNodeMigratingNotification):
445 return False
446 return self.id == other.id and type(self) is type(other)
447
448 def __hash__(self) -> int:
449 """
450 Return a hash value for the notification to allow
451 instances to be used in sets and as dictionary keys.
452
453 Returns:
454 int: Hash value based on notification type and id
455 """
456 return hash((self.__class__.__name__, int(self.id)))
457
458
459class OSSNodeMigratedNotification(MaintenanceNotification):
460 """
461 Notification for when a Redis OSS API client is used and a node has completed migrating slots.
462
463 This notification is received when a node has finished migrating all its slots
464 to other nodes during cluster rebalancing or maintenance operations.
465
466 Args:
467 id (int): Unique identifier for this notification
468 nodes_to_slots_mapping (Dict[str, List[Dict[str, str]]]): Map of source node address
469 to list of destination mappings. Each destination mapping is a dict with
470 the destination node address as key and the slot range as value.
471
472 Structure example:
473 {
474 "127.0.0.1:6379": [
475 {"127.0.0.1:6380": "1-100"},
476 {"127.0.0.1:6381": "101-200"}
477 ],
478 "127.0.0.1:6382": [
479 {"127.0.0.1:6383": "201-300"}
480 ]
481 }
482
483 Where:
484 - Key (str): Source node address in "host:port" format
485 - Value (List[Dict[str, str]]): List of destination mappings where each dict
486 contains destination node address as key and slot range as value
487 """
488
489 DEFAULT_TTL = 120
490
491 def __init__(
492 self,
493 id: int,
494 nodes_to_slots_mapping: Dict[str, List[Dict[str, str]]],
495 ):
496 super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL)
497 self.nodes_to_slots_mapping = nodes_to_slots_mapping
498
499 def __repr__(self) -> str:
500 expiry_time = self.creation_time + self.ttl
501 remaining = max(0, expiry_time - time.monotonic())
502 return (
503 f"{self.__class__.__name__}("
504 f"id={self.id}, "
505 f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, "
506 f"ttl={self.ttl}, "
507 f"creation_time={self.creation_time}, "
508 f"expires_at={expiry_time}, "
509 f"remaining={remaining:.1f}s, "
510 f"expired={self.is_expired()}"
511 f")"
512 )
513
514 def __eq__(self, other) -> bool:
515 """
516 Two OSSNodeMigratedNotification notifications are considered equal if they have the same
517 id and are of the same type.
518 """
519 if not isinstance(other, OSSNodeMigratedNotification):
520 return False
521 return self.id == other.id and type(self) is type(other)
522
523 def __hash__(self) -> int:
524 """
525 Return a hash value for the notification to allow
526 instances to be used in sets and as dictionary keys.
527
528 Returns:
529 int: Hash value based on notification type and id
530 """
531 return hash((self.__class__.__name__, int(self.id)))
532
533
534def _is_private_fqdn(host: str) -> bool:
535 """
536 Determine if an FQDN is likely to be internal/private.
537
538 This uses heuristics based on RFC 952 and RFC 1123 standards:
539 - .local domains (RFC 6762 - Multicast DNS)
540 - .internal domains (common internal convention)
541 - Single-label hostnames (no dots)
542 - Common internal TLDs
543
544 Args:
545 host (str): The FQDN to check
546
547 Returns:
548 bool: True if the FQDN appears to be internal/private
549 """
550 host_lower = host.lower().rstrip(".")
551
552 # Single-label hostnames (no dots) are typically internal
553 if "." not in host_lower:
554 return True
555
556 # Common internal/private domain patterns
557 internal_patterns = [
558 r"\.local$", # mDNS/Bonjour domains
559 r"\.internal$", # Common internal convention
560 r"\.corp$", # Corporate domains
561 r"\.lan$", # Local area network
562 r"\.intranet$", # Intranet domains
563 r"\.private$", # Private domains
564 ]
565
566 for pattern in internal_patterns:
567 if re.search(pattern, host_lower):
568 return True
569
570 # If none of the internal patterns match, assume it's external
571 return False
572
573
574def add_debug_log_for_notification(
575 connection: "MaintNotificationsAbstractConnection",
576 notification: Union[str, MaintenanceNotification],
577):
578 if logger.isEnabledFor(logging.DEBUG):
579 socket_address = None
580 try:
581 socket_address = (
582 connection._sock.getsockname() if connection._sock else None
583 )
584 socket_address = socket_address[1] if socket_address else None
585 except (AttributeError, OSError):
586 pass
587
588 logger.debug(
589 f"Handling maintenance notification: {notification}, "
590 f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, "
591 f"local socket port: {socket_address}",
592 )
593
594
595class MaintNotificationsConfig:
596 """
597 Configuration class for maintenance notifications handling behaviour. Notifications are received through
598 push notifications.
599
600 This class defines how the Redis client should react to different push notifications
601 such as node moving, migrations, etc. in a Redis cluster.
602
603 """
604
605 def __init__(
606 self,
607 enabled: Union[bool, Literal["auto"]] = "auto",
608 proactive_reconnect: bool = True,
609 relaxed_timeout: Optional[Number] = 10,
610 endpoint_type: Optional[EndpointType] = None,
611 ):
612 """
613 Initialize a new MaintNotificationsConfig.
614
615 Args:
616 enabled (bool | "auto"): Controls maintenance notifications handling behavior.
617 - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
618 otherwise a ResponseError is raised.
619 - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
620 gracefully handled - a warning is logged and normal operation continues.
621 - False: Maintenance notifications are completely disabled.
622 Defaults to "auto".
623 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
624 Defaults to True.
625 relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.
626 If -1 is provided - the relaxed timeout is disabled. Defaults to 20.
627 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
628 If None, the endpoint type will be automatically determined based on the host and TLS configuration.
629 Defaults to None.
630
631 Raises:
632 ValueError: If endpoint_type is provided but is not a valid endpoint type.
633 """
634 self.enabled = enabled
635 self.relaxed_timeout = relaxed_timeout
636 self.proactive_reconnect = proactive_reconnect
637 self.endpoint_type = endpoint_type
638
639 def __repr__(self) -> str:
640 return (
641 f"{self.__class__.__name__}("
642 f"enabled={self.enabled}, "
643 f"proactive_reconnect={self.proactive_reconnect}, "
644 f"relaxed_timeout={self.relaxed_timeout}, "
645 f"endpoint_type={self.endpoint_type!r}"
646 f")"
647 )
648
649 def is_relaxed_timeouts_enabled(self) -> bool:
650 """
651 Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout.
652 If relaxed_timeout is set to None, it will make the operation blocking
653 and waiting until any response is received.
654
655 Returns:
656 True if the relaxed_timeout is enabled, False otherwise.
657 """
658 return self.relaxed_timeout != -1
659
660 def get_endpoint_type(
661 self, host: str, connection: "MaintNotificationsAbstractConnection"
662 ) -> EndpointType:
663 """
664 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
665
666 Logic:
667 1. If endpoint_type is explicitly set, use it
668 2. Otherwise, check the original host from connection.host:
669 - If host is an IP address, use it directly to determine internal-ip vs external-ip
670 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
671
672 Args:
673 host: User provided hostname to analyze
674 connection: The connection object to analyze for endpoint type determination
675
676 Returns:
677 """
678
679 # If endpoint_type is explicitly set, use it
680 if self.endpoint_type is not None:
681 return self.endpoint_type
682
683 # Check if the host is an IP address
684 try:
685 ip_addr = ipaddress.ip_address(host)
686 # Host is an IP address - use it directly
687 is_private = ip_addr.is_private
688 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
689 except ValueError:
690 # Host is an FQDN - need to check resolved IP to determine internal vs external
691 pass
692
693 # Host is an FQDN, get the resolved IP to determine if it's internal or external
694 resolved_ip = connection.get_resolved_ip()
695
696 if resolved_ip:
697 try:
698 ip_addr = ipaddress.ip_address(resolved_ip)
699 is_private = ip_addr.is_private
700 # Use FQDN types since the original host was an FQDN
701 return (
702 EndpointType.INTERNAL_FQDN
703 if is_private
704 else EndpointType.EXTERNAL_FQDN
705 )
706 except ValueError:
707 # This shouldn't happen since we got the IP from the socket, but fallback
708 pass
709
710 # Final fallback: use heuristics on the FQDN itself
711 is_private = _is_private_fqdn(host)
712 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
713
714
715class MaintNotificationsPoolHandler:
716 def __init__(
717 self,
718 pool: "MaintNotificationsAbstractConnectionPool",
719 config: MaintNotificationsConfig,
720 ) -> None:
721 self.pool = pool
722 self.config = config
723 self._processed_notifications = set()
724 self._lock = threading.RLock()
725 self.connection = None
726
727 def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
728 self.connection = connection
729
730 def get_handler_for_connection(self):
731 # Copy all data that should be shared between connections
732 # but each connection should have its own pool handler
733 # since each connection can be in a different state
734 copy = MaintNotificationsPoolHandler(self.pool, self.config)
735 copy._processed_notifications = self._processed_notifications
736 copy._lock = self._lock
737 copy.connection = None
738 return copy
739
740 def remove_expired_notifications(self):
741 with self._lock:
742 for notification in tuple(self._processed_notifications):
743 if notification.is_expired():
744 self._processed_notifications.remove(notification)
745
746 def handle_notification(self, notification: MaintenanceNotification):
747 self.remove_expired_notifications()
748
749 if isinstance(notification, NodeMovingNotification):
750 return self.handle_node_moving_notification(notification)
751 else:
752 logger.error(f"Unhandled notification type: {notification}")
753
754 def handle_node_moving_notification(self, notification: NodeMovingNotification):
755 if (
756 not self.config.proactive_reconnect
757 and not self.config.is_relaxed_timeouts_enabled()
758 ):
759 return
760 with self._lock:
761 if notification in self._processed_notifications:
762 # nothing to do in the connection pool handling
763 # the notification has already been handled or is expired
764 # just return
765 return
766
767 with self.pool._lock:
768 logger.debug(
769 f"Handling node MOVING notification: {notification}, "
770 f"with connection: {self.connection}, connected to ip "
771 f"{self.connection.get_resolved_ip() if self.connection else None}"
772 )
773 if (
774 self.config.proactive_reconnect
775 or self.config.is_relaxed_timeouts_enabled()
776 ):
777 # Get the current connected address - if any
778 # This is the address that is being moved
779 # and we need to handle only connections
780 # connected to the same address
781 moving_address_src = (
782 self.connection.getpeername() if self.connection else None
783 )
784
785 if getattr(self.pool, "set_in_maintenance", False):
786 # Set pool in maintenance mode - executed only if
787 # BlockingConnectionPool is used
788 self.pool.set_in_maintenance(True)
789
790 # Update maintenance state, timeout and optionally host address
791 # connection settings for matching connections
792 self.pool.update_connections_settings(
793 state=MaintenanceState.MOVING,
794 maintenance_notification_hash=hash(notification),
795 relaxed_timeout=self.config.relaxed_timeout,
796 host_address=notification.new_node_host,
797 matching_address=moving_address_src,
798 matching_pattern="connected_address",
799 update_notification_hash=True,
800 include_free_connections=True,
801 )
802
803 if self.config.proactive_reconnect:
804 if notification.new_node_host is not None:
805 self.run_proactive_reconnect(moving_address_src)
806 else:
807 threading.Timer(
808 notification.ttl / 2,
809 self.run_proactive_reconnect,
810 args=(moving_address_src,),
811 ).start()
812
813 # Update config for new connections:
814 # Set state to MOVING
815 # update host
816 # if relax timeouts are enabled - update timeouts
817 kwargs: dict = {
818 "maintenance_state": MaintenanceState.MOVING,
819 "maintenance_notification_hash": hash(notification),
820 }
821 if notification.new_node_host is not None:
822 # the host is not updated if the new node host is None
823 # this happens when the MOVING push notification does not contain
824 # the new node host - in this case we only update the timeouts
825 kwargs.update(
826 {
827 "host": notification.new_node_host,
828 }
829 )
830 if self.config.is_relaxed_timeouts_enabled():
831 kwargs.update(
832 {
833 "socket_timeout": self.config.relaxed_timeout,
834 "socket_connect_timeout": self.config.relaxed_timeout,
835 }
836 )
837 self.pool.update_connection_kwargs(**kwargs)
838
839 if getattr(self.pool, "set_in_maintenance", False):
840 self.pool.set_in_maintenance(False)
841
842 threading.Timer(
843 notification.ttl,
844 self.handle_node_moved_notification,
845 args=(notification,),
846 ).start()
847
848 self._processed_notifications.add(notification)
849
850 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
851 """
852 Run proactive reconnect for the pool.
853 Active connections are marked for reconnect after they complete the current command.
854 Inactive connections are disconnected and will be connected on next use.
855 """
856 with self._lock:
857 with self.pool._lock:
858 # take care for the active connections in the pool
859 # mark them for reconnect after they complete the current command
860 self.pool.update_active_connections_for_reconnect(
861 moving_address_src=moving_address_src,
862 )
863 # take care for the inactive connections in the pool
864 # delete them and create new ones
865 self.pool.disconnect_free_connections(
866 moving_address_src=moving_address_src,
867 )
868
869 def handle_node_moved_notification(self, notification: NodeMovingNotification):
870 """
871 Handle the cleanup after a node moving notification expires.
872 """
873 notification_hash = hash(notification)
874
875 with self._lock:
876 logger.debug(
877 f"Reverting temporary changes related to notification: {notification}, "
878 f"with connection: {self.connection}, connected to ip "
879 f"{self.connection.get_resolved_ip() if self.connection else None}"
880 )
881 # if the current maintenance_notification_hash in kwargs is not matching the notification
882 # it means there has been a new moving notification after this one
883 # and we don't need to revert the kwargs yet
884 if (
885 self.pool.connection_kwargs.get("maintenance_notification_hash")
886 == notification_hash
887 ):
888 orig_host = self.pool.connection_kwargs.get("orig_host_address")
889 orig_socket_timeout = self.pool.connection_kwargs.get(
890 "orig_socket_timeout"
891 )
892 orig_connect_timeout = self.pool.connection_kwargs.get(
893 "orig_socket_connect_timeout"
894 )
895 kwargs: dict = {
896 "maintenance_state": MaintenanceState.NONE,
897 "maintenance_notification_hash": None,
898 "host": orig_host,
899 "socket_timeout": orig_socket_timeout,
900 "socket_connect_timeout": orig_connect_timeout,
901 }
902 self.pool.update_connection_kwargs(**kwargs)
903
904 with self.pool._lock:
905 reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled()
906 reset_host_address = self.config.proactive_reconnect
907
908 self.pool.update_connections_settings(
909 relaxed_timeout=-1,
910 state=MaintenanceState.NONE,
911 maintenance_notification_hash=None,
912 matching_notification_hash=notification_hash,
913 matching_pattern="notification_hash",
914 update_notification_hash=True,
915 reset_relaxed_timeout=reset_relaxed_timeout,
916 reset_host_address=reset_host_address,
917 include_free_connections=True,
918 )
919
920
921class MaintNotificationsConnectionHandler:
922 # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications
923 _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = {
924 NodeMigratingNotification: 1,
925 NodeFailingOverNotification: 1,
926 OSSNodeMigratingNotification: 1,
927 NodeMigratedNotification: 0,
928 NodeFailedOverNotification: 0,
929 OSSNodeMigratedNotification: 0,
930 }
931
932 def __init__(
933 self,
934 connection: "MaintNotificationsAbstractConnection",
935 config: MaintNotificationsConfig,
936 ) -> None:
937 self.connection = connection
938 self.config = config
939
940 def handle_notification(self, notification: MaintenanceNotification):
941 # get the notification type by checking its class in the _NOTIFICATION_TYPES dict
942 notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None)
943
944 if notification_type is None:
945 logger.error(f"Unhandled notification type: {notification}")
946 return
947
948 if notification_type:
949 self.handle_maintenance_start_notification(
950 MaintenanceState.MAINTENANCE, notification
951 )
952 else:
953 self.handle_maintenance_completed_notification()
954
955 def handle_maintenance_start_notification(
956 self, maintenance_state: MaintenanceState, notification: MaintenanceNotification
957 ):
958 add_debug_log_for_notification(self.connection, notification)
959
960 if (
961 self.connection.maintenance_state == MaintenanceState.MOVING
962 or not self.config.is_relaxed_timeouts_enabled()
963 ):
964 return
965
966 self.connection.maintenance_state = maintenance_state
967 self.connection.set_tmp_settings(
968 tmp_relaxed_timeout=self.config.relaxed_timeout
969 )
970 # extend the timeout for all created connections
971 self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
972 if isinstance(notification, OSSNodeMigratingNotification):
973 # add the notification id to the set of processed start maint notifications
974 # this is used to skip the unrelaxing of the timeouts if we have received more than
975 # one start notification before the the final end notification
976 self.connection.add_maint_start_notification(notification.id)
977
978 def handle_maintenance_completed_notification(self):
979 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
980 if (
981 self.connection.maintenance_state == MaintenanceState.MOVING
982 or not self.config.is_relaxed_timeouts_enabled()
983 ):
984 return
985 add_debug_log_for_notification(self.connection, "MAINTENANCE_COMPLETED")
986 self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
987 # Maintenance completed - reset the connection
988 # timeouts by providing -1 as the relaxed timeout
989 self.connection.update_current_socket_timeout(-1)
990 self.connection.maintenance_state = MaintenanceState.NONE
991 # reset the sets that keep track of received start maint
992 # notifications and skipped end maint notifications
993 self.connection.reset_received_notifications()
994
995
996class OSSMaintNotificationsHandler:
997 def __init__(
998 self,
999 cluster_client: "MaintNotificationsAbstractRedisCluster",
1000 config: MaintNotificationsConfig,
1001 ) -> None:
1002 self.cluster_client = cluster_client
1003 self.config = config
1004 self._processed_notifications = set()
1005 self._in_progress = set()
1006 self._lock = threading.RLock()
1007
1008 def get_handler_for_connection(self):
1009 # Copy all data that should be shared between connections
1010 # but each connection should have its own pool handler
1011 # since each connection can be in a different state
1012 copy = OSSMaintNotificationsHandler(self.cluster_client, self.config)
1013 copy._processed_notifications = self._processed_notifications
1014 copy._in_progress = self._in_progress
1015 copy._lock = self._lock
1016 return copy
1017
1018 def remove_expired_notifications(self):
1019 with self._lock:
1020 for notification in tuple(self._processed_notifications):
1021 if notification.is_expired():
1022 self._processed_notifications.remove(notification)
1023
1024 def handle_notification(self, notification: MaintenanceNotification):
1025 if isinstance(notification, OSSNodeMigratedNotification):
1026 self.handle_oss_maintenance_completed_notification(notification)
1027 else:
1028 logger.error(f"Unhandled notification type: {notification}")
1029
1030 def handle_oss_maintenance_completed_notification(
1031 self, notification: OSSNodeMigratedNotification
1032 ):
1033 self.remove_expired_notifications()
1034
1035 with self._lock:
1036 if (
1037 notification in self._in_progress
1038 or notification in self._processed_notifications
1039 ):
1040 # we are already handling this notification or it has already been processed
1041 # we should skip in_progress notification since when we reinitialize the cluster
1042 # we execute a CLUSTER SLOTS command that can use a different connection
1043 # that has also has the notification and we don't want to
1044 # process the same notification twice
1045 return
1046
1047 logger.debug(f"Handling SMIGRATED notification: {notification}")
1048 self._in_progress.add(notification)
1049
1050 # Extract the information about the src and destination nodes that are affected
1051 # by the maintenance. nodes_to_slots_mapping structure:
1052 # {
1053 # "src_host:port": [
1054 # {"dest_host:port": "slot_range"},
1055 # ...
1056 # ],
1057 # ...
1058 # }
1059 additional_startup_nodes_info = []
1060 affected_nodes = set()
1061 for (
1062 src_address,
1063 dest_mappings,
1064 ) in notification.nodes_to_slots_mapping.items():
1065 src_host, src_port = src_address.split(":")
1066 src_node = self.cluster_client.nodes_manager.get_node(
1067 host=src_host, port=src_port
1068 )
1069 if src_node is not None:
1070 affected_nodes.add(src_node)
1071
1072 for dest_mapping in dest_mappings:
1073 for dest_address in dest_mapping.keys():
1074 dest_host, dest_port = dest_address.split(":")
1075 additional_startup_nodes_info.append(
1076 (dest_host, int(dest_port))
1077 )
1078
1079 # Updates the cluster slots cache with the new slots mapping
1080 # This will also update the nodes cache with the new nodes mapping
1081 self.cluster_client.nodes_manager.initialize(
1082 disconnect_startup_nodes_pools=False,
1083 additional_startup_nodes_info=additional_startup_nodes_info,
1084 )
1085
1086 all_nodes = set(affected_nodes)
1087 all_nodes = all_nodes.union(
1088 self.cluster_client.nodes_manager.nodes_cache.values()
1089 )
1090
1091 for current_node in all_nodes:
1092 if current_node.redis_connection is None:
1093 continue
1094 with current_node.redis_connection.connection_pool._lock:
1095 if current_node in affected_nodes:
1096 # mark for reconnect all in use connections to the node - this will force them to
1097 # disconnect after they complete their current commands
1098 # Some of them might be used by sub sub and we don't know which ones - so we disconnect
1099 # all in flight connections after they are done with current command execution
1100 for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
1101 conn.mark_for_reconnect()
1102
1103 if (
1104 current_node
1105 not in self.cluster_client.nodes_manager.nodes_cache.values()
1106 ):
1107 # disconnect all free connections to the node - this node will be dropped
1108 # from the cluster, so we don't need to revert the timeouts
1109 for conn in current_node.redis_connection.connection_pool._get_free_connections():
1110 conn.disconnect()
1111
1112 # mark the notification as processed
1113 self._processed_notifications.add(notification)
1114 self._in_progress.remove(notification)