Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/maint_notifications.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

349 statements  

1import enum 

2import ipaddress 

3import logging 

4import re 

5import threading 

6import time 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union 

9 

10from redis.typing import Number 

11 

12if TYPE_CHECKING: 

13 from redis.cluster import MaintNotificationsAbstractRedisCluster 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18class MaintenanceState(enum.Enum): 

19 NONE = "none" 

20 MOVING = "moving" 

21 MAINTENANCE = "maintenance" 

22 

23 

24class EndpointType(enum.Enum): 

25 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command.""" 

26 

27 INTERNAL_IP = "internal-ip" 

28 INTERNAL_FQDN = "internal-fqdn" 

29 EXTERNAL_IP = "external-ip" 

30 EXTERNAL_FQDN = "external-fqdn" 

31 NONE = "none" 

32 

33 def __str__(self): 

34 """Return the string value of the enum.""" 

35 return self.value 

36 

37 

38if TYPE_CHECKING: 

39 from redis.connection import ( 

40 MaintNotificationsAbstractConnection, 

41 MaintNotificationsAbstractConnectionPool, 

42 ) 

43 

44 

45class MaintenanceNotification(ABC): 

46 """ 

47 Base class for maintenance notifications sent through push messages by Redis server. 

48 

49 This class provides common functionality for all maintenance notifications including 

50 unique identification and TTL (Time-To-Live) functionality. 

51 

52 Attributes: 

53 id (int): Unique identifier for this notification 

54 ttl (int): Time-to-live in seconds for this notification 

55 creation_time (float): Timestamp when the notification was created/read 

56 """ 

57 

58 def __init__(self, id: int, ttl: int): 

59 """ 

60 Initialize a new MaintenanceNotification with unique ID and TTL functionality. 

61 

62 Args: 

63 id (int): Unique identifier for this notification 

64 ttl (int): Time-to-live in seconds for this notification 

65 """ 

66 self.id = id 

67 self.ttl = ttl 

68 self.creation_time = time.monotonic() 

69 self.expire_at = self.creation_time + self.ttl 

70 

71 def is_expired(self) -> bool: 

72 """ 

73 Check if this notification has expired based on its TTL 

74 and creation time. 

75 

76 Returns: 

77 bool: True if the notification has expired, False otherwise 

78 """ 

79 return time.monotonic() > (self.creation_time + self.ttl) 

80 

81 @abstractmethod 

82 def __repr__(self) -> str: 

83 """ 

84 Return a string representation of the maintenance notification. 

85 

86 This method must be implemented by all concrete subclasses. 

87 

88 Returns: 

89 str: String representation of the notification 

90 """ 

91 pass 

92 

93 @abstractmethod 

94 def __eq__(self, other) -> bool: 

95 """ 

96 Compare two maintenance notifications for equality. 

97 

98 This method must be implemented by all concrete subclasses. 

99 Notifications are typically considered equal if they have the same id 

100 and are of the same type. 

101 

102 Args: 

103 other: The other object to compare with 

104 

105 Returns: 

106 bool: True if the notifications are equal, False otherwise 

107 """ 

108 pass 

109 

110 @abstractmethod 

111 def __hash__(self) -> int: 

112 """ 

113 Return a hash value for the maintenance notification. 

114 

115 This method must be implemented by all concrete subclasses to allow 

116 instances to be used in sets and as dictionary keys. 

117 

118 Returns: 

119 int: Hash value for the notification 

120 """ 

121 pass 

122 

123 

124class NodeMovingNotification(MaintenanceNotification): 

125 """ 

126 This notification is received when a node is replaced with a new node 

127 during cluster rebalancing or maintenance operations. 

128 """ 

129 

130 def __init__( 

131 self, 

132 id: int, 

133 new_node_host: Optional[str], 

134 new_node_port: Optional[int], 

135 ttl: int, 

136 ): 

137 """ 

138 Initialize a new NodeMovingNotification. 

139 

140 Args: 

141 id (int): Unique identifier for this notification 

142 new_node_host (str): Hostname or IP address of the new replacement node 

143 new_node_port (int): Port number of the new replacement node 

144 ttl (int): Time-to-live in seconds for this notification 

145 """ 

146 super().__init__(id, ttl) 

147 self.new_node_host = new_node_host 

148 self.new_node_port = new_node_port 

149 

150 def __repr__(self) -> str: 

151 expiry_time = self.expire_at 

152 remaining = max(0, expiry_time - time.monotonic()) 

153 

154 return ( 

155 f"{self.__class__.__name__}(" 

156 f"id={self.id}, " 

157 f"new_node_host='{self.new_node_host}', " 

158 f"new_node_port={self.new_node_port}, " 

159 f"ttl={self.ttl}, " 

160 f"creation_time={self.creation_time}, " 

161 f"expires_at={expiry_time}, " 

162 f"remaining={remaining:.1f}s, " 

163 f"expired={self.is_expired()}" 

164 f")" 

165 ) 

166 

167 def __eq__(self, other) -> bool: 

168 """ 

