Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/maint_notifications.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

245 statements  

1import enum 

2import ipaddress 

3import logging 

4import re 

5import threading 

6import time 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Literal, Optional, Union 

9 

10from redis.typing import Number 

11 

12 

13class MaintenanceState(enum.Enum): 

14 NONE = "none" 

15 MOVING = "moving" 

16 MAINTENANCE = "maintenance" 

17 

18 

19class EndpointType(enum.Enum): 

20 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command.""" 

21 

22 INTERNAL_IP = "internal-ip" 

23 INTERNAL_FQDN = "internal-fqdn" 

24 EXTERNAL_IP = "external-ip" 

25 EXTERNAL_FQDN = "external-fqdn" 

26 NONE = "none" 

27 

28 def __str__(self): 

29 """Return the string value of the enum.""" 

30 return self.value 

31 

32 

33if TYPE_CHECKING: 

34 from redis.connection import ( 

35 MaintNotificationsAbstractConnection, 

36 MaintNotificationsAbstractConnectionPool, 

37 ) 

38 

39 

40class MaintenanceNotification(ABC): 

41 """ 

42 Base class for maintenance notifications sent through push messages by Redis server. 

43 

44 This class provides common functionality for all maintenance notifications including 

45 unique identification and TTL (Time-To-Live) functionality. 

46 

47 Attributes: 

48 id (int): Unique identifier for this notification 

49 ttl (int): Time-to-live in seconds for this notification 

50 creation_time (float): Timestamp when the notification was created/read 

51 """ 

52 

53 def __init__(self, id: int, ttl: int): 

54 """ 

55 Initialize a new MaintenanceNotification with unique ID and TTL functionality. 

56 

57 Args: 

58 id (int): Unique identifier for this notification 

59 ttl (int): Time-to-live in seconds for this notification 

60 """ 

61 self.id = id 

62 self.ttl = ttl 

63 self.creation_time = time.monotonic() 

64 self.expire_at = self.creation_time + self.ttl 

65 

66 def is_expired(self) -> bool: 

67 """ 

68 Check if this notification has expired based on its TTL 

69 and creation time. 

70 

71 Returns: 

72 bool: True if the notification has expired, False otherwise 

73 """ 

74 return time.monotonic() > (self.creation_time + self.ttl) 

75 

76 @abstractmethod 

77 def __repr__(self) -> str: 

78 """ 

79 Return a string representation of the maintenance notification. 

80 

81 This method must be implemented by all concrete subclasses. 

82 

83 Returns: 

84 str: String representation of the notification 

85 """ 

86 pass 

87 

88 @abstractmethod 

89 def __eq__(self, other) -> bool: 

90 """ 

91 Compare two maintenance notifications for equality. 

92 

93 This method must be implemented by all concrete subclasses. 

94 Notifications are typically considered equal if they have the same id 

95 and are of the same type. 

96 

97 Args: 

98 other: The other object to compare with 

99 

100 Returns: 

101 bool: True if the notifications are equal, False otherwise 

102 """ 

103 pass 

104 

105 @abstractmethod 

106 def __hash__(self) -> int: 

107 """ 

108 Return a hash value for the maintenance notification. 

109 

110 This method must be implemented by all concrete subclasses to allow 

111 instances to be used in sets and as dictionary keys. 

112 

113 Returns: 

114 int: Hash value for the notification 

115 """ 

116 pass 

117 

118 

119class NodeMovingNotification(MaintenanceNotification): 

120 """ 

121 This notification is received when a node is replaced with a new node 

122 during cluster rebalancing or maintenance operations. 

123 """ 

124 

125 def __init__( 

126 self, 

127 id: int, 

128 new_node_host: Optional[str], 

129 new_node_port: Optional[int], 

130 ttl: int, 

131 ): 

132 """ 

133 Initialize a new NodeMovingNotification. 

134 

135 Args: 

136 id (int): Unique identifier for this notification 

