Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/maintenance_events.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

239 statements  

1import enum 

2import ipaddress 

3import logging 

4import re 

5import threading 

6import time 

7from abc import ABC, abstractmethod 

8from typing import TYPE_CHECKING, Optional, Union 

9 

10from redis.typing import Number 

11 

12 

13class MaintenanceState(enum.Enum): 

14 NONE = "none" 

15 MOVING = "moving" 

16 MAINTENANCE = "maintenance" 

17 

18 

19class EndpointType(enum.Enum): 

20 """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command.""" 

21 

22 INTERNAL_IP = "internal-ip" 

23 INTERNAL_FQDN = "internal-fqdn" 

24 EXTERNAL_IP = "external-ip" 

25 EXTERNAL_FQDN = "external-fqdn" 

26 NONE = "none" 

27 

28 def __str__(self): 

29 """Return the string value of the enum.""" 

30 return self.value 

31 

32 

33if TYPE_CHECKING: 

34 from redis.connection import ( 

35 BlockingConnectionPool, 

36 ConnectionInterface, 

37 ConnectionPool, 

38 ) 

39 

40 

41class MaintenanceEvent(ABC): 

42 """ 

43 Base class for maintenance events sent through push messages by Redis server. 

44 

45 This class provides common functionality for all maintenance events including 

46 unique identification and TTL (Time-To-Live) functionality. 

47 

48 Attributes: 

49 id (int): Unique identifier for this event 

50 ttl (int): Time-to-live in seconds for this notification 

51 creation_time (float): Timestamp when the notification was created/read 

52 """ 

53 

54 def __init__(self, id: int, ttl: int): 

55 """ 

56 Initialize a new MaintenanceEvent with unique ID and TTL functionality. 

57 

58 Args: 

59 id (int): Unique identifier for this event 

60 ttl (int): Time-to-live in seconds for this notification 

61 """ 

62 self.id = id 

63 self.ttl = ttl 

64 self.creation_time = time.monotonic() 

65 self.expire_at = self.creation_time + self.ttl 

66 

67 def is_expired(self) -> bool: 

68 """ 

69 Check if this event has expired based on its TTL 

70 and creation time. 

71 

72 Returns: 

73 bool: True if the event has expired, False otherwise 

74 """ 

75 return time.monotonic() > (self.creation_time + self.ttl) 

76 

77 @abstractmethod 

78 def __repr__(self) -> str: 

79 """ 

80 Return a string representation of the maintenance event. 

81 

82 This method must be implemented by all concrete subclasses. 

83 

84 Returns: 

85 str: String representation of the event 

86 """ 

87 pass 

88 

89 @abstractmethod 

90 def __eq__(self, other) -> bool: 

91 """ 

92 Compare two maintenance events for equality. 

93 

94 This method must be implemented by all concrete subclasses. 

95 Events are typically considered equal if they have the same id 

96 and are of the same type. 

97 

98 Args: 

99 other: The other object to compare with 

100 

101 Returns: 

102 bool: True if the events are equal, False otherwise 

103 """ 

104 pass 

105 

106 @abstractmethod 

107 def __hash__(self) -> int: 

108 """ 

109 Return a hash value for the maintenance event. 

110 

111 This method must be implemented by all concrete subclasses to allow 

112 instances to be used in sets and as dictionary keys. 

113 

114 Returns: 

115 int: Hash value for the event 

116 """ 

117 pass 

118 

119 

120class NodeMovingEvent(MaintenanceEvent): 

121 """ 

122 This event is received when a node is replaced with a new node 

123 during cluster rebalancing or maintenance operations. 

124 """ 

125 

126 def __init__( 

127 self, 

128 id: int, 

129 new_node_host: Optional[str], 

130 new_node_port: Optional[int], 

131 ttl: int, 

132 ): 

133 """ 

134 Initialize a new NodeMovingEvent. 

135 

136 Args: 

137 id (int): Unique identifier for this event 

138 new_node_host (str): Hostname or IP address of the new replacement node 

139 new_node_port (int): Port number of the new replacement node 

140 ttl (int): Time-to-live in seconds for this notification 

141 """ 

142 super().__init__(id, ttl) 

143 self.new_node_host = new_node_host 

144 self.new_node_port = new_node_port 

145 

146 def __repr__(self) -> str: 

147 expiry_time = self.expire_at 

148 remaining = max(0, expiry_time - time.monotonic()) 

149 

