Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/maint_notifications.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

377 statements  

1import enum 

2import ipaddress 

3import logging 

4import re 

5import threading 

6import time 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union 

9 

10from redis.observability.attributes import get_pool_name 

11from redis.observability.recorder import ( 

12 record_connection_handoff, 

13 record_connection_relaxed_timeout, 

14 record_maint_notification_count, 

15) 

16from redis.typing import Number 

17 

18if TYPE_CHECKING: 

19 from redis.cluster import MaintNotificationsAbstractRedisCluster 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class MaintenanceState(enum.Enum): 

25 NONE = "none" 

26 MOVING = "moving" 

27 MAINTENANCE = "maintenance" 

28 

29 

30class EndpointType(enum.Enum): 

31 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command.""" 

32 

33 INTERNAL_IP = "internal-ip" 

34 INTERNAL_FQDN = "internal-fqdn" 

35 EXTERNAL_IP = "external-ip" 

36 EXTERNAL_FQDN = "external-fqdn" 

37 NONE = "none" 

38 

39 def __str__(self): 

40 """Return the string value of the enum.""" 

41 return self.value 

42 

43 

44if TYPE_CHECKING: 

45 from redis.connection import ( 

46 MaintNotificationsAbstractConnection, 

47 MaintNotificationsAbstractConnectionPool, 

48 ) 

49 

50 

51class MaintenanceNotification(ABC): 

52 """ 

53 Base class for maintenance notifications sent through push messages by Redis server. 

54 

55 This class provides common functionality for all maintenance notifications including 

56 unique identification and TTL (Time-To-Live) functionality. 

57 

58 Attributes: 

59 id (int): Unique identifier for this notification 

60 ttl (int): Time-to-live in seconds for this notification 

61 creation_time (float): Timestamp when the notification was created/read 

62 """ 

63 

64 def __init__(self, id: int, ttl: int): 

65 """ 

66 Initialize a new MaintenanceNotification with unique ID and TTL functionality. 

67 

68 Args: 

69 id (int): Unique identifier for this notification 

70 ttl (int): Time-to-live in seconds for this notification 

71 """ 

72 self.id = id 

73 self.ttl = ttl 

74 self.creation_time = time.monotonic() 

75 self.expire_at = self.creation_time + self.ttl 

76 

77 def is_expired(self) -> bool: 

78 """ 

79 Check if this notification has expired based on its TTL 

80 and creation time. 

81 

82 Returns: 

83 bool: True if the notification has expired, False otherwise 

84 """ 

85 return time.monotonic() > (self.creation_time + self.ttl) 

86 

87 @abstractmethod 

88 def __repr__(self) -> str: 

89 """ 

90 Return a string representation of the maintenance notification. 

91 

92 This method must be implemented by all concrete subclasses. 

93 

94 Returns: 

95 str: String representation of the notification 

96 """ 

97 pass 

98 

99 @abstractmethod 

100 def __eq__(self, other) -> bool: 

101 """ 

102 Compare two maintenance notifications for equality. 

103 

104 This method must be implemented by all concrete subclasses. 

105 Notifications are typically considered equal if they have the same id 

106 and are of the same type. 

107 

108 Args: 

109 other: The other object to compare with 

110 

111 Returns: 

112 bool: True if the notifications are equal, False otherwise 

113 """ 

114 pass 

115 

116 @abstractmethod 

117 def __hash__(self) -> int: 

118 """ 

119 Return a hash value for the maintenance notification. 

120 

121 This method must be implemented by all concrete subclasses to allow 

122 instances to be used in sets and as dictionary keys. 

123 

124 Returns: 

125 int: Hash value for the notification 

126 """ 

127 pass 

128 

129 

130class NodeMovingNotification(MaintenanceNotification): 

131 """ 

132 This notification is received when a node is replaced with a new node 

133 during cluster rebalancing or maintenance operations. 

134 """ 

135 

136 def __init__( 

137 self, 

138 id: int, 

139 new_node_host: Optional[str], 

140 new_node_port: Optional[int], 

141 ttl: int, 

142 ): 

143 """ 

144 Initialize a new NodeMovingNotification. 

145 

146 Args: 

147 id (int): Unique identifier for this notification 

148 new_node_host (str): Hostname or IP address of the new replacement node 

149 new_node_port (int): Port number of the new replacement node 

150 ttl (int): Time-to-live in seconds for this notification 

151 """ 

152 super().__init__(id, ttl) 

153 self.new_node_host = new_node_host 

154 self.new_node_port = new_node_port 

155 

156 def __repr__(self) -> str: 

157 expiry_time = self.expire_at 

158 remaining = max(0, expiry_time - time.monotonic()) 

159 

160 return ( 

161 f"{self.__class__.__name__}(" 

162 f"id={self.id}, " 

163 f"new_node_host='{self.new_node_host}', " 

164 f"new_node_port={self.new_node_port}, " 

165 f"ttl={self.ttl}, " 

166 f"creation_time={self.creation_time}, " 

167 f"expires_at={expiry_time}, " 

168 f"remaining={remaining:.1f}s, " 

169 f"expired={self.is_expired()}" 

170 f")" 

171 ) 

172 

173 def __eq__(self, other) -> bool: 

174 """ 

175 Two NodeMovingNotification notifications are considered equal if they have the same 

