Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/maint_notifications.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

239 statements  

1import enum 

2import ipaddress 

3import logging 

4import re 

5import threading 

6import time 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Literal, Optional, Union 

9 

10from redis.typing import Number 

11 

12 

13class MaintenanceState(enum.Enum): 

14 NONE = "none" 

15 MOVING = "moving" 

16 MAINTENANCE = "maintenance" 

17 

18 

19class EndpointType(enum.Enum): 

20 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command.""" 

21 

22 INTERNAL_IP = "internal-ip" 

23 INTERNAL_FQDN = "internal-fqdn" 

24 EXTERNAL_IP = "external-ip" 

25 EXTERNAL_FQDN = "external-fqdn" 

26 NONE = "none" 

27 

28 def __str__(self): 

29 """Return the string value of the enum.""" 

30 return self.value 

31 

32 

33if TYPE_CHECKING: 

34 from redis.connection import ( 

35 BlockingConnectionPool, 

36 ConnectionInterface, 

37 ConnectionPool, 

38 ) 

39 

40 

41class MaintenanceNotification(ABC): 

42 """ 

43 Base class for maintenance notifications sent through push messages by Redis server. 

44 

45 This class provides common functionality for all maintenance notifications including 

46 unique identification and TTL (Time-To-Live) functionality. 

47 

48 Attributes: 

49 id (int): Unique identifier for this notification 

50 ttl (int): Time-to-live in seconds for this notification 

51 creation_time (float): Timestamp when the notification was created/read 

52 """ 

53 

54 def __init__(self, id: int, ttl: int): 

55 """ 

56 Initialize a new MaintenanceNotification with unique ID and TTL functionality. 

57 

58 Args: 

59 id (int): Unique identifier for this notification 

60 ttl (int): Time-to-live in seconds for this notification 

61 """ 

62 self.id = id 

63 self.ttl = ttl 

64 self.creation_time = time.monotonic() 

65 self.expire_at = self.creation_time + self.ttl 

66 

67 def is_expired(self) -> bool: 

68 """ 

69 Check if this notification has expired based on its TTL 

70 and creation time. 

71 

72 Returns: 

73 bool: True if the notification has expired, False otherwise 

74 """ 

75 return time.monotonic() > (self.creation_time + self.ttl) 

76 

77 @abstractmethod 

78 def __repr__(self) -> str: 

79 """ 

80 Return a string representation of the maintenance notification. 

81 

82 This method must be implemented by all concrete subclasses. 

83 

84 Returns: 

85 str: String representation of the notification 

86 """ 

87 pass 

88 

89 @abstractmethod 

90 def __eq__(self, other) -> bool: 

91 """ 

92 Compare two maintenance notifications for equality. 

93 

94 This method must be implemented by all concrete subclasses. 

95 Notifications are typically considered equal if they have the same id 

96 and are of the same type. 

97 

98 Args: 

99 other: The other object to compare with 

100 

101 Returns: 

102 bool: True if the notifications are equal, False otherwise 

103 """ 

104 pass 

105 

106 @abstractmethod 

107 def __hash__(self) -> int: 

108 """ 

109 Return a hash value for the maintenance notification. 

110 

111 This method must be implemented by all concrete subclasses to allow 

112 instances to be used in sets and as dictionary keys. 

113 

114 Returns: 

115 int: Hash value for the notification 

116 """ 

117 pass 

118 

119 

120class NodeMovingNotification(MaintenanceNotification): 

121 """ 

122 This notification is received when a node is replaced with a new node 

123 during cluster rebalancing or maintenance operations. 

124 """ 

125 

126 def __init__( 

127 self, 

128 id: int, 

129 new_node_host: Optional[str], 

130 new_node_port: Optional[int], 

131 ttl: int, 

132 ): 

133 """ 

134 Initialize a new NodeMovingNotification. 

135 

136 Args: 

137 id (int): Unique identifier for this notification 

138 new_node_host (str): Hostname or IP address of the new replacement node 

139 new_node_port (int): Port number of the new replacement node 

140 ttl (int): Time-to-live in seconds for this notification 