137 new_node_host (str): Hostname or IP address of the new replacement node 

138 new_node_port (int): Port number of the new replacement node 

139 ttl (int): Time-to-live in seconds for this notification 

140 """ 

141 super().__init__(id, ttl) 

142 self.new_node_host = new_node_host 

143 self.new_node_port = new_node_port 

144 

145 def __repr__(self) -> str: 

146 expiry_time = self.expire_at 

147 remaining = max(0, expiry_time - time.monotonic()) 

148 

149 return ( 

150 f"{self.__class__.__name__}(" 

151 f"id={self.id}, " 

152 f"new_node_host='{self.new_node_host}', " 

153 f"new_node_port={self.new_node_port}, " 

154 f"ttl={self.ttl}, " 

155 f"creation_time={self.creation_time}, " 

156 f"expires_at={expiry_time}, " 

157 f"remaining={remaining:.1f}s, " 

158 f"expired={self.is_expired()}" 

159 f")" 

160 ) 

161 

162 def __eq__(self, other) -> bool: 

163 """ 

164 Two NodeMovingNotification notifications are considered equal if they have the same 

165 id, new_node_host, and new_node_port. 

166 """ 

167 if not isinstance(other, NodeMovingNotification): 

168 return False 

169 return ( 

170 self.id == other.id 

171 and self.new_node_host == other.new_node_host 

172 and self.new_node_port == other.new_node_port 

173 ) 

174 

175 def __hash__(self) -> int: 

176 """ 

177 Return a hash value for the notification to allow 

178 instances to be used in sets and as dictionary keys. 

179 

180 Returns: 

181 int: Hash value based on notification type class name, id, 

182 new_node_host and new_node_port 

183 """ 

184 try: 

185 node_port = int(self.new_node_port) if self.new_node_port else None 

186 except ValueError: 

187 node_port = 0 

188 

189 return hash( 

190 ( 

191 self.__class__.__name__, 

192 int(self.id), 

193 str(self.new_node_host), 

194 node_port, 

195 ) 

196 ) 

197 

198 

199class NodeMigratingNotification(MaintenanceNotification): 

200 """ 

201 Notification for when a Redis cluster node is in the process of migrating slots. 

202 

203 This notification is received when a node starts migrating its slots to another node 

204 during cluster rebalancing or maintenance operations. 

205 

206 Args: 

207 id (int): Unique identifier for this notification 

208 ttl (int): Time-to-live in seconds for this notification 

209 """ 

210 

211 def __init__(self, id: int, ttl: int): 

212 super().__init__(id, ttl) 

213 

214 def __repr__(self) -> str: 

215 expiry_time = self.creation_time + self.ttl 

216 remaining = max(0, expiry_time - time.monotonic()) 

217 return ( 

218 f"{self.__class__.__name__}(" 

219 f"id={self.id}, " 

220 f"ttl={self.ttl}, " 

221 f"creation_time={self.creation_time}, " 

222 f"expires_at={expiry_time}, " 

223 f"remaining={remaining:.1f}s, " 

224 f"expired={self.is_expired()}" 

225 f")" 

226 ) 

227 

228 def __eq__(self, other) -> bool: 

229 """ 

230 Two NodeMigratingNotification notifications are considered equal if they have the same 

231 id and are of the same type. 

232 """ 

233 if not isinstance(other, NodeMigratingNotification): 

234 return False 

235 return self.id == other.id and type(self) is type(other) 

236 

237 def __hash__(self) -> int: 

238 """ 

239 Return a hash value for the notification to allow 

240 instances to be used in sets and as dictionary keys. 

241 

242 Returns: 

243 int: Hash value based on notification type and id 

244 """ 

245 return hash((self.__class__.__name__, int(self.id))) 

246 

247 

248class NodeMigratedNotification(MaintenanceNotification): 

249 """ 

250 Notification for when a Redis cluster node has completed migrating slots. 

251 

252 This notification is received when a node has finished migrating all its slots 