176 id, new_node_host, and new_node_port. 

177 """ 

178 if not isinstance(other, NodeMovingNotification): 

179 return False 

180 return ( 

181 self.id == other.id 

182 and self.new_node_host == other.new_node_host 

183 and self.new_node_port == other.new_node_port 

184 ) 

185 

186 def __hash__(self) -> int: 

187 """ 

188 Return a hash value for the notification to allow 

189 instances to be used in sets and as dictionary keys. 

190 

191 Returns: 

192 int: Hash value based on notification type class name, id, 

193 new_node_host and new_node_port 

194 """ 

195 try: 

196 node_port = int(self.new_node_port) if self.new_node_port else None 

197 except ValueError: 

198 node_port = 0 

199 

200 return hash( 

201 ( 

202 self.__class__.__name__, 

203 int(self.id), 

204 str(self.new_node_host), 

205 node_port, 

206 ) 

207 ) 

208 

209 

210class NodeMigratingNotification(MaintenanceNotification): 

211 """ 

212 Notification for when a Redis cluster node is in the process of migrating slots. 

213 

214 This notification is received when a node starts migrating its slots to another node 

215 during cluster rebalancing or maintenance operations. 

216 

217 Args: 

218 id (int): Unique identifier for this notification 

219 ttl (int): Time-to-live in seconds for this notification 

220 """ 

221 

222 def __init__(self, id: int, ttl: int): 

223 super().__init__(id, ttl) 

224 

225 def __repr__(self) -> str: 

226 expiry_time = self.creation_time + self.ttl 

227 remaining = max(0, expiry_time - time.monotonic()) 

228 return ( 

229 f"{self.__class__.__name__}(" 

230 f"id={self.id}, " 

231 f"ttl={self.ttl}, " 

232 f"creation_time={self.creation_time}, " 

233 f"expires_at={expiry_time}, " 

234 f"remaining={remaining:.1f}s, " 

235 f"expired={self.is_expired()}" 

236 f")" 

237 ) 

238 

239 def __eq__(self, other) -> bool: 

240 """ 

241 Two NodeMigratingNotification notifications are considered equal if they have the same 

242 id and are of the same type. 

243 """ 

244 if not isinstance(other, NodeMigratingNotification): 

245 return False 

246 return self.id == other.id and type(self) is type(other) 

247 

248 def __hash__(self) -> int: 

249 """ 

250 Return a hash value for the notification to allow 

251 instances to be used in sets and as dictionary keys. 

252 

253 Returns: 

254 int: Hash value based on notification type and id 

255 """ 

256 return hash((self.__class__.__name__, int(self.id))) 

257 

258 

259class NodeMigratedNotification(MaintenanceNotification): 

260 """ 

261 Notification for when a Redis cluster node has completed migrating slots. 

262 

263 This notification is received when a node has finished migrating all its slots 

264 to other nodes during cluster rebalancing or maintenance operations. 

265 

266 Args: 

267 id (int): Unique identifier for this notification 

268 """ 

269 

270 DEFAULT_TTL = 5 

271 

272 def __init__(self, id: int): 

273 super().__init__(id, NodeMigratedNotification.DEFAULT_TTL) 

274 

275 def __repr__(self) -> str: 

276 expiry_time = self.creation_time + self.ttl 

277 remaining = max(0, expiry_time - time.monotonic()) 

278 return ( 

279 f"{self.__class__.__name__}(" 

280 f"id={self.id}, " 

281 f"ttl={self.ttl}, " 

282 f"creation_time={self.creation_time}, " 

283 f"expires_at={expiry_time}, " 

284 f"remaining={remaining:.1f}s, " 

285 f"expired={self.is_expired()}" 

286 f")" 

287 ) 

288 

289 def __eq__(self, other) -> bool: 

290 """ 

291 Two NodeMigratedNotification notifications are considered equal if they have the same 

292 id and are of the same type. 

293 """ 

294 if not isinstance(other, NodeMigratedNotification): 

295 return False 

296 return self.id == other.id and type(self) is type(other) 

297 

298 def __hash__(self) -> int: 

299 """ 

300 Return a hash value for the notification to allow 

301 instances to be used in sets and as dictionary keys. 

302 

303 Returns: 

304 int: Hash value based on notification type and id 

305 """ 

306 return hash((self.__class__.__name__, int(self.id))) 

307 

308 

309class NodeFailingOverNotification(MaintenanceNotification): 

310 """ 

311 Notification for when a Redis cluster node is in the process of failing over. 

312 

313 This notification is received when a node starts a failover process during 

314 cluster maintenance operations or when handling node failures. 

315 

316 Args: 

317 id (int): Unique identifier for this notification 

318 ttl (int): Time-to-live in seconds for this notification 

319 """ 

320 

321 def __init__(self, id: int, ttl: int): 

322 super().__init__(id, ttl) 

323 

324 def __repr__(self) -> str: 

325 expiry_time = self.creation_time + self.ttl 

326 remaining = max(0, expiry_time - time.monotonic()) 

327 return ( 

328 f"{self.__class__.__name__}(" 

329 f"id={self.id}, " 

330 f"ttl={self.ttl}, " 

331 f"creation_time={self.creation_time}, " 

332 f"expires_at={expiry_time}, " 

333 f"remaining={remaining:.1f}s, " 

334 f"expired={self.is_expired()}" 

335 f")" 

336 ) 

337 

338 def __eq__(self, other) -> bool: 

339 """ 