141 """ 

142 super().__init__(id, ttl) 

143 self.new_node_host = new_node_host 

144 self.new_node_port = new_node_port 

145 

146 def __repr__(self) -> str: 

147 expiry_time = self.expire_at 

148 remaining = max(0, expiry_time - time.monotonic()) 

149 

150 return ( 

151 f"{self.__class__.__name__}(" 

152 f"id={self.id}, " 

153 f"new_node_host='{self.new_node_host}', " 

154 f"new_node_port={self.new_node_port}, " 

155 f"ttl={self.ttl}, " 

156 f"creation_time={self.creation_time}, " 

157 f"expires_at={expiry_time}, " 

158 f"remaining={remaining:.1f}s, " 

159 f"expired={self.is_expired()}" 

160 f")" 

161 ) 

162 

163 def __eq__(self, other) -> bool: 

164 """ 

165 Two NodeMovingNotification notifications are considered equal if they have the same 

166 id, new_node_host, and new_node_port. 

167 """ 

168 if not isinstance(other, NodeMovingNotification): 

169 return False 

170 return ( 

171 self.id == other.id 

172 and self.new_node_host == other.new_node_host 

173 and self.new_node_port == other.new_node_port 

174 ) 

175 

176 def __hash__(self) -> int: 

177 """ 

178 Return a hash value for the notification to allow 

179 instances to be used in sets and as dictionary keys. 

180 

181 Returns: 

182 int: Hash value based on notification type class name, id, 

183 new_node_host and new_node_port 

184 """ 

185 try: 

186 node_port = int(self.new_node_port) if self.new_node_port else None 

187 except ValueError: 

188 node_port = 0 

189 

190 return hash( 

191 ( 

192 self.__class__.__name__, 

193 int(self.id), 

194 str(self.new_node_host), 

195 node_port, 

196 ) 

197 ) 

198 

199 

200class NodeMigratingNotification(MaintenanceNotification): 

201 """ 

202 Notification for when a Redis cluster node is in the process of migrating slots. 

203 

204 This notification is received when a node starts migrating its slots to another node 

205 during cluster rebalancing or maintenance operations. 

206 

207 Args: 

208 id (int): Unique identifier for this notification 

209 ttl (int): Time-to-live in seconds for this notification 

210 """ 

211 

212 def __init__(self, id: int, ttl: int): 

213 super().__init__(id, ttl) 

214 

215 def __repr__(self) -> str: 

216 expiry_time = self.creation_time + self.ttl 

217 remaining = max(0, expiry_time - time.monotonic()) 

218 return ( 

219 f"{self.__class__.__name__}(" 

220 f"id={self.id}, " 

221 f"ttl={self.ttl}, " 

222 f"creation_time={self.creation_time}, " 

223 f"expires_at={expiry_time}, " 

224 f"remaining={remaining:.1f}s, " 

225 f"expired={self.is_expired()}" 

226 f")" 

227 ) 

228 

229 def __eq__(self, other) -> bool: 

230 """ 

231 Two NodeMigratingNotification notifications are considered equal if they have the same 

232 id and are of the same type. 

233 """ 

234 if not isinstance(other, NodeMigratingNotification): 

235 return False 

236 return self.id == other.id and type(self) is type(other) 

237 

238 def __hash__(self) -> int: 

239 """ 

240 Return a hash value for the notification to allow 

241 instances to be used in sets and as dictionary keys. 

242 

243 Returns: 

244 int: Hash value based on notification type and id 

245 """ 

246 return hash((self.__class__.__name__, int(self.id))) 

247 

248 

249class NodeMigratedNotification(MaintenanceNotification): 

250 """ 

251 Notification for when a Redis cluster node has completed migrating slots. 

252 

253 This notification is received when a node has finished migrating all its slots 

254 to other nodes during cluster rebalancing or maintenance operations. 

255 

256 Args: 

257 id (int): Unique identifier for this notification 

258 """ 

259 

260 DEFAULT_TTL = 5 

261 

262 def __init__(self, id: int): 

263 super().__init__(id, NodeMigratedNotification.DEFAULT_TTL) 

264 

265 def __repr__(self) -> str: 

266 expiry_time = self.creation_time + self.ttl 

267 remaining = max(0, expiry_time - time.monotonic()) 

268 return ( 

269 f"{self.__class__.__name__}(" 

270 f"id={self.id}, " 

271 f"ttl={self.ttl}, " 

272 f"creation_time={self.creation_time}, " 

273 f"expires_at={expiry_time}, " 

274 f"remaining={remaining:.1f}s, " 

275 f"expired={self.is_expired()}" 

276 f")" 

277 ) 

278 

279 def __eq__(self, other) -> bool: 

280 """ 

281 Two NodeMigratedNotification notifications are considered equal if they have the same 

282 id and are of the same type. 

283 """ 

284 if not isinstance(other, NodeMigratedNotification): 

285 return False 

286 return self.id == other.id and type(self) is type(other) 

287 

288 def __hash__(self) -> int: 

289 """ 