253 to other nodes during cluster rebalancing or maintenance operations. 

254 

255 Args: 

256 id (int): Unique identifier for this notification 

257 """ 

258 

259 DEFAULT_TTL = 5 

260 

261 def __init__(self, id: int): 

262 super().__init__(id, NodeMigratedNotification.DEFAULT_TTL) 

263 

264 def __repr__(self) -> str: 

265 expiry_time = self.creation_time + self.ttl 

266 remaining = max(0, expiry_time - time.monotonic()) 

267 return ( 

268 f"{self.__class__.__name__}(" 

269 f"id={self.id}, " 

270 f"ttl={self.ttl}, " 

271 f"creation_time={self.creation_time}, " 

272 f"expires_at={expiry_time}, " 

273 f"remaining={remaining:.1f}s, " 

274 f"expired={self.is_expired()}" 

275 f")" 

276 ) 

277 

278 def __eq__(self, other) -> bool: 

279 """ 

280 Two NodeMigratedNotification notifications are considered equal if they have the same 

281 id and are of the same type. 

282 """ 

283 if not isinstance(other, NodeMigratedNotification): 

284 return False 

285 return self.id == other.id and type(self) is type(other) 

286 

287 def __hash__(self) -> int: 

288 """ 

289 Return a hash value for the notification to allow 

290 instances to be used in sets and as dictionary keys. 

291 

292 Returns: 

293 int: Hash value based on notification type and id 

294 """ 

295 return hash((self.__class__.__name__, int(self.id))) 

296 

297 

298class NodeFailingOverNotification(MaintenanceNotification): 

299 """ 

300 Notification for when a Redis cluster node is in the process of failing over. 

301 

302 This notification is received when a node starts a failover process during 

303 cluster maintenance operations or when handling node failures. 

304 

305 Args: 

306 id (int): Unique identifier for this notification 

307 ttl (int): Time-to-live in seconds for this notification 

308 """ 

309 

310 def __init__(self, id: int, ttl: int): 

311 super().__init__(id, ttl) 

312 

313 def __repr__(self) -> str: 

314 expiry_time = self.creation_time + self.ttl 

315 remaining = max(0, expiry_time - time.monotonic()) 

316 return ( 

317 f"{self.__class__.__name__}(" 

318 f"id={self.id}, " 

319 f"ttl={self.ttl}, " 

320 f"creation_time={self.creation_time}, " 

321 f"expires_at={expiry_time}, " 

322 f"remaining={remaining:.1f}s, " 

323 f"expired={self.is_expired()}" 

324 f")" 

325 ) 

326 

327 def __eq__(self, other) -> bool: 

328 """ 

329 Two NodeFailingOverNotification notifications are considered equal if they have the same 

330 id and are of the same type. 

331 """ 

332 if not isinstance(other, NodeFailingOverNotification): 

333 return False 

334 return self.id == other.id and type(self) is type(other) 

335 

336 def __hash__(self) -> int: 

337 """ 

338 Return a hash value for the notification to allow 

339 instances to be used in sets and as dictionary keys. 

340 

341 Returns: 

342 int: Hash value based on notification type and id 

343 """ 

344 return hash((self.__class__.__name__, int(self.id))) 

345 

346 

347class NodeFailedOverNotification(MaintenanceNotification): 

348 """ 

349 Notification for when a Redis cluster node has completed a failover. 

350 

351 This notification is received when a node has finished the failover process 

352 during cluster maintenance operations or after handling node failures. 

353 

354 Args: 

355 id (int): Unique identifier for this notification 

356 """ 

357 

358 DEFAULT_TTL = 5 

359 

360 def __init__(self, id: int): 

361 super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL) 

362 

363 def __repr__(self) -> str: 

364 expiry_time = self.creation_time + self.ttl 

365 remaining = max(0, expiry_time - time.monotonic()) 

366 return ( 

367 f"{self.__class__.__name__}(" 

368 f"id={self.id}, " 

369 f"ttl={self.ttl}, " 

370 f"creation_time={self.creation_time}, " 

371 f"expires_at={expiry_time}, " 

372 f"remaining={remaining:.1f}s, " 

373 f"expired={self.is_expired()}" 

374 f")" 

375 ) 

376 

377 def __eq__(self, other) -> bool: 

378 """ 