340 Two NodeFailingOverNotification notifications are considered equal if they have the same 

341 id and are of the same type. 

342 """ 

343 if not isinstance(other, NodeFailingOverNotification): 

344 return False 

345 return self.id == other.id and type(self) is type(other) 

346 

347 def __hash__(self) -> int: 

348 """ 

349 Return a hash value for the notification to allow 

350 instances to be used in sets and as dictionary keys. 

351 

352 Returns: 

353 int: Hash value based on notification type and id 

354 """ 

355 return hash((self.__class__.__name__, int(self.id))) 

356 

357 

358class NodeFailedOverNotification(MaintenanceNotification): 

359 """ 

360 Notification for when a Redis cluster node has completed a failover. 

361 

362 This notification is received when a node has finished the failover process 

363 during cluster maintenance operations or after handling node failures. 

364 

365 Args: 

366 id (int): Unique identifier for this notification 

367 """ 

368 

369 DEFAULT_TTL = 5 

370 

371 def __init__(self, id: int): 

372 super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL) 

373 

374 def __repr__(self) -> str: 

375 expiry_time = self.creation_time + self.ttl 

376 remaining = max(0, expiry_time - time.monotonic()) 

377 return ( 

378 f"{self.__class__.__name__}(" 

379 f"id={self.id}, " 

380 f"ttl={self.ttl}, " 

381 f"creation_time={self.creation_time}, " 

382 f"expires_at={expiry_time}, " 

383 f"remaining={remaining:.1f}s, " 

384 f"expired={self.is_expired()}" 

385 f")" 

386 ) 

387 

388 def __eq__(self, other) -> bool: 

389 """ 

390 Two NodeFailedOverNotification notifications are considered equal if they have the same 

391 id and are of the same type. 

392 """ 

393 if not isinstance(other, NodeFailedOverNotification): 

394 return False 

395 return self.id == other.id and type(self) is type(other) 

396 

397 def __hash__(self) -> int: 

398 """ 

399 Return a hash value for the notification to allow 

400 instances to be used in sets and as dictionary keys. 

401 

402 Returns: 

403 int: Hash value based on notification type and id 

404 """ 

405 return hash((self.__class__.__name__, int(self.id))) 

406 

407 

408class OSSNodeMigratingNotification(MaintenanceNotification): 

409 """ 

410 Notification for when a Redis OSS API client is used and a node is in the process of migrating slots. 

411 

412 This notification is received when a node starts migrating its slots to another node 

413 during cluster rebalancing or maintenance operations. 

414 

415 Args: 

416 id (int): Unique identifier for this notification 

417 slots (Optional[List[int]]): List of slots being migrated 

418 """ 

419 

420 DEFAULT_TTL = 30 

421 

422 def __init__( 

423 self, 

424 id: int, 

425 slots: Optional[str] = None, 

426 ): 

427 super().__init__(id, OSSNodeMigratingNotification.DEFAULT_TTL) 

428 self.slots = slots 

429 

430 def __repr__(self) -> str: 

431 expiry_time = self.creation_time + self.ttl 

432 remaining = max(0, expiry_time - time.monotonic()) 

433 return ( 

434 f"{self.__class__.__name__}(" 

435 f"id={self.id}, " 

436 f"slots={self.slots}, " 

437 f"ttl={self.ttl}, " 

438 f"creation_time={self.creation_time}, " 

439 f"expires_at={expiry_time}, " 

440 f"remaining={remaining:.1f}s, " 

441 f"expired={self.is_expired()}" 

442 f")" 

443 ) 

444 

445 def __eq__(self, other) -> bool: 

446 """ 

447 Two OSSNodeMigratingNotification notifications are considered equal if they have the same 

448 id and are of the same type. 

449 """ 

450 if not isinstance(other, OSSNodeMigratingNotification): 

451 return False 

452 return self.id == other.id and type(self) is type(other) 

453 

454 def __hash__(self) -> int: 

455 """ 

456 Return a hash value for the notification to allow 

457 instances to be used in sets and as dictionary keys. 

458 

459 Returns: 

460 int: Hash value based on notification type and id 

461 """ 

462 return hash((self.__class__.__name__, int(self.id))) 

463 

464 

465class OSSNodeMigratedNotification(MaintenanceNotification): 

466 """ 

467 Notification for when a Redis OSS API client is used and a node has completed migrating slots. 

468 

469 This notification is received when a node has finished migrating all its slots 

470 to other nodes during cluster rebalancing or maintenance operations. 

471 

472 Args: 

473 id (int): Unique identifier for this notification 

474 nodes_to_slots_mapping (Dict[str, List[Dict[str, str]]]): Map of source node address 

475 to list of destination mappings. Each destination mapping is a dict with 

476 the destination node address as key and the slot range as value. 

477 

478 Structure example: 

479 { 

480 "127.0.0.1:6379": [ 

481 {"127.0.0.1:6380": "1-100"}, 

482 {"127.0.0.1:6381": "101-200"} 

483 ], 

484 "127.0.0.1:6382": [ 

485 {"127.0.0.1:6383": "201-300"} 

486 ] 

487 } 

488 

489 Where: 

490 - Key (str): Source node address in "host:port" format 