290 Return a hash value for the notification to allow 

291 instances to be used in sets and as dictionary keys. 

292 

293 Returns: 

294 int: Hash value based on notification type and id 

295 """ 

296 return hash((self.__class__.__name__, int(self.id))) 

297 

298 

299class NodeFailingOverNotification(MaintenanceNotification): 

300 """ 

301 Notification for when a Redis cluster node is in the process of failing over. 

302 

303 This notification is received when a node starts a failover process during 

304 cluster maintenance operations or when handling node failures. 

305 

306 Args: 

307 id (int): Unique identifier for this notification 

308 ttl (int): Time-to-live in seconds for this notification 

309 """ 

310 

311 def __init__(self, id: int, ttl: int): 

312 super().__init__(id, ttl) 

313 

314 def __repr__(self) -> str: 

315 expiry_time = self.creation_time + self.ttl 

316 remaining = max(0, expiry_time - time.monotonic()) 

317 return ( 

318 f"{self.__class__.__name__}(" 

319 f"id={self.id}, " 

320 f"ttl={self.ttl}, " 

321 f"creation_time={self.creation_time}, " 

322 f"expires_at={expiry_time}, " 

323 f"remaining={remaining:.1f}s, " 

324 f"expired={self.is_expired()}" 

325 f")" 

326 ) 

327 

328 def __eq__(self, other) -> bool: 

329 """ 

330 Two NodeFailingOverNotification notifications are considered equal if they have the same 

331 id and are of the same type. 

332 """ 

333 if not isinstance(other, NodeFailingOverNotification): 

334 return False 

335 return self.id == other.id and type(self) is type(other) 

336 

337 def __hash__(self) -> int: 

338 """ 

339 Return a hash value for the notification to allow 

340 instances to be used in sets and as dictionary keys. 

341 

342 Returns: 

343 int: Hash value based on notification type and id 

344 """ 

345 return hash((self.__class__.__name__, int(self.id))) 

346 

347 

348class NodeFailedOverNotification(MaintenanceNotification): 

349 """ 

350 Notification for when a Redis cluster node has completed a failover. 

351 

352 This notification is received when a node has finished the failover process 

353 during cluster maintenance operations or after handling node failures. 

354 

355 Args: 

356 id (int): Unique identifier for this notification 

357 """ 

358 

359 DEFAULT_TTL = 5 

360 

361 def __init__(self, id: int): 

362 super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL) 

363 

364 def __repr__(self) -> str: 

365 expiry_time = self.creation_time + self.ttl 

366 remaining = max(0, expiry_time - time.monotonic()) 

367 return ( 

368 f"{self.__class__.__name__}(" 

369 f"id={self.id}, " 

370 f"ttl={self.ttl}, " 

371 f"creation_time={self.creation_time}, " 

372 f"expires_at={expiry_time}, " 

373 f"remaining={remaining:.1f}s, " 

374 f"expired={self.is_expired()}" 

375 f")" 

376 ) 

377 

378 def __eq__(self, other) -> bool: 

379 """ 

380 Two NodeFailedOverNotification notifications are considered equal if they have the same 

381 id and are of the same type. 

382 """ 

383 if not isinstance(other, NodeFailedOverNotification): 

384 return False 

385 return self.id == other.id and type(self) is type(other) 

386 

387 def __hash__(self) -> int: 

388 """ 

389 Return a hash value for the notification to allow 

390 instances to be used in sets and as dictionary keys. 

391 

392 Returns: 

393 int: Hash value based on notification type and id 

394 """ 

395 return hash((self.__class__.__name__, int(self.id))) 

396 

397 

398def _is_private_fqdn(host: str) -> bool: 

399 """ 

400 Determine if an FQDN is likely to be internal/private. 

401 

402 This uses heuristics based on RFC 952 and RFC 1123 standards: 

403 - .local domains (RFC 6762 - Multicast DNS) 

404 - .internal domains (common internal convention) 

405 - Single-label hostnames (no dots) 

406 - Common internal TLDs 

407 

408 Args: 

409 host (str): The FQDN to check 

410 

411 Returns: 

412 bool: True if the FQDN appears to be internal/private 