169 Two NodeMovingNotification notifications are considered equal if they have the same 

170 id, new_node_host, and new_node_port. 

171 """ 

172 if not isinstance(other, NodeMovingNotification): 

173 return False 

174 return ( 

175 self.id == other.id 

176 and self.new_node_host == other.new_node_host 

177 and self.new_node_port == other.new_node_port 

178 ) 

179 

180 def __hash__(self) -> int: 

181 """ 

182 Return a hash value for the notification to allow 

183 instances to be used in sets and as dictionary keys. 

184 

185 Returns: 

186 int: Hash value based on notification type class name, id, 

187 new_node_host and new_node_port 

188 """ 

189 try: 

190 node_port = int(self.new_node_port) if self.new_node_port else None 

191 except ValueError: 

192 node_port = 0 

193 

194 return hash( 

195 ( 

196 self.__class__.__name__, 

197 int(self.id), 

198 str(self.new_node_host), 

199 node_port, 

200 ) 

201 ) 

202 

203 

204class NodeMigratingNotification(MaintenanceNotification): 

205 """ 

206 Notification for when a Redis cluster node is in the process of migrating slots. 

207 

208 This notification is received when a node starts migrating its slots to another node 

209 during cluster rebalancing or maintenance operations. 

210 

211 Args: 

212 id (int): Unique identifier for this notification 

213 ttl (int): Time-to-live in seconds for this notification 

214 """ 

215 

216 def __init__(self, id: int, ttl: int): 

217 super().__init__(id, ttl) 

218 

219 def __repr__(self) -> str: 

220 expiry_time = self.creation_time + self.ttl 

221 remaining = max(0, expiry_time - time.monotonic()) 

222 return ( 

223 f"{self.__class__.__name__}(" 

224 f"id={self.id}, " 

225 f"ttl={self.ttl}, " 

226 f"creation_time={self.creation_time}, " 

227 f"expires_at={expiry_time}, " 

228 f"remaining={remaining:.1f}s, " 

229 f"expired={self.is_expired()}" 

230 f")" 

231 ) 

232 

233 def __eq__(self, other) -> bool: 

234 """ 

235 Two NodeMigratingNotification notifications are considered equal if they have the same 

236 id and are of the same type. 

237 """ 

238 if not isinstance(other, NodeMigratingNotification): 

239 return False 

240 return self.id == other.id and type(self) is type(other) 

241 

242 def __hash__(self) -> int: 

243 """ 

244 Return a hash value for the notification to allow 

245 instances to be used in sets and as dictionary keys. 

246 

247 Returns: 

248 int: Hash value based on notification type and id 

249 """ 

250 return hash((self.__class__.__name__, int(self.id))) 

251 

252 

253class NodeMigratedNotification(MaintenanceNotification): 

254 """ 

255 Notification for when a Redis cluster node has completed migrating slots. 

256 

257 This notification is received when a node has finished migrating all its slots 

258 to other nodes during cluster rebalancing or maintenance operations. 

259 

260 Args: 

261 id (int): Unique identifier for this notification 

262 """ 

263 

264 DEFAULT_TTL = 5 

265 

266 def __init__(self, id: int): 

267 super().__init__(id, NodeMigratedNotification.DEFAULT_TTL) 

268 

269 def __repr__(self) -> str: 

270 expiry_time = self.creation_time + self.ttl 

271 remaining = max(0, expiry_time - time.monotonic()) 

272 return ( 

273 f"{self.__class__.__name__}(" 

274 f"id={self.id}, " 

275 f"ttl={self.ttl}, " 

276 f"creation_time={self.creation_time}, " 

277 f"expires_at={expiry_time}, " 

278 f"remaining={remaining:.1f}s, " 

279 f"expired={self.is_expired()}" 

280 f")" 

281 ) 

282 

283 def __eq__(self, other) -> bool: 

284 """ 

285 Two NodeMigratedNotification notifications are considered equal if they have the same 

286 id and are of the same type. 

287 """ 

288 if not isinstance(other, NodeMigratedNotification): 

289 return False 

290 return self.id == other.id and type(self) is type(other) 

291 

292 def __hash__(self) -> int: 

293 """ 

294 Return a hash value for the notification to allow 

295 instances to be used in sets and as dictionary keys. 

296 

297 Returns: 

298 int: Hash value based on notification type and id 

299 """ 

300 return hash((self.__class__.__name__, int(self.id))) 

301 

302 

303class NodeFailingOverNotification(MaintenanceNotification): 

304 """ 

305 Notification for when a Redis cluster node is in the process of failing over. 

306 

307 This notification is received when a node starts a failover process during 

308 cluster maintenance operations or when handling node failures. 

309 

310 Args: 

311 id (int): Unique identifier for this notification 