150 return ( 

151 f"{self.__class__.__name__}(" 

152 f"id={self.id}, " 

153 f"new_node_host='{self.new_node_host}', " 

154 f"new_node_port={self.new_node_port}, " 

155 f"ttl={self.ttl}, " 

156 f"creation_time={self.creation_time}, " 

157 f"expires_at={expiry_time}, " 

158 f"remaining={remaining:.1f}s, " 

159 f"expired={self.is_expired()}" 

160 f")" 

161 ) 

162 

163 def __eq__(self, other) -> bool: 

164 """ 

165 Two NodeMovingEvent events are considered equal if they have the same 

166 id, new_node_host, and new_node_port. 

167 """ 

168 if not isinstance(other, NodeMovingEvent): 

169 return False 

170 return ( 

171 self.id == other.id 

172 and self.new_node_host == other.new_node_host 

173 and self.new_node_port == other.new_node_port 

174 ) 

175 

176 def __hash__(self) -> int: 

177 """ 

178 Return a hash value for the event to allow 

179 instances to be used in sets and as dictionary keys. 

180 

181 Returns: 

182 int: Hash value based on event type class name, id, 

183 new_node_host and new_node_port 

184 """ 

185 try: 

186 node_port = int(self.new_node_port) if self.new_node_port else None 

187 except ValueError: 

188 node_port = 0 

189 

190 return hash( 

191 ( 

192 self.__class__.__name__, 

193 int(self.id), 

194 str(self.new_node_host), 

195 node_port, 

196 ) 

197 ) 

198 

199 

200class NodeMigratingEvent(MaintenanceEvent): 

201 """ 

202 Event for when a Redis cluster node is in the process of migrating slots. 

203 

204 This event is received when a node starts migrating its slots to another node 

205 during cluster rebalancing or maintenance operations. 

206 

207 Args: 

208 id (int): Unique identifier for this event 

209 ttl (int): Time-to-live in seconds for this notification 

210 """ 

211 

212 def __init__(self, id: int, ttl: int): 

213 super().__init__(id, ttl) 

214 

215 def __repr__(self) -> str: 

216 expiry_time = self.creation_time + self.ttl 

217 remaining = max(0, expiry_time - time.monotonic()) 

218 return ( 

219 f"{self.__class__.__name__}(" 

220 f"id={self.id}, " 

221 f"ttl={self.ttl}, " 

222 f"creation_time={self.creation_time}, " 

223 f"expires_at={expiry_time}, " 

224 f"remaining={remaining:.1f}s, " 

225 f"expired={self.is_expired()}" 

226 f")" 

227 ) 

228 

229 def __eq__(self, other) -> bool: 

230 """ 

231 Two NodeMigratingEvent events are considered equal if they have the same 

232 id and are of the same type. 

233 """ 

234 if not isinstance(other, NodeMigratingEvent): 

235 return False 

236 return self.id == other.id and type(self) is type(other) 

237 

238 def __hash__(self) -> int: 

239 """ 

240 Return a hash value for the event to allow 

241 instances to be used in sets and as dictionary keys. 

242 

243 Returns: 

244 int: Hash value based on event type and id 

245 """ 

246 return hash((self.__class__.__name__, int(self.id))) 

247 

248 

249class NodeMigratedEvent(MaintenanceEvent): 

250 """ 

251 Event for when a Redis cluster node has completed migrating slots. 

252 

253 This event is received when a node has finished migrating all its slots 

254 to other nodes during cluster rebalancing or maintenance operations. 

255 

256 Args: 

257 id (int): Unique identifier for this event 

258 """ 

259 

260 DEFAULT_TTL = 5 

261 

262 def __init__(self, id: int): 

263 super().__init__(id, NodeMigratedEvent.DEFAULT_TTL) 

264 

265 def __repr__(self) -> str: 

266 expiry_time = self.creation_time + self.ttl 

267 remaining = max(0, expiry_time - time.monotonic()) 

268 return ( 

269 f"{self.__class__.__name__}(" 

270 f"id={self.id}, " 

271 f"ttl={self.ttl}, " 

272 f"creation_time={self.creation_time}, " 

273 f"expires_at={expiry_time}, " 

274 f"remaining={remaining:.1f}s, " 

275 f"expired={self.is_expired()}" 

276 f")" 

277 ) 

278 

279 def __eq__(self, other) -> bool: 

280 """ 

281 Two NodeMigratedEvent events are considered equal if they have the same 

282 id and are of the same type. 

283 """ 