413 """ 

414 host_lower = host.lower().rstrip(".") 

415 

416 # Single-label hostnames (no dots) are typically internal 

417 if "." not in host_lower: 

418 return True 

419 

420 # Common internal/private domain patterns 

421 internal_patterns = [ 

422 r"\.local$", # mDNS/Bonjour domains 

423 r"\.internal$", # Common internal convention 

424 r"\.corp$", # Corporate domains 

425 r"\.lan$", # Local area network 

426 r"\.intranet$", # Intranet domains 

427 r"\.private$", # Private domains 

428 ] 

429 

430 for pattern in internal_patterns: 

431 if re.search(pattern, host_lower): 

432 return True 

433 

434 # If none of the internal patterns match, assume it's external 

435 return False 

436 

437 

438class MaintNotificationsConfig: 

439 """ 

440 Configuration class for maintenance notifications handling behaviour. Notifications are received through 

441 push notifications. 

442 

443 This class defines how the Redis client should react to different push notifications 

444 such as node moving, migrations, etc. in a Redis cluster. 

445 

446 """ 

447 

448 def __init__( 

449 self, 

450 enabled: Union[bool, Literal["auto"]] = "auto", 

451 proactive_reconnect: bool = True, 

452 relaxed_timeout: Optional[Number] = 10, 

453 endpoint_type: Optional[EndpointType] = None, 

454 ): 

455 """ 

456 Initialize a new MaintNotificationsConfig. 

457 

458 Args: 

459 enabled (bool | "auto"): Controls maintenance notifications handling behavior. 

460 - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup, 

461 otherwise a ResponseError is raised. 

462 - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are 

463 gracefully handled - a warning is logged and normal operation continues. 

464 - False: Maintenance notifications are completely disabled. 

465 Defaults to "auto". 

466 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. 

467 Defaults to True. 

468 relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance. 

469 If -1 is provided - the relaxed timeout is disabled. Defaults to 20. 

470 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS. 

471 If None, the endpoint type will be automatically determined based on the host and TLS configuration. 

472 Defaults to None. 

473 

474 Raises: 

475 ValueError: If endpoint_type is provided but is not a valid endpoint type. 

476 """ 

477 self.enabled = enabled 

478 self.relaxed_timeout = relaxed_timeout 

479 self.proactive_reconnect = proactive_reconnect 

480 self.endpoint_type = endpoint_type 

481 

482 def __repr__(self) -> str: 

483 return ( 

484 f"{self.__class__.__name__}(" 

485 f"enabled={self.enabled}, " 

486 f"proactive_reconnect={self.proactive_reconnect}, " 

487 f"relaxed_timeout={self.relaxed_timeout}, " 

488 f"endpoint_type={self.endpoint_type!r}" 

489 f")" 

490 ) 

491 

492 def is_relaxed_timeouts_enabled(self) -> bool: 

493 """ 

494 Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout. 

495 If relaxed_timeout is set to None, it will make the operation blocking 

496 and waiting until any response is received. 

497 

498 Returns: 

499 True if the relaxed_timeout is enabled, False otherwise. 

500 """ 

501 return self.relaxed_timeout != -1 

502 

503 def get_endpoint_type( 

504 self, host: str, connection: "ConnectionInterface" 

505 ) -> EndpointType: 

506 """ 

507 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. 

508 

509 Logic: 

510 1. If endpoint_type is explicitly set, use it 

511 2. Otherwise, check the original host from connection.host: 

512 - If host is an IP address, use it directly to determine internal-ip vs external-ip 

513 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn 

514 

515 Args: 

516 host: User provided hostname to analyze 

517 connection: The connection object to analyze for endpoint type determination 

518 

519 Returns: 