312 ttl (int): Time-to-live in seconds for this notification 

313 """ 

314 

315 def __init__(self, id: int, ttl: int): 

316 super().__init__(id, ttl) 

317 

318 def __repr__(self) -> str: 

319 expiry_time = self.creation_time + self.ttl 

320 remaining = max(0, expiry_time - time.monotonic()) 

321 return ( 

322 f"{self.__class__.__name__}(" 

323 f"id={self.id}, " 

324 f"ttl={self.ttl}, " 

325 f"creation_time={self.creation_time}, " 

326 f"expires_at={expiry_time}, " 

327 f"remaining={remaining:.1f}s, " 

328 f"expired={self.is_expired()}" 

329 f")" 

330 ) 

331 

332 def __eq__(self, other) -> bool: 

333 """ 

334 Two NodeFailingOverNotification notifications are considered equal if they have the same 

335 id and are of the same type. 

336 """ 

337 if not isinstance(other, NodeFailingOverNotification): 

338 return False 

339 return self.id == other.id and type(self) is type(other) 

340 

341 def __hash__(self) -> int: 

342 """ 

343 Return a hash value for the notification to allow 

344 instances to be used in sets and as dictionary keys. 

345 

346 Returns: 

347 int: Hash value based on notification type and id 

348 """ 

349 return hash((self.__class__.__name__, int(self.id))) 

350 

351 

352class NodeFailedOverNotification(MaintenanceNotification): 

353 """ 

354 Notification for when a Redis cluster node has completed a failover. 

355 

356 This notification is received when a node has finished the failover process 

357 during cluster maintenance operations or after handling node failures. 

358 

359 Args: 

360 id (int): Unique identifier for this notification 

361 """ 

362 

363 DEFAULT_TTL = 5 

364 

365 def __init__(self, id: int): 

366 super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL) 

367 

368 def __repr__(self) -> str: 

369 expiry_time = self.creation_time + self.ttl 

370 remaining = max(0, expiry_time - time.monotonic()) 

371 return ( 

372 f"{self.__class__.__name__}(" 

373 f"id={self.id}, " 

374 f"ttl={self.ttl}, " 

375 f"creation_time={self.creation_time}, " 

376 f"expires_at={expiry_time}, " 

377 f"remaining={remaining:.1f}s, " 

378 f"expired={self.is_expired()}" 

379 f")" 

380 ) 

381 

382 def __eq__(self, other) -> bool: 

383 """ 

384 Two NodeFailedOverNotification notifications are considered equal if they have the same 

385 id and are of the same type. 

386 """ 

387 if not isinstance(other, NodeFailedOverNotification): 

388 return False 

389 return self.id == other.id and type(self) is type(other) 

390 

391 def __hash__(self) -> int: 

392 """ 

393 Return a hash value for the notification to allow 

394 instances to be used in sets and as dictionary keys. 

395 

396 Returns: 

397 int: Hash value based on notification type and id 

398 """ 

399 return hash((self.__class__.__name__, int(self.id))) 

400 

401 

402class OSSNodeMigratingNotification(MaintenanceNotification): 

403 """ 

404 Notification for when a Redis OSS API client is used and a node is in the process of migrating slots. 

405 

406 This notification is received when a node starts migrating its slots to another node 

407 during cluster rebalancing or maintenance operations. 

408 

409 Args: 

410 id (int): Unique identifier for this notification 

411 slots (Optional[List[int]]): List of slots being migrated 

412 """ 

413 

414 DEFAULT_TTL = 30 

415 

416 def __init__( 

417 self, 

418 id: int, 

419 slots: Optional[str] = None, 

420 ): 

421 super().__init__(id, OSSNodeMigratingNotification.DEFAULT_TTL) 

422 self.slots = slots 

423 

424 def __repr__(self) -> str: 

425 expiry_time = self.creation_time + self.ttl 

426 remaining = max(0, expiry_time - time.monotonic()) 

427 return ( 

428 f"{self.__class__.__name__}(" 

429 f"id={self.id}, " 

430 f"slots={self.slots}, " 

431 f"ttl={self.ttl}, " 

432 f"creation_time={self.creation_time}, " 

433 f"expires_at={expiry_time}, " 

434 f"remaining={remaining:.1f}s, " 

435 f"expired={self.is_expired()}" 

436 f")" 

437 ) 

438 

439 def __eq__(self, other) -> bool: 

440 """ 

441 Two OSSNodeMigratingNotification notifications are considered equal if they have the same 

442 id and are of the same type. 

443 """ 

444 if not isinstance(other, OSSNodeMigratingNotification): 

445 return False 

446 return self.id == other.id and type(self) is type(other) 

447 

448 def __hash__(self) -> int: 

449 """ 

450 Return a hash value for the notification to allow 

451 instances to be used in sets and as dictionary keys. 

452 

453 Returns: 

454 int: Hash value based on notification type and id 