491 - Value (List[Dict[str, str]]): List of destination mappings where each dict 

492 contains destination node address as key and slot range as value 

493 """ 

494 

495 DEFAULT_TTL = 120 

496 

497 def __init__( 

498 self, 

499 id: int, 

500 nodes_to_slots_mapping: Dict[str, List[Dict[str, str]]], 

501 ): 

502 super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) 

503 self.nodes_to_slots_mapping = nodes_to_slots_mapping 

504 

505 def __repr__(self) -> str: 

506 expiry_time = self.creation_time + self.ttl 

507 remaining = max(0, expiry_time - time.monotonic()) 

508 return ( 

509 f"{self.__class__.__name__}(" 

510 f"id={self.id}, " 

511 f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, " 

512 f"ttl={self.ttl}, " 

513 f"creation_time={self.creation_time}, " 

514 f"expires_at={expiry_time}, " 

515 f"remaining={remaining:.1f}s, " 

516 f"expired={self.is_expired()}" 

517 f")" 

518 ) 

519 

520 def __eq__(self, other) -> bool: 

521 """ 

522 Two OSSNodeMigratedNotification notifications are considered equal if they have the same 

523 id and are of the same type. 

524 """ 

525 if not isinstance(other, OSSNodeMigratedNotification): 

526 return False 

527 return self.id == other.id and type(self) is type(other) 

528 

529 def __hash__(self) -> int: 

530 """ 

531 Return a hash value for the notification to allow 

532 instances to be used in sets and as dictionary keys. 

533 

534 Returns: 

535 int: Hash value based on notification type and id 

536 """ 

537 return hash((self.__class__.__name__, int(self.id))) 

538 

539 

540def _is_private_fqdn(host: str) -> bool: 

541 """ 

542 Determine if an FQDN is likely to be internal/private. 

543 

544 This uses heuristics based on RFC 952 and RFC 1123 standards: 

545 - .local domains (RFC 6762 - Multicast DNS) 

546 - .internal domains (common internal convention) 

547 - Single-label hostnames (no dots) 

548 - Common internal TLDs 

549 

550 Args: 

551 host (str): The FQDN to check 

552 

553 Returns: 

554 bool: True if the FQDN appears to be internal/private 

555 """ 

556 host_lower = host.lower().rstrip(".") 

557 

558 # Single-label hostnames (no dots) are typically internal 

559 if "." not in host_lower: 

560 return True 

561 

562 # Common internal/private domain patterns 

563 internal_patterns = [ 

564 r"\.local$", # mDNS/Bonjour domains 

565 r"\.internal$", # Common internal convention 

566 r"\.corp$", # Corporate domains 

567 r"\.lan$", # Local area network 

568 r"\.intranet$", # Intranet domains 

569 r"\.private$", # Private domains 

570 ] 

571 

572 for pattern in internal_patterns: 

573 if re.search(pattern, host_lower): 

574 return True 

575 

576 # If none of the internal patterns match, assume it's external 

577 return False 

578 

579 

580notification_types_mapping: dict[type[MaintenanceNotification], str] = { 

581 NodeMovingNotification: "MOVING", 

582 NodeMigratingNotification: "MIGRATING", 

583 NodeMigratedNotification: "MIGRATED", 

584 NodeFailingOverNotification: "FAILING_OVER", 

585 NodeFailedOverNotification: "FAILED_OVER", 

586 OSSNodeMigratingNotification: "SMIGRATING", 

587 OSSNodeMigratedNotification: "SMIGRATED", 

588} 

589 

590 

591def add_debug_log_for_notification( 

592 connection: "MaintNotificationsAbstractConnection", 

593 notification: Union[str, MaintenanceNotification], 

594): 

595 if logger.isEnabledFor(logging.DEBUG): 

596 socket_address = None 

597 try: 

598 socket_address = ( 

599 connection._sock.getsockname() if connection._sock else None 

600 ) 

601 socket_address = socket_address[1] if socket_address else None 

602 except (AttributeError, OSError): 

603 pass 

604 

605 logger.debug( 

606 f"Handling maintenance notification: {notification}, " 

607 f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, " 

608 f"local socket port: {socket_address}", 

609 ) 

610 

611 

612class MaintNotificationsConfig: 

613 """ 

614 Configuration class for maintenance notifications handling behaviour. Notifications are received through 

615 push notifications. 

616 

617 This class defines how the Redis client should react to different push notifications 

618 such as node moving, migrations, etc. in a Redis cluster. 

619 

620 """ 

621 

622 def __init__( 

623 self, 

624 enabled: Union[bool, Literal["auto"]] = "auto", 

625 proactive_reconnect: bool = True, 

626 relaxed_timeout: Optional[Number] = 10, 

627 endpoint_type: Optional[EndpointType] = None, 

628 ): 

629 """ 

630 Initialize a new MaintNotificationsConfig. 

631 

632 Args: 

633 enabled (bool | "auto"): Controls maintenance notifications handling behavior. 

634 - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup, 

635 otherwise a ResponseError is raised. 

636 - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are 

637 gracefully handled - a warning is logged and normal operation continues. 

638 - False: Maintenance notifications are completely disabled. 

639 Defaults to "auto". 

640 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. 

641 Defaults to True. 

642 relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance. 

643 If -1 is provided - the relaxed timeout is disabled. Defaults to 20. 