520 """ 

521 

522 # If endpoint_type is explicitly set, use it 

523 if self.endpoint_type is not None: 

524 return self.endpoint_type 

525 

526 # Check if the host is an IP address 

527 try: 

528 ip_addr = ipaddress.ip_address(host) 

529 # Host is an IP address - use it directly 

530 is_private = ip_addr.is_private 

531 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP 

532 except ValueError: 

533 # Host is an FQDN - need to check resolved IP to determine internal vs external 

534 pass 

535 

536 # Host is an FQDN, get the resolved IP to determine if it's internal or external 

537 resolved_ip = connection.get_resolved_ip() 

538 

539 if resolved_ip: 

540 try: 

541 ip_addr = ipaddress.ip_address(resolved_ip) 

542 is_private = ip_addr.is_private 

543 # Use FQDN types since the original host was an FQDN 

544 return ( 

545 EndpointType.INTERNAL_FQDN 

546 if is_private 

547 else EndpointType.EXTERNAL_FQDN 

548 ) 

549 except ValueError: 

550 # This shouldn't happen since we got the IP from the socket, but fallback 

551 pass 

552 

553 # Final fallback: use heuristics on the FQDN itself 

554 is_private = _is_private_fqdn(host) 

555 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN 

556 

557 

558class MaintNotificationsPoolHandler: 

559 def __init__( 

560 self, 

561 pool: Union["ConnectionPool", "BlockingConnectionPool"], 

562 config: MaintNotificationsConfig, 

563 ) -> None: 

564 self.pool = pool 

565 self.config = config 

566 self._processed_notifications = set() 

567 self._lock = threading.RLock() 

568 self.connection = None 

569 

570 def set_connection(self, connection: "ConnectionInterface"): 

571 self.connection = connection 

572 

573 def remove_expired_notifications(self): 

574 with self._lock: 

575 for notification in tuple(self._processed_notifications): 

576 if notification.is_expired(): 

577 self._processed_notifications.remove(notification) 

578 

579 def handle_notification(self, notification: MaintenanceNotification): 

580 self.remove_expired_notifications() 

581 

582 if isinstance(notification, NodeMovingNotification): 

583 return self.handle_node_moving_notification(notification) 

584 else: 

585 logging.error(f"Unhandled notification type: {notification}") 

586 

587 def handle_node_moving_notification(self, notification: NodeMovingNotification): 

588 if ( 

589 not self.config.proactive_reconnect 

590 and not self.config.is_relaxed_timeouts_enabled() 

591 ): 

592 return 

593 with self._lock: 

594 if notification in self._processed_notifications: 

595 # nothing to do in the connection pool handling 

596 # the notification has already been handled or is expired 

597 # just return 

598 return 

599 

600 with self.pool._lock: 

601 if ( 

602 self.config.proactive_reconnect 

603 or self.config.is_relaxed_timeouts_enabled() 

604 ): 

605 # Get the current connected address - if any 

606 # This is the address that is being moved 

607 # and we need to handle only connections 

608 # connected to the same address 

609 moving_address_src = ( 

610 self.connection.getpeername() if self.connection else None 

611 ) 

612 

613 if getattr(self.pool, "set_in_maintenance", False): 

614 # Set pool in maintenance mode - executed only if 

615 # BlockingConnectionPool is used 

616 self.pool.set_in_maintenance(True) 

617 

618 # Update maintenance state, timeout and optionally host address 

619 # connection settings for matching connections 

620 self.pool.update_connections_settings( 

621 state=MaintenanceState.MOVING, 

622 maintenance_notification_hash=hash(notification), 

623 relaxed_timeout=self.config.relaxed_timeout, 

624 host_address=notification.new_node_host, 

625 matching_address=moving_address_src, 

626 matching_pattern="connected_address", 

627 update_notification_hash=True, 

628 include_free_connections=True, 

629 ) 

630 

631 if self.config.proactive_reconnect: 

632 if notification.new_node_host is not None: 

633 self.run_proactive_reconnect(moving_address_src) 

634 else: 

635 threading.Timer( 

636 notification.ttl / 2, 

637 self.run_proactive_reconnect, 

638 args=(moving_address_src,), 

639 ).start() 

640 

641 # Update config for new connections: 

642 # Set state to MOVING 

643 # update host 

644 # if relax timeouts are enabled - update timeouts 

645 kwargs: dict = { 

646 "maintenance_state": MaintenanceState.MOVING, 

647 "maintenance_notification_hash": hash(notification), 

648 } 

649 if notification.new_node_host is not None: 

650 # the host is not updated if the new node host is None 

651 # this happens when the MOVING push notification does not contain 

652 # the new node host - in this case we only update the timeouts 

653 kwargs.update( 

654 { 

655 "host": notification.new_node_host, 

656 } 

657 ) 

658 if self.config.is_relaxed_timeouts_enabled(): 

659 kwargs.update( 

660 { 

661 "socket_timeout": self.config.relaxed_timeout, 

662 "socket_connect_timeout": self.config.relaxed_timeout, 

663 } 

664 ) 

665 self.pool.update_connection_kwargs(**kwargs) 

666 

667 if getattr(self.pool, "set_in_maintenance", False): 

668 self.pool.set_in_maintenance(False) 

669 

670 threading.Timer( 

671 notification.ttl, 

672 self.handle_node_moved_notification, 

673 args=(notification,), 

674 ).start() 

675 

676 self._processed_notifications.add(notification) 

677 

678 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None): 

679 """ 

680 Run proactive reconnect for the pool. 