379 Two NodeFailedOverNotification notifications are considered equal if they have the same 

380 id and are of the same type. 

381 """ 

382 if not isinstance(other, NodeFailedOverNotification): 

383 return False 

384 return self.id == other.id and type(self) is type(other) 

385 

386 def __hash__(self) -> int: 

387 """ 

388 Return a hash value for the notification to allow 

389 instances to be used in sets and as dictionary keys. 

390 

391 Returns: 

392 int: Hash value based on notification type and id 

393 """ 

394 return hash((self.__class__.__name__, int(self.id))) 

395 

396 

397def _is_private_fqdn(host: str) -> bool: 

398 """ 

399 Determine if an FQDN is likely to be internal/private. 

400 

401 This uses heuristics based on RFC 952 and RFC 1123 standards: 

402 - .local domains (RFC 6762 - Multicast DNS) 

403 - .internal domains (common internal convention) 

404 - Single-label hostnames (no dots) 

405 - Common internal TLDs 

406 

407 Args: 

408 host (str): The FQDN to check 

409 

410 Returns: 

411 bool: True if the FQDN appears to be internal/private 

412 """ 

413 host_lower = host.lower().rstrip(".") 

414 

415 # Single-label hostnames (no dots) are typically internal 

416 if "." not in host_lower: 

417 return True 

418 

419 # Common internal/private domain patterns 

420 internal_patterns = [ 

421 r"\.local$", # mDNS/Bonjour domains 

422 r"\.internal$", # Common internal convention 

423 r"\.corp$", # Corporate domains 

424 r"\.lan$", # Local area network 

425 r"\.intranet$", # Intranet domains 

426 r"\.private$", # Private domains 

427 ] 

428 

429 for pattern in internal_patterns: 

430 if re.search(pattern, host_lower): 

431 return True 

432 

433 # If none of the internal patterns match, assume it's external 

434 return False 

435 

436 

437class MaintNotificationsConfig: 

438 """ 

439 Configuration class for maintenance notifications handling behaviour. Notifications are received through 

440 push notifications. 

441 

442 This class defines how the Redis client should react to different push notifications 

443 such as node moving, migrations, etc. in a Redis cluster. 

444 

445 """ 

446 

447 def __init__( 

448 self, 

449 enabled: Union[bool, Literal["auto"]] = "auto", 

450 proactive_reconnect: bool = True, 

451 relaxed_timeout: Optional[Number] = 10, 

452 endpoint_type: Optional[EndpointType] = None, 

453 ): 

454 """ 

455 Initialize a new MaintNotificationsConfig. 

456 

457 Args: 

458 enabled (bool | "auto"): Controls maintenance notifications handling behavior. 

459 - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup, 

460 otherwise a ResponseError is raised. 

461 - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are 

462 gracefully handled - a warning is logged and normal operation continues. 

463 - False: Maintenance notifications are completely disabled. 

464 Defaults to "auto". 

465 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. 

466 Defaults to True. 

467 relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance. 

468 If -1 is provided - the relaxed timeout is disabled. Defaults to 20. 

469 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS. 

470 If None, the endpoint type will be automatically determined based on the host and TLS configuration. 

471 Defaults to None. 

472 

473 Raises: 

474 ValueError: If endpoint_type is provided but is not a valid endpoint type. 

