Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/maint_notifications.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

363 statements  

1import enum 

2import ipaddress 

3import logging 

4import re 

5import threading 

6import time 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union 

9 

10from redis.observability.attributes import get_pool_name 

11from redis.observability.recorder import ( 

12 record_connection_handoff, 

13 record_connection_relaxed_timeout, 

14 record_maint_notification_count, 

15) 

16from redis.typing import Number 

17 

18if TYPE_CHECKING: 

19 from redis.cluster import MaintNotificationsAbstractRedisCluster 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class MaintenanceState(enum.Enum): 

25 NONE = "none" 

26 MOVING = "moving" 

27 MAINTENANCE = "maintenance" 

28 

29 

30class EndpointType(enum.Enum): 

31 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command.""" 

32 

33 INTERNAL_IP = "internal-ip" 

34 INTERNAL_FQDN = "internal-fqdn" 

35 EXTERNAL_IP = "external-ip" 

36 EXTERNAL_FQDN = "external-fqdn" 

37 NONE = "none" 

38 

39 def __str__(self): 

40 """Return the string value of the enum.""" 

41 return self.value 

42 

43 

44if TYPE_CHECKING: 

45 from redis.connection import ( 

46 MaintNotificationsAbstractConnection, 

47 MaintNotificationsAbstractConnectionPool, 

48 ) 

49 

50 

51class MaintenanceNotification(ABC): 

52 """ 

53 Base class for maintenance notifications sent through push messages by Redis server. 

54 

55 This class provides common functionality for all maintenance notifications including 

56 unique identification and TTL (Time-To-Live) functionality. 

57 

58 Attributes: 

59 id (int): Unique identifier for this notification 

60 ttl (int): Time-to-live in seconds for this notification 

61 creation_time (float): Timestamp when the notification was created/read 

62 """ 

63 

64 def __init__(self, id: int, ttl: int): 

65 """ 

66 Initialize a new MaintenanceNotification with unique ID and TTL functionality. 

67 

68 Args: 

69 id (int): Unique identifier for this notification 

70 ttl (int): Time-to-live in seconds for this notification 

71 """ 

72 self.id = id 

73 self.ttl = ttl 

74 self.creation_time = time.monotonic() 

75 self.expire_at = self.creation_time + self.ttl 

76 

77 def is_expired(self) -> bool: 

78 """ 

79 Check if this notification has expired based on its TTL 

80 and creation time. 

81 

82 Returns: 

83 bool: True if the notification has expired, False otherwise 

84 """ 

85 return time.monotonic() > (self.creation_time + self.ttl) 

86 

87 @abstractmethod 

88 def __repr__(self) -> str: 

89 """ 

90 Return a string representation of the maintenance notification. 

91 

92 This method must be implemented by all concrete subclasses. 

93 

94 Returns: 

95 str: String representation of the notification 

96 """ 

97 pass 

98 

99 @abstractmethod 

100 def __eq__(self, other) -> bool: 

101 """ 

102 Compare two maintenance notifications for equality. 

103 

104 This method must be implemented by all concrete subclasses. 

105 Notifications are typically considered equal if they have the same id 

106 and are of the same type. 

107 

108 Args: 

109 other: The other object to compare with 

110 

111 Returns: 

112 bool: True if the notifications are equal, False otherwise 

113 """ 

114 pass 

115 

116 @abstractmethod 

117 def __hash__(self) -> int: 

118 """ 

119 Return a hash value for the maintenance notification. 

120 

121 This method must be implemented by all concrete subclasses to allow 

122 instances to be used in sets and as dictionary keys. 

123 

124 Returns: 

125 int: Hash value for the notification 

126 """ 

127 pass 

128 

129 

130class NodeMovingNotification(MaintenanceNotification): 

131 """ 

132 This notification is received when a node is replaced with a new node 

133 during cluster rebalancing or maintenance operations. 

134 """ 

135 

136 def __init__( 

137 self, 

138 id: int, 

139 new_node_host: Optional[str], 

140 new_node_port: Optional[int], 

141 ttl: int, 

142 ): 

143 """ 

144 Initialize a new NodeMovingNotification. 

145 

146 Args: 

147 id (int): Unique identifier for this notification 

148 new_node_host (str): Hostname or IP address of the new replacement node 

149 new_node_port (int): Port number of the new replacement node 

150 ttl (int): Time-to-live in seconds for this notification 

151 """ 

152 super().__init__(id, ttl) 

153 self.new_node_host = new_node_host 

154 self.new_node_port = new_node_port 

155 

156 def __repr__(self) -> str: 

157 expiry_time = self.expire_at 

158 remaining = max(0, expiry_time - time.monotonic()) 

159 

160 return ( 

161 f"{self.__class__.__name__}(" 

162 f"id={self.id}, " 

163 f"new_node_host='{self.new_node_host}', " 

164 f"new_node_port={self.new_node_port}, " 

165 f"ttl={self.ttl}, " 

166 f"creation_time={self.creation_time}, " 

167 f"expires_at={expiry_time}, " 

168 f"remaining={remaining:.1f}s, " 

169 f"expired={self.is_expired()}" 

170 f")" 

171 ) 

172 

173 def __eq__(self, other) -> bool: 

174 """ 