284 if not isinstance(other, NodeMigratedEvent): 

285 return False 

286 return self.id == other.id and type(self) is type(other) 

287 

288 def __hash__(self) -> int: 

289 """ 

290 Return a hash value for the event to allow 

291 instances to be used in sets and as dictionary keys. 

292 

293 Returns: 

294 int: Hash value based on event type and id 

295 """ 

296 return hash((self.__class__.__name__, int(self.id))) 

297 

298 

299class NodeFailingOverEvent(MaintenanceEvent): 

300 """ 

301 Event for when a Redis cluster node is in the process of failing over. 

302 

303 This event is received when a node starts a failover process during 

304 cluster maintenance operations or when handling node failures. 

305 

306 Args: 

307 id (int): Unique identifier for this event 

308 ttl (int): Time-to-live in seconds for this notification 

309 """ 

310 

311 def __init__(self, id: int, ttl: int): 

312 super().__init__(id, ttl) 

313 

314 def __repr__(self) -> str: 

315 expiry_time = self.creation_time + self.ttl 

316 remaining = max(0, expiry_time - time.monotonic()) 

317 return ( 

318 f"{self.__class__.__name__}(" 

319 f"id={self.id}, " 

320 f"ttl={self.ttl}, " 

321 f"creation_time={self.creation_time}, " 

322 f"expires_at={expiry_time}, " 

323 f"remaining={remaining:.1f}s, " 

324 f"expired={self.is_expired()}" 

325 f")" 

326 ) 

327 

328 def __eq__(self, other) -> bool: 

329 """ 

330 Two NodeFailingOverEvent events are considered equal if they have the same 

331 id and are of the same type. 

332 """ 

333 if not isinstance(other, NodeFailingOverEvent): 

334 return False 

335 return self.id == other.id and type(self) is type(other) 

336 

337 def __hash__(self) -> int: 

338 """ 

339 Return a hash value for the event to allow 

340 instances to be used in sets and as dictionary keys. 

341 

342 Returns: 

343 int: Hash value based on event type and id 

344 """ 

345 return hash((self.__class__.__name__, int(self.id))) 

346 

347 

348class NodeFailedOverEvent(MaintenanceEvent): 

349 """ 

350 Event for when a Redis cluster node has completed a failover. 

351 

352 This event is received when a node has finished the failover process 

353 during cluster maintenance operations or after handling node failures. 

354 

355 Args: 

356 id (int): Unique identifier for this event 

357 """ 

358 

359 DEFAULT_TTL = 5 

360 

361 def __init__(self, id: int): 

362 super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL) 

363 

364 def __repr__(self) -> str: 

365 expiry_time = self.creation_time + self.ttl 

366 remaining = max(0, expiry_time - time.monotonic()) 

367 return ( 

368 f"{self.__class__.__name__}(" 

369 f"id={self.id}, " 

370 f"ttl={self.ttl}, " 

371 f"creation_time={self.creation_time}, " 

372 f"expires_at={expiry_time}, " 

373 f"remaining={remaining:.1f}s, " 

374 f"expired={self.is_expired()}" 

375 f")" 

376 ) 

377 

378 def __eq__(self, other) -> bool: 

379 """ 

380 Two NodeFailedOverEvent events are considered equal if they have the same 

381 id and are of the same type. 

382 """ 

383 if not isinstance(other, NodeFailedOverEvent): 

384 return False 

385 return self.id == other.id and type(self) is type(other) 

386 

387 def __hash__(self) -> int: 

388 """ 

389 Return a hash value for the event to allow 

390 instances to be used in sets and as dictionary keys. 

391 

392 Returns: 

393 int: Hash value based on event type and id 

394 """ 

395 return hash((self.__class__.__name__, int(self.id))) 

396 

397 

398def _is_private_fqdn(host: str) -> bool: 

399 """ 

400 Determine if an FQDN is likely to be internal/private. 

401 

402 This uses heuristics based on RFC 952 and RFC 1123 standards: 

403 - .local domains (RFC 6762 - Multicast DNS) 

404 - .internal domains (common internal convention) 

405 - Single-label hostnames (no dots) 

406 - Common internal TLDs 

407 

408 Args: 

409 host (str): The FQDN to check 

410 

411 Returns: 

412 bool: True if the FQDN appears to be internal/private 

413 """ 

414 host_lower = host.lower().rstrip(".") 

415 

416 # Single-label hostnames (no dots) are typically internal 