455 """ 

456 return hash((self.__class__.__name__, int(self.id))) 

457 

458 

459class OSSNodeMigratedNotification(MaintenanceNotification): 

460 """ 

461 Notification for when a Redis OSS API client is used and a node has completed migrating slots. 

462 

463 This notification is received when a node has finished migrating all its slots 

464 to other nodes during cluster rebalancing or maintenance operations. 

465 

466 Args: 

467 id (int): Unique identifier for this notification 

468 nodes_to_slots_mapping (Dict[str, List[Dict[str, str]]]): Map of source node address 

469 to list of destination mappings. Each destination mapping is a dict with 

470 the destination node address as key and the slot range as value. 

471 

472 Structure example: 

473 { 

474 "127.0.0.1:6379": [ 

475 {"127.0.0.1:6380": "1-100"}, 

476 {"127.0.0.1:6381": "101-200"} 

477 ], 

478 "127.0.0.1:6382": [ 

479 {"127.0.0.1:6383": "201-300"} 

480 ] 

481 } 

482 

483 Where: 

484 - Key (str): Source node address in "host:port" format 

485 - Value (List[Dict[str, str]]): List of destination mappings where each dict 

486 contains destination node address as key and slot range as value 

487 """ 

488 

489 DEFAULT_TTL = 120 

490 

491 def __init__( 

492 self, 

493 id: int, 

494 nodes_to_slots_mapping: Dict[str, List[Dict[str, str]]], 

495 ): 

496 super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) 

497 self.nodes_to_slots_mapping = nodes_to_slots_mapping 

498 

499 def __repr__(self) -> str: 

500 expiry_time = self.creation_time + self.ttl 

501 remaining = max(0, expiry_time - time.monotonic()) 

502 return ( 

503 f"{self.__class__.__name__}(" 

504 f"id={self.id}, " 

505 f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, " 

506 f"ttl={self.ttl}, " 

507 f"creation_time={self.creation_time}, " 

508 f"expires_at={expiry_time}, " 

509 f"remaining={remaining:.1f}s, " 

510 f"expired={self.is_expired()}" 

511 f")" 

512 ) 

513 

514 def __eq__(self, other) -> bool: 

515 """ 

516 Two OSSNodeMigratedNotification notifications are considered equal if they have the same 

517 id and are of the same type. 

518 """ 

519 if not isinstance(other, OSSNodeMigratedNotification): 

520 return False 

521 return self.id == other.id and type(self) is type(other) 

522 

523 def __hash__(self) -> int: 

524 """ 

525 Return a hash value for the notification to allow 

526 instances to be used in sets and as dictionary keys. 

527 

528 Returns: 

529 int: Hash value based on notification type and id 

530 """ 

531 return hash((self.__class__.__name__, int(self.id))) 

532 

533 

534def _is_private_fqdn(host: str) -> bool: 

535 """ 

536 Determine if an FQDN is likely to be internal/private. 

537 

538 This uses heuristics based on RFC 952 and RFC 1123 standards: 

539 - .local domains (RFC 6762 - Multicast DNS) 

540 - .internal domains (common internal convention) 

541 - Single-label hostnames (no dots) 

542 - Common internal TLDs 

543 

544 Args: 

545 host (str): The FQDN to check 

546 

547 Returns: 

548 bool: True if the FQDN appears to be internal/private 

549 """ 

550 host_lower = host.lower().rstrip(".") 

551 

552 # Single-label hostnames (no dots) are typically internal 

553 if "." not in host_lower: 

554 return True 

555 

556 # Common internal/private domain patterns 

557 internal_patterns = [ 

558 r"\.local$", # mDNS/Bonjour domains 

559 r"\.internal$", # Common internal convention 

560 r"\.corp$", # Corporate domains 

561 r"\.lan$", # Local area network 

562 r"\.intranet$", # Intranet domains 

563 r"\.private$", # Private domains 

564 ] 

565 

566 for pattern in internal_patterns: 

567 if re.search(pattern, host_lower): 

568 return True 

569 

570 # If none of the internal patterns match, assume it's external 

571 return False 

572 

573 

574def add_debug_log_for_notification( 

575 connection: "MaintNotificationsAbstractConnection", 

576 notification: Union[str, MaintenanceNotification], 

577): 

578 if logger.isEnabledFor(logging.DEBUG): 

579 socket_address = None 

580 try: 

581 socket_address = ( 

582 connection._sock.getsockname() if connection._sock else None 

583 ) 

584 socket_address = socket_address[1] if socket_address else None 

585 except (AttributeError, OSError): 

586 pass 

587 

588 logger.debug( 

589 f"Handling maintenance notification: {notification}, " 

590 f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, " 

591 f"local socket port: {socket_address}", 

592 ) 

593 

594 

595class MaintNotificationsConfig: 

596 """ 

597 Configuration class for maintenance notifications handling behaviour. Notifications are received through 