175 Two NodeMovingNotification notifications are considered equal if they have the same 

176 id, new_node_host, and new_node_port. 

177 """ 

178 if not isinstance(other, NodeMovingNotification): 

179 return False 

180 return ( 

181 self.id == other.id 

182 and self.new_node_host == other.new_node_host 

183 and self.new_node_port == other.new_node_port 

184 ) 

185 

186 def __hash__(self) -> int: 

187 """ 

188 Return a hash value for the notification to allow 

189 instances to be used in sets and as dictionary keys. 

190 

191 Returns: 

192 int: Hash value based on notification type class name, id, 

193 new_node_host and new_node_port 

194 """ 

195 try: 

196 node_port = int(self.new_node_port) if self.new_node_port else None 

197 except ValueError: 

198 node_port = 0 

199 

200 return hash( 

201 ( 

202 self.__class__.__name__, 

203 int(self.id), 

204 str(self.new_node_host), 

205 node_port, 

206 ) 

207 ) 

208 

209 

210class NodeMigratingNotification(MaintenanceNotification): 

211 """ 

212 Notification for when a Redis cluster node is in the process of migrating slots. 

213 

214 This notification is received when a node starts migrating its slots to another node 

215 during cluster rebalancing or maintenance operations. 

216 

217 Args: 

218 id (int): Unique identifier for this notification 

219 ttl (int): Time-to-live in seconds for this notification 

220 """ 

221 

222 def __init__(self, id: int, ttl: int): 

223 super().__init__(id, ttl) 

224 

225 def __repr__(self) -> str: 

226 expiry_time = self.creation_time + self.ttl 

227 remaining = max(0, expiry_time - time.monotonic()) 

228 return ( 

229 f"{self.__class__.__name__}(" 

230 f"id={self.id}, " 

231 f"ttl={self.ttl}, " 

232 f"creation_time={self.creation_time}, " 

233 f"expires_at={expiry_time}, " 

234 f"remaining={remaining:.1f}s, " 

235 f"expired={self.is_expired()}" 

236 f")" 

237 ) 

238 

239 def __eq__(self, other) -> bool: 

240 """ 

241 Two NodeMigratingNotification notifications are considered equal if they have the same 

242 id and are of the same type. 

243 """ 

244 if not isinstance(other, NodeMigratingNotification): 

245 return False 

246 return self.id == other.id and type(self) is type(other) 

247 

248 def __hash__(self) -> int: 

249 """ 

250 Return a hash value for the notification to allow 

251 instances to be used in sets and as dictionary keys. 

252 

253 Returns: 

254 int: Hash value based on notification type and id 

255 """ 

256 return hash((self.__class__.__name__, int(self.id))) 

257 

258 

259class NodeMigratedNotification(MaintenanceNotification): 

260 """ 

261 Notification for when a Redis cluster node has completed migrating slots. 

262 

263 This notification is received when a node has finished migrating all its slots 

264 to other nodes during cluster rebalancing or maintenance operations. 

265 

266 Args: 

267 id (int): Unique identifier for this notification 

268 """ 

269 

270 DEFAULT_TTL = 5 

271 

272 def __init__(self, id: int): 

273 super().__init__(id, NodeMigratedNotification.DEFAULT_TTL) 

274 

275 def __repr__(self) -> str: 

276 expiry_time = self.creation_time + self.ttl 

277 remaining = max(0, expiry_time - time.monotonic()) 

278 return ( 

279 f"{self.__class__.__name__}(" 

280 f"id={self.id}, " 

281 f"ttl={self.ttl}, " 

282 f"creation_time={self.creation_time}, " 

283 f"expires_at={expiry_time}, " 

284 f"remaining={remaining:.1f}s, " 

285 f"expired={self.is_expired()}" 

286 f")" 

287 ) 

288 

289 def __eq__(self, other) -> bool: 

290 """ 

291 Two NodeMigratedNotification notifications are considered equal if they have the same 

292 id and are of the same type. 

293 """ 

294 if not isinstance(other, NodeMigratedNotification): 

295 return False 

296 return self.id == other.id and type(self) is type(other) 

297 

298 def __hash__(self) -> int: 

299 """ 

300 Return a hash value for the notification to allow 

301 instances to be used in sets and as dictionary keys. 

302 

303 Returns: 

304 int: Hash value based on notification type and id 

305 """ 

306 return hash((self.__class__.__name__, int(self.id))) 

307 

308 

309class NodeFailingOverNotification(MaintenanceNotification): 

310 """ 

311 Notification for when a Redis cluster node is in the process of failing over. 

312 

313 This notification is received when a node starts a failover process during 

314 cluster maintenance operations or when handling node failures. 

315 

316 Args: 

317 id (int): Unique identifier for this notification 

318 ttl (int): Time-to-live in seconds for this notification 