644 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS. 

645 If None, the endpoint type will be automatically determined based on the host and TLS configuration. 

646 Defaults to None. 

647 

648 Raises: 

649 ValueError: If endpoint_type is provided but is not a valid endpoint type. 

650 """ 

651 self.enabled = enabled 

652 self.relaxed_timeout = relaxed_timeout 

653 self.proactive_reconnect = proactive_reconnect 

654 self.endpoint_type = endpoint_type 

655 

656 def __repr__(self) -> str: 

657 return ( 

658 f"{self.__class__.__name__}(" 

659 f"enabled={self.enabled}, " 

660 f"proactive_reconnect={self.proactive_reconnect}, " 

661 f"relaxed_timeout={self.relaxed_timeout}, " 

662 f"endpoint_type={self.endpoint_type!r}" 

663 f")" 

664 ) 

665 

666 def is_relaxed_timeouts_enabled(self) -> bool: 

667 """ 

668 Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout. 

669 If relaxed_timeout is set to None, it will make the operation blocking 

670 and waiting until any response is received. 

671 

672 Returns: 

673 True if the relaxed_timeout is enabled, False otherwise. 

674 """ 

675 return self.relaxed_timeout != -1 

676 

677 def get_endpoint_type( 

678 self, host: str, connection: "MaintNotificationsAbstractConnection" 

679 ) -> EndpointType: 

680 """ 

681 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. 

682 

683 Logic: 

684 1. If endpoint_type is explicitly set, use it 

685 2. Otherwise, check the original host from connection.host: 

686 - If host is an IP address, use it directly to determine internal-ip vs external-ip 

687 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn 

688 

689 Args: 

690 host: User provided hostname to analyze 

691 connection: The connection object to analyze for endpoint type determination 

692 

693 Returns: 

694 """ 

695 

696 # If endpoint_type is explicitly set, use it 

697 if self.endpoint_type is not None: 

698 return self.endpoint_type 

699 

700 # Check if the host is an IP address 

701 try: 

702 ip_addr = ipaddress.ip_address(host) 

703 # Host is an IP address - use it directly 

704 is_private = ip_addr.is_private 

705 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP 

706 except ValueError: 

707 # Host is an FQDN - need to check resolved IP to determine internal vs external 

708 pass 

709 

710 # Host is an FQDN, get the resolved IP to determine if it's internal or external 

711 resolved_ip = connection.get_resolved_ip() 

712 

713 if resolved_ip: 

714 try: 

715 ip_addr = ipaddress.ip_address(resolved_ip) 

716 is_private = ip_addr.is_private 

717 # Use FQDN types since the original host was an FQDN 

718 return ( 

719 EndpointType.INTERNAL_FQDN 

720 if is_private 

721 else EndpointType.EXTERNAL_FQDN 

722 ) 

723 except ValueError: 

724 # This shouldn't happen since we got the IP from the socket, but fallback 

725 pass 

726 

727 # Final fallback: use heuristics on the FQDN itself 

728 is_private = _is_private_fqdn(host) 

729 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN 

730 

731 

732class MaintNotificationsPoolHandler: 

733 def __init__( 

734 self, 

735 pool: "MaintNotificationsAbstractConnectionPool", 

736 config: MaintNotificationsConfig, 

737 ) -> None: 

738 self.pool = pool 

739 self.config = config 

740 self._processed_notifications = set() 

741 self._lock = threading.RLock() 

742 self.connection = None 

743 

744 def set_connection(self, connection: "MaintNotificationsAbstractConnection"): 

745 self.connection = connection 

746 

747 def get_handler_for_connection(self): 

748 # Copy all data that should be shared between connections 

749 # but each connection should have its own pool handler 

750 # since each connection can be in a different state 

751 copy = MaintNotificationsPoolHandler(self.pool, self.config) 

752 copy._processed_notifications = self._processed_notifications 

753 copy._lock = self._lock 

754 copy.connection = None 

755 return copy 

756 

757 def remove_expired_notifications(self): 

758 with self._lock: 

759 for notification in tuple(self._processed_notifications): 

760 if notification.is_expired(): 

761 self._processed_notifications.remove(notification) 

762 

763 def handle_notification(self, notification: MaintenanceNotification): 

764 self.remove_expired_notifications() 

765 

766 if isinstance(notification, NodeMovingNotification): 

767 return self.handle_node_moving_notification(notification) 

768 else: 

769 logger.error(f"Unhandled notification type: {notification}") 

770 

771 def handle_node_moving_notification(self, notification: NodeMovingNotification): 

772 if ( 

773 not self.config.proactive_reconnect 

774 and not self.config.is_relaxed_timeouts_enabled() 

775 ): 

776 return 

777 with self._lock: 

778 if notification in self._processed_notifications: 

779 # nothing to do in the connection pool handling 

780 # the notification has already been handled or is expired 

781 # just return 

782 return 

783 

784 with self.pool._lock: 

785 logger.debug( 

786 f"Handling node MOVING notification: {notification}, " 

787 f"with connection: {self.connection}, connected to ip " 

788 f"{self.connection.get_resolved_ip() if self.connection else None}" 

789 ) 

790 if ( 

791 self.config.proactive_reconnect 

792 or self.config.is_relaxed_timeouts_enabled() 

793 ): 

794 # Get the current connected address - if any 

795 # This is the address that is being moved 

796 # and we need to handle only connections 

797 # connected to the same address 

798 moving_address_src = ( 

799 self.connection.getpeername() if self.connection else None 

800 ) 

801 

802 if getattr(self.pool, "set_in_maintenance", False): 

803 # Set pool in maintenance mode - executed only if 

804 # BlockingConnectionPool is used 

805 self.pool.set_in_maintenance(True) 

806 

807 # Update maintenance state, timeout and optionally host address 

808 # connection settings for matching connections 

809 self.pool.update_connections_settings( 

810 state=MaintenanceState.MOVING, 

811 maintenance_notification_hash=hash(notification), 

812 relaxed_timeout=self.config.relaxed_timeout, 

813 host_address=notification.new_node_host, 

814 matching_address=moving_address_src, 

815 matching_pattern="connected_address", 

816 update_notification_hash=True, 

817 include_free_connections=True, 

818 ) 

819 

820 if self.config.proactive_reconnect: 

821 if notification.new_node_host is not None: 

822 self.run_proactive_reconnect(moving_address_src) 

823 else: 

824 threading.Timer( 

825 notification.ttl / 2, 

826 self.run_proactive_reconnect, 

827 args=(moving_address_src,), 

828 ).start() 

829 

830 # Update config for new connections: 

831 # Set state to MOVING 

832 # update host 

833 # if relax timeouts are enabled - update timeouts 

834 kwargs: dict = { 

835 "maintenance_state": MaintenanceState.MOVING, 

836 "maintenance_notification_hash": hash(notification), 

837 } 

838 if notification.new_node_host is not None: 

839 # the host is not updated if the new node host is None 

840 # this happens when the MOVING push notification does not contain 

841 # the new node host - in this case we only update the timeouts 

842 kwargs.update( 

843 { 

844 "host": notification.new_node_host, 

845 } 

846 ) 

847 if self.config.is_relaxed_timeouts_enabled(): 

848 kwargs.update( 

849 { 

850 "socket_timeout": self.config.relaxed_timeout, 

851 "socket_connect_timeout": self.config.relaxed_timeout, 

852 } 

853 ) 

854 self.pool.update_connection_kwargs(**kwargs) 

855 

856 if getattr(self.pool, "set_in_maintenance", False): 

857 self.pool.set_in_maintenance(False) 

858 

859 threading.Timer( 

860 notification.ttl, 

861 self.handle_node_moved_notification, 

862 args=(notification,), 

863 ).start() 

864 

865 record_connection_handoff( 

866 pool_name=get_pool_name(self.pool), 

867 ) 

868 

869 self._processed_notifications.add(notification) 

870 

871 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None): 

872 """ 