417 if "." not in host_lower: 

418 return True 

419 

420 # Common internal/private domain patterns 

421 internal_patterns = [ 

422 r"\.local$", # mDNS/Bonjour domains 

423 r"\.internal$", # Common internal convention 

424 r"\.corp$", # Corporate domains 

425 r"\.lan$", # Local area network 

426 r"\.intranet$", # Intranet domains 

427 r"\.private$", # Private domains 

428 ] 

429 

430 for pattern in internal_patterns: 

431 if re.search(pattern, host_lower): 

432 return True 

433 

434 # If none of the internal patterns match, assume it's external 

435 return False 

436 

437 

438class MaintenanceEventsConfig: 

439 """ 

440 Configuration class for maintenance events handling behaviour. Events are received through 

441 push notifications. 

442 

443 This class defines how the Redis client should react to different push notifications 

444 such as node moving, migrations, etc. in a Redis cluster. 

445 

446 """ 

447 

448 def __init__( 

449 self, 

450 enabled: bool = True, 

451 proactive_reconnect: bool = True, 

452 relax_timeout: Optional[Number] = 20, 

453 endpoint_type: Optional[EndpointType] = None, 

454 ): 

455 """ 

456 Initialize a new MaintenanceEventsConfig. 

457 

458 Args: 

459 enabled (bool): Whether to enable maintenance events handling. 

460 Defaults to False. 

461 proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. 

462 Defaults to True. 

463 relax_timeout (Number): The relax timeout to use for the connection during maintenance. 

464 If -1 is provided - the relax timeout is disabled. Defaults to 20. 

465 endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS. 

466 If None, the endpoint type will be automatically determined based on the host and TLS configuration. 

467 Defaults to None. 

468 

469 Raises: 

470 ValueError: If endpoint_type is provided but is not a valid endpoint type. 

471 """ 

472 self.enabled = enabled 

473 self.relax_timeout = relax_timeout 

474 self.proactive_reconnect = proactive_reconnect 

475 self.endpoint_type = endpoint_type 

476 

477 def __repr__(self) -> str: 

478 return ( 

479 f"{self.__class__.__name__}(" 

480 f"enabled={self.enabled}, " 

481 f"proactive_reconnect={self.proactive_reconnect}, " 

482 f"relax_timeout={self.relax_timeout}, " 

483 f"endpoint_type={self.endpoint_type!r}" 

484 f")" 

485 ) 

486 

487 def is_relax_timeouts_enabled(self) -> bool: 

488 """ 

489 Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout. 

490 If relax_timeout is set to None, it will make the operation blocking 

491 and waiting until any response is received. 

492 

493 Returns: 

494 True if the relax_timeout is enabled, False otherwise. 

495 """ 

496 return self.relax_timeout != -1 

497 

498 def get_endpoint_type( 

499 self, host: str, connection: "ConnectionInterface" 

500 ) -> EndpointType: 

501 """ 

502 Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. 

503 

504 Logic: 

505 1. If endpoint_type is explicitly set, use it 

506 2. Otherwise, check the original host from connection.host: 

507 - If host is an IP address, use it directly to determine internal-ip vs external-ip 

508 - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn 

509 

510 Args: 

511 host: User provided hostname to analyze 

512 connection: The connection object to analyze for endpoint type determination 

513 

514 Returns: 

515 """ 

516 

517 # If endpoint_type is explicitly set, use it 

518 if self.endpoint_type is not None: 

519 return self.endpoint_type 

520 

521 # Check if the host is an IP address 

522 try: 

523 ip_addr = ipaddress.ip_address(host) 

524 # Host is an IP address - use it directly 

525 is_private = ip_addr.is_private 

526 return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP 

527 except ValueError: 

528 # Host is an FQDN - need to check resolved IP to determine internal vs external 

529 pass 

530 

531 # Host is an FQDN, get the resolved IP to determine if it's internal or external 

532 resolved_ip = connection.get_resolved_ip() 

533 

534 if resolved_ip: 

535 try: 

536 ip_addr = ipaddress.ip_address(resolved_ip) 

537 is_private = ip_addr.is_private 

538 # Use FQDN types since the original host was an FQDN 

539 return ( 

540 EndpointType.INTERNAL_FQDN 

541 if is_private 

542 else EndpointType.EXTERNAL_FQDN 

543 ) 

544 except ValueError: 

545 # This shouldn't happen since we got the IP from the socket, but fallback 

546 pass 

547 