319 """ 

320 

321 def __init__(self, id: int, ttl: int): 

322 super().__init__(id, ttl) 

323 

324 def __repr__(self) -> str: 

325 expiry_time = self.creation_time + self.ttl 

326 remaining = max(0, expiry_time - time.monotonic()) 

327 return ( 

328 f"{self.__class__.__name__}(" 

329 f"id={self.id}, " 

330 f"ttl={self.ttl}, " 

331 f"creation_time={self.creation_time}, " 

332 f"expires_at={expiry_time}, " 

333 f"remaining={remaining:.1f}s, " 

334 f"expired={self.is_expired()}" 

335 f")" 

336 ) 

337 

338 def __eq__(self, other) -> bool: 

339 """ 

340 Two NodeFailingOverNotification notifications are considered equal if they have the same 

341 id and are of the same type. 

342 """ 

343 if not isinstance(other, NodeFailingOverNotification): 

344 return False 

345 return self.id == other.id and type(self) is type(other) 

346 

347 def __hash__(self) -> int: 

348 """ 

349 Return a hash value for the notification to allow 

350 instances to be used in sets and as dictionary keys. 

351 

352 Returns: 

353 int: Hash value based on notification type and id 

354 """ 

355 return hash((self.__class__.__name__, int(self.id))) 

356 

357 

358class NodeFailedOverNotification(MaintenanceNotification): 

359 """ 

360 Notification for when a Redis cluster node has completed a failover. 

361 

362 This notification is received when a node has finished the failover process 

363 during cluster maintenance operations or after handling node failures. 

364 

365 Args: 

366 id (int): Unique identifier for this notification 

367 """ 

368 

369 DEFAULT_TTL = 5 

370 

371 def __init__(self, id: int): 

372 super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL) 

373 

374 def __repr__(self) -> str: 

375 expiry_time = self.creation_time + self.ttl 

376 remaining = max(0, expiry_time - time.monotonic()) 

377 return ( 

378 f"{self.__class__.__name__}(" 

379 f"id={self.id}, " 

380 f"ttl={self.ttl}, " 

381 f"creation_time={self.creation_time}, " 

382 f"expires_at={expiry_time}, " 

383 f"remaining={remaining:.1f}s, " 

384 f"expired={self.is_expired()}" 

385 f")" 

386 ) 

387 

388 def __eq__(self, other) -> bool: 

389 """ 

390 Two NodeFailedOverNotification notifications are considered equal if they have the same 

391 id and are of the same type. 

392 """ 

393 if not isinstance(other, NodeFailedOverNotification): 

394 return False 

395 return self.id == other.id and type(self) is type(other) 

396 

397 def __hash__(self) -> int: 

398 """ 

399 Return a hash value for the notification to allow 

400 instances to be used in sets and as dictionary keys. 

401 

402 Returns: 

403 int: Hash value based on notification type and id 

404 """ 

405 return hash((self.__class__.__name__, int(self.id))) 

406 

407 

408class OSSNodeMigratingNotification(MaintenanceNotification): 

409 """ 

410 Notification for when a Redis OSS API client is used and a node is in the process of migrating slots. 

411 

412 This notification is received when a node starts migrating its slots to another node 

413 during cluster rebalancing or maintenance operations. 

414 

415 Args: 

416 id (int): Unique identifier for this notification 

417 slots (Optional[List[int]]): List of slots being migrated 

418 """ 

419 

420 DEFAULT_TTL = 30 

421 

422 def __init__( 

423 self, 

424 id: int, 

425 slots: Optional[str] = None, 

426 ): 

427 super().__init__(id, OSSNodeMigratingNotification.DEFAULT_TTL) 

428 self.slots = slots 

429 

430 def __repr__(self) -> str: 

431 expiry_time = self.creation_time + self.ttl 

432 remaining = max(0, expiry_time - time.monotonic()) 

433 return ( 

434 f"{self.__class__.__name__}(" 

435 f"id={self.id}, " 

436 f"slots={self.slots}, " 

437 f"ttl={self.ttl}, " 

438 f"creation_time={self.creation_time}, " 

439 f"expires_at={expiry_time}, " 

440 f"remaining={remaining:.1f}s, " 

441 f"expired={self.is_expired()}" 

442 f")" 

443 ) 

444 

445 def __eq__(self, other) -> bool: 

446 """ 

447 Two OSSNodeMigratingNotification notifications are considered equal if they have the same 

448 id and are of the same type. 

449 """ 

450 if not isinstance(other, OSSNodeMigratingNotification): 

451 return False 

452 return self.id == other.id and type(self) is type(other) 

453 

454 def __hash__(self) -> int: 

455 """ 

456 Return a hash value for the notification to allow 

457 instances to be used in sets and as dictionary keys. 

458 

459 Returns: 

460 int: Hash value based on notification type and id 

461 """ 

462 return hash((self.__class__.__name__, int(self.id))) 

463 

464 

465class OSSNodeMigratedNotification(MaintenanceNotification): 

466 """ 

467 Notification for when a Redis OSS API client is used and a node has completed migrating slots. 

468 

469 This notification is received when a node has finished migrating all its slots 

470 to other nodes during cluster rebalancing or maintenance operations. 

471 

472 Args: 

