1import enum
2import ipaddress
3import logging
4import re
5import threading
6import time
7from abc import ABC, abstractmethod
8from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
9
10from redis.observability.attributes import get_pool_name
11from redis.observability.recorder import (
12 record_connection_handoff,
13 record_connection_relaxed_timeout,
14 record_maint_notification_count,
15)
16from redis.typing import Number
17
18if TYPE_CHECKING:
19 from redis.cluster import MaintNotificationsAbstractRedisCluster
20
21logger = logging.getLogger(__name__)
22
23
24class MaintenanceState(enum.Enum):
25 NONE = "none"
26 MOVING = "moving"
27 MAINTENANCE = "maintenance"
28
29
30class EndpointType(enum.Enum):
31 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
32
33 INTERNAL_IP = "internal-ip"
34 INTERNAL_FQDN = "internal-fqdn"
35 EXTERNAL_IP = "external-ip"
36 EXTERNAL_FQDN = "external-fqdn"
37 NONE = "none"
38
39 def __str__(self):
40 """Return the string value of the enum."""
41 return self.value
42
43
44if TYPE_CHECKING:
45 from redis.connection import (
46 MaintNotificationsAbstractConnection,
47 MaintNotificationsAbstractConnectionPool,
48 )
49
50
51class MaintenanceNotification(ABC):
52 """
53 Base class for maintenance notifications sent through push messages by Redis server.
54
55 This class provides common functionality for all maintenance notifications including
56 unique identification and TTL (Time-To-Live) functionality.
57
58 Attributes:
59 id (int): Unique identifier for this notification
60 ttl (int): Time-to-live in seconds for this notification
61 creation_time (float): Timestamp when the notification was created/read
62 """
63
64 def __init__(self, id: int, ttl: int):
65 """
66 Initialize a new MaintenanceNotification with unique ID and TTL functionality.
67
68 Args:
69 id (int): Unique identifier for this notification
70 ttl (int): Time-to-live in seconds for this notification
71 """
72 self.id = id
73 self.ttl = ttl
74 self.creation_time = time.monotonic()
75 self.expire_at = self.creation_time + self.ttl
76
77 def is_expired(self) -> bool:
78 """
79 Check if this notification has expired based on its TTL
80 and creation time.
81
82 Returns:
83 bool: True if the notification has expired, False otherwise
84 """
85 return time.monotonic() > (self.creation_time + self.ttl)
86
87 @abstractmethod
88 def __repr__(self) -> str:
89 """
90 Return a string representation of the maintenance notification.
91
92 This method must be implemented by all concrete subclasses.
93
94 Returns:
95 str: String representation of the notification
96 """
97 pass
98
99 @abstractmethod
100 def __eq__(self, other) -> bool:
101 """
102 Compare two maintenance notifications for equality.
103
104 This method must be implemented by all concrete subclasses.
105 Notifications are typically considered equal if they have the same id
106 and are of the same type.
107
108 Args:
109 other: The other object to compare with
110
111 Returns:
112 bool: True if the notifications are equal, False otherwise
113 """
114 pass
115
116 @abstractmethod
117 def __hash__(self) -> int:
118 """
119 Return a hash value for the maintenance notification.
120
121 This method must be implemented by all concrete subclasses to allow
122 instances to be used in sets and as dictionary keys.
123
124 Returns:
125 int: Hash value for the notification
126 """
127 pass
128
129
130class NodeMovingNotification(MaintenanceNotification):
131 """
132 This notification is received when a node is replaced with a new node
133 during cluster rebalancing or maintenance operations.
134 """
135
136 def __init__(
137 self,
138 id: int,
139 new_node_host: Optional[str],
140 new_node_port: Optional[int],
141 ttl: int,
142 ):
143 """
144 Initialize a new NodeMovingNotification.
145
146 Args:
147 id (int): Unique identifier for this notification
148 new_node_host (str): Hostname or IP address of the new replacement node
149 new_node_port (int): Port number of the new replacement node
150 ttl (int): Time-to-live in seconds for this notification
151 """
152 super().__init__(id, ttl)
153 self.new_node_host = new_node_host
154 self.new_node_port = new_node_port
155
156 def __repr__(self) -> str:
157 expiry_time = self.expire_at
158 remaining = max(0, expiry_time - time.monotonic())
159
160 return (
161 f"{self.__class__.__name__}("
162 f"id={self.id}, "
163 f"new_node_host='{self.new_node_host}', "
164 f"new_node_port={self.new_node_port}, "
165 f"ttl={self.ttl}, "
166 f"creation_time={self.creation_time}, "
167 f"expires_at={expiry_time}, "
168 f"remaining={remaining:.1f}s, "
169 f"expired={self.is_expired()}"
170 f")"
171 )
172
173 def __eq__(self, other) -> bool:
174 """
175 Two NodeMovingNotification notifications are considered equal if they have the same
176 id, new_node_host, and new_node_port.
177 """
178 if not isinstance(other, NodeMovingNotification):
179 return False
180 return (
181 self.id == other.id
182 and self.new_node_host == other.new_node_host
183 and self.new_node_port == other.new_node_port
184 )
185
186 def __hash__(self) -> int:
187 """
188 Return a hash value for the notification to allow
189 instances to be used in sets and as dictionary keys.
190
191 Returns:
192 int: Hash value based on notification type class name, id,
193 new_node_host and new_node_port
194 """
195 try:
196 node_port = int(self.new_node_port) if self.new_node_port else None
197 except ValueError:
198 node_port = 0
199
200 return hash(
201 (
202 self.__class__.__name__,
203 int(self.id),
204 str(self.new_node_host),
205 node_port,
206 )
207 )
208
209
210class NodeMigratingNotification(MaintenanceNotification):
211 """
212 Notification for when a Redis cluster node is in the process of migrating slots.
213
214 This notification is received when a node starts migrating its slots to another node
215 during cluster rebalancing or maintenance operations.
216
217 Args:
218 id (int): Unique identifier for this notification
219 ttl (int): Time-to-live in seconds for this notification
220 """
221
222 def __init__(self, id: int, ttl: int):
223 super().__init__(id, ttl)
224
225 def __repr__(self) -> str:
226 expiry_time = self.creation_time + self.ttl
227 remaining = max(0, expiry_time - time.monotonic())
228 return (
229 f"{self.__class__.__name__}("
230 f"id={self.id}, "
231 f"ttl={self.ttl}, "
232 f"creation_time={self.creation_time}, "
233 f"expires_at={expiry_time}, "
234 f"remaining={remaining:.1f}s, "
235 f"expired={self.is_expired()}"
236 f")"
237 )
238
239 def __eq__(self, other) -> bool:
240 """
241 Two NodeMigratingNotification notifications are considered equal if they have the same
242 id and are of the same type.
243 """
244 if not isinstance(other, NodeMigratingNotification):
245 return False
246 return self.id == other.id and type(self) is type(other)
247
248 def __hash__(self) -> int:
249 """
250 Return a hash value for the notification to allow
251 instances to be used in sets and as dictionary keys.
252
253 Returns:
254 int: Hash value based on notification type and id
255 """
256 return hash((self.__class__.__name__, int(self.id)))
257
258
259class NodeMigratedNotification(MaintenanceNotification):
260 """
261 Notification for when a Redis cluster node has completed migrating slots.
262
263 This notification is received when a node has finished migrating all its slots
264 to other nodes during cluster rebalancing or maintenance operations.
265
266 Args:
267 id (int): Unique identifier for this notification
268 """
269
270 DEFAULT_TTL = 5
271
272 def __init__(self, id: int):
273 super().__init__(id, NodeMigratedNotification.DEFAULT_TTL)
274
275 def __repr__(self) -> str:
276 expiry_time = self.creation_time + self.ttl
277 remaining = max(0, expiry_time - time.monotonic())
278 return (
279 f"{self.__class__.__name__}("
280 f"id={self.id}, "
281 f"ttl={self.ttl}, "
282 f"creation_time={self.creation_time}, "
283 f"expires_at={expiry_time}, "
284 f"remaining={remaining:.1f}s, "
285 f"expired={self.is_expired()}"
286 f")"
287 )
288
289 def __eq__(self, other) -> bool:
290 """
291 Two NodeMigratedNotification notifications are considered equal if they have the same
292 id and are of the same type.
293 """
294 if not isinstance(other, NodeMigratedNotification):
295 return False
296 return self.id == other.id and type(self) is type(other)
297
298 def __hash__(self) -> int:
299 """
300 Return a hash value for the notification to allow
301 instances to be used in sets and as dictionary keys.
302
303 Returns:
304 int: Hash value based on notification type and id
305 """
306 return hash((self.__class__.__name__, int(self.id)))
307
308
309class NodeFailingOverNotification(MaintenanceNotification):
310 """
311 Notification for when a Redis cluster node is in the process of failing over.
312
313 This notification is received when a node starts a failover process during
314 cluster maintenance operations or when handling node failures.
315
316 Args:
317 id (int): Unique identifier for this notification
318 ttl (int): Time-to-live in seconds for this notification
319 """
320
321 def __init__(self, id: int, ttl: int):
322 super().__init__(id, ttl)
323
324 def __repr__(self) -> str:
325 expiry_time = self.creation_time + self.ttl
326 remaining = max(0, expiry_time - time.monotonic())
327 return (
328 f"{self.__class__.__name__}("
329 f"id={self.id}, "
330 f"ttl={self.ttl}, "
331 f"creation_time={self.creation_time}, "
332 f"expires_at={expiry_time}, "
333 f"remaining={remaining:.1f}s, "
334 f"expired={self.is_expired()}"
335 f")"
336 )
337
338 def __eq__(self, other) -> bool:
339 """
340 Two NodeFailingOverNotification notifications are considered equal if they have the same
341 id and are of the same type.
342 """
343 if not isinstance(other, NodeFailingOverNotification):
344 return False
345 return self.id == other.id and type(self) is type(other)
346
347 def __hash__(self) -> int:
348 """
349 Return a hash value for the notification to allow
350 instances to be used in sets and as dictionary keys.
351
352 Returns:
353 int: Hash value based on notification type and id
354 """
355 return hash((self.__class__.__name__, int(self.id)))
356
357
358class NodeFailedOverNotification(MaintenanceNotification):
359 """
360 Notification for when a Redis cluster node has completed a failover.
361
362 This notification is received when a node has finished the failover process
363 during cluster maintenance operations or after handling node failures.
364
365 Args:
366 id (int): Unique identifier for this notification
367 """
368
369 DEFAULT_TTL = 5
370
371 def __init__(self, id: int):
372 super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL)
373
374 def __repr__(self) -> str:
375 expiry_time = self.creation_time + self.ttl
376 remaining = max(0, expiry_time - time.monotonic())
377 return (
378 f"{self.__class__.__name__}("
379 f"id={self.id}, "
380 f"ttl={self.ttl}, "
381 f"creation_time={self.creation_time}, "
382 f"expires_at={expiry_time}, "
383 f"remaining={remaining:.1f}s, "
384 f"expired={self.is_expired()}"
385 f")"
386 )
387
388 def __eq__(self, other) -> bool:
389 """
390 Two NodeFailedOverNotification notifications are considered equal if they have the same
391 id and are of the same type.
392 """
393 if not isinstance(other, NodeFailedOverNotification):
394 return False
395 return self.id == other.id and type(self) is type(other)
396
397 def __hash__(self) -> int:
398 """
399 Return a hash value for the notification to allow
400 instances to be used in sets and as dictionary keys.
401
402 Returns:
403 int: Hash value based on notification type and id
404 """
405 return hash((self.__class__.__name__, int(self.id)))
406
407
408class OSSNodeMigratingNotification(MaintenanceNotification):
409 """
410 Notification for when a Redis OSS API client is used and a node is in the process of migrating slots.
411
412 This notification is received when a node starts migrating its slots to another node
413 during cluster rebalancing or maintenance operations.
414
415 Args:
416 id (int): Unique identifier for this notification
417 slots (Optional[List[int]]): List of slots being migrated
418 """
419
420 DEFAULT_TTL = 30
421
422 def __init__(
423 self,
424 id: int,
425 slots: Optional[str] = None,
426 ):
427 super().__init__(id, OSSNodeMigratingNotification.DEFAULT_TTL)
428 self.slots = slots
429
430 def __repr__(self) -> str:
431 expiry_time = self.creation_time + self.ttl
432 remaining = max(0, expiry_time - time.monotonic())
433 return (
434 f"{self.__class__.__name__}("
435 f"id={self.id}, "
436 f"slots={self.slots}, "
437 f"ttl={self.ttl}, "
438 f"creation_time={self.creation_time}, "
439 f"expires_at={expiry_time}, "
440 f"remaining={remaining:.1f}s, "
441 f"expired={self.is_expired()}"
442 f")"
443 )
444
445 def __eq__(self, other) -> bool:
446 """
447 Two OSSNodeMigratingNotification notifications are considered equal if they have the same
448 id and are of the same type.
449 """
450 if not isinstance(other, OSSNodeMigratingNotification):
451 return False
452 return self.id == other.id and type(self) is type(other)
453
454 def __hash__(self) -> int:
455 """
456 Return a hash value for the notification to allow
457 instances to be used in sets and as dictionary keys.
458
459 Returns:
460 int: Hash value based on notification type and id
461 """
462 return hash((self.__class__.__name__, int(self.id)))
463
464
465class OSSNodeMigratedNotification(MaintenanceNotification):
466 """
467 Notification for when a Redis OSS API client is used and a node has completed migrating slots.
468
469 This notification is received when a node has finished migrating all its slots
470 to other nodes during cluster rebalancing or maintenance operations.
471
472 Args:
473 id (int): Unique identifier for this notification
474 nodes_to_slots_mapping (Dict[str, List[Dict[str, str]]]): Map of source node address
475 to list of destination mappings. Each destination mapping is a dict with
476 the destination node address as key and the slot range as value.
477
478 Structure example:
479 {
480 "127.0.0.1:6379": [
481 {"127.0.0.1:6380": "1-100"},
482 {"127.0.0.1:6381": "101-200"}
483 ],
484 "127.0.0.1:6382": [
485 {"127.0.0.1:6383": "201-300"}
486 ]
487 }
488
489 Where:
490 - Key (str): Source node address in "host:port" format
491 - Value (List[Dict[str, str]]): List of destination mappings where each dict
492 contains destination node address as key and slot range as value
493 """
494
495 DEFAULT_TTL = 120
496
497 def __init__(
498 self,
499 id: int,
500 nodes_to_slots_mapping: Dict[str, List[Dict[str, str]]],
501 ):
502 super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL)
503 self.nodes_to_slots_mapping = nodes_to_slots_mapping
504
505 def __repr__(self) -> str:
506 expiry_time = self.creation_time + self.ttl
507 remaining = max(0, expiry_time - time.monotonic())
508 return (
509 f"{self.__class__.__name__}("
510 f"id={self.id}, "
511 f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, "
512 f"ttl={self.ttl}, "
513 f"creation_time={self.creation_time}, "
514 f"expires_at={expiry_time}, "
515 f"remaining={remaining:.1f}s, "
516 f"expired={self.is_expired()}"
517 f")"
518 )
519
520 def __eq__(self, other) -> bool:
521 """
522 Two OSSNodeMigratedNotification notifications are considered equal if they have the same
523 id and are of the same type.
524 """
525 if not isinstance(other, OSSNodeMigratedNotification):
526 return False
527 return self.id == other.id and type(self) is type(other)
528
529 def __hash__(self) -> int:
530 """
531 Return a hash value for the notification to allow
532 instances to be used in sets and as dictionary keys.
533
534 Returns:
535 int: Hash value based on notification type and id
536 """
537 return hash((self.__class__.__name__, int(self.id)))
538
539
540def _is_private_fqdn(host: str) -> bool:
541 """
542 Determine if an FQDN is likely to be internal/private.
543
544 This uses heuristics based on RFC 952 and RFC 1123 standards:
545 - .local domains (RFC 6762 - Multicast DNS)
546 - .internal domains (common internal convention)
547 - Single-label hostnames (no dots)
548 - Common internal TLDs
549
550 Args:
551 host (str): The FQDN to check
552
553 Returns:
554 bool: True if the FQDN appears to be internal/private
555 """
556 host_lower = host.lower().rstrip(".")
557
558 # Single-label hostnames (no dots) are typically internal
559 if "." not in host_lower:
560 return True
561
562 # Common internal/private domain patterns
563 internal_patterns = [
564 r"\.local$", # mDNS/Bonjour domains
565 r"\.internal$", # Common internal convention
566 r"\.corp$", # Corporate domains
567 r"\.lan$", # Local area network
568 r"\.intranet$", # Intranet domains
569 r"\.private$", # Private domains
570 ]
571
572 for pattern in internal_patterns:
573 if re.search(pattern, host_lower):
574 return True
575
576 # If none of the internal patterns match, assume it's external
577 return False
578
579
580def add_debug_log_for_notification(
581 connection: "MaintNotificationsAbstractConnection",
582 notification: Union[str, MaintenanceNotification],
583):
584 if logger.isEnabledFor(logging.DEBUG):
585 socket_address = None
586 try:
587 socket_address = (
588 connection._sock.getsockname() if connection._sock else None
589 )
590 socket_address = socket_address[1] if socket_address else None
591 except (AttributeError, OSError):
592 pass
593
594 logger.debug(
595 f"Handling maintenance notification: {notification}, "
596 f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, "
597 f"local socket port: {socket_address}",
598 )
599
600
601class MaintNotificationsConfig:
602 """
603 Configuration class for maintenance notifications handling behaviour. Notifications are received through
604 push notifications.
605
606 This class defines how the Redis client should react to different push notifications
607 such as node moving, migrations, etc. in a Redis cluster.
608
609 """
610
611 def __init__(
612 self,
613 enabled: Union[bool, Literal["auto"]] = "auto",
614 proactive_reconnect: bool = True,
615 relaxed_timeout: Optional[Number] = 10,
616 endpoint_type: Optional[EndpointType] = None,
617 ):
618 """
619 Initialize a new MaintNotificationsConfig.
620
621 Args:
622 enabled (bool | "auto"): Controls maintenance notifications handling behavior.
623 - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
624 otherwise a ResponseError is raised.
625 - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
626 gracefully handled - a warning is logged and normal operation continues.
627 - False: Maintenance notifications are completely disabled.
628 Defaults to "auto".
629 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
630 Defaults to True.
631 relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.
632 If -1 is provided - the relaxed timeout is disabled. Defaults to 20.
633 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
634 If None, the endpoint type will be automatically determined based on the host and TLS configuration.
635 Defaults to None.
636
637 Raises:
638 ValueError: If endpoint_type is provided but is not a valid endpoint type.
639 """
640 self.enabled = enabled
641 self.relaxed_timeout = relaxed_timeout
642 self.proactive_reconnect = proactive_reconnect
643 self.endpoint_type = endpoint_type
644
645 def __repr__(self) -> str:
646 return (
647 f"{self.__class__.__name__}("
648 f"enabled={self.enabled}, "
649 f"proactive_reconnect={self.proactive_reconnect}, "
650 f"relaxed_timeout={self.relaxed_timeout}, "
651 f"endpoint_type={self.endpoint_type!r}"
652 f")"
653 )
654
655 def is_relaxed_timeouts_enabled(self) -> bool:
656 """
657 Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout.
658 If relaxed_timeout is set to None, it will make the operation blocking
659 and waiting until any response is received.
660
661 Returns:
662 True if the relaxed_timeout is enabled, False otherwise.
663 """
664 return self.relaxed_timeout != -1
665
666 def get_endpoint_type(
667 self, host: str, connection: "MaintNotificationsAbstractConnection"
668 ) -> EndpointType:
669 """
670 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
671
672 Logic:
673 1. If endpoint_type is explicitly set, use it
674 2. Otherwise, check the original host from connection.host:
675 - If host is an IP address, use it directly to determine internal-ip vs external-ip
676 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
677
678 Args:
679 host: User provided hostname to analyze
680 connection: The connection object to analyze for endpoint type determination
681
682 Returns:
683 """
684
685 # If endpoint_type is explicitly set, use it
686 if self.endpoint_type is not None:
687 return self.endpoint_type
688
689 # Check if the host is an IP address
690 try:
691 ip_addr = ipaddress.ip_address(host)
692 # Host is an IP address - use it directly
693 is_private = ip_addr.is_private
694 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
695 except ValueError:
696 # Host is an FQDN - need to check resolved IP to determine internal vs external
697 pass
698
699 # Host is an FQDN, get the resolved IP to determine if it's internal or external
700 resolved_ip = connection.get_resolved_ip()
701
702 if resolved_ip:
703 try:
704 ip_addr = ipaddress.ip_address(resolved_ip)
705 is_private = ip_addr.is_private
706 # Use FQDN types since the original host was an FQDN
707 return (
708 EndpointType.INTERNAL_FQDN
709 if is_private
710 else EndpointType.EXTERNAL_FQDN
711 )
712 except ValueError:
713 # This shouldn't happen since we got the IP from the socket, but fallback
714 pass
715
716 # Final fallback: use heuristics on the FQDN itself
717 is_private = _is_private_fqdn(host)
718 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
719
720
721class MaintNotificationsPoolHandler:
722 def __init__(
723 self,
724 pool: "MaintNotificationsAbstractConnectionPool",
725 config: MaintNotificationsConfig,
726 ) -> None:
727 self.pool = pool
728 self.config = config
729 self._processed_notifications = set()
730 self._lock = threading.RLock()
731 self.connection = None
732
733 def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
734 self.connection = connection
735
736 def get_handler_for_connection(self):
737 # Copy all data that should be shared between connections
738 # but each connection should have its own pool handler
739 # since each connection can be in a different state
740 copy = MaintNotificationsPoolHandler(self.pool, self.config)
741 copy._processed_notifications = self._processed_notifications
742 copy._lock = self._lock
743 copy.connection = None
744 return copy
745
746 def remove_expired_notifications(self):
747 with self._lock:
748 for notification in tuple(self._processed_notifications):
749 if notification.is_expired():
750 self._processed_notifications.remove(notification)
751
752 def handle_notification(self, notification: MaintenanceNotification):
753 self.remove_expired_notifications()
754
755 if isinstance(notification, NodeMovingNotification):
756 return self.handle_node_moving_notification(notification)
757 else:
758 logger.error(f"Unhandled notification type: {notification}")
759
760 def handle_node_moving_notification(self, notification: NodeMovingNotification):
761 if (
762 not self.config.proactive_reconnect
763 and not self.config.is_relaxed_timeouts_enabled()
764 ):
765 return
766 with self._lock:
767 if notification in self._processed_notifications:
768 # nothing to do in the connection pool handling
769 # the notification has already been handled or is expired
770 # just return
771 return
772
773 with self.pool._lock:
774 logger.debug(
775 f"Handling node MOVING notification: {notification}, "
776 f"with connection: {self.connection}, connected to ip "
777 f"{self.connection.get_resolved_ip() if self.connection else None}"
778 )
779 if (
780 self.config.proactive_reconnect
781 or self.config.is_relaxed_timeouts_enabled()
782 ):
783 # Get the current connected address - if any
784 # This is the address that is being moved
785 # and we need to handle only connections
786 # connected to the same address
787 moving_address_src = (
788 self.connection.getpeername() if self.connection else None
789 )
790
791 if getattr(self.pool, "set_in_maintenance", False):
792 # Set pool in maintenance mode - executed only if
793 # BlockingConnectionPool is used
794 self.pool.set_in_maintenance(True)
795
796 # Update maintenance state, timeout and optionally host address
797 # connection settings for matching connections
798 self.pool.update_connections_settings(
799 state=MaintenanceState.MOVING,
800 maintenance_notification_hash=hash(notification),
801 relaxed_timeout=self.config.relaxed_timeout,
802 host_address=notification.new_node_host,
803 matching_address=moving_address_src,
804 matching_pattern="connected_address",
805 update_notification_hash=True,
806 include_free_connections=True,
807 )
808
809 if self.config.proactive_reconnect:
810 if notification.new_node_host is not None:
811 self.run_proactive_reconnect(moving_address_src)
812 else:
813 threading.Timer(
814 notification.ttl / 2,
815 self.run_proactive_reconnect,
816 args=(moving_address_src,),
817 ).start()
818
819 # Update config for new connections:
820 # Set state to MOVING
821 # update host
822 # if relax timeouts are enabled - update timeouts
823 kwargs: dict = {
824 "maintenance_state": MaintenanceState.MOVING,
825 "maintenance_notification_hash": hash(notification),
826 }
827 if notification.new_node_host is not None:
828 # the host is not updated if the new node host is None
829 # this happens when the MOVING push notification does not contain
830 # the new node host - in this case we only update the timeouts
831 kwargs.update(
832 {
833 "host": notification.new_node_host,
834 }
835 )
836 if self.config.is_relaxed_timeouts_enabled():
837 kwargs.update(
838 {
839 "socket_timeout": self.config.relaxed_timeout,
840 "socket_connect_timeout": self.config.relaxed_timeout,
841 }
842 )
843 self.pool.update_connection_kwargs(**kwargs)
844
845 if getattr(self.pool, "set_in_maintenance", False):
846 self.pool.set_in_maintenance(False)
847
848 threading.Timer(
849 notification.ttl,
850 self.handle_node_moved_notification,
851 args=(notification,),
852 ).start()
853
854 record_connection_handoff(
855 pool_name=get_pool_name(self.pool),
856 )
857
858 self._processed_notifications.add(notification)
859
860 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
861 """
862 Run proactive reconnect for the pool.
863 Active connections are marked for reconnect after they complete the current command.
864 Inactive connections are disconnected and will be connected on next use.
865 """
866 with self._lock:
867 with self.pool._lock:
868 # take care for the active connections in the pool
869 # mark them for reconnect after they complete the current command
870 self.pool.update_active_connections_for_reconnect(
871 moving_address_src=moving_address_src,
872 )
873 # take care for the inactive connections in the pool
874 # delete them and create new ones
875 self.pool.disconnect_free_connections(
876 moving_address_src=moving_address_src,
877 )
878
879 def handle_node_moved_notification(self, notification: NodeMovingNotification):
880 """
881 Handle the cleanup after a node moving notification expires.
882 """
883 notification_hash = hash(notification)
884
885 with self._lock:
886 logger.debug(
887 f"Reverting temporary changes related to notification: {notification}, "
888 f"with connection: {self.connection}, connected to ip "
889 f"{self.connection.get_resolved_ip() if self.connection else None}"
890 )
891 # if the current maintenance_notification_hash in kwargs is not matching the notification
892 # it means there has been a new moving notification after this one
893 # and we don't need to revert the kwargs yet
894 if (
895 self.pool.connection_kwargs.get("maintenance_notification_hash")
896 == notification_hash
897 ):
898 orig_host = self.pool.connection_kwargs.get("orig_host_address")
899 orig_socket_timeout = self.pool.connection_kwargs.get(
900 "orig_socket_timeout"
901 )
902 orig_connect_timeout = self.pool.connection_kwargs.get(
903 "orig_socket_connect_timeout"
904 )
905 kwargs: dict = {
906 "maintenance_state": MaintenanceState.NONE,
907 "maintenance_notification_hash": None,
908 "host": orig_host,
909 "socket_timeout": orig_socket_timeout,
910 "socket_connect_timeout": orig_connect_timeout,
911 }
912 self.pool.update_connection_kwargs(**kwargs)
913
914 with self.pool._lock:
915 reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled()
916 reset_host_address = self.config.proactive_reconnect
917
918 self.pool.update_connections_settings(
919 relaxed_timeout=-1,
920 state=MaintenanceState.NONE,
921 maintenance_notification_hash=None,
922 matching_notification_hash=notification_hash,
923 matching_pattern="notification_hash",
924 update_notification_hash=True,
925 reset_relaxed_timeout=reset_relaxed_timeout,
926 reset_host_address=reset_host_address,
927 include_free_connections=True,
928 )
929
930
931class MaintNotificationsConnectionHandler:
932 # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications
933 _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = {
934 NodeMigratingNotification: 1,
935 NodeFailingOverNotification: 1,
936 OSSNodeMigratingNotification: 1,
937 NodeMigratedNotification: 0,
938 NodeFailedOverNotification: 0,
939 OSSNodeMigratedNotification: 0,
940 }
941
942 def __init__(
943 self,
944 connection: "MaintNotificationsAbstractConnection",
945 config: MaintNotificationsConfig,
946 ) -> None:
947 self.connection = connection
948 self.config = config
949
950 def handle_notification(self, notification: MaintenanceNotification):
951 # get the notification type by checking its class in the _NOTIFICATION_TYPES dict
952 notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None)
953
954 record_maint_notification_count(
955 server_address=self.connection.host,
956 server_port=self.connection.port,
957 network_peer_address=self.connection.host,
958 network_peer_port=self.connection.port,
959 maint_notification=notification.__class__.__name__,
960 )
961
962 if notification_type is None:
963 logger.error(f"Unhandled notification type: {notification}")
964 return
965
966 if notification_type:
967 self.handle_maintenance_start_notification(
968 MaintenanceState.MAINTENANCE, notification
969 )
970 else:
971 self.handle_maintenance_completed_notification(notification=notification)
972
973 def handle_maintenance_start_notification(
974 self, maintenance_state: MaintenanceState, notification: MaintenanceNotification
975 ):
976 add_debug_log_for_notification(self.connection, notification)
977
978 if (
979 self.connection.maintenance_state == MaintenanceState.MOVING
980 or not self.config.is_relaxed_timeouts_enabled()
981 ):
982 return
983
984 self.connection.maintenance_state = maintenance_state
985 self.connection.set_tmp_settings(
986 tmp_relaxed_timeout=self.config.relaxed_timeout
987 )
988 # extend the timeout for all created connections
989 self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
990 if isinstance(notification, OSSNodeMigratingNotification):
991 # add the notification id to the set of processed start maint notifications
992 # this is used to skip the unrelaxing of the timeouts if we have received more than
993 # one start notification before the the final end notification
994 self.connection.add_maint_start_notification(notification.id)
995
996 record_connection_relaxed_timeout(
997 connection_name=repr(self.connection),
998 maint_notification=notification.__class__.__name__,
999 relaxed=True,
1000 )
1001
1002 def handle_maintenance_completed_notification(self, **kwargs):
1003 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
1004 if (
1005 self.connection.maintenance_state == MaintenanceState.MOVING
1006 or not self.config.is_relaxed_timeouts_enabled()
1007 ):
1008 return
1009 notification = None
1010 if kwargs.get("notification"):
1011 notification = kwargs["notification"]
1012 add_debug_log_for_notification(
1013 self.connection, notification if notification else "MAINTENANCE_COMPLETED"
1014 )
1015 self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
1016 # Maintenance completed - reset the connection
1017 # timeouts by providing -1 as the relaxed timeout
1018 self.connection.update_current_socket_timeout(-1)
1019 self.connection.maintenance_state = MaintenanceState.NONE
1020 # reset the sets that keep track of received start maint
1021 # notifications and skipped end maint notifications
1022 self.connection.reset_received_notifications()
1023
1024 if notification:
1025 record_connection_relaxed_timeout(
1026 connection_name=repr(self.connection),
1027 maint_notification=notification.__class__.__name__,
1028 relaxed=False,
1029 )
1030
1031
1032class OSSMaintNotificationsHandler:
1033 def __init__(
1034 self,
1035 cluster_client: "MaintNotificationsAbstractRedisCluster",
1036 config: MaintNotificationsConfig,
1037 ) -> None:
1038 self.cluster_client = cluster_client
1039 self.config = config
1040 self._processed_notifications = set()
1041 self._in_progress = set()
1042 self._lock = threading.RLock()
1043
1044 def get_handler_for_connection(self):
1045 # Copy all data that should be shared between connections
1046 # but each connection should have its own pool handler
1047 # since each connection can be in a different state
1048 copy = OSSMaintNotificationsHandler(self.cluster_client, self.config)
1049 copy._processed_notifications = self._processed_notifications
1050 copy._in_progress = self._in_progress
1051 copy._lock = self._lock
1052 return copy
1053
1054 def remove_expired_notifications(self):
1055 with self._lock:
1056 for notification in tuple(self._processed_notifications):
1057 if notification.is_expired():
1058 self._processed_notifications.remove(notification)
1059
1060 def handle_notification(self, notification: MaintenanceNotification):
1061 if isinstance(notification, OSSNodeMigratedNotification):
1062 self.handle_oss_maintenance_completed_notification(notification)
1063 else:
1064 logger.error(f"Unhandled notification type: {notification}")
1065
1066 def handle_oss_maintenance_completed_notification(
1067 self, notification: OSSNodeMigratedNotification
1068 ):
1069 self.remove_expired_notifications()
1070
1071 with self._lock:
1072 if (
1073 notification in self._in_progress
1074 or notification in self._processed_notifications
1075 ):
1076 # we are already handling this notification or it has already been processed
1077 # we should skip in_progress notification since when we reinitialize the cluster
1078 # we execute a CLUSTER SLOTS command that can use a different connection
1079 # that has also has the notification and we don't want to
1080 # process the same notification twice
1081 return
1082
1083 if logger.isEnabledFor(logging.DEBUG):
1084 logger.debug(f"Handling SMIGRATED notification: {notification}")
1085 self._in_progress.add(notification)
1086
1087 # Extract the information about the src and destination nodes that are affected
1088 # by the maintenance. nodes_to_slots_mapping structure:
1089 # {
1090 # "src_host:port": [
1091 # {"dest_host:port": "slot_range"},
1092 # ...
1093 # ],
1094 # ...
1095 # }
1096 additional_startup_nodes_info = []
1097 affected_nodes = set()
1098 for (
1099 src_address,
1100 dest_mappings,
1101 ) in notification.nodes_to_slots_mapping.items():
1102 src_host, src_port = src_address.split(":")
1103 src_node = self.cluster_client.nodes_manager.get_node(
1104 host=src_host, port=src_port
1105 )
1106 if src_node is not None:
1107 affected_nodes.add(src_node)
1108
1109 for dest_mapping in dest_mappings:
1110 for dest_address in dest_mapping.keys():
1111 dest_host, dest_port = dest_address.split(":")
1112 additional_startup_nodes_info.append(
1113 (dest_host, int(dest_port))
1114 )
1115
1116 # Updates the cluster slots cache with the new slots mapping
1117 # This will also update the nodes cache with the new nodes mapping
1118 self.cluster_client.nodes_manager.initialize(
1119 disconnect_startup_nodes_pools=False,
1120 additional_startup_nodes_info=additional_startup_nodes_info,
1121 )
1122
1123 all_nodes = set(affected_nodes)
1124 all_nodes = all_nodes.union(
1125 self.cluster_client.nodes_manager.nodes_cache.values()
1126 )
1127
1128 for current_node in all_nodes:
1129 if current_node.redis_connection is None:
1130 continue
1131 with current_node.redis_connection.connection_pool._lock:
1132 if current_node in affected_nodes:
1133 # mark for reconnect all in use connections to the node - this will force them to
1134 # disconnect after they complete their current commands
1135 # Some of them might be used by sub sub and we don't know which ones - so we disconnect
1136 # all in flight connections after they are done with current command execution
1137 for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
1138 add_debug_log_for_notification(
1139 conn, "SMIGRATED - mark for reconnect"
1140 )
1141 conn.mark_for_reconnect()
1142 else:
1143 if logger.isEnabledFor(logging.DEBUG):
1144 logger.debug(
1145 f"SMIGRATED: Node {current_node.name} not affected by maintenance, "
1146 f"skipping mark for reconnect"
1147 )
1148
1149 if (
1150 current_node
1151 not in self.cluster_client.nodes_manager.nodes_cache.values()
1152 ):
1153 # disconnect all free connections to the node - this node will be dropped
1154 # from the cluster, so we don't need to revert the timeouts
1155 for conn in current_node.redis_connection.connection_pool._get_free_connections():
1156 conn.disconnect()
1157
1158 # mark the notification as processed
1159 self._processed_notifications.add(notification)
1160 self._in_progress.remove(notification)