548 # Final fallback: use heuristics on the FQDN itself 

549 is_private = _is_private_fqdn(host) 

550 return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN 

551 

552 

553class MaintenanceEventPoolHandler: 

554 def __init__( 

555 self, 

556 pool: Union["ConnectionPool", "BlockingConnectionPool"], 

557 config: MaintenanceEventsConfig, 

558 ) -> None: 

559 self.pool = pool 

560 self.config = config 

561 self._processed_events = set() 

562 self._lock = threading.RLock() 

563 self.connection = None 

564 

565 def set_connection(self, connection: "ConnectionInterface"): 

566 self.connection = connection 

567 

568 def remove_expired_notifications(self): 

569 with self._lock: 

570 for notification in tuple(self._processed_events): 

571 if notification.is_expired(): 

572 self._processed_events.remove(notification) 

573 

574 def handle_event(self, notification: MaintenanceEvent): 

575 self.remove_expired_notifications() 

576 

577 if isinstance(notification, NodeMovingEvent): 

578 return self.handle_node_moving_event(notification) 

579 else: 

580 logging.error(f"Unhandled notification type: {notification}") 

581 

582 def handle_node_moving_event(self, event: NodeMovingEvent): 

583 if ( 

584 not self.config.proactive_reconnect 

585 and not self.config.is_relax_timeouts_enabled() 

586 ): 

587 return 

588 with self._lock: 

589 if event in self._processed_events: 

590 # nothing to do in the connection pool handling 

591 # the event has already been handled or is expired 

592 # just return 

593 return 

594 

595 with self.pool._lock: 

596 if ( 

597 self.config.proactive_reconnect 

598 or self.config.is_relax_timeouts_enabled() 

599 ): 

600 # Get the current connected address - if any 

601 # This is the address that is being moved 

602 # and we need to handle only connections 

603 # connected to the same address 

604 moving_address_src = ( 

605 self.connection.getpeername() if self.connection else None 

606 ) 

607 

608 if getattr(self.pool, "set_in_maintenance", False): 

609 # Set pool in maintenance mode - executed only if 

610 # BlockingConnectionPool is used 

611 self.pool.set_in_maintenance(True) 

612 

613 # Update maintenance state, timeout and optionally host address 

614 # connection settings for matching connections 

615 self.pool.update_connections_settings( 

616 state=MaintenanceState.MOVING, 

617 maintenance_event_hash=hash(event), 

618 relax_timeout=self.config.relax_timeout, 

619 host_address=event.new_node_host, 

620 matching_address=moving_address_src, 

621 matching_pattern="connected_address", 

622 update_event_hash=True, 

623 include_free_connections=True, 

624 ) 

625 

626 if self.config.proactive_reconnect: 

627 if event.new_node_host is not None: 

628 self.run_proactive_reconnect(moving_address_src) 

629 else: 

630 threading.Timer( 

631 event.ttl / 2, 

632 self.run_proactive_reconnect, 

633 args=(moving_address_src,), 

634 ).start() 

635 

636 # Update config for new connections: 

637 # Set state to MOVING 

638 # update host 

639 # if relax timeouts are enabled - update timeouts 

640 kwargs: dict = { 

641 "maintenance_state": MaintenanceState.MOVING, 

642 "maintenance_event_hash": hash(event), 

643 } 

644 if event.new_node_host is not None: 

645 # the host is not updated if the new node host is None 

646 # this happens when the MOVING push notification does not contain 

647 # the new node host - in this case we only update the timeouts 

648 kwargs.update( 

649 { 

650 "host": event.new_node_host, 

651 } 

652 ) 

653 if self.config.is_relax_timeouts_enabled(): 

654 kwargs.update( 

655 { 

656 "socket_timeout": self.config.relax_timeout, 

657 "socket_connect_timeout": self.config.relax_timeout, 

658 } 

659 ) 

660 self.pool.update_connection_kwargs(**kwargs) 

661 

662 if getattr(self.pool, "set_in_maintenance", False): 

663 self.pool.set_in_maintenance(False) 

664 

665 threading.Timer( 

666 event.ttl, self.handle_node_moved_event, args=(event,) 

667 ).start() 

668 

669 self._processed_events.add(event) 

670 

671 def run_proactive_reconnect(self, moving_address_src: Optional[str] = None): 

672 """ 

673 Run proactive reconnect for the pool. 

674 Active connections are marked for reconnect after they complete the current command. 

675 Inactive connections are disconnected and will be connected on next use. 

676 """ 