681 Active connections are marked for reconnect after they complete the current command. 

682 Inactive connections are disconnected and will be connected on next use. 

683 """ 

684 with self._lock: 

685 with self.pool._lock: 

686 # take care for the active connections in the pool 

687 # mark them for reconnect after they complete the current command 

688 self.pool.update_active_connections_for_reconnect( 

689 moving_address_src=moving_address_src, 

690 ) 

691 # take care for the inactive connections in the pool 

692 # delete them and create new ones 

693 self.pool.disconnect_free_connections( 

694 moving_address_src=moving_address_src, 

695 ) 

696 

697 def handle_node_moved_notification(self, notification: NodeMovingNotification): 

698 """ 

699 Handle the cleanup after a node moving notification expires. 

700 """ 

701 notification_hash = hash(notification) 

702 

703 with self._lock: 

704 # if the current maintenance_notification_hash in kwargs is not matching the notification 

705 # it means there has been a new moving notification after this one 

706 # and we don't need to revert the kwargs yet 

707 if ( 

708 self.pool.connection_kwargs.get("maintenance_notification_hash") 

709 == notification_hash 

710 ): 

711 orig_host = self.pool.connection_kwargs.get("orig_host_address") 

712 orig_socket_timeout = self.pool.connection_kwargs.get( 

713 "orig_socket_timeout" 

714 ) 

715 orig_connect_timeout = self.pool.connection_kwargs.get( 

716 "orig_socket_connect_timeout" 

717 ) 

718 kwargs: dict = { 

719 "maintenance_state": MaintenanceState.NONE, 

720 "maintenance_notification_hash": None, 

721 "host": orig_host, 

722 "socket_timeout": orig_socket_timeout, 

723 "socket_connect_timeout": orig_connect_timeout, 

724 } 

725 self.pool.update_connection_kwargs(**kwargs) 

726 

727 with self.pool._lock: 

728 reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled() 

729 reset_host_address = self.config.proactive_reconnect 

730 

731 self.pool.update_connections_settings( 

732 relaxed_timeout=-1, 

733 state=MaintenanceState.NONE, 

734 maintenance_notification_hash=None, 

735 matching_notification_hash=notification_hash, 

736 matching_pattern="notification_hash", 

737 update_notification_hash=True, 

738 reset_relaxed_timeout=reset_relaxed_timeout, 

739 reset_host_address=reset_host_address, 

740 include_free_connections=True, 

741 ) 

742 

743 

744class MaintNotificationsConnectionHandler: 

745 # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications 

746 _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = { 

747 NodeMigratingNotification: 1, 

748 NodeFailingOverNotification: 1, 

749 NodeMigratedNotification: 0, 

750 NodeFailedOverNotification: 0, 

751 } 

752 

753 def __init__( 

754 self, connection: "ConnectionInterface", config: MaintNotificationsConfig 

755 ) -> None: 

756 self.connection = connection 

757 self.config = config 

758 

759 def handle_notification(self, notification: MaintenanceNotification): 

760 # get the notification type by checking its class in the _NOTIFICATION_TYPES dict 

761 notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None) 

762 

763 if notification_type is None: 

764 logging.error(f"Unhandled notification type: {notification}") 

765 return 

766 

767 if notification_type: 

768 self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE) 

769 else: 

770 self.handle_maintenance_completed_notification() 

771 

772 def handle_maintenance_start_notification( 

773 self, maintenance_state: MaintenanceState 

774 ): 

775 if ( 

776 self.connection.maintenance_state == MaintenanceState.MOVING 

777 or not self.config.is_relaxed_timeouts_enabled() 

778 ): 

779 return 

780 

781 self.connection.maintenance_state = maintenance_state 

782 self.connection.set_tmp_settings( 

783 tmp_relaxed_timeout=self.config.relaxed_timeout 

784 ) 

785 # extend the timeout for all created connections 

786 self.connection.update_current_socket_timeout(self.config.relaxed_timeout) 

787 

788 def handle_maintenance_completed_notification(self): 

789 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled 

790 if ( 

791 self.connection.maintenance_state == MaintenanceState.MOVING 

792 or not self.config.is_relaxed_timeouts_enabled() 

793 ): 

794 return 

795 self.connection.reset_tmp_settings(reset_relaxed_timeout=True) 

796 # Maintenance completed - reset the connection 

797 # timeouts by providing -1 as the relaxed timeout 

798 self.connection.update_current_socket_timeout(-1) 

799 self.connection.maintenance_state = MaintenanceState.NONE