873 Run proactive reconnect for the pool. 

874 Active connections are marked for reconnect after they complete the current command. 

875 Inactive connections are disconnected and will be connected on next use. 

876 """ 

877 with self._lock: 

878 with self.pool._lock: 

879 # take care for the active connections in the pool 

880 # mark them for reconnect after they complete the current command 

881 self.pool.update_active_connections_for_reconnect( 

882 moving_address_src=moving_address_src, 

883 ) 

884 # take care for the inactive connections in the pool 

885 # delete them and create new ones 

886 self.pool.disconnect_free_connections( 

887 moving_address_src=moving_address_src, 

888 ) 

889 

890 def handle_node_moved_notification(self, notification: NodeMovingNotification): 

891 """ 

892 Handle the cleanup after a node moving notification expires. 

893 """ 

894 notification_hash = hash(notification) 

895 

896 with self._lock: 

897 logger.debug( 

898 f"Reverting temporary changes related to notification: {notification}, " 

899 f"with connection: {self.connection}, connected to ip " 

900 f"{self.connection.get_resolved_ip() if self.connection else None}" 

901 ) 

902 # if the current maintenance_notification_hash in kwargs is not matching the notification 

903 # it means there has been a new moving notification after this one 

904 # and we don't need to revert the kwargs yet 

905 if ( 

906 self.pool.connection_kwargs.get("maintenance_notification_hash") 

907 == notification_hash 

908 ): 

909 orig_host = self.pool.connection_kwargs.get("orig_host_address") 

910 orig_socket_timeout = self.pool.connection_kwargs.get( 

911 "orig_socket_timeout" 

912 ) 

913 orig_connect_timeout = self.pool.connection_kwargs.get( 

914 "orig_socket_connect_timeout" 

915 ) 

916 kwargs: dict = { 

917 "maintenance_state": MaintenanceState.NONE, 

918 "maintenance_notification_hash": None, 

919 "host": orig_host, 

920 "socket_timeout": orig_socket_timeout, 

921 "socket_connect_timeout": orig_connect_timeout, 

922 } 

923 self.pool.update_connection_kwargs(**kwargs) 

924 

925 with self.pool._lock: 

926 reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled() 

927 reset_host_address = self.config.proactive_reconnect 

928 

929 self.pool.update_connections_settings( 

930 relaxed_timeout=-1, 

931 state=MaintenanceState.NONE, 

932 maintenance_notification_hash=None, 

933 matching_notification_hash=notification_hash, 

934 matching_pattern="notification_hash", 

935 update_notification_hash=True, 

936 reset_relaxed_timeout=reset_relaxed_timeout, 

937 reset_host_address=reset_host_address, 

938 include_free_connections=True, 

939 ) 

940 

941 

942class MaintNotificationsConnectionHandler: 

943 # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications 

944 _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = { 

945 NodeMigratingNotification: 1, 

946 NodeFailingOverNotification: 1, 

947 OSSNodeMigratingNotification: 1, 

948 NodeMigratedNotification: 0, 

949 NodeFailedOverNotification: 0, 

950 OSSNodeMigratedNotification: 0, 

951 } 

952 

953 def __init__( 

954 self, 

955 connection: "MaintNotificationsAbstractConnection", 

956 config: MaintNotificationsConfig, 

957 ) -> None: 

958 self.connection = connection 

959 self.config = config 

960 

961 def _get_pool_name(self) -> str: 

962 """ 