473 id (int): Unique identifier for this notification 

474 nodes_to_slots_mapping (Dict[str, List[Dict[str, str]]]): Map of source node address 

475 to list of destination mappings. Each destination mapping is a dict with 

476 the destination node address as key and the slot range as value. 

477 

478 Structure example: 

479 { 

480 "127.0.0.1:6379": [ 

481 {"127.0.0.1:6380": "1-100"}, 

482 {"127.0.0.1:6381": "101-200"} 

483 ], 

484 "127.0.0.1:6382": [ 

485 {"127.0.0.1:6383": "201-300"} 

486 ] 

487 } 

488 

489 Where: 

490 - Key (str): Source node address in "host:port" format 

491 - Value (List[Dict[str, str]]): List of destination mappings where each dict 

492 contains destination node address as key and slot range as value 

493 """ 

494 

495 DEFAULT_TTL = 120 

496 

497 def __init__( 

498 self, 

499 id: int, 

500 nodes_to_slots_mapping: Dict[str, List[Dict[str, str]]], 

501 ): 

502 super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) 

503 self.nodes_to_slots_mapping = nodes_to_slots_mapping 

504 

505 def __repr__(self) -> str: 

506 expiry_time = self.creation_time + self.ttl 

507 remaining = max(0, expiry_time - time.monotonic()) 

508 return ( 

509 f"{self.__class__.__name__}(" 

510 f"id={self.id}, " 

511 f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, " 

512 f"ttl={self.ttl}, " 

513 f"creation_time={self.creation_time}, " 

514 f"expires_at={expiry_time}, " 

515 f"remaining={remaining:.1f}s, " 

516 f"expired={self.is_expired()}" 

517 f")" 

518 ) 

519 

520 def __eq__(self, other) -> bool: 

521 """ 

522 Two OSSNodeMigratedNotification notifications are considered equal if they have the same 

523 id and are of the same type. 

524 """ 

525 if not isinstance(other, OSSNodeMigratedNotification): 

526 return False 

527 return self.id == other.id and type(self) is type(other) 

528 

529 def __hash__(self) -> int: 

530 """ 

531 Return a hash value for the notification to allow 

532 instances to be used in sets and as dictionary keys. 

533 

534 Returns: 

535 int: Hash value based on notification type and id 

536 """ 

537 return hash((self.__class__.__name__, int(self.id))) 

538 

539 

540def _is_private_fqdn(host: str) -> bool: 

541 """ 

542 Determine if an FQDN is likely to be internal/private. 

543 

544 This uses heuristics based on RFC 952 and RFC 1123 standards: 

545 - .local domains (RFC 6762 - Multicast DNS) 

546 - .internal domains (common internal convention) 

547 - Single-label hostnames (no dots) 

548 - Common internal TLDs 

549 

550 Args: 

551 host (str): The FQDN to check 

552 

553 Returns: 

554 bool: True if the FQDN appears to be internal/private 

555 """ 

556 host_lower = host.lower().rstrip(".") 

557 

558 # Single-label hostnames (no dots) are typically internal 

559 if "." not in host_lower: 

560 return True 

561 

562 # Common internal/private domain patterns 

563 internal_patterns = [ 

564 r"\.local$", # mDNS/Bonjour domains 

565 r"\.internal$", # Common internal convention 

566 r"\.corp$", # Corporate domains 

567 r"\.lan$", # Local area network 

568 r"\.intranet$", # Intranet domains 

569 r"\.private$", # Private domains 

570 ] 

571 

572 for pattern in internal_patterns: 

573 if re.search(pattern, host_lower): 

574 return True 

575 

576 # If none of the internal patterns match, assume it's external 

577 return False 

578 

579 

580def add_debug_log_for_notification( 

581 connection: "MaintNotificationsAbstractConnection", 

582 notification: Union[str, MaintenanceNotification], 

583): 

584 if logger.isEnabledFor(logging.DEBUG): 

585 socket_address = None 

586 try: 

587 socket_address = ( 

588 connection._sock.getsockname() if connection._sock else None 

589 ) 

590 socket_address = socket_address[1] if socket_address else None 

591 except (AttributeError, OSError): 

592 pass 

593 

594 logger.debug( 

595 f"Handling maintenance notification: {notification}, " 

596 f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, " 

597 f"local socket port: {socket_address}", 

598 ) 

599 

600 

601class MaintNotificationsConfig: 

602 """ 

603 Configuration class for maintenance notifications handling behaviour. Notifications are received through 

604 push notifications. 

605 

606 This class defines how the Redis client should react to different push notifications 

607 such as node moving, migrations, etc. in a Redis cluster. 

608 

609 """ 

610 

611 def __init__( 

612 self, 

613 enabled: Union[bool, Literal["auto"]] = "auto", 

614 proactive_reconnect: bool = True, 

615 relaxed_timeout: Optional[Number] = 10, 

616 endpoint_type: Optional[EndpointType] = None, 

617 ): 

618 """ 

619 Initialize a new MaintNotificationsConfig. 

620 

621 Args: 

622 enabled (bool | "auto"): Controls maintenance notifications handling behavior. 

