1import enum
2import ipaddress
3import logging
4import re
5import threading
6import time
7from abc import ABC, abstractmethod
8from typing import TYPE_CHECKING, Optional, Union
9
10from redis.typing import Number
11
12
13class MaintenanceState(enum.Enum):
14 NONE = "none"
15 MOVING = "moving"
16 MAINTENANCE = "maintenance"
17
18
19class EndpointType(enum.Enum):
20 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
21
22 INTERNAL_IP = "internal-ip"
23 INTERNAL_FQDN = "internal-fqdn"
24 EXTERNAL_IP = "external-ip"
25 EXTERNAL_FQDN = "external-fqdn"
26 NONE = "none"
27
28 def __str__(self):
29 """Return the string value of the enum."""
30 return self.value
31
32
33if TYPE_CHECKING:
34 from redis.connection import (
35 BlockingConnectionPool,
36 ConnectionInterface,
37 ConnectionPool,
38 )
39
40
41class MaintenanceEvent(ABC):
42 """
43 Base class for maintenance events sent through push messages by Redis server.
44
45 This class provides common functionality for all maintenance events including
46 unique identification and TTL (Time-To-Live) functionality.
47
48 Attributes:
49 id (int): Unique identifier for this event
50 ttl (int): Time-to-live in seconds for this notification
51 creation_time (float): Timestamp when the notification was created/read
52 """
53
54 def __init__(self, id: int, ttl: int):
55 """
56 Initialize a new MaintenanceEvent with unique ID and TTL functionality.
57
58 Args:
59 id (int): Unique identifier for this event
60 ttl (int): Time-to-live in seconds for this notification
61 """
62 self.id = id
63 self.ttl = ttl
64 self.creation_time = time.monotonic()
65 self.expire_at = self.creation_time + self.ttl
66
67 def is_expired(self) -> bool:
68 """
69 Check if this event has expired based on its TTL
70 and creation time.
71
72 Returns:
73 bool: True if the event has expired, False otherwise
74 """
75 return time.monotonic() > (self.creation_time + self.ttl)
76
77 @abstractmethod
78 def __repr__(self) -> str:
79 """
80 Return a string representation of the maintenance event.
81
82 This method must be implemented by all concrete subclasses.
83
84 Returns:
85 str: String representation of the event
86 """
87 pass
88
89 @abstractmethod
90 def __eq__(self, other) -> bool:
91 """
92 Compare two maintenance events for equality.
93
94 This method must be implemented by all concrete subclasses.
95 Events are typically considered equal if they have the same id
96 and are of the same type.
97
98 Args:
99 other: The other object to compare with
100
101 Returns:
102 bool: True if the events are equal, False otherwise
103 """
104 pass
105
106 @abstractmethod
107 def __hash__(self) -> int:
108 """
109 Return a hash value for the maintenance event.
110
111 This method must be implemented by all concrete subclasses to allow
112 instances to be used in sets and as dictionary keys.
113
114 Returns:
115 int: Hash value for the event
116 """
117 pass
118
119
120class NodeMovingEvent(MaintenanceEvent):
121 """
122 This event is received when a node is replaced with a new node
123 during cluster rebalancing or maintenance operations.
124 """
125
126 def __init__(
127 self,
128 id: int,
129 new_node_host: Optional[str],
130 new_node_port: Optional[int],
131 ttl: int,
132 ):
133 """
134 Initialize a new NodeMovingEvent.
135
136 Args:
137 id (int): Unique identifier for this event
138 new_node_host (str): Hostname or IP address of the new replacement node
139 new_node_port (int): Port number of the new replacement node
140 ttl (int): Time-to-live in seconds for this notification
141 """
142 super().__init__(id, ttl)
143 self.new_node_host = new_node_host
144 self.new_node_port = new_node_port
145
146 def __repr__(self) -> str:
147 expiry_time = self.expire_at
148 remaining = max(0, expiry_time - time.monotonic())
149
150 return (
151 f"{self.__class__.__name__}("
152 f"id={self.id}, "
153 f"new_node_host='{self.new_node_host}', "
154 f"new_node_port={self.new_node_port}, "
155 f"ttl={self.ttl}, "
156 f"creation_time={self.creation_time}, "
157 f"expires_at={expiry_time}, "
158 f"remaining={remaining:.1f}s, "
159 f"expired={self.is_expired()}"
160 f")"
161 )
162
163 def __eq__(self, other) -> bool:
164 """
165 Two NodeMovingEvent events are considered equal if they have the same
166 id, new_node_host, and new_node_port.
167 """
168 if not isinstance(other, NodeMovingEvent):
169 return False
170 return (
171 self.id == other.id
172 and self.new_node_host == other.new_node_host
173 and self.new_node_port == other.new_node_port
174 )
175
176 def __hash__(self) -> int:
177 """
178 Return a hash value for the event to allow
179 instances to be used in sets and as dictionary keys.
180
181 Returns:
182 int: Hash value based on event type class name, id,
183 new_node_host and new_node_port
184 """
185 try:
186 node_port = int(self.new_node_port) if self.new_node_port else None
187 except ValueError:
188 node_port = 0
189
190 return hash(
191 (
192 self.__class__.__name__,
193 int(self.id),
194 str(self.new_node_host),
195 node_port,
196 )
197 )
198
199
200class NodeMigratingEvent(MaintenanceEvent):
201 """
202 Event for when a Redis cluster node is in the process of migrating slots.
203
204 This event is received when a node starts migrating its slots to another node
205 during cluster rebalancing or maintenance operations.
206
207 Args:
208 id (int): Unique identifier for this event
209 ttl (int): Time-to-live in seconds for this notification
210 """
211
212 def __init__(self, id: int, ttl: int):
213 super().__init__(id, ttl)
214
215 def __repr__(self) -> str:
216 expiry_time = self.creation_time + self.ttl
217 remaining = max(0, expiry_time - time.monotonic())
218 return (
219 f"{self.__class__.__name__}("
220 f"id={self.id}, "
221 f"ttl={self.ttl}, "
222 f"creation_time={self.creation_time}, "
223 f"expires_at={expiry_time}, "
224 f"remaining={remaining:.1f}s, "
225 f"expired={self.is_expired()}"
226 f")"
227 )
228
229 def __eq__(self, other) -> bool:
230 """
231 Two NodeMigratingEvent events are considered equal if they have the same
232 id and are of the same type.
233 """
234 if not isinstance(other, NodeMigratingEvent):
235 return False
236 return self.id == other.id and type(self) is type(other)
237
238 def __hash__(self) -> int:
239 """
240 Return a hash value for the event to allow
241 instances to be used in sets and as dictionary keys.
242
243 Returns:
244 int: Hash value based on event type and id
245 """
246 return hash((self.__class__.__name__, int(self.id)))
247
248
249class NodeMigratedEvent(MaintenanceEvent):
250 """
251 Event for when a Redis cluster node has completed migrating slots.
252
253 This event is received when a node has finished migrating all its slots
254 to other nodes during cluster rebalancing or maintenance operations.
255
256 Args:
257 id (int): Unique identifier for this event
258 """
259
260 DEFAULT_TTL = 5
261
262 def __init__(self, id: int):
263 super().__init__(id, NodeMigratedEvent.DEFAULT_TTL)
264
265 def __repr__(self) -> str:
266 expiry_time = self.creation_time + self.ttl
267 remaining = max(0, expiry_time - time.monotonic())
268 return (
269 f"{self.__class__.__name__}("
270 f"id={self.id}, "
271 f"ttl={self.ttl}, "
272 f"creation_time={self.creation_time}, "
273 f"expires_at={expiry_time}, "
274 f"remaining={remaining:.1f}s, "
275 f"expired={self.is_expired()}"
276 f")"
277 )
278
279 def __eq__(self, other) -> bool:
280 """
281 Two NodeMigratedEvent events are considered equal if they have the same
282 id and are of the same type.
283 """
284 if not isinstance(other, NodeMigratedEvent):
285 return False
286 return self.id == other.id and type(self) is type(other)
287
288 def __hash__(self) -> int:
289 """
290 Return a hash value for the event to allow
291 instances to be used in sets and as dictionary keys.
292
293 Returns:
294 int: Hash value based on event type and id
295 """
296 return hash((self.__class__.__name__, int(self.id)))
297
298
299class NodeFailingOverEvent(MaintenanceEvent):
300 """
301 Event for when a Redis cluster node is in the process of failing over.
302
303 This event is received when a node starts a failover process during
304 cluster maintenance operations or when handling node failures.
305
306 Args:
307 id (int): Unique identifier for this event
308 ttl (int): Time-to-live in seconds for this notification
309 """
310
311 def __init__(self, id: int, ttl: int):
312 super().__init__(id, ttl)
313
314 def __repr__(self) -> str:
315 expiry_time = self.creation_time + self.ttl
316 remaining = max(0, expiry_time - time.monotonic())
317 return (
318 f"{self.__class__.__name__}("
319 f"id={self.id}, "
320 f"ttl={self.ttl}, "
321 f"creation_time={self.creation_time}, "
322 f"expires_at={expiry_time}, "
323 f"remaining={remaining:.1f}s, "
324 f"expired={self.is_expired()}"
325 f")"
326 )
327
328 def __eq__(self, other) -> bool:
329 """
330 Two NodeFailingOverEvent events are considered equal if they have the same
331 id and are of the same type.
332 """
333 if not isinstance(other, NodeFailingOverEvent):
334 return False
335 return self.id == other.id and type(self) is type(other)
336
337 def __hash__(self) -> int:
338 """
339 Return a hash value for the event to allow
340 instances to be used in sets and as dictionary keys.
341
342 Returns:
343 int: Hash value based on event type and id
344 """
345 return hash((self.__class__.__name__, int(self.id)))
346
347
348class NodeFailedOverEvent(MaintenanceEvent):
349 """
350 Event for when a Redis cluster node has completed a failover.
351
352 This event is received when a node has finished the failover process
353 during cluster maintenance operations or after handling node failures.
354
355 Args:
356 id (int): Unique identifier for this event
357 """
358
359 DEFAULT_TTL = 5
360
361 def __init__(self, id: int):
362 super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL)
363
364 def __repr__(self) -> str:
365 expiry_time = self.creation_time + self.ttl
366 remaining = max(0, expiry_time - time.monotonic())
367 return (
368 f"{self.__class__.__name__}("
369 f"id={self.id}, "
370 f"ttl={self.ttl}, "
371 f"creation_time={self.creation_time}, "
372 f"expires_at={expiry_time}, "
373 f"remaining={remaining:.1f}s, "
374 f"expired={self.is_expired()}"
375 f")"
376 )
377
378 def __eq__(self, other) -> bool:
379 """
380 Two NodeFailedOverEvent events are considered equal if they have the same
381 id and are of the same type.
382 """
383 if not isinstance(other, NodeFailedOverEvent):
384 return False
385 return self.id == other.id and type(self) is type(other)
386
387 def __hash__(self) -> int:
388 """
389 Return a hash value for the event to allow
390 instances to be used in sets and as dictionary keys.
391
392 Returns:
393 int: Hash value based on event type and id
394 """
395 return hash((self.__class__.__name__, int(self.id)))
396
397
398def _is_private_fqdn(host: str) -> bool:
399 """
400 Determine if an FQDN is likely to be internal/private.
401
402 This uses heuristics based on RFC 952 and RFC 1123 standards:
403 - .local domains (RFC 6762 - Multicast DNS)
404 - .internal domains (common internal convention)
405 - Single-label hostnames (no dots)
406 - Common internal TLDs
407
408 Args:
409 host (str): The FQDN to check
410
411 Returns:
412 bool: True if the FQDN appears to be internal/private
413 """
414 host_lower = host.lower().rstrip(".")
415
416 # Single-label hostnames (no dots) are typically internal
417 if "." not in host_lower:
418 return True
419
420 # Common internal/private domain patterns
421 internal_patterns = [
422 r"\.local$", # mDNS/Bonjour domains
423 r"\.internal$", # Common internal convention
424 r"\.corp$", # Corporate domains
425 r"\.lan$", # Local area network
426 r"\.intranet$", # Intranet domains
427 r"\.private$", # Private domains
428 ]
429
430 for pattern in internal_patterns:
431 if re.search(pattern, host_lower):
432 return True
433
434 # If none of the internal patterns match, assume it's external
435 return False
436
437
438class MaintenanceEventsConfig:
439 """
440 Configuration class for maintenance events handling behaviour. Events are received through
441 push notifications.
442
443 This class defines how the Redis client should react to different push notifications
444 such as node moving, migrations, etc. in a Redis cluster.
445
446 """
447
448 def __init__(
449 self,
450 enabled: bool = True,
451 proactive_reconnect: bool = True,
452 relax_timeout: Optional[Number] = 20,
453 endpoint_type: Optional[EndpointType] = None,
454 ):
455 """
456 Initialize a new MaintenanceEventsConfig.
457
458 Args:
459 enabled (bool): Whether to enable maintenance events handling.
460 Defaults to False.
461 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
462 Defaults to True.
463 relax_timeout (Number): The relax timeout to use for the connection during maintenance.
464 If -1 is provided - the relax timeout is disabled. Defaults to 20.
465 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
466 If None, the endpoint type will be automatically determined based on the host and TLS configuration.
467 Defaults to None.
468
469 Raises:
470 ValueError: If endpoint_type is provided but is not a valid endpoint type.
471 """
472 self.enabled = enabled
473 self.relax_timeout = relax_timeout
474 self.proactive_reconnect = proactive_reconnect
475 self.endpoint_type = endpoint_type
476
477 def __repr__(self) -> str:
478 return (
479 f"{self.__class__.__name__}("
480 f"enabled={self.enabled}, "
481 f"proactive_reconnect={self.proactive_reconnect}, "
482 f"relax_timeout={self.relax_timeout}, "
483 f"endpoint_type={self.endpoint_type!r}"
484 f")"
485 )
486
487 def is_relax_timeouts_enabled(self) -> bool:
488 """
489 Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout.
490 If relax_timeout is set to None, it will make the operation blocking
491 and waiting until any response is received.
492
493 Returns:
494 True if the relax_timeout is enabled, False otherwise.
495 """
496 return self.relax_timeout != -1
497
498 def get_endpoint_type(
499 self, host: str, connection: "ConnectionInterface"
500 ) -> EndpointType:
501 """
502 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
503
504 Logic:
505 1. If endpoint_type is explicitly set, use it
506 2. Otherwise, check the original host from connection.host:
507 - If host is an IP address, use it directly to determine internal-ip vs external-ip
508 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
509
510 Args:
511 host: User provided hostname to analyze
512 connection: The connection object to analyze for endpoint type determination
513
514 Returns:
515 """
516
517 # If endpoint_type is explicitly set, use it
518 if self.endpoint_type is not None:
519 return self.endpoint_type
520
521 # Check if the host is an IP address
522 try:
523 ip_addr = ipaddress.ip_address(host)
524 # Host is an IP address - use it directly
525 is_private = ip_addr.is_private
526 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
527 except ValueError:
528 # Host is an FQDN - need to check resolved IP to determine internal vs external
529 pass
530
531 # Host is an FQDN, get the resolved IP to determine if it's internal or external
532 resolved_ip = connection.get_resolved_ip()
533
534 if resolved_ip:
535 try:
536 ip_addr = ipaddress.ip_address(resolved_ip)
537 is_private = ip_addr.is_private
538 # Use FQDN types since the original host was an FQDN
539 return (
540 EndpointType.INTERNAL_FQDN
541 if is_private
542 else EndpointType.EXTERNAL_FQDN
543 )
544 except ValueError:
545 # This shouldn't happen since we got the IP from the socket, but fallback
546 pass
547
548 # Final fallback: use heuristics on the FQDN itself
549 is_private = _is_private_fqdn(host)
550 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
551
552
553class MaintenanceEventPoolHandler:
554 def __init__(
555 self,
556 pool: Union["ConnectionPool", "BlockingConnectionPool"],
557 config: MaintenanceEventsConfig,
558 ) -> None:
559 self.pool = pool
560 self.config = config
561 self._processed_events = set()
562 self._lock = threading.RLock()
563 self.connection = None
564
565 def set_connection(self, connection: "ConnectionInterface"):
566 self.connection = connection
567
568 def remove_expired_notifications(self):
569 with self._lock:
570 for notification in tuple(self._processed_events):
571 if notification.is_expired():
572 self._processed_events.remove(notification)
573
574 def handle_event(self, notification: MaintenanceEvent):
575 self.remove_expired_notifications()
576
577 if isinstance(notification, NodeMovingEvent):
578 return self.handle_node_moving_event(notification)
579 else:
580 logging.error(f"Unhandled notification type: {notification}")
581
582 def handle_node_moving_event(self, event: NodeMovingEvent):
583 if (
584 not self.config.proactive_reconnect
585 and not self.config.is_relax_timeouts_enabled()
586 ):
587 return
588 with self._lock:
589 if event in self._processed_events:
590 # nothing to do in the connection pool handling
591 # the event has already been handled or is expired
592 # just return
593 return
594
595 with self.pool._lock:
596 if (
597 self.config.proactive_reconnect
598 or self.config.is_relax_timeouts_enabled()
599 ):
600 # Get the current connected address - if any
601 # This is the address that is being moved
602 # and we need to handle only connections
603 # connected to the same address
604 moving_address_src = (
605 self.connection.getpeername() if self.connection else None
606 )
607
608 if getattr(self.pool, "set_in_maintenance", False):
609 # Set pool in maintenance mode - executed only if
610 # BlockingConnectionPool is used
611 self.pool.set_in_maintenance(True)
612
613 # Update maintenance state, timeout and optionally host address
614 # connection settings for matching connections
615 self.pool.update_connections_settings(
616 state=MaintenanceState.MOVING,
617 maintenance_event_hash=hash(event),
618 relax_timeout=self.config.relax_timeout,
619 host_address=event.new_node_host,
620 matching_address=moving_address_src,
621 matching_pattern="connected_address",
622 update_event_hash=True,
623 include_free_connections=True,
624 )
625
626 if self.config.proactive_reconnect:
627 if event.new_node_host is not None:
628 self.run_proactive_reconnect(moving_address_src)
629 else:
630 threading.Timer(
631 event.ttl / 2,
632 self.run_proactive_reconnect,
633 args=(moving_address_src,),
634 ).start()
635
636 # Update config for new connections:
637 # Set state to MOVING
638 # update host
639 # if relax timeouts are enabled - update timeouts
640 kwargs: dict = {
641 "maintenance_state": MaintenanceState.MOVING,
642 "maintenance_event_hash": hash(event),
643 }
644 if event.new_node_host is not None:
645 # the host is not updated if the new node host is None
646 # this happens when the MOVING push notification does not contain
647 # the new node host - in this case we only update the timeouts
648 kwargs.update(
649 {
650 "host": event.new_node_host,
651 }
652 )
653 if self.config.is_relax_timeouts_enabled():
654 kwargs.update(
655 {
656 "socket_timeout": self.config.relax_timeout,
657 "socket_connect_timeout": self.config.relax_timeout,
658 }
659 )
660 self.pool.update_connection_kwargs(**kwargs)
661
662 if getattr(self.pool, "set_in_maintenance", False):
663 self.pool.set_in_maintenance(False)
664
665 threading.Timer(
666 event.ttl, self.handle_node_moved_event, args=(event,)
667 ).start()
668
669 self._processed_events.add(event)
670
671 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
672 """
673 Run proactive reconnect for the pool.
674 Active connections are marked for reconnect after they complete the current command.
675 Inactive connections are disconnected and will be connected on next use.
676 """
677 with self._lock:
678 with self.pool._lock:
679 # take care for the active connections in the pool
680 # mark them for reconnect after they complete the current command
681 self.pool.update_active_connections_for_reconnect(
682 moving_address_src=moving_address_src,
683 )
684 # take care for the inactive connections in the pool
685 # delete them and create new ones
686 self.pool.disconnect_free_connections(
687 moving_address_src=moving_address_src,
688 )
689
690 def handle_node_moved_event(self, event: NodeMovingEvent):
691 """
692 Handle the cleanup after a node moving event expires.
693 """
694 event_hash = hash(event)
695
696 with self._lock:
697 # if the current maintenance_event_hash in kwargs is not matching the event
698 # it means there has been a new moving event after this one
699 # and we don't need to revert the kwargs yet
700 if self.pool.connection_kwargs.get("maintenance_event_hash") == event_hash:
701 orig_host = self.pool.connection_kwargs.get("orig_host_address")
702 orig_socket_timeout = self.pool.connection_kwargs.get(
703 "orig_socket_timeout"
704 )
705 orig_connect_timeout = self.pool.connection_kwargs.get(
706 "orig_socket_connect_timeout"
707 )
708 kwargs: dict = {
709 "maintenance_state": MaintenanceState.NONE,
710 "maintenance_event_hash": None,
711 "host": orig_host,
712 "socket_timeout": orig_socket_timeout,
713 "socket_connect_timeout": orig_connect_timeout,
714 }
715 self.pool.update_connection_kwargs(**kwargs)
716
717 with self.pool._lock:
718 reset_relax_timeout = self.config.is_relax_timeouts_enabled()
719 reset_host_address = self.config.proactive_reconnect
720
721 self.pool.update_connections_settings(
722 relax_timeout=-1,
723 state=MaintenanceState.NONE,
724 maintenance_event_hash=None,
725 matching_event_hash=event_hash,
726 matching_pattern="event_hash",
727 update_event_hash=True,
728 reset_relax_timeout=reset_relax_timeout,
729 reset_host_address=reset_host_address,
730 include_free_connections=True,
731 )
732
733
734class MaintenanceEventConnectionHandler:
735 # 1 = "starting maintenance" events, 0 = "completed maintenance" events
736 _EVENT_TYPES: dict[type["MaintenanceEvent"], int] = {
737 NodeMigratingEvent: 1,
738 NodeFailingOverEvent: 1,
739 NodeMigratedEvent: 0,
740 NodeFailedOverEvent: 0,
741 }
742
743 def __init__(
744 self, connection: "ConnectionInterface", config: MaintenanceEventsConfig
745 ) -> None:
746 self.connection = connection
747 self.config = config
748
749 def handle_event(self, event: MaintenanceEvent):
750 # get the event type by checking its class in the _EVENT_TYPES dict
751 event_type = self._EVENT_TYPES.get(event.__class__, None)
752
753 if event_type is None:
754 logging.error(f"Unhandled event type: {event}")
755 return
756
757 if event_type:
758 self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
759 else:
760 self.handle_maintenance_completed_event()
761
762 def handle_maintenance_start_event(self, maintenance_state: MaintenanceState):
763 if (
764 self.connection.maintenance_state == MaintenanceState.MOVING
765 or not self.config.is_relax_timeouts_enabled()
766 ):
767 return
768
769 self.connection.maintenance_state = maintenance_state
770 self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout)
771 # extend the timeout for all created connections
772 self.connection.update_current_socket_timeout(self.config.relax_timeout)
773
774 def handle_maintenance_completed_event(self):
775 # Only reset timeouts if state is not MOVING and relax timeouts are enabled
776 if (
777 self.connection.maintenance_state == MaintenanceState.MOVING
778 or not self.config.is_relax_timeouts_enabled()
779 ):
780 return
781 self.connection.reset_tmp_settings(reset_relax_timeout=True)
782 # Maintenance completed - reset the connection
783 # timeouts by providing -1 as the relax timeout
784 self.connection.update_current_socket_timeout(-1)
785 self.connection.maintenance_state = MaintenanceState.NONE