598 push notifications. 

599 

600 This class defines how the Redis client should react to different push notifications 

601 such as node moving, migrations, etc. in a Redis cluster. 

602 

603 """ 

604 

605 def __init__( 

606 self, 

607 enabled: Union[bool, Literal["auto"]] = "auto", 

608 proactive_reconnect: bool = True, 

609 relaxed_timeout: Optional[Number] = 10, 

610 endpoint_type: Optional[EndpointType] = None, 

611 ): 

612 """ 

613 Initialize a new MaintNotificationsConfig. 

614 

615 Args: 

616 enabled (bool | "auto"): Controls maintenance notifications handling behavior. 

617 - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup, 

618 otherwise a ResponseError is raised. 

619 - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are 

620 gracefully handled - a warning is logged and normal operation continues. 

621 - False: Maintenance notifications are completely disabled. 

622 Defaults to "auto". 

623 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. 

624 Defaults to True. 

625 relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance. 

626 If -1 is provided - the relaxed timeout is disabled. Defaults to 20. 

627 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS. 

628 If None, the endpoint type will be automatically determined based on the host and TLS configuration. 

629 Defaults to None. 

630 

631 Raises: 

632 ValueError: If endpoint_type is provided but is not a valid endpoint type. 

633 """ 

634 self.enabled = enabled 

635 self.relaxed_timeout = relaxed_timeout 

636 self.proactive_reconnect = proactive_reconnect 

637 self.endpoint_type = endpoint_type 

638 

639 def __repr__(self) -> str: 

640 return ( 

641 f"{self.__class__.__name__}(" 

642 f"enabled={self.enabled}, " 

643 f"proactive_reconnect={self.proactive_reconnect}, " 

644 f"relaxed_timeout={self.relaxed_timeout}, " 

645 f"endpoint_type={self.endpoint_type!r}" 

646 f")" 

647 ) 

648 

649 def is_relaxed_timeouts_enabled(self) -> bool: 

650 """ 

651 Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout. 

652 If relaxed_timeout is set to None, it will make the operation blocking 

653 and waiting until any response is received. 

654 

655 Returns: 

656 True if the relaxed_timeout is enabled, False otherwise. 

657 """ 

658 return self.relaxed_timeout != -1 

659 

660 def get_endpoint_type( 

661 self, host: str, connection: "MaintNotificationsAbstractConnection" 

662 ) -> EndpointType: 

663 """ 

664 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. 

665 

666 Logic: 

667 1. If endpoint_type is explicitly set, use it 

668 2. Otherwise, check the original host from connection.host: 

669 - If host is an IP address, use it directly to determine internal-ip vs external-ip 

670 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn 

671 

672 Args: 

673 host: User provided hostname to analyze 

674 connection: The connection object to analyze for endpoint type determination 

675 

676 Returns: 