623 - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup, 

624 otherwise a ResponseError is raised. 

625 - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are 

626 gracefully handled - a warning is logged and normal operation continues. 

627 - False: Maintenance notifications are completely disabled. 

628 Defaults to "auto". 

629 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. 

630 Defaults to True. 

631 relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance. 

632 If -1 is provided - the relaxed timeout is disabled. Defaults to 20. 

633 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS. 

634 If None, the endpoint type will be automatically determined based on the host and TLS configuration. 

635 Defaults to None. 

636 

637 Raises: 

638 ValueError: If endpoint_type is provided but is not a valid endpoint type. 

639 """ 

640 self.enabled = enabled 

641 self.relaxed_timeout = relaxed_timeout 

642 self.proactive_reconnect = proactive_reconnect 

643 self.endpoint_type = endpoint_type 

644 

645 def __repr__(self) -> str: 

646 return ( 

647 f"{self.__class__.__name__}(" 

648 f"enabled={self.enabled}, " 

649 f"proactive_reconnect={self.proactive_reconnect}, " 

650 f"relaxed_timeout={self.relaxed_timeout}, " 

651 f"endpoint_type={self.endpoint_type!r}" 

652 f")" 

653 ) 

654 

655 def is_relaxed_timeouts_enabled(self) -> bool: 

656 """ 

657 Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout. 

658 If relaxed_timeout is set to None, it will make the operation blocking 

659 and waiting until any response is received. 

660 

661 Returns: 

662 True if the relaxed_timeout is enabled, False otherwise. 

663 """ 

664 return self.relaxed_timeout != -1 

665 

666 def get_endpoint_type( 

667 self, host: str, connection: "MaintNotificationsAbstractConnection" 

668 ) -> EndpointType: 

669 """ 

670 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. 

671 

672 Logic: 

673 1. If endpoint_type is explicitly set, use it 

674 2. Otherwise, check the original host from connection.host: 

675 - If host is an IP address, use it directly to determine internal-ip vs external-ip 

676 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn 

677 

678 Args: 

679 host: User provided hostname to analyze 

680 connection: The connection object to analyze for endpoint type determination 

681 

682 Returns: 