963 Get the pool name from the connection's pool handler. 

964 Falls back to connection representation if pool is not available. 

965 """ 

966 pool_handler = getattr( 

967 self.connection, "_maint_notifications_pool_handler", None 

968 ) 

969 if pool_handler and getattr(pool_handler, "pool", None): 

970 return get_pool_name(pool_handler.pool) 

971 # Fallback for standalone connections without a pool 

972 return repr(self.connection) 

973 

974 def handle_notification(self, notification: MaintenanceNotification): 

975 # get the notification type by checking its class in the _NOTIFICATION_TYPES dict 

976 notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None) 

977 maint_notification = notification_types_mapping.get(notification.__class__, "") 

978 

979 record_maint_notification_count( 

980 server_address=self.connection.host, 

981 server_port=self.connection.port, 

982 network_peer_address=self.connection.host, 

983 network_peer_port=self.connection.port, 

984 maint_notification=maint_notification, 

985 ) 

986 

987 if notification_type is None: 

988 logger.error(f"Unhandled notification type: {notification}") 

989 return 

990 

991 if notification_type: 

992 self.handle_maintenance_start_notification( 

993 MaintenanceState.MAINTENANCE, notification 

994 ) 

995 else: 

996 self.handle_maintenance_completed_notification(notification=notification) 

997 

998 def handle_maintenance_start_notification( 

999 self, maintenance_state: MaintenanceState, notification: MaintenanceNotification 

1000 ): 

1001 add_debug_log_for_notification(self.connection, notification) 

1002 

1003 if ( 

1004 self.connection.maintenance_state == MaintenanceState.MOVING 

1005 or not self.config.is_relaxed_timeouts_enabled() 

1006 ): 

1007 return 

1008 

1009 self.connection.maintenance_state = maintenance_state 

1010 self.connection.set_tmp_settings( 

1011 tmp_relaxed_timeout=self.config.relaxed_timeout 

1012 ) 

1013 # extend the timeout for all created connections 

1014 self.connection.update_current_socket_timeout(self.config.relaxed_timeout) 

1015 if isinstance(notification, OSSNodeMigratingNotification): 

1016 # add the notification id to the set of processed start maint notifications 

1017 # this is used to skip the unrelaxing of the timeouts if we have received more than 

1018 # one start notification before the the final end notification 

1019 self.connection.add_maint_start_notification(notification.id) 

1020 

1021 maint_notification = notification_types_mapping.get(notification.__class__, "") 

1022 record_connection_relaxed_timeout( 

1023 connection_name=self._get_pool_name(), 

1024 maint_notification=maint_notification, 

1025 relaxed=True, 

1026 ) 

1027 

1028 def handle_maintenance_completed_notification(self, **kwargs): 

1029 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled 

1030 if ( 

1031 self.connection.maintenance_state == MaintenanceState.MOVING 

1032 or not self.config.is_relaxed_timeouts_enabled() 

1033 ): 

1034 return 

1035 notification = None 

1036 if kwargs.get("notification"): 

1037 notification = kwargs["notification"] 

1038 add_debug_log_for_notification( 

1039 self.connection, notification if notification else "MAINTENANCE_COMPLETED" 

1040 ) 

1041 self.connection.reset_tmp_settings(reset_relaxed_timeout=True) 

1042 # Maintenance completed - reset the connection 

1043 # timeouts by providing -1 as the relaxed timeout 

1044 self.connection.update_current_socket_timeout(-1) 

1045 self.connection.maintenance_state = MaintenanceState.NONE 

1046 # reset the sets that keep track of received start maint 

1047 # notifications and skipped end maint notifications 

1048 self.connection.reset_received_notifications() 

1049 

1050 if notification: 

1051 maint_notification = notification_types_mapping.get( 

1052 notification.__class__, "" 

1053 ) 

1054 record_connection_relaxed_timeout( 

1055 connection_name=self._get_pool_name(), 

1056 maint_notification=maint_notification, 

1057 relaxed=False, 

1058 ) 

1059 

1060 

1061class OSSMaintNotificationsHandler: 

1062 def __init__( 

1063 self, 

1064 cluster_client: "MaintNotificationsAbstractRedisCluster", 

1065 config: MaintNotificationsConfig, 

1066 ) -> None: 

1067 self.cluster_client = cluster_client 

1068 self.config = config 

1069 self._processed_notifications = set() 

1070 self._in_progress = set() 

1071 self._lock = threading.RLock() 

1072 

1073 def get_handler_for_connection(self): 

1074 # Copy all data that should be shared between connections 

1075 # but each connection should have its own pool handler 

1076 # since each connection can be in a different state 

1077 copy = OSSMaintNotificationsHandler(self.cluster_client, self.config) 

1078 copy._processed_notifications = self._processed_notifications 

1079 copy._in_progress = self._in_progress 

1080 copy._lock = self._lock 

1081 return copy 

1082 

1083 def remove_expired_notifications(self): 

1084 with self._lock: 

1085 for notification in tuple(self._processed_notifications): 

1086 if notification.is_expired(): 

1087 self._processed_notifications.remove(notification) 

1088 

1089 def handle_notification(self, notification: MaintenanceNotification): 

1090 if isinstance(notification, OSSNodeMigratedNotification): 

1091 self.handle_oss_maintenance_completed_notification(notification) 

1092 else: 

1093 logger.error(f"Unhandled notification type: {notification}") 

1094 

1095 def handle_oss_maintenance_completed_notification( 

1096 self, notification: OSSNodeMigratedNotification 

1097 ): 

1098 self.remove_expired_notifications() 

1099 

1100 with self._lock: 

1101 if ( 

1102 notification in self._in_progress 

1103 or notification in self._processed_notifications 

1104 ): 

1105 # we are already handling this notification or it has already been processed 

1106 # we should skip in_progress notification since when we reinitialize the cluster 

1107 # we execute a CLUSTER SLOTS command that can use a different connection 

1108 # that has also has the notification and we don't want to 

1109 # process the same notification twice 

1110 return 

1111 

1112 if logger.isEnabledFor(logging.DEBUG): 

1113 logger.debug(f"Handling SMIGRATED notification: {notification}") 

1114 self._in_progress.add(notification) 

1115 

1116 # Extract the information about the src and destination nodes that are affected 

1117 # by the maintenance. nodes_to_slots_mapping structure: 

1118 # { 

1119 # "src_host:port": [ 

1120 # {"dest_host:port": "slot_range"}, 

1121 # ... 

1122 # ], 

1123 # ... 

1124 # } 

1125 additional_startup_nodes_info = [] 

1126 affected_nodes = set() 

1127 for ( 

1128 src_address, 

1129 dest_mappings, 

1130 ) in notification.nodes_to_slots_mapping.items(): 

1131 src_host, src_port = src_address.split(":") 

1132 src_node = self.cluster_client.nodes_manager.get_node( 

1133 host=src_host, port=src_port 

1134 ) 

1135 if src_node is not None: 

1136 affected_nodes.add(src_node) 

1137 

1138 for dest_mapping in dest_mappings: 

1139 for dest_address in dest_mapping.keys(): 

1140 dest_host, dest_port = dest_address.split(":") 

1141 additional_startup_nodes_info.append( 

1142 (dest_host, int(dest_port)) 

1143 ) 

1144 

1145 # Updates the cluster slots cache with the new slots mapping 

1146 # This will also update the nodes cache with the new nodes mapping 

1147 self.cluster_client.nodes_manager.initialize( 

1148 disconnect_startup_nodes_pools=False, 

1149 additional_startup_nodes_info=additional_startup_nodes_info, 

1150 ) 

1151 

1152 all_nodes = set(affected_nodes) 

1153 all_nodes = all_nodes.union( 

1154 self.cluster_client.nodes_manager.nodes_cache.values() 

1155 ) 

1156 

1157 for current_node in all_nodes: 

1158 if current_node.redis_connection is None: 

1159 continue 

1160 with current_node.redis_connection.connection_pool._lock: 

1161 handoff_recorded = False 

1162 if current_node in affected_nodes: 

1163 # mark for reconnect all in use connections to the node - this will force them to 

1164 # disconnect after they complete their current commands 

1165 # Some of them might be used by sub sub and we don't know which ones - so we disconnect 

1166 # all in flight connections after they are done with current command execution 

1167 for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): 

1168 add_debug_log_for_notification( 

1169 conn, "SMIGRATED - mark for reconnect" 

1170 ) 

1171 conn.mark_for_reconnect() 

1172 

1173 record_connection_handoff( 

1174 pool_name=get_pool_name( 

1175 current_node.redis_connection.connection_pool 

1176 ) 

1177 ) 

1178 handoff_recorded = True 

1179 else: 

1180 if logger.isEnabledFor(logging.DEBUG): 

1181 logger.debug( 

1182 f"SMIGRATED: Node {current_node.name} not affected by maintenance, " 

1183 f"skipping mark for reconnect" 

1184 ) 

1185 

1186 if ( 

1187 current_node 

1188 not in self.cluster_client.nodes_manager.nodes_cache.values() 

1189 ): 

1190 # disconnect all free connections to the node - this node will be dropped 

1191 # from the cluster, so we don't need to revert the timeouts 

1192 for conn in current_node.redis_connection.connection_pool._get_free_connections(): 

1193 conn.disconnect() 

1194 

1195 # Only record handoff if not already recorded for this node 

1196 if not handoff_recorded: 

1197 record_connection_handoff( 

1198 pool_name=get_pool_name( 

1199 current_node.redis_connection.connection_pool 

1200 ) 

1201 ) 

1202 

1203 # mark the notification as processed 

1204 self._processed_notifications.add(notification) 

1205 self._in_progress.remove(notification)