677 """ 

678 

679 # If endpoint_type is explicitly set, use it 

680 if self.endpoint_type is not None: 

681 return self.endpoint_type 

682 

683 # Check if the host is an IP address 

684 try: 

685 ip_addr = ipaddress.ip_address(host) 

686 # Host is an IP address - use it directly 

687 is_private = ip_addr.is_private 

688 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP 

689 except ValueError: 

690 # Host is an FQDN - need to check resolved IP to determine internal vs external 

691 pass 

692 

693 # Host is an FQDN, get the resolved IP to determine if it's internal or external 

694 resolved_ip = connection.get_resolved_ip() 

695 

696 if resolved_ip: 

697 try: 

698 ip_addr = ipaddress.ip_address(resolved_ip) 

699 is_private = ip_addr.is_private 

700 # Use FQDN types since the original host was an FQDN 

701 return ( 

702 EndpointType.INTERNAL_FQDN 

703 if is_private 

704 else EndpointType.EXTERNAL_FQDN 

705 ) 

706 except ValueError: 

707 # This shouldn't happen since we got the IP from the socket, but fallback 

708 pass 

709 

710 # Final fallback: use heuristics on the FQDN itself 

711 is_private = _is_private_fqdn(host) 

712 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN 

713 

714 

715class MaintNotificationsPoolHandler: 

716 def __init__( 

717 self, 

718 pool: "MaintNotificationsAbstractConnectionPool", 

719 config: MaintNotificationsConfig, 

720 ) -> None: 

721 self.pool = pool 

722 self.config = config 

723 self._processed_notifications = set() 

724 self._lock = threading.RLock() 

725 self.connection = None 

726 

727 def set_connection(self, connection: "MaintNotificationsAbstractConnection"): 

728 self.connection = connection 

729 

730 def get_handler_for_connection(self): 

731 # Copy all data that should be shared between connections 

732 # but each connection should have its own pool handler 

733 # since each connection can be in a different state 

734 copy = MaintNotificationsPoolHandler(self.pool, self.config) 

735 copy._processed_notifications = self._processed_notifications 

736 copy._lock = self._lock 

737 copy.connection = None 

738 return copy 

739 

740 def remove_expired_notifications(self): 

741 with self._lock: 

742 for notification in tuple(self._processed_notifications): 

743 if notification.is_expired(): 

744 self._processed_notifications.remove(notification) 

745 

746 def handle_notification(self, notification: MaintenanceNotification): 

747 self.remove_expired_notifications() 

748 

749 if isinstance(notification, NodeMovingNotification): 

750 return self.handle_node_moving_notification(notification) 

751 else: 

752 logger.error(f"Unhandled notification type: {notification}") 

753 

754 def handle_node_moving_notification(self, notification: NodeMovingNotification): 

755 if ( 

756 not self.config.proactive_reconnect 

757 and not self.config.is_relaxed_timeouts_enabled() 

758 ): 

759 return 

760 with self._lock: 

761 if notification in self._processed_notifications: 

762 # nothing to do in the connection pool handling 

763 # the notification has already been handled or is expired 

764 # just return 

765 return 

766 

767 with self.pool._lock: 

768 logger.debug( 

769 f"Handling node MOVING notification: {notification}, " 

770 f"with connection: {self.connection}, connected to ip " 

771 f"{self.connection.get_resolved_ip() if self.connection else None}" 

772 ) 

773 if ( 

774 self.config.proactive_reconnect 

775 or self.config.is_relaxed_timeouts_enabled() 

776 ): 

777 # Get the current connected address - if any 

778 # This is the address that is being moved 

779 # and we need to handle only connections 

780 # connected to the same address 

781 moving_address_src = ( 

782 self.connection.getpeername() if self.connection else None 

783 ) 

784 

785 if getattr(self.pool, "set_in_maintenance", False): 

786 # Set pool in maintenance mode - executed only if 

787 # BlockingConnectionPool is used 

788 self.pool.set_in_maintenance(True) 

789 

790 # Update maintenance state, timeout and optionally host address 

791 # connection settings for matching connections 

792 self.pool.update_connections_settings( 

793 state=MaintenanceState.MOVING, 

794 maintenance_notification_hash=hash(notification), 

795 relaxed_timeout=self.config.relaxed_timeout, 

796 host_address=notification.new_node_host, 

797 matching_address=moving_address_src, 

798 matching_pattern="connected_address", 

799 update_notification_hash=True, 

800 include_free_connections=True, 

801 ) 

802 

803 if self.config.proactive_reconnect: 

804 if notification.new_node_host is not None: 

805 self.run_proactive_reconnect(moving_address_src) 

806 else: 

807 threading.Timer( 

808 notification.ttl / 2, 

809 self.run_proactive_reconnect, 

810 args=(moving_address_src,), 

811 ).start() 

812 

813 # Update config for new connections: 

814 # Set state to MOVING 

815 # update host 

816 # if relax timeouts are enabled - update timeouts 

817 kwargs: dict = { 

818 "maintenance_state": MaintenanceState.MOVING, 

819 "maintenance_notification_hash": hash(notification), 

820 } 

821 if notification.new_node_host is not None: 

822 # the host is not updated if the new node host is None 

823 # this happens when the MOVING push notification does not contain 

824 # the new node host - in this case we only update the timeouts 

825 kwargs.update( 

826 { 

827 "host": notification.new_node_host, 

828 } 

829 ) 

830 if self.config.is_relaxed_timeouts_enabled(): 

831 kwargs.update( 

832 { 

833 "socket_timeout": self.config.relaxed_timeout, 

834 "socket_connect_timeout": self.config.relaxed_timeout, 

835 } 

836 ) 

837 self.pool.update_connection_kwargs(**kwargs) 

838 

839 if getattr(self.pool, "set_in_maintenance", False): 

840 self.pool.set_in_maintenance(False) 

841 

842 threading.Timer( 

843 notification.ttl, 

844 self.handle_node_moved_notification, 

845 args=(notification,), 

846 ).start() 

847 

848 self._processed_notifications.add(notification) 

849 

850 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None): 

851 """ 

852 Run proactive reconnect for the pool. 

853 Active connections are marked for reconnect after they complete the current command. 

854 Inactive connections are disconnected and will be connected on next use. 

855 """ 

856 with self._lock: 

857 with self.pool._lock: 

858 # take care for the active connections in the pool 

859 # mark them for reconnect after they complete the current command 

860 self.pool.update_active_connections_for_reconnect( 

861 moving_address_src=moving_address_src, 

862 ) 

863 # take care for the inactive connections in the pool 

864 # delete them and create new ones 

865 self.pool.disconnect_free_connections( 

866 moving_address_src=moving_address_src, 

867 ) 

868 

869 def handle_node_moved_notification(self, notification: NodeMovingNotification): 

870 """ 

871 Handle the cleanup after a node moving notification expires. 