475 """ 

476 self.enabled = enabled 

477 self.relaxed_timeout = relaxed_timeout 

478 self.proactive_reconnect = proactive_reconnect 

479 self.endpoint_type = endpoint_type 

480 

481 def __repr__(self) -> str: 

482 return ( 

483 f"{self.__class__.__name__}(" 

484 f"enabled={self.enabled}, " 

485 f"proactive_reconnect={self.proactive_reconnect}, " 

486 f"relaxed_timeout={self.relaxed_timeout}, " 

487 f"endpoint_type={self.endpoint_type!r}" 

488 f")" 

489 ) 

490 

491 def is_relaxed_timeouts_enabled(self) -> bool: 

492 """ 

493 Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout. 

494 If relaxed_timeout is set to None, it will make the operation blocking 

495 and waiting until any response is received. 

496 

497 Returns: 

498 True if the relaxed_timeout is enabled, False otherwise. 

499 """ 

500 return self.relaxed_timeout != -1 

501 

502 def get_endpoint_type( 

503 self, host: str, connection: "MaintNotificationsAbstractConnection" 

504 ) -> EndpointType: 

505 """ 

506 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. 

507 

508 Logic: 

509 1. If endpoint_type is explicitly set, use it 

510 2. Otherwise, check the original host from connection.host: 

511 - If host is an IP address, use it directly to determine internal-ip vs external-ip 

512 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn 

513 

514 Args: 

515 host: User provided hostname to analyze 

516 connection: The connection object to analyze for endpoint type determination 

517 

518 Returns: 