677 with self._lock: 

678 with self.pool._lock: 

679 # take care for the active connections in the pool 

680 # mark them for reconnect after they complete the current command 

681 self.pool.update_active_connections_for_reconnect( 

682 moving_address_src=moving_address_src, 

683 ) 

684 # take care for the inactive connections in the pool 

685 # delete them and create new ones 

686 self.pool.disconnect_free_connections( 

687 moving_address_src=moving_address_src, 

688 ) 

689 

690 def handle_node_moved_event(self, event: NodeMovingEvent): 

691 """ 

692 Handle the cleanup after a node moving event expires. 

693 """ 

694 event_hash = hash(event) 

695 

696 with self._lock: 

697 # if the current maintenance_event_hash in kwargs is not matching the event 

698 # it means there has been a new moving event after this one 

699 # and we don't need to revert the kwargs yet 

700 if self.pool.connection_kwargs.get("maintenance_event_hash") == event_hash: 

701 orig_host = self.pool.connection_kwargs.get("orig_host_address") 

702 orig_socket_timeout = self.pool.connection_kwargs.get( 

703 "orig_socket_timeout" 

704 ) 

705 orig_connect_timeout = self.pool.connection_kwargs.get( 

706 "orig_socket_connect_timeout" 

707 ) 

708 kwargs: dict = { 

709 "maintenance_state": MaintenanceState.NONE, 

710 "maintenance_event_hash": None, 

711 "host": orig_host, 

712 "socket_timeout": orig_socket_timeout, 

713 "socket_connect_timeout": orig_connect_timeout, 

714 } 

715 self.pool.update_connection_kwargs(**kwargs) 

716 

717 with self.pool._lock: 

718 reset_relax_timeout = self.config.is_relax_timeouts_enabled() 

719 reset_host_address = self.config.proactive_reconnect 

720 

721 self.pool.update_connections_settings( 

722 relax_timeout=-1, 

723 state=MaintenanceState.NONE, 

724 maintenance_event_hash=None, 

725 matching_event_hash=event_hash, 

726 matching_pattern="event_hash", 

727 update_event_hash=True, 

728 reset_relax_timeout=reset_relax_timeout, 

729 reset_host_address=reset_host_address, 

730 include_free_connections=True, 

731 ) 

732 

733 

734class MaintenanceEventConnectionHandler: 

735 # 1 = "starting maintenance" events, 0 = "completed maintenance" events 

736 _EVENT_TYPES: dict[type["MaintenanceEvent"], int] = { 

737 NodeMigratingEvent: 1, 

738 NodeFailingOverEvent: 1, 

739 NodeMigratedEvent: 0, 

740 NodeFailedOverEvent: 0, 

741 } 

742 

743 def __init__( 

744 self, connection: "ConnectionInterface", config: MaintenanceEventsConfig 

745 ) -> None: 

746 self.connection = connection 

747 self.config = config 

748 

749 def handle_event(self, event: MaintenanceEvent): 

750 # get the event type by checking its class in the _EVENT_TYPES dict 

751 event_type = self._EVENT_TYPES.get(event.__class__, None) 

752 

753 if event_type is None: 

754 logging.error(f"Unhandled event type: {event}") 

755 return 

756 

757 if event_type: 

758 self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) 

759 else: 

760 self.handle_maintenance_completed_event() 

761 

762 def handle_maintenance_start_event(self, maintenance_state: MaintenanceState): 

763 if ( 

764 self.connection.maintenance_state == MaintenanceState.MOVING 

765 or not self.config.is_relax_timeouts_enabled() 

766 ): 

767 return 

768 

769 self.connection.maintenance_state = maintenance_state 

770 self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) 

771 # extend the timeout for all created connections 

772 self.connection.update_current_socket_timeout(self.config.relax_timeout) 

773 

774 def handle_maintenance_completed_event(self): 

775 # Only reset timeouts if state is not MOVING and relax timeouts are enabled 

776 if ( 

777 self.connection.maintenance_state == MaintenanceState.MOVING 

778 or not self.config.is_relax_timeouts_enabled() 

779 ): 

780 return 

781 self.connection.reset_tmp_settings(reset_relax_timeout=True) 

782 # Maintenance completed - reset the connection 

783 # timeouts by providing -1 as the relax timeout 

784 self.connection.update_current_socket_timeout(-1) 

785 self.connection.maintenance_state = MaintenanceState.NONE