872 """ 

873 notification_hash = hash(notification) 

874 

875 with self._lock: 

876 logger.debug( 

877 f"Reverting temporary changes related to notification: {notification}, " 

878 f"with connection: {self.connection}, connected to ip " 

879 f"{self.connection.get_resolved_ip() if self.connection else None}" 

880 ) 

881 # if the current maintenance_notification_hash in kwargs is not matching the notification 

882 # it means there has been a new moving notification after this one 

883 # and we don't need to revert the kwargs yet 

884 if ( 

885 self.pool.connection_kwargs.get("maintenance_notification_hash") 

886 == notification_hash 

887 ): 

888 orig_host = self.pool.connection_kwargs.get("orig_host_address") 

889 orig_socket_timeout = self.pool.connection_kwargs.get( 

890 "orig_socket_timeout" 

891 ) 

892 orig_connect_timeout = self.pool.connection_kwargs.get( 

893 "orig_socket_connect_timeout" 

894 ) 

895 kwargs: dict = { 

896 "maintenance_state": MaintenanceState.NONE, 

897 "maintenance_notification_hash": None, 

898 "host": orig_host, 

899 "socket_timeout": orig_socket_timeout, 

900 "socket_connect_timeout": orig_connect_timeout, 

901 } 

902 self.pool.update_connection_kwargs(**kwargs) 

903 

904 with self.pool._lock: 

905 reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled() 

906 reset_host_address = self.config.proactive_reconnect 

907 

908 self.pool.update_connections_settings( 

909 relaxed_timeout=-1, 

910 state=MaintenanceState.NONE, 

911 maintenance_notification_hash=None, 

912 matching_notification_hash=notification_hash, 

913 matching_pattern="notification_hash", 

914 update_notification_hash=True, 

915 reset_relaxed_timeout=reset_relaxed_timeout, 

916 reset_host_address=reset_host_address, 

917 include_free_connections=True, 

918 ) 

919 

920 

921class MaintNotificationsConnectionHandler: 

922 # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications 

923 _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = { 

924 NodeMigratingNotification: 1, 

925 NodeFailingOverNotification: 1, 

926 OSSNodeMigratingNotification: 1, 

927 NodeMigratedNotification: 0, 

928 NodeFailedOverNotification: 0, 

929 OSSNodeMigratedNotification: 0, 

930 } 

931 

932 def __init__( 

933 self, 

934 connection: "MaintNotificationsAbstractConnection", 

935 config: MaintNotificationsConfig, 

936 ) -> None: 

937 self.connection = connection 

938 self.config = config 

939 

940 def handle_notification(self, notification: MaintenanceNotification): 

941 # get the notification type by checking its class in the _NOTIFICATION_TYPES dict 

942 notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None) 

943 

944 if notification_type is None: 

945 logger.error(f"Unhandled notification type: {notification}") 

946 return 

947 

948 if notification_type: 

949 self.handle_maintenance_start_notification( 

950 MaintenanceState.MAINTENANCE, notification 

951 ) 

952 else: 

953 self.handle_maintenance_completed_notification() 

954 

955 def handle_maintenance_start_notification( 

956 self, maintenance_state: MaintenanceState, notification: MaintenanceNotification 

957 ): 

958 add_debug_log_for_notification(self.connection, notification) 

959 

960 if ( 

961 self.connection.maintenance_state == MaintenanceState.MOVING 

962 or not self.config.is_relaxed_timeouts_enabled() 

963 ): 

964 return 

965 

966 self.connection.maintenance_state = maintenance_state 

967 self.connection.set_tmp_settings( 

968 tmp_relaxed_timeout=self.config.relaxed_timeout 

969 ) 

970 # extend the timeout for all created connections 

971 self.connection.update_current_socket_timeout(self.config.relaxed_timeout) 

972 if isinstance(notification, OSSNodeMigratingNotification): 

973 # add the notification id to the set of processed start maint notifications 

974 # this is used to skip the unrelaxing of the timeouts if we have received more than 

975 # one start notification before the the final end notification 

976 self.connection.add_maint_start_notification(notification.id) 

977 

978 def handle_maintenance_completed_notification(self): 

979 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled 

980 if ( 

981 self.connection.maintenance_state == MaintenanceState.MOVING 

982 or not self.config.is_relaxed_timeouts_enabled() 

983 ): 

984 return 

985 add_debug_log_for_notification(self.connection, "MAINTENANCE_COMPLETED") 

986 self.connection.reset_tmp_settings(reset_relaxed_timeout=True) 

987 # Maintenance completed - reset the connection 

988 # timeouts by providing -1 as the relaxed timeout 

989 self.connection.update_current_socket_timeout(-1) 

990 self.connection.maintenance_state = MaintenanceState.NONE 

991 # reset the sets that keep track of received start maint 

992 # notifications and skipped end maint notifications 

993 self.connection.reset_received_notifications() 

994 

995 

996class OSSMaintNotificationsHandler: 

997 def __init__( 

998 self, 

999 cluster_client: "MaintNotificationsAbstractRedisCluster", 

1000 config: MaintNotificationsConfig, 

1001 ) -> None: 

1002 self.cluster_client = cluster_client 

1003 self.config = config 

1004 self._processed_notifications = set() 

1005 self._in_progress = set() 

1006 self._lock = threading.RLock() 

1007 

1008 def get_handler_for_connection(self): 

1009 # Copy all data that should be shared between connections 

1010 # but each connection should have its own pool handler 

1011 # since each connection can be in a different state 

1012 copy = OSSMaintNotificationsHandler(self.cluster_client, self.config) 

1013 copy._processed_notifications = self._processed_notifications 

1014 copy._in_progress = self._in_progress 

1015 copy._lock = self._lock 

1016 return copy 

1017 

1018 def remove_expired_notifications(self): 

1019 with self._lock: 

1020 for notification in tuple(self._processed_notifications): 

1021 if notification.is_expired(): 

1022 self._processed_notifications.remove(notification) 

1023 

1024 def handle_notification(self, notification: MaintenanceNotification): 

1025 if isinstance(notification, OSSNodeMigratedNotification): 

1026 self.handle_oss_maintenance_completed_notification(notification) 

1027 else: 

1028 logger.error(f"Unhandled notification type: {notification}") 

1029 

1030 def handle_oss_maintenance_completed_notification( 

1031 self, notification: OSSNodeMigratedNotification 

1032 ): 

1033 self.remove_expired_notifications() 

1034 

1035 with self._lock: 

1036 if ( 

1037 notification in self._in_progress 

1038 or notification in self._processed_notifications 

1039 ): 

1040 # we are already handling this notification or it has already been processed 

1041 # we should skip in_progress notification since when we reinitialize the cluster 

1042 # we execute a CLUSTER SLOTS command that can use a different connection 

1043 # that has also has the notification and we don't want to 

1044 # process the same notification twice 

1045 return 

1046 

1047 logger.debug(f"Handling SMIGRATED notification: {notification}") 

1048 self._in_progress.add(notification) 

1049 

1050 # Extract the information about the src and destination nodes that are affected 

1051 # by the maintenance. nodes_to_slots_mapping structure: 

1052 # { 

1053 # "src_host:port": [ 

1054 # {"dest_host:port": "slot_range"}, 

1055 # ... 

1056 # ], 

1057 # ... 

1058 # } 

1059 additional_startup_nodes_info = [] 

1060 affected_nodes = set() 

1061 for ( 

1062 src_address, 

1063 dest_mappings, 

1064 ) in notification.nodes_to_slots_mapping.items(): 

1065 src_host, src_port = src_address.split(":") 

1066 src_node = self.cluster_client.nodes_manager.get_node( 

1067 host=src_host, port=src_port 

1068 ) 

1069 if src_node is not None: 

1070 affected_nodes.add(src_node) 

1071 

1072 for dest_mapping in dest_mappings: 

1073 for dest_address in dest_mapping.keys(): 

1074 dest_host, dest_port = dest_address.split(":") 

1075 additional_startup_nodes_info.append( 

1076 (dest_host, int(dest_port)) 

1077 ) 

1078 

1079 # Updates the cluster slots cache with the new slots mapping 

1080 # This will also update the nodes cache with the new nodes mapping 

1081 self.cluster_client.nodes_manager.initialize( 

1082 disconnect_startup_nodes_pools=False, 

1083 additional_startup_nodes_info=additional_startup_nodes_info, 

1084 ) 

1085 

1086 all_nodes = set(affected_nodes) 

1087 all_nodes = all_nodes.union( 

1088 self.cluster_client.nodes_manager.nodes_cache.values() 

1089 ) 

1090 

1091 for current_node in all_nodes: 

1092 if current_node.redis_connection is None: 

1093 continue 

1094 with current_node.redis_connection.connection_pool._lock: 

1095 if current_node in affected_nodes: 

1096 # mark for reconnect all in use connections to the node - this will force them to 

1097 # disconnect after they complete their current commands 

1098 # Some of them might be used by sub sub and we don't know which ones - so we disconnect 

1099 # all in flight connections after they are done with current command execution 

1100 for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): 

1101 conn.mark_for_reconnect() 

1102 

1103 if ( 

1104 current_node 

1105 not in self.cluster_client.nodes_manager.nodes_cache.values() 

1106 ): 

1107 # disconnect all free connections to the node - this node will be dropped 

1108 # from the cluster, so we don't need to revert the timeouts 

1109 for conn in current_node.redis_connection.connection_pool._get_free_connections(): 

1110 conn.disconnect() 

1111 

1112 # mark the notification as processed 

1113 self._processed_notifications.add(notification) 

1114 self._in_progress.remove(notification)