519 """ 

520 

521 # If endpoint_type is explicitly set, use it 

522 if self.endpoint_type is not None: 

523 return self.endpoint_type 

524 

525 # Check if the host is an IP address 

526 try: 

527 ip_addr = ipaddress.ip_address(host) 

528 # Host is an IP address - use it directly 

529 is_private = ip_addr.is_private 

530 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP 

531 except ValueError: 

532 # Host is an FQDN - need to check resolved IP to determine internal vs external 

533 pass 

534 

535 # Host is an FQDN, get the resolved IP to determine if it's internal or external 

536 resolved_ip = connection.get_resolved_ip() 

537 

538 if resolved_ip: 

539 try: 

540 ip_addr = ipaddress.ip_address(resolved_ip) 

541 is_private = ip_addr.is_private 

542 # Use FQDN types since the original host was an FQDN 

543 return ( 

544 EndpointType.INTERNAL_FQDN 

545 if is_private 

546 else EndpointType.EXTERNAL_FQDN 

547 ) 

548 except ValueError: 

549 # This shouldn't happen since we got the IP from the socket, but fallback 

550 pass 

551 

552 # Final fallback: use heuristics on the FQDN itself 

553 is_private = _is_private_fqdn(host) 

554 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN 

555 

556 

557class MaintNotificationsPoolHandler: 

558 def __init__( 

559 self, 

560 pool: "MaintNotificationsAbstractConnectionPool", 

561 config: MaintNotificationsConfig, 

562 ) -> None: 

563 self.pool = pool 

564 self.config = config 

565 self._processed_notifications = set() 

566 self._lock = threading.RLock() 

567 self.connection = None 

568 

569 def set_connection(self, connection: "MaintNotificationsAbstractConnection"): 

570 self.connection = connection 

571 

572 def get_handler_for_connection(self): 

573 # Copy all data that should be shared between connections 

574 # but each connection should have its own pool handler 

575 # since each connection can be in a different state 

576 copy = MaintNotificationsPoolHandler(self.pool, self.config) 

577 copy._processed_notifications = self._processed_notifications 

578 copy._lock = self._lock 

579 copy.connection = None 

580 return copy 

581 

582 def remove_expired_notifications(self): 

583 with self._lock: 

584 for notification in tuple(self._processed_notifications): 

585 if notification.is_expired(): 

586 self._processed_notifications.remove(notification) 

587 

588 def handle_notification(self, notification: MaintenanceNotification): 

589 self.remove_expired_notifications() 

590 

591 if isinstance(notification, NodeMovingNotification): 

592 return self.handle_node_moving_notification(notification) 

593 else: 

594 logging.error(f"Unhandled notification type: {notification}") 

595 

596 def handle_node_moving_notification(self, notification: NodeMovingNotification): 

597 if ( 

598 not self.config.proactive_reconnect 

599 and not self.config.is_relaxed_timeouts_enabled() 

600 ): 

601 return 

602 with self._lock: 

603 if notification in self._processed_notifications: 

604 # nothing to do in the connection pool handling 

605 # the notification has already been handled or is expired 

606 # just return 

607 return 

608 

609 with self.pool._lock: 

610 if ( 

611 self.config.proactive_reconnect 

612 or self.config.is_relaxed_timeouts_enabled() 

613 ): 

614 # Get the current connected address - if any 

615 # This is the address that is being moved 

616 # and we need to handle only connections 

617 # connected to the same address 

618 moving_address_src = ( 

619 self.connection.getpeername() if self.connection else None 

620 ) 

621 

622 if getattr(self.pool, "set_in_maintenance", False): 

623 # Set pool in maintenance mode - executed only if 

624 # BlockingConnectionPool is used 

625 self.pool.set_in_maintenance(True) 

626 

627 # Update maintenance state, timeout and optionally host address 

628 # connection settings for matching connections 

629 self.pool.update_connections_settings( 

630 state=MaintenanceState.MOVING, 

631 maintenance_notification_hash=hash(notification), 

632 relaxed_timeout=self.config.relaxed_timeout, 

633 host_address=notification.new_node_host, 

634 matching_address=moving_address_src, 

635 matching_pattern="connected_address", 

636 update_notification_hash=True, 

637 include_free_connections=True, 

638 ) 

639 

640 if self.config.proactive_reconnect: 

641 if notification.new_node_host is not None: 

642 self.run_proactive_reconnect(moving_address_src) 

643 else: 

644 threading.Timer( 

645 notification.ttl / 2, 

646 self.run_proactive_reconnect, 

647 args=(moving_address_src,), 

648 ).start() 

649 

650 # Update config for new connections: 

651 # Set state to MOVING 

652 # update host 

653 # if relax timeouts are enabled - update timeouts 

654 kwargs: dict = { 

655 "maintenance_state": MaintenanceState.MOVING, 

656 "maintenance_notification_hash": hash(notification), 

657 } 

658 if notification.new_node_host is not None: 

659 # the host is not updated if the new node host is None 

660 # this happens when the MOVING push notification does not contain 

661 # the new node host - in this case we only update the timeouts 

662 kwargs.update( 

663 { 

664 "host": notification.new_node_host, 

665 } 

666 ) 

667 if self.config.is_relaxed_timeouts_enabled(): 

668 kwargs.update( 

669 { 

670 "socket_timeout": self.config.relaxed_timeout, 

671 "socket_connect_timeout": self.config.relaxed_timeout, 

672 } 

673 ) 

674 self.pool.update_connection_kwargs(**kwargs) 

675 

676 if getattr(self.pool, "set_in_maintenance", False): 

677 self.pool.set_in_maintenance(False) 

678 

679 threading.Timer( 

680 notification.ttl, 

681 self.handle_node_moved_notification, 

682 args=(notification,), 

683 ).start() 

684 

685 self._processed_notifications.add(notification) 

686 

687 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None): 

688 """ 

689 Run proactive reconnect for the pool. 

690 Active connections are marked for reconnect after they complete the current command. 

691 Inactive connections are disconnected and will be connected on next use. 

692 """ 

693 with self._lock: 

694 with self.pool._lock: 

695 # take care for the active connections in the pool 

696 # mark them for reconnect after they complete the current command 

697 self.pool.update_active_connections_for_reconnect( 

698 moving_address_src=moving_address_src, 

699 ) 

700 # take care for the inactive connections in the pool 

701 # delete them and create new ones 

702 self.pool.disconnect_free_connections( 

703 moving_address_src=moving_address_src, 

704 ) 

705 

706 def handle_node_moved_notification(self, notification: NodeMovingNotification): 

707 """ 