683 """ 

684 

685 # If endpoint_type is explicitly set, use it 

686 if self.endpoint_type is not None: 

687 return self.endpoint_type 

688 

689 # Check if the host is an IP address 

690 try: 

691 ip_addr = ipaddress.ip_address(host) 

692 # Host is an IP address - use it directly 

693 is_private = ip_addr.is_private 

694 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP 

695 except ValueError: 

696 # Host is an FQDN - need to check resolved IP to determine internal vs external 

697 pass 

698 

699 # Host is an FQDN, get the resolved IP to determine if it's internal or external 

700 resolved_ip = connection.get_resolved_ip() 

701 

702 if resolved_ip: 

703 try: 

704 ip_addr = ipaddress.ip_address(resolved_ip) 

705 is_private = ip_addr.is_private 

706 # Use FQDN types since the original host was an FQDN 

707 return ( 

708 EndpointType.INTERNAL_FQDN 

709 if is_private 

710 else EndpointType.EXTERNAL_FQDN 

711 ) 

712 except ValueError: 

713 # This shouldn't happen since we got the IP from the socket, but fallback 

714 pass 

715 

716 # Final fallback: use heuristics on the FQDN itself 

717 is_private = _is_private_fqdn(host) 

718 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN 

719 

720 

721class MaintNotificationsPoolHandler: 

722 def __init__( 

723 self, 

724 pool: "MaintNotificationsAbstractConnectionPool", 

725 config: MaintNotificationsConfig, 

726 ) -> None: 

727 self.pool = pool 

728 self.config = config 

729 self._processed_notifications = set() 

730 self._lock = threading.RLock() 

731 self.connection = None 

732 

733 def set_connection(self, connection: "MaintNotificationsAbstractConnection"): 

734 self.connection = connection 

735 

736 def get_handler_for_connection(self): 

737 # Copy all data that should be shared between connections 

738 # but each connection should have its own pool handler 

739 # since each connection can be in a different state 

740 copy = MaintNotificationsPoolHandler(self.pool, self.config) 

741 copy._processed_notifications = self._processed_notifications 

742 copy._lock = self._lock 

743 copy.connection = None 

744 return copy 

745 

746 def remove_expired_notifications(self): 

747 with self._lock: 

748 for notification in tuple(self._processed_notifications): 

749 if notification.is_expired(): 

750 self._processed_notifications.remove(notification) 

751 

752 def handle_notification(self, notification: MaintenanceNotification): 

753 self.remove_expired_notifications() 

754 

755 if isinstance(notification, NodeMovingNotification): 

756 return self.handle_node_moving_notification(notification) 

757 else: 

758 logger.error(f"Unhandled notification type: {notification}") 

759 

760 def handle_node_moving_notification(self, notification: NodeMovingNotification): 

761 if ( 

762 not self.config.proactive_reconnect 

763 and not self.config.is_relaxed_timeouts_enabled() 

764 ): 

765 return 

766 with self._lock: 

767 if notification in self._processed_notifications: 

768 # nothing to do in the connection pool handling 

769 # the notification has already been handled or is expired 

770 # just return 

771 return 

772 

773 with self.pool._lock: 

774 logger.debug( 

775 f"Handling node MOVING notification: {notification}, " 

776 f"with connection: {self.connection}, connected to ip " 

777 f"{self.connection.get_resolved_ip() if self.connection else None}" 

778 ) 

779 if ( 

780 self.config.proactive_reconnect 

781 or self.config.is_relaxed_timeouts_enabled() 

782 ): 

783 # Get the current connected address - if any 

784 # This is the address that is being moved 

785 # and we need to handle only connections 

786 # connected to the same address 

787 moving_address_src = ( 

788 self.connection.getpeername() if self.connection else None 

789 ) 

790 

791 if getattr(self.pool, "set_in_maintenance", False): 

792 # Set pool in maintenance mode - executed only if 

793 # BlockingConnectionPool is used 

794 self.pool.set_in_maintenance(True) 

795 

796 # Update maintenance state, timeout and optionally host address 

797 # connection settings for matching connections 

798 self.pool.update_connections_settings( 

799 state=MaintenanceState.MOVING, 

800 maintenance_notification_hash=hash(notification), 

801 relaxed_timeout=self.config.relaxed_timeout, 

802 host_address=notification.new_node_host, 

803 matching_address=moving_address_src, 

804 matching_pattern="connected_address", 

805 update_notification_hash=True, 

806 include_free_connections=True, 

807 ) 

808 

809 if self.config.proactive_reconnect: 

810 if notification.new_node_host is not None: 

811 self.run_proactive_reconnect(moving_address_src) 

812 else: 

813 threading.Timer( 

814 notification.ttl / 2, 

815 self.run_proactive_reconnect, 

816 args=(moving_address_src,), 

817 ).start() 

818 

819 # Update config for new connections: 

820 # Set state to MOVING 

821 # update host 

822 # if relax timeouts are enabled - update timeouts 

823 kwargs: dict = { 

824 "maintenance_state": MaintenanceState.MOVING, 

825 "maintenance_notification_hash": hash(notification), 

826 } 

827 if notification.new_node_host is not None: 

828 # the host is not updated if the new node host is None 

829 # this happens when the MOVING push notification does not contain 

830 # the new node host - in this case we only update the timeouts 

831 kwargs.update( 

832 { 

833 "host": notification.new_node_host, 

834 } 

835 ) 

836 if self.config.is_relaxed_timeouts_enabled(): 

837 kwargs.update( 

838 { 

839 "socket_timeout": self.config.relaxed_timeout, 

840 "socket_connect_timeout": self.config.relaxed_timeout, 

841 } 

842 ) 

843 self.pool.update_connection_kwargs(**kwargs) 

844 

845 if getattr(self.pool, "set_in_maintenance", False): 

846 self.pool.set_in_maintenance(False) 

847 

848 threading.Timer( 

849 notification.ttl, 

850 self.handle_node_moved_notification, 

851 args=(notification,), 

852 ).start() 

853 

854 record_connection_handoff( 

855 pool_name=get_pool_name(self.pool), 

856 ) 

857 

858 self._processed_notifications.add(notification) 

859 

860 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None): 

861 """ 

862 Run proactive reconnect for the pool. 

863 Active connections are marked for reconnect after they complete the current command. 

864 Inactive connections are disconnected and will be connected on next use. 

865 """ 

866 with self._lock: 

867 with self.pool._lock: 

868 # take care for the active connections in the pool 

869 # mark them for reconnect after they complete the current command 

870 self.pool.update_active_connections_for_reconnect( 

871 moving_address_src=moving_address_src, 

872 ) 

873 # take care for the inactive connections in the pool 

874 # delete them and create new ones 

875 self.pool.disconnect_free_connections( 

876 moving_address_src=moving_address_src, 

877 ) 

878 

879 def handle_node_moved_notification(self, notification: NodeMovingNotification): 

880 """ 

881 Handle the cleanup after a node moving notification expires. 