708 Handle the cleanup after a node moving notification expires. 

709 """ 

710 notification_hash = hash(notification) 

711 

712 with self._lock: 

713 # if the current maintenance_notification_hash in kwargs is not matching the notification 

714 # it means there has been a new moving notification after this one 

715 # and we don't need to revert the kwargs yet 

716 if ( 

717 self.pool.connection_kwargs.get("maintenance_notification_hash") 

718 == notification_hash 

719 ): 

720 orig_host = self.pool.connection_kwargs.get("orig_host_address") 

721 orig_socket_timeout = self.pool.connection_kwargs.get( 

722 "orig_socket_timeout" 

723 ) 

724 orig_connect_timeout = self.pool.connection_kwargs.get( 

725 "orig_socket_connect_timeout" 

726 ) 

727 kwargs: dict = { 

728 "maintenance_state": MaintenanceState.NONE, 

729 "maintenance_notification_hash": None, 

730 "host": orig_host, 

731 "socket_timeout": orig_socket_timeout, 

732 "socket_connect_timeout": orig_connect_timeout, 

733 } 

734 self.pool.update_connection_kwargs(**kwargs) 

735 

736 with self.pool._lock: 

737 reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled() 

738 reset_host_address = self.config.proactive_reconnect 

739 

740 self.pool.update_connections_settings( 

741 relaxed_timeout=-1, 

742 state=MaintenanceState.NONE, 

743 maintenance_notification_hash=None, 

744 matching_notification_hash=notification_hash, 

745 matching_pattern="notification_hash", 

746 update_notification_hash=True, 

747 reset_relaxed_timeout=reset_relaxed_timeout, 

748 reset_host_address=reset_host_address, 

749 include_free_connections=True, 

750 ) 

751 

752 

753class MaintNotificationsConnectionHandler: 

754 # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications 

755 _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = { 

756 NodeMigratingNotification: 1, 

757 NodeFailingOverNotification: 1, 

758 NodeMigratedNotification: 0, 

759 NodeFailedOverNotification: 0, 

760 } 

761 

762 def __init__( 

763 self, 

764 connection: "MaintNotificationsAbstractConnection", 

765 config: MaintNotificationsConfig, 

766 ) -> None: 

767 self.connection = connection 

768 self.config = config 

769 

770 def handle_notification(self, notification: MaintenanceNotification): 

771 # get the notification type by checking its class in the _NOTIFICATION_TYPES dict 

772 notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None) 

773 

774 if notification_type is None: 

775 logging.error(f"Unhandled notification type: {notification}") 

776 return 

777 

778 if notification_type: 

779 self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE) 

780 else: 

781 self.handle_maintenance_completed_notification() 

782 

783 def handle_maintenance_start_notification( 

784 self, maintenance_state: MaintenanceState 

785 ): 

786 if ( 

787 self.connection.maintenance_state == MaintenanceState.MOVING 

788 or not self.config.is_relaxed_timeouts_enabled() 

789 ): 

790 return 

791 

792 self.connection.maintenance_state = maintenance_state 

793 self.connection.set_tmp_settings( 

794 tmp_relaxed_timeout=self.config.relaxed_timeout 

795 ) 

796 # extend the timeout for all created connections 

797 self.connection.update_current_socket_timeout(self.config.relaxed_timeout) 

798 

799 def handle_maintenance_completed_notification(self): 

800 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled 

801 if ( 

802 self.connection.maintenance_state == MaintenanceState.MOVING 

803 or not self.config.is_relaxed_timeouts_enabled() 

804 ): 

805 return 

806 self.connection.reset_tmp_settings(reset_relaxed_timeout=True) 

807 # Maintenance completed - reset the connection 

808 # timeouts by providing -1 as the relaxed timeout 

809 self.connection.update_current_socket_timeout(-1) 

810 self.connection.maintenance_state = MaintenanceState.NONE