882 """ 

883 notification_hash = hash(notification) 

884 

885 with self._lock: 

886 logger.debug( 

887 f"Reverting temporary changes related to notification: {notification}, " 

888 f"with connection: {self.connection}, connected to ip " 

889 f"{self.connection.get_resolved_ip() if self.connection else None}" 

890 ) 

891 # if the current maintenance_notification_hash in kwargs is not matching the notification 

892 # it means there has been a new moving notification after this one 

893 # and we don't need to revert the kwargs yet 

894 if ( 

895 self.pool.connection_kwargs.get("maintenance_notification_hash") 

896 == notification_hash 

897 ): 

898 orig_host = self.pool.connection_kwargs.get("orig_host_address") 

899 orig_socket_timeout = self.pool.connection_kwargs.get( 

900 "orig_socket_timeout" 

901 ) 

902 orig_connect_timeout = self.pool.connection_kwargs.get( 

903 "orig_socket_connect_timeout" 

904 ) 

905 kwargs: dict = { 

906 "maintenance_state": MaintenanceState.NONE, 

907 "maintenance_notification_hash": None, 

908 "host": orig_host, 

909 "socket_timeout": orig_socket_timeout, 

910 "socket_connect_timeout": orig_connect_timeout, 

911 } 

912 self.pool.update_connection_kwargs(**kwargs) 

913 

914 with self.pool._lock: 

915 reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled() 

916 reset_host_address = self.config.proactive_reconnect 

917 

918 self.pool.update_connections_settings( 

919 relaxed_timeout=-1, 

920 state=MaintenanceState.NONE, 

921 maintenance_notification_hash=None, 

922 matching_notification_hash=notification_hash, 

923 matching_pattern="notification_hash", 

924 update_notification_hash=True, 

925 reset_relaxed_timeout=reset_relaxed_timeout, 

926 reset_host_address=reset_host_address, 

927 include_free_connections=True, 

928 ) 

929 

930 

931class MaintNotificationsConnectionHandler: 

932 # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications 

933 _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = { 

934 NodeMigratingNotification: 1, 

935 NodeFailingOverNotification: 1, 

936 OSSNodeMigratingNotification: 1, 

937 NodeMigratedNotification: 0, 

938 NodeFailedOverNotification: 0, 

939 OSSNodeMigratedNotification: 0, 

940 } 

941 

942 def __init__( 

943 self, 

944 connection: "MaintNotificationsAbstractConnection", 

945 config: MaintNotificationsConfig, 

946 ) -> None: 

947 self.connection = connection 

948 self.config = config 

949 

950 def handle_notification(self, notification: MaintenanceNotification): 

951 # get the notification type by checking its class in the _NOTIFICATION_TYPES dict 

952 notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None) 

953 

954 record_maint_notification_count( 

955 server_address=self.connection.host, 

956 server_port=self.connection.port, 

957 network_peer_address=self.connection.host, 

958 network_peer_port=self.connection.port, 

959 maint_notification=notification.__class__.__name__, 

960 ) 

961 

962 if notification_type is None: 

963 logger.error(f"Unhandled notification type: {notification}") 

964 return 

965 

966 if notification_type: 

967 self.handle_maintenance_start_notification( 

968 MaintenanceState.MAINTENANCE, notification 

969 ) 

970 else: 

971 self.handle_maintenance_completed_notification(notification=notification) 

972 

973 def handle_maintenance_start_notification( 

974 self, maintenance_state: MaintenanceState, notification: MaintenanceNotification 

975 ): 

976 add_debug_log_for_notification(self.connection, notification) 

977 

978 if ( 

979 self.connection.maintenance_state == MaintenanceState.MOVING 

980 or not self.config.is_relaxed_timeouts_enabled() 

981 ): 

982 return 

983 

984 self.connection.maintenance_state = maintenance_state 

985 self.connection.set_tmp_settings( 

986 tmp_relaxed_timeout=self.config.relaxed_timeout 

987 ) 

988 # extend the timeout for all created connections 

989 self.connection.update_current_socket_timeout(self.config.relaxed_timeout) 

990 if isinstance(notification, OSSNodeMigratingNotification): 

991 # add the notification id to the set of processed start maint notifications 

992 # this is used to skip the unrelaxing of the timeouts if we have received more than 

993 # one start notification before the the final end notification 

994 self.connection.add_maint_start_notification(notification.id) 

995 

996 record_connection_relaxed_timeout( 

997 connection_name=repr(self.connection), 

998 maint_notification=notification.__class__.__name__, 

999 relaxed=True, 

1000 ) 

1001 

1002 def handle_maintenance_completed_notification(self, **kwargs): 

1003 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled 

1004 if ( 

1005 self.connection.maintenance_state == MaintenanceState.MOVING 

1006 or not self.config.is_relaxed_timeouts_enabled() 

1007 ): 

1008 return 

1009 notification = None 

1010 if kwargs.get("notification"): 

1011 notification = kwargs["notification"] 

1012 add_debug_log_for_notification( 

1013 self.connection, notification if notification else "MAINTENANCE_COMPLETED" 

1014 ) 

1015 self.connection.reset_tmp_settings(reset_relaxed_timeout=True) 

1016 # Maintenance completed - reset the connection 

1017 # timeouts by providing -1 as the relaxed timeout 

1018 self.connection.update_current_socket_timeout(-1) 

1019 self.connection.maintenance_state = MaintenanceState.NONE 

1020 # reset the sets that keep track of received start maint 

1021 # notifications and skipped end maint notifications 

1022 self.connection.reset_received_notifications() 

1023 

1024 if notification: 

1025 record_connection_relaxed_timeout( 

1026 connection_name=repr(self.connection), 

1027 maint_notification=notification.__class__.__name__, 

1028 relaxed=False, 

1029 ) 

1030 

1031 

1032class OSSMaintNotificationsHandler: 

1033 def __init__( 

1034 self, 

1035 cluster_client: "MaintNotificationsAbstractRedisCluster", 

1036 config: MaintNotificationsConfig, 

1037 ) -> None: 

1038 self.cluster_client = cluster_client 

1039 self.config = config 

1040 self._processed_notifications = set() 

1041 self._in_progress = set() 

1042 self._lock = threading.RLock() 

1043 

1044 def get_handler_for_connection(self): 

1045 # Copy all data that should be shared between connections 

1046 # but each connection should have its own pool handler 

1047 # since each connection can be in a different state 

1048 copy = OSSMaintNotificationsHandler(self.cluster_client, self.config) 

1049 copy._processed_notifications = self._processed_notifications 

1050 copy._in_progress = self._in_progress 

1051 copy._lock = self._lock 

1052 return copy 

1053 

1054 def remove_expired_notifications(self): 

1055 with self._lock: 

1056 for notification in tuple(self._processed_notifications): 

1057 if notification.is_expired(): 

1058 self._processed_notifications.remove(notification) 

1059 

1060 def handle_notification(self, notification: MaintenanceNotification): 

1061 if isinstance(notification, OSSNodeMigratedNotification): 

1062 self.handle_oss_maintenance_completed_notification(notification) 

1063 else: 

1064 logger.error(f"Unhandled notification type: {notification}") 

1065 

1066 def handle_oss_maintenance_completed_notification( 

1067 self, notification: OSSNodeMigratedNotification 

1068 ): 

1069 self.remove_expired_notifications() 

1070 

1071 with self._lock: 

1072 if ( 

1073 notification in self._in_progress 

1074 or notification in self._processed_notifications 

1075 ): 

1076 # we are already handling this notification or it has already been processed 

1077 # we should skip in_progress notification since when we reinitialize the cluster 

1078 # we execute a CLUSTER SLOTS command that can use a different connection 

1079 # that has also has the notification and we don't want to 

1080 # process the same notification twice 

1081 return 

1082 

1083 if logger.isEnabledFor(logging.DEBUG): 

1084 logger.debug(f"Handling SMIGRATED notification: {notification}") 

1085 self._in_progress.add(notification) 

1086 

1087 # Extract the information about the src and destination nodes that are affected 

1088 # by the maintenance. nodes_to_slots_mapping structure: 

1089 # { 

1090 # "src_host:port": [ 

1091 # {"dest_host:port": "slot_range"}, 

1092 # ... 

1093 # ], 

1094 # ... 

1095 # } 

1096 additional_startup_nodes_info = [] 

1097 affected_nodes = set() 

1098 for ( 

1099 src_address, 

1100 dest_mappings, 

1101 ) in notification.nodes_to_slots_mapping.items(): 

1102 src_host, src_port = src_address.split(":") 

1103 src_node = self.cluster_client.nodes_manager.get_node( 

1104 host=src_host, port=src_port 

1105 ) 

1106 if src_node is not None: 

1107 affected_nodes.add(src_node) 

1108 

1109 for dest_mapping in dest_mappings: 

1110 for dest_address in dest_mapping.keys(): 

1111 dest_host, dest_port = dest_address.split(":") 

1112 additional_startup_nodes_info.append( 

1113 (dest_host, int(dest_port)) 

1114 ) 

1115 

1116 # Updates the cluster slots cache with the new slots mapping 

1117 # This will also update the nodes cache with the new nodes mapping 

1118 self.cluster_client.nodes_manager.initialize( 

1119 disconnect_startup_nodes_pools=False, 

1120 additional_startup_nodes_info=additional_startup_nodes_info, 

1121 ) 

1122 

1123 all_nodes = set(affected_nodes) 

1124 all_nodes = all_nodes.union( 

1125 self.cluster_client.nodes_manager.nodes_cache.values() 

1126 ) 

1127 

1128 for current_node in all_nodes: 

1129 if current_node.redis_connection is None: 

1130 continue 

1131 with current_node.redis_connection.connection_pool._lock: 

1132 if current_node in affected_nodes: 

1133 # mark for reconnect all in use connections to the node - this will force them to 

1134 # disconnect after they complete their current commands 

1135 # Some of them might be used by sub sub and we don't know which ones - so we disconnect 

1136 # all in flight connections after they are done with current command execution 

1137 for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): 

1138 add_debug_log_for_notification( 

1139 conn, "SMIGRATED - mark for reconnect" 

1140 ) 

1141 conn.mark_for_reconnect() 

1142 else: 

1143 if logger.isEnabledFor(logging.DEBUG): 

1144 logger.debug( 

1145 f"SMIGRATED: Node {current_node.name} not affected by maintenance, " 

1146 f"skipping mark for reconnect" 

1147 ) 

1148 

1149 if ( 

1150 current_node 

1151 not in self.cluster_client.nodes_manager.nodes_cache.values() 

1152 ): 

1153 # disconnect all free connections to the node - this node will be dropped 

1154 # from the cluster, so we don't need to revert the timeouts 

1155 for conn in current_node.redis_connection.connection_pool._get_free_connections(): 

1156 conn.disconnect() 

1157 

1158 # mark the notification as processed 

1159 self._processed_notifications.add(notification) 

1160 self._in_progress.remove(notification)