Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/cluster.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1482 statements  

1import asyncio 

2import collections 

3import logging 

4import random 

5import socket 

6import threading 

7import time 

8import warnings 

9import weakref 

10from abc import ABC, abstractmethod 

11from collections import defaultdict 

12from copy import copy 

13from itertools import chain 

14from types import MethodType 

15from typing import ( 

16 TYPE_CHECKING, 

17 Any, 

18 Callable, 

19 Coroutine, 

20 Deque, 

21 Dict, 

22 Generator, 

23 List, 

24 Literal, 

25 Mapping, 

26 Optional, 

27 Set, 

28 Tuple, 

29 Type, 

30 TypeVar, 

31 Union, 

32) 

33 

34if TYPE_CHECKING: 

35 from redis.asyncio.keyspace_notifications import ( 

36 AsyncClusterKeyspaceNotifications, 

37 ) 

38 

39from redis._parsers import AsyncCommandsParser, Encoder 

40from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy 

41from redis._parsers.helpers import ( 

42 _RedisCallbacks, 

43 _RedisCallbacksRESP2, 

44 _RedisCallbacksRESP3, 

45) 

46from redis.asyncio.client import PubSub, ResponseCallbackT 

47from redis.asyncio.connection import ( 

48 AbstractConnection, 

49 Connection, 

50 ConnectionPoolInterface, 

51 SSLConnection, 

52 parse_url, 

53) 

54from redis.asyncio.lock import Lock 

55from redis.asyncio.observability.recorder import ( 

56 record_error_count, 

57 record_operation_duration, 

58) 

59from redis.asyncio.retry import Retry 

60from redis.auth.token import TokenInterface 

61from redis.backoff import ExponentialWithJitterBackoff, NoBackoff 

62from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis 

63from redis.cluster import ( 

64 PIPELINE_BLOCKED_COMMANDS, 

65 PRIMARY, 

66 REPLICA, 

67 SLOT_ID, 

68 AbstractRedisCluster, 

69 LoadBalancer, 

70 LoadBalancingStrategy, 

71 block_pipeline_command, 

72 get_node_name, 

73 parse_cluster_slots, 

74) 

75from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands 

76from redis.commands.helpers import list_or_args 

77from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver 

78from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot 

79from redis.credentials import CredentialProvider 

80from redis.driver_info import DriverInfo, resolve_driver_info 

81from redis.event import ( 

82 AfterAsyncClusterInstantiationEvent, 

83 AsyncAfterSlotsCacheRefreshEvent, 

84 AsyncEventListenerInterface, 

85 EventDispatcher, 

86) 

87from redis.exceptions import ( 

88 AskError, 

89 BusyLoadingError, 

90 ClusterDownError, 

91 ClusterError, 

92 ConnectionError, 

93 CrossSlotTransactionError, 

94 DataError, 

95 ExecAbortError, 

96 InvalidPipelineStack, 

97 MaxConnectionsError, 

98 MovedError, 

99 RedisClusterException, 

100 RedisError, 

101 ResponseError, 

102 SlotNotCoveredError, 

103 TimeoutError, 

104 TryAgainError, 

105 WatchError, 

106) 

107from redis.typing import AnyKeyT, EncodableT, KeyT 

108from redis.utils import ( 

109 DEFAULT_RESP_VERSION, 

110 SSL_AVAILABLE, 

111 check_protocol_version, 

112 deprecated_args, 

113 deprecated_function, 

114 safe_str, 

115 str_if_bytes, 

116 truncate_text, 

117) 

118 

119if SSL_AVAILABLE: 

120 from ssl import TLSVersion, VerifyFlags, VerifyMode 

121else: 

122 TLSVersion = None 

123 VerifyMode = None 

124 VerifyFlags = None 

125 

126logger = logging.getLogger(__name__) 

127 

128TargetNodesT = TypeVar( 

129 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] 

130) 

131 

132 

133class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

134 """ 

135 Create a new RedisCluster client. 

136 

137 Pass one of parameters: 

138 

139 - `host` & `port` 

140 - `startup_nodes` 

141 

142 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. 

143 | Use ``await`` :meth:`close` to disconnect connections & close client. 

144 

145 Many commands support the target_nodes kwarg. It can be one of the 

146 :attr:`NODE_FLAGS`: 

147 

148 - :attr:`PRIMARIES` 

149 - :attr:`REPLICAS` 

150 - :attr:`ALL_NODES` 

151 - :attr:`RANDOM` 

152 - :attr:`DEFAULT_NODE` 

153 

154 Note: This client is not thread/process/fork safe. 

155 

156 :param host: 

157 | Can be used to point to a startup node 

158 :param port: 

159 | Port used if **host** is provided 

160 :param startup_nodes: 

161 | :class:`~.ClusterNode` to used as a startup node 

162 :param require_full_coverage: 

163 | When set to ``False``: the client will not require a full coverage of 

164 the slots. However, if not all slots are covered, and at least one node 

165 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw 

166 a :class:`~.ClusterDownError` for some key-based commands. 

167 | When set to ``True``: all slots must be covered to construct the cluster 

168 client. If not all slots are covered, :class:`~.RedisClusterException` will be 

169 thrown. 

170 | See: 

171 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters 

172 :param read_from_replicas: 

173 | @deprecated - please use load_balancing_strategy instead 

174 | Enable read from replicas in READONLY mode. 

175 When set to true, read commands will be assigned between the primary and 

176 its replications in a Round-Robin manner. 

177 The data read from replicas is eventually consistent with the data in primary nodes. 

178 :param load_balancing_strategy: 

179 | Enable read from replicas in READONLY mode and defines the load balancing 

180 strategy that will be used for cluster node selection. 

181 The data read from replicas is eventually consistent with the data in primary nodes. 

182 :param dynamic_startup_nodes: 

183 | Set the RedisCluster's startup nodes to all the discovered nodes. 

184 If true (default value), the cluster's discovered nodes will be used to 

185 determine the cluster nodes-slots mapping in the next topology refresh. 

186 It will remove the initial passed startup nodes if their endpoints aren't 

187 listed in the CLUSTER SLOTS output. 

188 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists 

189 specific IP addresses, it is best to set it to false. 

190 :param reinitialize_steps: 

191 | Specifies the number of MOVED errors that need to occur before reinitializing 

192 the whole cluster topology. If a MOVED error occurs and the cluster does not 

193 need to be reinitialized on this current error handling, only the MOVED slot 

194 will be patched with the redirected node. 

195 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. 

196 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 

197 0. 

198 :param cluster_error_retry_attempts: 

199 | @deprecated - Please configure the 'retry' object instead 

200 In case 'retry' object is set - this argument is ignored! 

201 

202 Number of times to retry before raising an error when :class:`~.TimeoutError`, 

203 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` 

204 or :class:`~.ClusterDownError` are encountered 

205 :param retry: 

206 | A retry object that defines the retry strategy and the number of 

207 retries for the cluster client. 

208 In current implementation for the cluster client (starting form redis-py version 6.0.0) 

209 the retry object is not yet fully utilized, instead it is used just to determine 

210 the number of retries for the cluster client. 

211 In the future releases the retry object will be used to handle the cluster client retries! 

212 :param max_connections: 

213 | Maximum number of connections per node. If there are no free connections & the 

214 maximum number of connections are already created, a 

215 :class:`~.MaxConnectionsError` is raised. 

216 :param address_remap: 

217 | An optional callable which, when provided with an internal network 

218 address of a node, e.g. a `(host, port)` tuple, will return the address 

219 where the node is reachable. This can be used to map the addresses at 

220 which the nodes _think_ they are, to addresses at which a client may 

221 reach them, such as when they sit behind a proxy. 

222 

223 | Rest of the arguments will be passed to the 

224 :class:`~redis.asyncio.connection.Connection` instances when created 

225 

226 :raises RedisClusterException: 

227 if any arguments are invalid or unknown. Eg: 

228 

229 - `db` != 0 or None 

230 - `path` argument for unix socket connection 

231 - none of the `host`/`port` & `startup_nodes` were provided 

232 

233 """ 

234 

235 @classmethod 

236 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": 

237 """ 

238 Return a Redis client object configured from the given URL. 

239 

240 For example:: 

241 

242 redis://[[username]:[password]]@localhost:6379/0 

243 rediss://[[username]:[password]]@localhost:6379/0 

244 

245 Three URL schemes are supported: 

246 

247 - `redis://` creates a TCP socket connection. See more at: 

248 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

249 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

250 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

251 

252 The username, password, hostname, path and all querystring values are passed 

253 through ``urllib.parse.unquote`` in order to replace any percent-encoded values 

254 with their corresponding characters. 

255 

256 All querystring options are cast to their appropriate Python types. Boolean 

257 arguments can be specified with string values "True"/"False" or "Yes"/"No". 

258 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once 

259 parsed, the querystring arguments and keyword arguments are passed to 

260 :class:`~redis.asyncio.connection.Connection` when created. 

261 In the case of conflicting arguments, querystring arguments are used. 

262 """ 

263 kwargs.update(parse_url(url)) 

264 if kwargs.pop("connection_class", None) is SSLConnection: 

265 kwargs["ssl"] = True 

266 return cls(**kwargs) 

267 

268 # Type discrimination marker for @overload self-type pattern 

269 _is_async_client: Literal[True] = True 

270 

271 __slots__ = ( 

272 "_initialize", 

273 "_lock", 

274 "retry", 

275 "command_flags", 

276 "commands_parser", 

277 "connection_kwargs", 

278 "encoder", 

279 "node_flags", 

280 "nodes_manager", 

281 "read_from_replicas", 

282 "reinitialize_counter", 

283 "reinitialize_steps", 

284 "response_callbacks", 

285 "result_callbacks", 

286 ) 

287 

288 @deprecated_args( 

289 args_to_warn=["read_from_replicas"], 

290 reason="Please configure the 'load_balancing_strategy' instead", 

291 version="5.3.0", 

292 ) 

293 @deprecated_args( 

294 args_to_warn=[ 

295 "cluster_error_retry_attempts", 

296 ], 

297 reason="Please configure the 'retry' object instead", 

298 version="6.0.0", 

299 ) 

300 @deprecated_args( 

301 args_to_warn=["lib_name", "lib_version"], 

302 reason="Use 'driver_info' parameter instead. " 

303 "lib_name and lib_version will be removed in a future version.", 

304 ) 

305 def __init__( 

306 self, 

307 host: Optional[str] = None, 

308 port: Union[str, int] = 6379, 

309 # Cluster related kwargs 

310 startup_nodes: Optional[List["ClusterNode"]] = None, 

311 require_full_coverage: bool = True, 

312 read_from_replicas: bool = False, 

313 load_balancing_strategy: Optional[LoadBalancingStrategy] = None, 

314 dynamic_startup_nodes: bool = True, 

315 reinitialize_steps: int = 5, 

316 cluster_error_retry_attempts: int = 3, 

317 max_connections: int = 2**31, 

318 retry: Optional["Retry"] = None, 

319 retry_on_error: Optional[List[Type[Exception]]] = None, 

320 # Client related kwargs 

321 db: Union[str, int] = 0, 

322 path: Optional[str] = None, 

323 credential_provider: Optional[CredentialProvider] = None, 

324 username: Optional[str] = None, 

325 password: Optional[str] = None, 

326 client_name: Optional[str] = None, 

327 lib_name: Optional[str] = None, 

328 lib_version: Optional[str] = None, 

329 driver_info: Optional["DriverInfo"] = None, 

330 # Encoding related kwargs 

331 encoding: str = "utf-8", 

332 encoding_errors: str = "strict", 

333 decode_responses: bool = False, 

334 # Connection related kwargs 

335 health_check_interval: float = 0, 

336 socket_connect_timeout: Optional[float] = None, 

337 socket_keepalive: bool = False, 

338 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

339 socket_timeout: Optional[float] = None, 

340 # SSL related kwargs 

341 ssl: bool = False, 

342 ssl_ca_certs: Optional[str] = None, 

343 ssl_ca_data: Optional[str] = None, 

344 ssl_cert_reqs: Union[str, VerifyMode] = "required", 

345 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, 

346 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, 

347 ssl_certfile: Optional[str] = None, 

348 ssl_check_hostname: bool = True, 

349 ssl_keyfile: Optional[str] = None, 

350 ssl_min_version: Optional[TLSVersion] = None, 

351 ssl_ciphers: Optional[str] = None, 

352 protocol: Optional[int] = 3, 

353 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

354 event_dispatcher: Optional[EventDispatcher] = None, 

355 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(), 

356 ) -> None: 

357 if db: 

358 raise RedisClusterException( 

359 "Argument 'db' must be 0 or None in cluster mode" 

360 ) 

361 

362 if path: 

363 raise RedisClusterException( 

364 "Unix domain socket is not supported in cluster mode" 

365 ) 

366 

367 if (not host or not port) and not startup_nodes: 

368 raise RedisClusterException( 

369 "RedisCluster requires at least one node to discover the cluster.\n" 

370 "Please provide one of the following or use RedisCluster.from_url:\n" 

371 ' - host and port: RedisCluster(host="localhost", port=6379)\n' 

372 " - startup_nodes: RedisCluster(startup_nodes=[" 

373 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' 

374 ) 

375 

376 computed_driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

377 

378 kwargs: Dict[str, Any] = { 

379 "max_connections": max_connections, 

380 "connection_class": Connection, 

381 # Client related kwargs 

382 "credential_provider": credential_provider, 

383 "username": username, 

384 "password": password, 

385 "client_name": client_name, 

386 "driver_info": computed_driver_info, 

387 # Encoding related kwargs 

388 "encoding": encoding, 

389 "encoding_errors": encoding_errors, 

390 "decode_responses": decode_responses, 

391 # Connection related kwargs 

392 "health_check_interval": health_check_interval, 

393 "socket_connect_timeout": socket_connect_timeout, 

394 "socket_keepalive": socket_keepalive, 

395 "socket_keepalive_options": socket_keepalive_options, 

396 "socket_timeout": socket_timeout, 

397 "protocol": protocol, 

398 } 

399 

400 if ssl: 

401 # SSL related kwargs 

402 kwargs.update( 

403 { 

404 "connection_class": SSLConnection, 

405 "ssl_ca_certs": ssl_ca_certs, 

406 "ssl_ca_data": ssl_ca_data, 

407 "ssl_cert_reqs": ssl_cert_reqs, 

408 "ssl_include_verify_flags": ssl_include_verify_flags, 

409 "ssl_exclude_verify_flags": ssl_exclude_verify_flags, 

410 "ssl_certfile": ssl_certfile, 

411 "ssl_check_hostname": ssl_check_hostname, 

412 "ssl_keyfile": ssl_keyfile, 

413 "ssl_min_version": ssl_min_version, 

414 "ssl_ciphers": ssl_ciphers, 

415 } 

416 ) 

417 

418 if read_from_replicas or load_balancing_strategy: 

419 # Call our on_connect function to configure READONLY mode 

420 kwargs["redis_connect_func"] = self.on_connect 

421 

422 if retry: 

423 self.retry = retry 

424 else: 

425 self.retry = Retry( 

426 backoff=ExponentialWithJitterBackoff(base=1, cap=10), 

427 retries=cluster_error_retry_attempts, 

428 ) 

429 if retry_on_error: 

430 self.retry.update_supported_errors(retry_on_error) 

431 

432 kwargs["response_callbacks"] = _RedisCallbacks.copy() 

433 if check_protocol_version(kwargs.get("protocol", DEFAULT_RESP_VERSION), 3): 

434 kwargs["response_callbacks"].update(_RedisCallbacksRESP3) 

435 else: 

436 kwargs["response_callbacks"].update(_RedisCallbacksRESP2) 

437 self.connection_kwargs = kwargs 

438 

439 if startup_nodes: 

440 passed_nodes = [] 

441 for node in startup_nodes: 

442 passed_nodes.append( 

443 ClusterNode(node.host, node.port, **self.connection_kwargs) 

444 ) 

445 startup_nodes = passed_nodes 

446 else: 

447 startup_nodes = [] 

448 if host and port: 

449 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) 

450 

451 if event_dispatcher is None: 

452 self._event_dispatcher = EventDispatcher() 

453 else: 

454 self._event_dispatcher = event_dispatcher 

455 

456 self.startup_nodes = startup_nodes 

457 self.nodes_manager = NodesManager( 

458 startup_nodes, 

459 require_full_coverage, 

460 kwargs, 

461 dynamic_startup_nodes=dynamic_startup_nodes, 

462 address_remap=address_remap, 

463 event_dispatcher=self._event_dispatcher, 

464 ) 

465 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

466 self.read_from_replicas = read_from_replicas 

467 self.load_balancing_strategy = load_balancing_strategy 

468 self.reinitialize_steps = reinitialize_steps 

469 self.reinitialize_counter = 0 

470 

471 # For backward compatibility, mapping from existing policies to new one 

472 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { 

473 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, 

474 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, 

475 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, 

476 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, 

477 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, 

478 SLOT_ID: RequestPolicy.DEFAULT_KEYED, 

479 } 

480 

481 self._policies_callback_mapping: dict[ 

482 Union[RequestPolicy, ResponsePolicy], Callable 

483 ] = { 

484 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ 

485 self.get_random_primary_or_all_nodes(command_name) 

486 ], 

487 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, 

488 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], 

489 RequestPolicy.ALL_SHARDS: self.get_primaries, 

490 RequestPolicy.ALL_NODES: self.get_nodes, 

491 RequestPolicy.ALL_REPLICAS: self.get_replicas, 

492 RequestPolicy.SPECIAL: self.get_special_nodes, 

493 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, 

494 ResponsePolicy.DEFAULT_KEYED: lambda res: res, 

495 } 

496 

497 self._policy_resolver = policy_resolver 

498 self.commands_parser = AsyncCommandsParser() 

499 self._aggregate_nodes = None 

500 self.node_flags = self.__class__.NODE_FLAGS.copy() 

501 self.command_flags = self.__class__.COMMAND_FLAGS.copy() 

502 self.response_callbacks = kwargs["response_callbacks"] 

503 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() 

504 self.result_callbacks["CLUSTER SLOTS"] = ( 

505 lambda cmd, res, **kwargs: parse_cluster_slots( 

506 list(res.values())[0], **kwargs 

507 ) 

508 ) 

509 

510 self._initialize = True 

511 self._lock: Optional[asyncio.Lock] = None 

512 

513 # When used as an async context manager, we need to increment and decrement 

514 # a usage counter so that we can close the connection pool when no one is 

515 # using the client. 

516 self._usage_counter = 0 

517 self._usage_lock = asyncio.Lock() 

518 

519 async def initialize(self) -> "RedisCluster": 

520 """Get all nodes from startup nodes & creates connections if not initialized.""" 

521 if self._initialize: 

522 if not self._lock: 

523 self._lock = asyncio.Lock() 

524 async with self._lock: 

525 if self._initialize: 

526 try: 

527 await self.nodes_manager.initialize() 

528 await self.commands_parser.initialize( 

529 self.nodes_manager.default_node 

530 ) 

531 self._initialize = False 

532 except BaseException: 

533 await self.nodes_manager.aclose() 

534 await self.nodes_manager.aclose("startup_nodes") 

535 raise 

536 return self 

537 

538 async def aclose(self) -> None: 

539 """Close all connections & client if initialized.""" 

540 if not self._initialize: 

541 if not self._lock: 

542 self._lock = asyncio.Lock() 

543 async with self._lock: 

544 if not self._initialize: 

545 self._initialize = True 

546 await self.nodes_manager.aclose() 

547 await self.nodes_manager.aclose("startup_nodes") 

548 

549 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") 

550 async def close(self) -> None: 

551 """alias for aclose() for backwards compatibility""" 

552 await self.aclose() 

553 

554 async def __aenter__(self) -> "RedisCluster": 

555 """ 

556 Async context manager entry. Increments a usage counter so that the 

557 connection pool is only closed (via aclose()) when no context is using 

558 the client. 

559 """ 

560 await self._increment_usage() 

561 try: 

562 # Initialize the client (i.e. establish connection, etc.) 

563 return await self.initialize() 

564 except Exception: 

565 # If initialization fails, decrement the counter to keep it in sync 

566 await self._decrement_usage() 

567 raise 

568 

569 async def _increment_usage(self) -> int: 

570 """ 

571 Helper coroutine to increment the usage counter while holding the lock. 

572 Returns the new value of the usage counter. 

573 """ 

574 async with self._usage_lock: 

575 self._usage_counter += 1 

576 return self._usage_counter 

577 

578 async def _decrement_usage(self) -> int: 

579 """ 

580 Helper coroutine to decrement the usage counter while holding the lock. 

581 Returns the new value of the usage counter. 

582 """ 

583 async with self._usage_lock: 

584 self._usage_counter -= 1 

585 return self._usage_counter 

586 

587 async def __aexit__(self, exc_type, exc_value, traceback): 

588 """ 

589 Async context manager exit. Decrements a usage counter. If this is the 

590 last exit (counter becomes zero), the client closes its connection pool. 

591 """ 

592 current_usage = await asyncio.shield(self._decrement_usage()) 

593 if current_usage == 0: 

594 # This was the last active context, so disconnect the pool. 

595 await asyncio.shield(self.aclose()) 

596 

597 def __await__(self) -> Generator[Any, None, "RedisCluster"]: 

598 return self.initialize().__await__() 

599 

600 _DEL_MESSAGE = "Unclosed RedisCluster client" 

601 

602 def __del__( 

603 self, 

604 _warn: Any = warnings.warn, 

605 _grl: Any = asyncio.get_running_loop, 

606 ) -> None: 

607 if hasattr(self, "_initialize") and not self._initialize: 

608 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

609 try: 

610 context = {"client": self, "message": self._DEL_MESSAGE} 

611 _grl().call_exception_handler(context) 

612 except RuntimeError: 

613 pass 

614 

615 async def on_connect(self, connection: Connection) -> None: 

616 await connection.on_connect() 

617 

618 # Sending READONLY command to server to configure connection as 

619 # readonly. Since each cluster node may change its server type due 

620 # to a failover, we should establish a READONLY connection 

621 # regardless of the server type. If this is a primary connection, 

622 # READONLY would not affect executing write commands. 

623 await connection.send_command("READONLY") 

624 if str_if_bytes(await connection.read_response()) != "OK": 

625 raise ConnectionError("READONLY command failed") 

626 

627 def get_nodes(self) -> List["ClusterNode"]: 

628 """Get all nodes of the cluster.""" 

629 return list(self.nodes_manager.nodes_cache.values()) 

630 

631 def get_primaries(self) -> List["ClusterNode"]: 

632 """Get the primary nodes of the cluster.""" 

633 return self.nodes_manager.get_nodes_by_server_type(PRIMARY) 

634 

635 def get_replicas(self) -> List["ClusterNode"]: 

636 """Get the replica nodes of the cluster.""" 

637 return self.nodes_manager.get_nodes_by_server_type(REPLICA) 

638 

639 def get_random_node(self) -> "ClusterNode": 

640 """Get a random node of the cluster.""" 

641 return random.choice(list(self.nodes_manager.nodes_cache.values())) 

642 

643 def get_default_node(self) -> "ClusterNode": 

644 """Get the default node of the client.""" 

645 return self.nodes_manager.default_node 

646 

647 def set_default_node(self, node: "ClusterNode") -> None: 

648 """ 

649 Set the default node of the client. 

650 

651 :raises DataError: if None is passed or node does not exist in cluster. 

652 """ 

653 if not node or not self.get_node(node_name=node.name): 

654 raise DataError("The requested node does not exist in the cluster.") 

655 

656 self.nodes_manager.default_node = node 

657 

658 def get_node( 

659 self, 

660 host: Optional[str] = None, 

661 port: Optional[int] = None, 

662 node_name: Optional[str] = None, 

663 ) -> Optional["ClusterNode"]: 

664 """Get node by (host, port) or node_name.""" 

665 return self.nodes_manager.get_node(host, port, node_name) 

666 

667 def get_node_from_key( 

668 self, key: str, replica: bool = False 

669 ) -> Optional["ClusterNode"]: 

670 """ 

671 Get the cluster node corresponding to the provided key. 

672 

673 :param key: 

674 :param replica: 

675 | Indicates if a replica should be returned 

676 | 

677 None will returned if no replica holds this key 

678 

679 :raises SlotNotCoveredError: if the key is not covered by any slot. 

680 """ 

681 slot = self.keyslot(key) 

682 slot_cache = self.nodes_manager.slots_cache.get(slot) 

683 if not slot_cache: 

684 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') 

685 

686 if replica: 

687 if len(self.nodes_manager.slots_cache[slot]) < 2: 

688 return None 

689 node_idx = 1 

690 else: 

691 node_idx = 0 

692 

693 return slot_cache[node_idx] 

694 

695 def get_random_primary_or_all_nodes(self, command_name): 

696 """ 

697 Returns random primary or all nodes depends on READONLY mode. 

698 """ 

699 if self.read_from_replicas and command_name in READ_COMMANDS: 

700 return self.get_random_node() 

701 

702 return self.get_random_primary_node() 

703 

704 def get_random_primary_node(self) -> "ClusterNode": 

705 """ 

706 Returns a random primary node 

707 """ 

708 return random.choice(self.get_primaries()) 

709 

710 async def get_nodes_from_slot(self, command: str, *args): 

711 """ 

712 Returns a list of nodes that hold the specified keys' slots. 

713 """ 

714 # get the node that holds the key's slot 

715 return [ 

716 self.nodes_manager.get_node_from_slot( 

717 await self._determine_slot(command, *args), 

718 self.read_from_replicas and command in READ_COMMANDS, 

719 self.load_balancing_strategy if command in READ_COMMANDS else None, 

720 ) 

721 ] 

722 

723 def get_special_nodes(self) -> Optional[list["ClusterNode"]]: 

724 """ 

725 Returns a list of nodes for commands with a special policy. 

726 """ 

727 if not self._aggregate_nodes: 

728 raise RedisClusterException( 

729 "Cannot execute FT.CURSOR commands without FT.AGGREGATE" 

730 ) 

731 

732 return self._aggregate_nodes 

733 

734 def keyslot(self, key: EncodableT) -> int: 

735 """ 

736 Find the keyslot for a given key. 

737 

738 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding 

739 """ 

740 return key_slot(self.encoder.encode(key)) 

741 

742 def get_encoder(self) -> Encoder: 

743 """Get the encoder object of the client.""" 

744 return self.encoder 

745 

746 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: 

747 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" 

748 return self.connection_kwargs 

749 

750 def set_retry(self, retry: Retry) -> None: 

751 self.retry = retry 

752 

753 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

754 """Set a custom response callback.""" 

755 self.response_callbacks[command] = callback 

756 

757 async def _determine_nodes( 

758 self, 

759 command: str, 

760 *args: Any, 

761 request_policy: RequestPolicy, 

762 node_flag: Optional[str] = None, 

763 ) -> List["ClusterNode"]: 

764 # Determine which nodes should be executed the command on. 

765 # Returns a list of target nodes. 

766 if not node_flag: 

767 # get the nodes group for this command if it was predefined 

768 node_flag = self.command_flags.get(command) 

769 

770 if node_flag in self._command_flags_mapping: 

771 request_policy = self._command_flags_mapping[node_flag] 

772 

773 policy_callback = self._policies_callback_mapping[request_policy] 

774 

775 if request_policy == RequestPolicy.DEFAULT_KEYED: 

776 nodes = await policy_callback(command, *args) 

777 elif request_policy == RequestPolicy.DEFAULT_KEYLESS: 

778 nodes = policy_callback(command) 

779 else: 

780 nodes = policy_callback() 

781 

782 if command.lower() == "ft.aggregate": 

783 self._aggregate_nodes = nodes 

784 

785 return nodes 

786 

787 async def _determine_slot(self, command: str, *args: Any) -> int: 

788 if self.command_flags.get(command) == SLOT_ID: 

789 # The command contains the slot ID 

790 return int(args[0]) 

791 

792 # Get the keys in the command 

793 

794 # EVAL and EVALSHA are common enough that it's wasteful to go to the 

795 # redis server to parse the keys. Besides, there is a bug in redis<7.0 

796 # where `self._get_command_keys()` fails anyway. So, we special case 

797 # EVAL/EVALSHA. 

798 # - issue: https://github.com/redis/redis/issues/9493 

799 # - fix: https://github.com/redis/redis/pull/9733 

800 if command.upper() in ("EVAL", "EVALSHA"): 

801 # command syntax: EVAL "script body" num_keys ... 

802 if len(args) < 2: 

803 raise RedisClusterException( 

804 f"Invalid args in command: {command, *args}" 

805 ) 

806 keys = args[2 : 2 + int(args[1])] 

807 # if there are 0 keys, that means the script can be run on any node 

808 # so we can just return a random slot 

809 if not keys: 

810 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

811 else: 

812 keys = await self.commands_parser.get_keys(command, *args) 

813 if not keys: 

814 # FCALL can call a function with 0 keys, that means the function 

815 # can be run on any node so we can just return a random slot 

816 if command.upper() in ("FCALL", "FCALL_RO"): 

817 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

818 raise RedisClusterException( 

819 "No way to dispatch this command to Redis Cluster. " 

820 "Missing key.\nYou can execute the command by specifying " 

821 f"target nodes.\nCommand: {args}" 

822 ) 

823 

824 # single key command 

825 if len(keys) == 1: 

826 return self.keyslot(keys[0]) 

827 

828 # multi-key command; we need to make sure all keys are mapped to 

829 # the same slot 

830 slots = {self.keyslot(key) for key in keys} 

831 if len(slots) != 1: 

832 raise RedisClusterException( 

833 f"{command} - all keys must map to the same key slot" 

834 ) 

835 

836 return slots.pop() 

837 

838 def _is_node_flag(self, target_nodes: Any) -> bool: 

839 return isinstance(target_nodes, str) and target_nodes in self.node_flags 

840 

841 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: 

842 if isinstance(target_nodes, list): 

843 nodes = target_nodes 

844 elif isinstance(target_nodes, ClusterNode): 

845 # Supports passing a single ClusterNode as a variable 

846 nodes = [target_nodes] 

847 elif isinstance(target_nodes, dict): 

848 # Supports dictionaries of the format {node_name: node}. 

849 # It enables to execute commands with multi nodes as follows: 

850 # rc.cluster_save_config(rc.get_primaries()) 

851 nodes = list(target_nodes.values()) 

852 else: 

853 raise TypeError( 

854 "target_nodes type can be one of the following: " 

855 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," 

856 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. " 

857 f"The passed type is {type(target_nodes)}" 

858 ) 

859 return nodes 

860 

861 async def _record_error_metric( 

862 self, 

863 error: Exception, 

864 connection: Union[Connection, "ClusterNode"], 

865 is_internal: bool = True, 

866 retry_attempts: Optional[int] = None, 

867 ): 

868 """ 

869 Records error count metric directly. 

870 Accepts either a Connection or ClusterNode object. 

871 """ 

872 await record_error_count( 

873 server_address=connection.host, 

874 server_port=connection.port, 

875 network_peer_address=connection.host, 

876 network_peer_port=connection.port, 

877 error_type=error, 

878 retry_attempts=retry_attempts if retry_attempts is not None else 0, 

879 is_internal=is_internal, 

880 ) 

881 

882 async def _record_command_metric( 

883 self, 

884 command_name: str, 

885 duration_seconds: float, 

886 connection: Union[Connection, "ClusterNode"], 

887 error: Optional[Exception] = None, 

888 ): 

889 """ 

890 Records operation duration metric directly. 

891 Accepts either a Connection or ClusterNode object. 

892 """ 

893 # Connection has db attribute, ClusterNode has connection_kwargs 

894 if hasattr(connection, "db"): 

895 db = connection.db 

896 else: 

897 db = connection.connection_kwargs.get("db", 0) 

898 await record_operation_duration( 

899 command_name=command_name, 

900 duration_seconds=duration_seconds, 

901 server_address=connection.host, 

902 server_port=connection.port, 

903 db_namespace=str(db) if db is not None else None, 

904 error=error, 

905 ) 

906 

907 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: 

908 """ 

909 Execute a raw command on the appropriate cluster node or target_nodes. 

910 

911 It will retry the command as specified by the retries property of 

912 the :attr:`retry` & then raise an exception. 

913 

914 :param args: 

915 | Raw command args 

916 :param kwargs: 

917 

918 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

919 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

920 - Rest of the kwargs are passed to the Redis connection 

921 

922 :raises RedisClusterException: if target_nodes is not provided & the command 

923 can't be mapped to a slot 

924 """ 

925 command = args[0] 

926 target_nodes = [] 

927 target_nodes_specified = False 

928 retry_attempts = self.retry.get_retries() 

929 

930 passed_targets = kwargs.pop("target_nodes", None) 

931 if passed_targets and not self._is_node_flag(passed_targets): 

932 target_nodes = self._parse_target_nodes(passed_targets) 

933 target_nodes_specified = True 

934 retry_attempts = 0 

935 

936 command_policies = await self._policy_resolver.resolve(args[0].lower()) 

937 

938 if not command_policies and not target_nodes_specified: 

939 command_flag = self.command_flags.get(command) 

940 if not command_flag: 

941 # Fallback to default policy 

942 if not self.get_default_node(): 

943 slot = None 

944 else: 

945 slot = await self._determine_slot(*args) 

946 if slot is None: 

947 command_policies = CommandPolicies() 

948 else: 

949 command_policies = CommandPolicies( 

950 request_policy=RequestPolicy.DEFAULT_KEYED, 

951 response_policy=ResponsePolicy.DEFAULT_KEYED, 

952 ) 

953 else: 

954 if command_flag in self._command_flags_mapping: 

955 command_policies = CommandPolicies( 

956 request_policy=self._command_flags_mapping[command_flag] 

957 ) 

958 else: 

959 command_policies = CommandPolicies() 

960 elif not command_policies and target_nodes_specified: 

961 command_policies = CommandPolicies() 

962 

963 # Add one for the first execution 

964 execute_attempts = 1 + retry_attempts 

965 failure_count = 0 

966 

967 # Start timing for observability 

968 start_time = time.monotonic() 

969 

970 for _ in range(execute_attempts): 

971 if self._initialize: 

972 await self.initialize() 

973 if ( 

974 len(target_nodes) == 1 

975 and target_nodes[0] == self.get_default_node() 

976 ): 

977 # Replace the default cluster node 

978 self.replace_default_node() 

979 try: 

980 if not target_nodes_specified: 

981 # Determine the nodes to execute the command on 

982 target_nodes = await self._determine_nodes( 

983 *args, 

984 request_policy=command_policies.request_policy, 

985 node_flag=passed_targets, 

986 ) 

987 if not target_nodes: 

988 raise RedisClusterException( 

989 f"No targets were found to execute {args} command on" 

990 ) 

991 

992 if len(target_nodes) == 1: 

993 # Return the processed result 

994 ret = await self._execute_command(target_nodes[0], *args, **kwargs) 

995 if command in self.result_callbacks: 

996 ret = self.result_callbacks[command]( 

997 command, {target_nodes[0].name: ret}, **kwargs 

998 ) 

999 return self._policies_callback_mapping[ 

1000 command_policies.response_policy 

1001 ](ret) 

1002 else: 

1003 keys = [node.name for node in target_nodes] 

1004 values = await asyncio.gather( 

1005 *( 

1006 asyncio.create_task( 

1007 self._execute_command(node, *args, **kwargs) 

1008 ) 

1009 for node in target_nodes 

1010 ) 

1011 ) 

1012 if command in self.result_callbacks: 

1013 return self.result_callbacks[command]( 

1014 command, dict(zip(keys, values)), **kwargs 

1015 ) 

1016 return self._policies_callback_mapping[ 

1017 command_policies.response_policy 

1018 ](dict(zip(keys, values))) 

1019 except Exception as e: 

1020 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: 

1021 # The nodes and slots cache were should be reinitialized. 

1022 # Try again with the new cluster setup. 

1023 retry_attempts -= 1 

1024 failure_count += 1 

1025 

1026 if hasattr(e, "connection"): 

1027 await self._record_command_metric( 

1028 command_name=command, 

1029 duration_seconds=time.monotonic() - start_time, 

1030 connection=e.connection, 

1031 error=e, 

1032 ) 

1033 await self._record_error_metric( 

1034 error=e, 

1035 connection=e.connection, 

1036 retry_attempts=failure_count, 

1037 ) 

1038 continue 

1039 else: 

1040 # raise the exception 

1041 if hasattr(e, "connection"): 

1042 await self._record_error_metric( 

1043 error=e, 

1044 connection=e.connection, 

1045 retry_attempts=failure_count, 

1046 is_internal=False, 

1047 ) 

1048 raise e 

1049 

1050 async def _execute_command( 

1051 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any 

1052 ) -> Any: 

1053 asking = moved = False 

1054 redirect_addr = None 

1055 ttl = self.RedisClusterRequestTTL 

1056 command = args[0] 

1057 start_time = time.monotonic() 

1058 

1059 while ttl > 0: 

1060 ttl -= 1 

1061 try: 

1062 if asking: 

1063 target_node = self.get_node(node_name=redirect_addr) 

1064 await target_node.execute_command("ASKING") 

1065 asking = False 

1066 elif moved: 

1067 # MOVED occurred and the slots cache was updated, 

1068 # refresh the target node 

1069 slot = await self._determine_slot(*args) 

1070 target_node = self.nodes_manager.get_node_from_slot( 

1071 slot, 

1072 self.read_from_replicas and args[0] in READ_COMMANDS, 

1073 self.load_balancing_strategy 

1074 if args[0] in READ_COMMANDS 

1075 else None, 

1076 ) 

1077 moved = False 

1078 

1079 response = await target_node.execute_command(*args, **kwargs) 

1080 await self._record_command_metric( 

1081 command_name=command, 

1082 duration_seconds=time.monotonic() - start_time, 

1083 connection=target_node, 

1084 ) 

1085 return response 

1086 except BusyLoadingError as e: 

1087 e.connection = target_node 

1088 await self._record_command_metric( 

1089 command_name=command, 

1090 duration_seconds=time.monotonic() - start_time, 

1091 connection=target_node, 

1092 error=e, 

1093 ) 

1094 raise 

1095 except MaxConnectionsError as e: 

1096 # MaxConnectionsError indicates client-side resource exhaustion 

1097 # (too many connections in the pool), not a node failure. 

1098 # Don't treat this as a node failure - just re-raise the error 

1099 # without reinitializing the cluster. 

1100 e.connection = target_node 

1101 await self._record_command_metric( 

1102 command_name=command, 

1103 duration_seconds=time.monotonic() - start_time, 

1104 connection=target_node, 

1105 error=e, 

1106 ) 

1107 raise 

1108 except (ConnectionError, TimeoutError) as e: 

1109 # Connection retries are being handled in the node's 

1110 # Retry object. 

1111 # Mark active connections for reconnect and disconnect free ones 

1112 # This handles connection state (like READONLY) that may be stale 

1113 target_node.update_active_connections_for_reconnect() 

1114 await target_node.disconnect_free_connections() 

1115 

1116 # Move the failed node to the end of the cached nodes list 

1117 # so it's tried last during reinitialization 

1118 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name) 

1119 

1120 # Signal that reinitialization is needed 

1121 # The retry loop will handle initialize() AND replace_default_node() 

1122 self._initialize = True 

1123 e.connection = target_node 

1124 await self._record_command_metric( 

1125 command_name=command, 

1126 duration_seconds=time.monotonic() - start_time, 

1127 connection=target_node, 

1128 error=e, 

1129 ) 

1130 raise 

1131 except (ClusterDownError, SlotNotCoveredError) as e: 

1132 # ClusterDownError can occur during a failover and to get 

1133 # self-healed, we will try to reinitialize the cluster layout 

1134 # and retry executing the command 

1135 

1136 # SlotNotCoveredError can occur when the cluster is not fully 

1137 # initialized or can be temporary issue. 

1138 # We will try to reinitialize the cluster topology 

1139 # and retry executing the command 

1140 

1141 await self.aclose() 

1142 await asyncio.sleep(0.25) 

1143 e.connection = target_node 

1144 await self._record_command_metric( 

1145 command_name=command, 

1146 duration_seconds=time.monotonic() - start_time, 

1147 connection=target_node, 

1148 error=e, 

1149 ) 

1150 raise 

1151 except MovedError as e: 

1152 # First, we will try to patch the slots/nodes cache with the 

1153 # redirected node output and try again. If MovedError exceeds 

1154 # 'reinitialize_steps' number of times, we will force 

1155 # reinitializing the tables, and then try again. 

1156 # 'reinitialize_steps' counter will increase faster when 

1157 # the same client object is shared between multiple threads. To 

1158 # reduce the frequency you can set this variable in the 

1159 # RedisCluster constructor. 

1160 self.reinitialize_counter += 1 

1161 if ( 

1162 self.reinitialize_steps 

1163 and self.reinitialize_counter % self.reinitialize_steps == 0 

1164 ): 

1165 await self.aclose() 

1166 # Reset the counter 

1167 self.reinitialize_counter = 0 

1168 else: 

1169 await self.nodes_manager.move_slot(e) 

1170 moved = True 

1171 await self._record_command_metric( 

1172 command_name=command, 

1173 duration_seconds=time.monotonic() - start_time, 

1174 connection=target_node, 

1175 error=e, 

1176 ) 

1177 await self._record_error_metric( 

1178 error=e, 

1179 connection=target_node, 

1180 ) 

1181 except AskError as e: 

1182 redirect_addr = get_node_name(host=e.host, port=e.port) 

1183 asking = True 

1184 await self._record_command_metric( 

1185 command_name=command, 

1186 duration_seconds=time.monotonic() - start_time, 

1187 connection=target_node, 

1188 error=e, 

1189 ) 

1190 await self._record_error_metric( 

1191 error=e, 

1192 connection=target_node, 

1193 ) 

1194 except TryAgainError as e: 

1195 if ttl < self.RedisClusterRequestTTL / 2: 

1196 await asyncio.sleep(0.05) 

1197 await self._record_command_metric( 

1198 command_name=command, 

1199 duration_seconds=time.monotonic() - start_time, 

1200 connection=target_node, 

1201 error=e, 

1202 ) 

1203 await self._record_error_metric( 

1204 error=e, 

1205 connection=target_node, 

1206 ) 

1207 except ResponseError as e: 

1208 e.connection = target_node 

1209 await self._record_command_metric( 

1210 command_name=command, 

1211 duration_seconds=time.monotonic() - start_time, 

1212 connection=target_node, 

1213 error=e, 

1214 ) 

1215 raise 

1216 except Exception as e: 

1217 e.connection = target_node 

1218 await self._record_command_metric( 

1219 command_name=command, 

1220 duration_seconds=time.monotonic() - start_time, 

1221 connection=target_node, 

1222 error=e, 

1223 ) 

1224 raise 

1225 

1226 e = ClusterError("TTL exhausted.") 

1227 e.connection = target_node 

1228 await self._record_command_metric( 

1229 command_name=command, 

1230 duration_seconds=time.monotonic() - start_time, 

1231 connection=target_node, 

1232 error=e, 

1233 ) 

1234 raise e 

1235 

1236 def pipeline( 

1237 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None 

1238 ) -> "ClusterPipeline": 

1239 """ 

1240 Create & return a new :class:`~.ClusterPipeline` object. 

1241 

1242 Cluster implementation of pipeline does not support transaction or shard_hint. 

1243 

1244 :raises RedisClusterException: if transaction or shard_hint are truthy values 

1245 """ 

1246 if shard_hint: 

1247 raise RedisClusterException("shard_hint is deprecated in cluster mode") 

1248 

1249 return ClusterPipeline(self, transaction) 

1250 

1251 def pubsub( 

1252 self, 

1253 node: Optional["ClusterNode"] = None, 

1254 host: Optional[str] = None, 

1255 port: Optional[int] = None, 

1256 **kwargs: Any, 

1257 ) -> "ClusterPubSub": 

1258 """ 

1259 Create and return a ClusterPubSub instance. 

1260 

1261 Allows passing a ClusterNode, or host&port, to get a pubsub instance 

1262 connected to the specified node 

1263 

1264 :param node: ClusterNode to connect to 

1265 :param host: Host of the node to connect to 

1266 :param port: Port of the node to connect to 

1267 :param kwargs: Additional keyword arguments 

1268 :return: ClusterPubSub instance 

1269 """ 

1270 return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) 

1271 

1272 def keyspace_notifications( 

1273 self, 

1274 key_prefix: Union[str, bytes, None] = None, 

1275 ignore_subscribe_messages: bool = True, 

1276 ) -> "AsyncClusterKeyspaceNotifications": 

1277 """ 

1278 Return an 

1279 :class:`~redis.asyncio.keyspace_notifications.AsyncClusterKeyspaceNotifications` 

1280 object for subscribing to keyspace and keyevent notifications across 

1281 all primary nodes in the cluster. 

1282 

1283 Note: Keyspace notifications must be enabled on all Redis cluster nodes 

1284 via the ``notify-keyspace-events`` configuration option. 

1285 

1286 Args: 

1287 key_prefix: Optional prefix to filter and strip from keys in 

1288 notifications. 

1289 ignore_subscribe_messages: If True, subscribe/unsubscribe 

1290 confirmations are not returned by 

1291 get_message/listen. 

1292 """ 

1293 from redis.asyncio.keyspace_notifications import ( 

1294 AsyncClusterKeyspaceNotifications, 

1295 ) 

1296 

1297 return AsyncClusterKeyspaceNotifications( 

1298 self, 

1299 key_prefix=key_prefix, 

1300 ignore_subscribe_messages=ignore_subscribe_messages, 

1301 ) 

1302 

1303 def lock( 

1304 self, 

1305 name: KeyT, 

1306 timeout: Optional[float] = None, 

1307 sleep: float = 0.1, 

1308 blocking: bool = True, 

1309 blocking_timeout: Optional[float] = None, 

1310 lock_class: Optional[Type[Lock]] = None, 

1311 thread_local: bool = True, 

1312 raise_on_release_error: bool = True, 

1313 ) -> Lock: 

1314 """ 

1315 Return a new Lock object using key ``name`` that mimics 

1316 the behavior of threading.Lock. 

1317 

1318 If specified, ``timeout`` indicates a maximum life for the lock. 

1319 By default, it will remain locked until release() is called. 

1320 

1321 ``sleep`` indicates the amount of time to sleep per loop iteration 

1322 when the lock is in blocking mode and another client is currently 

1323 holding the lock. 

1324 

1325 ``blocking`` indicates whether calling ``acquire`` should block until 

1326 the lock has been acquired or to fail immediately, causing ``acquire`` 

1327 to return False and the lock not being acquired. Defaults to True. 

1328 Note this value can be overridden by passing a ``blocking`` 

1329 argument to ``acquire``. 

1330 

1331 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

1332 spend trying to acquire the lock. A value of ``None`` indicates 

1333 continue trying forever. ``blocking_timeout`` can be specified as a 

1334 float or integer, both representing the number of seconds to wait. 

1335 

1336 ``lock_class`` forces the specified lock implementation. Note that as 

1337 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is 

1338 a Lua-based lock). So, it's unlikely you'll need this parameter, unless 

1339 you have created your own custom lock class. 

1340 

1341 ``thread_local`` indicates whether the lock token is placed in 

1342 thread-local storage. By default, the token is placed in thread local 

1343 storage so that a thread only sees its token, not a token set by 

1344 another thread. Consider the following timeline: 

1345 

1346 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

1347 thread-1 sets the token to "abc" 

1348 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

1349 Lock instance. 

1350 time: 5, thread-1 has not yet completed. redis expires the lock 

1351 key. 

1352 time: 5, thread-2 acquired `my-lock` now that it's available. 

1353 thread-2 sets the token to "xyz" 

1354 time: 6, thread-1 finishes its work and calls release(). if the 

1355 token is *not* stored in thread local storage, then 

1356 thread-1 would see the token value as "xyz" and would be 

1357 able to successfully release the thread-2's lock. 

1358 

1359 ``raise_on_release_error`` indicates whether to raise an exception when 

1360 the lock is no longer owned when exiting the context manager. By default, 

1361 this is True, meaning an exception will be raised. If False, the warning 

1362 will be logged and the exception will be suppressed. 

1363 

1364 In some use cases it's necessary to disable thread local storage. For 

1365 example, if you have code where one thread acquires a lock and passes 

1366 that lock instance to a worker thread to release later. If thread 

1367 local storage isn't disabled in this case, the worker thread won't see 

1368 the token set by the thread that acquired the lock. Our assumption 

1369 is that these cases aren't common and as such default to using 

1370 thread local storage.""" 

1371 if lock_class is None: 

1372 lock_class = Lock 

1373 return lock_class( 

1374 self, 

1375 name, 

1376 timeout=timeout, 

1377 sleep=sleep, 

1378 blocking=blocking, 

1379 blocking_timeout=blocking_timeout, 

1380 thread_local=thread_local, 

1381 raise_on_release_error=raise_on_release_error, 

1382 ) 

1383 

1384 async def transaction( 

1385 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs 

1386 ): 

1387 """ 

1388 Convenience method for executing the callable `func` as a transaction 

1389 while watching all keys specified in `watches`. The 'func' callable 

1390 should expect a single argument which is a Pipeline object. 

1391 """ 

1392 shard_hint = kwargs.pop("shard_hint", None) 

1393 value_from_callable = kwargs.pop("value_from_callable", False) 

1394 watch_delay = kwargs.pop("watch_delay", None) 

1395 async with self.pipeline(True, shard_hint) as pipe: 

1396 while True: 

1397 try: 

1398 if watches: 

1399 await pipe.watch(*watches) 

1400 func_value = await func(pipe) 

1401 exec_value = await pipe.execute() 

1402 return func_value if value_from_callable else exec_value 

1403 except WatchError: 

1404 if watch_delay is not None and watch_delay > 0: 

1405 time.sleep(watch_delay) 

1406 continue 

1407 

1408 

1409class ClusterNode: 

1410 """ 

1411 Create a new ClusterNode. 

1412 

1413 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` 

1414 objects for the (host, port). 

1415 """ 

1416 

1417 __slots__ = ( 

1418 "_connections", 

1419 "_free", 

1420 "_lock", 

1421 "_event_dispatcher", 

1422 "connection_class", 

1423 "connection_kwargs", 

1424 "host", 

1425 "max_connections", 

1426 "name", 

1427 "port", 

1428 "response_callbacks", 

1429 "server_type", 

1430 ) 

1431 

1432 def __init__( 

1433 self, 

1434 host: str, 

1435 port: Union[str, int], 

1436 server_type: Optional[str] = None, 

1437 *, 

1438 max_connections: int = 2**31, 

1439 connection_class: Type[Connection] = Connection, 

1440 **connection_kwargs: Any, 

1441 ) -> None: 

1442 if host == "localhost": 

1443 host = socket.gethostbyname(host) 

1444 

1445 connection_kwargs["host"] = host 

1446 connection_kwargs["port"] = port 

1447 self.host = host 

1448 self.port = port 

1449 self.name = get_node_name(host, port) 

1450 self.server_type = server_type 

1451 

1452 self.max_connections = max_connections 

1453 self.connection_class = connection_class 

1454 self.connection_kwargs = connection_kwargs 

1455 self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) 

1456 

1457 self._connections: List[Connection] = [] 

1458 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) 

1459 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1460 if self._event_dispatcher is None: 

1461 self._event_dispatcher = EventDispatcher() 

1462 

1463 def __repr__(self) -> str: 

1464 return ( 

1465 f"[host={self.host}, port={self.port}, " 

1466 f"name={self.name}, server_type={self.server_type}]" 

1467 ) 

1468 

1469 def __eq__(self, obj: Any) -> bool: 

1470 return isinstance(obj, ClusterNode) and obj.name == self.name 

1471 

1472 def __hash__(self) -> int: 

1473 return hash(self.name) 

1474 

1475 _DEL_MESSAGE = "Unclosed ClusterNode object" 

1476 

1477 def __del__( 

1478 self, 

1479 _warn: Any = warnings.warn, 

1480 _grl: Any = asyncio.get_running_loop, 

1481 ) -> None: 

1482 for connection in self._connections: 

1483 if connection.is_connected: 

1484 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

1485 

1486 try: 

1487 context = {"client": self, "message": self._DEL_MESSAGE} 

1488 _grl().call_exception_handler(context) 

1489 except RuntimeError: 

1490 pass 

1491 break 

1492 

1493 async def disconnect(self) -> None: 

1494 ret = await asyncio.gather( 

1495 *( 

1496 asyncio.create_task(connection.disconnect()) 

1497 for connection in self._connections 

1498 ), 

1499 return_exceptions=True, 

1500 ) 

1501 exc = next((res for res in ret if isinstance(res, Exception)), None) 

1502 if exc: 

1503 raise exc 

1504 

1505 def acquire_connection(self) -> Connection: 

1506 try: 

1507 return self._free.popleft() 

1508 except IndexError: 

1509 if len(self._connections) < self.max_connections: 

1510 # We are configuring the connection pool not to retry 

1511 # connections on lower level clients to avoid retrying 

1512 # connections to nodes that are not reachable 

1513 # and to avoid blocking the connection pool. 

1514 # The only error that will have some handling in the lower 

1515 # level clients is ConnectionError which will trigger disconnection 

1516 # of the socket. 

1517 # The retries will be handled on cluster client level 

1518 # where we will have proper handling of the cluster topology 

1519 retry = Retry( 

1520 backoff=NoBackoff(), 

1521 retries=0, 

1522 supported_errors=(ConnectionError,), 

1523 ) 

1524 connection_kwargs = self.connection_kwargs.copy() 

1525 connection_kwargs["retry"] = retry 

1526 connection = self.connection_class(**connection_kwargs) 

1527 self._connections.append(connection) 

1528 return connection 

1529 

1530 raise MaxConnectionsError() 

1531 

1532 async def disconnect_if_needed(self, connection: Connection) -> None: 

1533 """ 

1534 Disconnect a connection if it's marked for reconnect. 

1535 This implements lazy disconnection to avoid race conditions. 

1536 The connection will auto-reconnect on next use. 

1537 """ 

1538 if connection.should_reconnect(): 

1539 await connection.disconnect() 

1540 

1541 def release(self, connection: Connection) -> None: 

1542 """ 

1543 Release connection back to free queue. 

1544 If the connection is marked for reconnect, it will be disconnected 

1545 lazily when next acquired via disconnect_if_needed(). 

1546 """ 

1547 self._free.append(connection) 

1548 

1549 def get_encoder(self) -> Encoder: 

1550 """Return an :class:`Encoder` derived from this node's connection kwargs.""" 

1551 kwargs = self.connection_kwargs 

1552 encoder_class = kwargs.get("encoder_class", Encoder) 

1553 return encoder_class( 

1554 encoding=kwargs.get("encoding", "utf-8"), 

1555 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1556 decode_responses=kwargs.get("decode_responses", False), 

1557 ) 

1558 

1559 def update_active_connections_for_reconnect(self) -> None: 

1560 """ 

1561 Mark all in-use (active) connections for reconnect. 

1562 In-use connections are those in _connections but not currently in _free. 

1563 They will be disconnected when released back to the pool. 

1564 """ 

1565 free_set = set(self._free) 

1566 for connection in self._connections: 

1567 if connection not in free_set: 

1568 connection.mark_for_reconnect() 

1569 

1570 async def disconnect_free_connections(self) -> None: 

1571 """ 

1572 Disconnect all free/idle connections in the pool. 

1573 This is useful after topology changes (e.g., failover) to clear 

1574 stale connection state like READONLY mode. 

1575 The connections remain in the pool and will reconnect on next use. 

1576 """ 

1577 if self._free: 

1578 # Take a snapshot to avoid issues if _free changes during await 

1579 await asyncio.gather( 

1580 *(connection.disconnect() for connection in tuple(self._free)), 

1581 return_exceptions=True, 

1582 ) 

1583 

1584 async def parse_response( 

1585 self, connection: Connection, command: str, **kwargs: Any 

1586 ) -> Any: 

1587 try: 

1588 if NEVER_DECODE in kwargs: 

1589 response = await connection.read_response(disable_decoding=True) 

1590 kwargs.pop(NEVER_DECODE) 

1591 else: 

1592 response = await connection.read_response() 

1593 except ResponseError: 

1594 if EMPTY_RESPONSE in kwargs: 

1595 return kwargs[EMPTY_RESPONSE] 

1596 raise 

1597 

1598 if EMPTY_RESPONSE in kwargs: 

1599 kwargs.pop(EMPTY_RESPONSE) 

1600 

1601 # Remove keys entry, it needs only for cache. 

1602 kwargs.pop("keys", None) 

1603 

1604 # Return response 

1605 if command in self.response_callbacks: 

1606 return self.response_callbacks[command](response, **kwargs) 

1607 

1608 return response 

1609 

1610 async def execute_command(self, *args: Any, **kwargs: Any) -> Any: 

1611 # Acquire connection 

1612 connection = self.acquire_connection() 

1613 # Handle lazy disconnect for connections marked for reconnect 

1614 await self.disconnect_if_needed(connection) 

1615 

1616 # Execute command 

1617 await connection.send_packed_command(connection.pack_command(*args), False) 

1618 

1619 # Read response 

1620 try: 

1621 return await self.parse_response(connection, args[0], **kwargs) 

1622 finally: 

1623 await self.disconnect_if_needed(connection) 

1624 # Release connection 

1625 self._free.append(connection) 

1626 

1627 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: 

1628 # Acquire connection 

1629 connection = self.acquire_connection() 

1630 # Handle lazy disconnect for connections marked for reconnect 

1631 await self.disconnect_if_needed(connection) 

1632 

1633 # Execute command 

1634 await connection.send_packed_command( 

1635 connection.pack_commands(cmd.args for cmd in commands), False 

1636 ) 

1637 

1638 # Read responses 

1639 ret = False 

1640 for cmd in commands: 

1641 try: 

1642 cmd.result = await self.parse_response( 

1643 connection, cmd.args[0], **cmd.kwargs 

1644 ) 

1645 except Exception as e: 

1646 cmd.result = e 

1647 ret = True 

1648 

1649 # Release connection 

1650 await self.disconnect_if_needed(connection) 

1651 self._free.append(connection) 

1652 

1653 return ret 

1654 

1655 async def re_auth_callback(self, token: TokenInterface): 

1656 tmp_queue = collections.deque() 

1657 while self._free: 

1658 conn = self._free.popleft() 

1659 await conn.retry.call_with_retry( 

1660 lambda: conn.send_command( 

1661 "AUTH", token.try_get("oid"), token.get_value() 

1662 ), 

1663 lambda error: self._mock(error), 

1664 ) 

1665 await conn.retry.call_with_retry( 

1666 lambda: conn.read_response(), lambda error: self._mock(error) 

1667 ) 

1668 tmp_queue.append(conn) 

1669 

1670 while tmp_queue: 

1671 conn = tmp_queue.popleft() 

1672 self._free.append(conn) 

1673 

1674 async def _mock(self, error: RedisError): 

1675 """ 

1676 Dummy functions, needs to be passed as error callback to retry object. 

1677 :param error: 

1678 :return: 

1679 """ 

1680 pass 

1681 

1682 

1683class NodesManager: 

1684 __slots__ = ( 

1685 "_dynamic_startup_nodes", 

1686 "_event_dispatcher", 

1687 "_background_tasks", 

1688 "connection_kwargs", 

1689 "default_node", 

1690 "nodes_cache", 

1691 "_epoch", 

1692 "read_load_balancer", 

1693 "_initialize_lock", 

1694 "require_full_coverage", 

1695 "slots_cache", 

1696 "startup_nodes", 

1697 "address_remap", 

1698 ) 

1699 

1700 def __init__( 

1701 self, 

1702 startup_nodes: List["ClusterNode"], 

1703 require_full_coverage: bool, 

1704 connection_kwargs: Dict[str, Any], 

1705 dynamic_startup_nodes: bool = True, 

1706 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

1707 event_dispatcher: Optional[EventDispatcher] = None, 

1708 ) -> None: 

1709 self.startup_nodes = {node.name: node for node in startup_nodes} 

1710 self.require_full_coverage = require_full_coverage 

1711 self.connection_kwargs = connection_kwargs 

1712 self.address_remap = address_remap 

1713 

1714 self.default_node: "ClusterNode" = None 

1715 self.nodes_cache: Dict[str, "ClusterNode"] = {} 

1716 self.slots_cache: Dict[int, List["ClusterNode"]] = {} 

1717 self._epoch: int = 0 

1718 self.read_load_balancer = LoadBalancer() 

1719 self._initialize_lock: asyncio.Lock = asyncio.Lock() 

1720 

1721 self._background_tasks: Set[asyncio.Task] = set() 

1722 self._dynamic_startup_nodes: bool = dynamic_startup_nodes 

1723 if event_dispatcher is None: 

1724 self._event_dispatcher = EventDispatcher() 

1725 else: 

1726 self._event_dispatcher = event_dispatcher 

1727 

1728 def get_node( 

1729 self, 

1730 host: Optional[str] = None, 

1731 port: Optional[int] = None, 

1732 node_name: Optional[str] = None, 

1733 ) -> Optional["ClusterNode"]: 

1734 if host and port: 

1735 # the user passed host and port 

1736 if host == "localhost": 

1737 host = socket.gethostbyname(host) 

1738 return self.nodes_cache.get(get_node_name(host=host, port=port)) 

1739 elif node_name: 

1740 return self.nodes_cache.get(node_name) 

1741 else: 

1742 raise DataError( 

1743 "get_node requires one of the following: 1. node name 2. host and port" 

1744 ) 

1745 

1746 def set_nodes( 

1747 self, 

1748 old: Dict[str, "ClusterNode"], 

1749 new: Dict[str, "ClusterNode"], 

1750 remove_old: bool = False, 

1751 ) -> None: 

1752 if remove_old: 

1753 for name in list(old.keys()): 

1754 if name not in new: 

1755 # Node is removed from cache before disconnect starts, 

1756 # so it won't be found in lookups during disconnect 

1757 # Mark active connections for reconnect so they get disconnected after current command completes 

1758 # and disconnect free connections immediately 

1759 # the node is removed from the cache before the connections changes so it won't be used and should be safe 

1760 # not to wait for the disconnects 

1761 removed_node = old.pop(name) 

1762 removed_node.update_active_connections_for_reconnect() 

1763 task = asyncio.create_task( 

1764 removed_node.disconnect_free_connections() 

1765 ) 

1766 self._background_tasks.add(task) 

1767 task.add_done_callback(self._background_tasks.discard) 

1768 

1769 for name, node in new.items(): 

1770 if name in old: 

1771 # Preserve the existing node but mark connections for reconnect. 

1772 # This method is sync so we can't call disconnect_free_connections() 

1773 # which is async. Instead, we mark free connections for reconnect 

1774 # and they will be lazily disconnected when acquired via 

1775 # disconnect_if_needed() to avoid race conditions. 

1776 # TODO: Make this method async in the next major release to allow 

1777 # immediate disconnection of free connections. 

1778 existing_node = old[name] 

1779 existing_node.update_active_connections_for_reconnect() 

1780 for conn in existing_node._free: 

1781 conn.mark_for_reconnect() 

1782 continue 

1783 # New node is detected and should be added to the pool 

1784 old[name] = node 

1785 

1786 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None: 

1787 """ 

1788 Move a failing node to the end of startup_nodes and nodes_cache so it's 

1789 tried last during reinitialization and when selecting the default node. 

1790 If the node is not in the respective list, nothing is done. 

1791 """ 

1792 # Move in startup_nodes 

1793 if node_name in self.startup_nodes and len(self.startup_nodes) > 1: 

1794 node = self.startup_nodes.pop(node_name) 

1795 self.startup_nodes[node_name] = node # Re-insert at end 

1796 

1797 # Move in nodes_cache - this affects get_nodes_by_server_type ordering 

1798 # which is used to select the default_node during initialize() 

1799 if node_name in self.nodes_cache and len(self.nodes_cache) > 1: 

1800 node = self.nodes_cache.pop(node_name) 

1801 self.nodes_cache[node_name] = node # Re-insert at end 

1802 

1803 async def move_slot(self, e: AskError | MovedError): 

1804 node_changed = False 

1805 redirected_node = self.get_node(host=e.host, port=e.port) 

1806 if redirected_node: 

1807 # The node already exists 

1808 if redirected_node.server_type != PRIMARY: 

1809 # Update the node's server type 

1810 redirected_node.server_type = PRIMARY 

1811 else: 

1812 # This is a new node, we will add it to the nodes cache 

1813 redirected_node = ClusterNode( 

1814 e.host, e.port, PRIMARY, **self.connection_kwargs 

1815 ) 

1816 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) 

1817 slot_nodes = self.slots_cache[e.slot_id] 

1818 if redirected_node not in slot_nodes: 

1819 # The new slot owner is a new server, or a server from a different 

1820 # shard. We need to remove all current nodes from the slot's list 

1821 # (including replications) and add just the new node. 

1822 self.slots_cache[e.slot_id] = [redirected_node] 

1823 node_changed = True 

1824 elif redirected_node is not slot_nodes[0]: 

1825 # The MOVED error resulted from a failover, and the new slot owner 

1826 # had previously been a replica. 

1827 old_primary = slot_nodes[0] 

1828 # Update the old primary to be a replica and add it to the end of 

1829 # the slot's node list 

1830 old_primary.server_type = REPLICA 

1831 slot_nodes.append(old_primary) 

1832 # Remove the old replica, which is now a primary, from the slot's 

1833 # node list 

1834 slot_nodes.remove(redirected_node) 

1835 # Override the old primary with the new one 

1836 slot_nodes[0] = redirected_node 

1837 if self.default_node == old_primary: 

1838 # Update the default node with the new primary 

1839 self.default_node = redirected_node 

1840 node_changed = True 

1841 # else: circular MOVED to current primary -> no-op 

1842 # Dispatch so listeners can run shard-pubsub reconciliation; skipped on 

1843 # the no-op branch to avoid needless walks under MOVED storms. A 

1844 # listener must not break slots-cache refresh; log and continue so a 

1845 # single buggy listener cannot starve the rest. 

1846 if node_changed: 

1847 try: 

1848 await self._event_dispatcher.dispatch_async( 

1849 AsyncAfterSlotsCacheRefreshEvent() 

1850 ) 

1851 except Exception as exc: 

1852 # Don't shadow the method parameter ``e``: ``except as`` binds 

1853 # the listener exception in the function scope and ``del``s 

1854 # the name on block exit (PEP 3134), which would also wipe 

1855 # out the original AskError/MovedError parameter. 

1856 logger.exception( 

1857 "listener raised during slots-cache refresh: %s: %s", 

1858 type(exc).__name__, 

1859 exc, 

1860 ) 

1861 

1862 def get_node_from_slot( 

1863 self, 

1864 slot: int, 

1865 read_from_replicas: bool = False, 

1866 load_balancing_strategy=None, 

1867 ) -> "ClusterNode": 

1868 if read_from_replicas is True and load_balancing_strategy is None: 

1869 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN 

1870 

1871 try: 

1872 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: 

1873 # get the server index using the strategy defined in load_balancing_strategy 

1874 primary_name = self.slots_cache[slot][0].name 

1875 node_idx = self.read_load_balancer.get_server_index( 

1876 primary_name, len(self.slots_cache[slot]), load_balancing_strategy 

1877 ) 

1878 return self.slots_cache[slot][node_idx] 

1879 return self.slots_cache[slot][0] 

1880 except (IndexError, TypeError): 

1881 raise SlotNotCoveredError( 

1882 f'Slot "{slot}" not covered by the cluster. ' 

1883 f'"require_full_coverage={self.require_full_coverage}"' 

1884 ) 

1885 

1886 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: 

1887 return [ 

1888 node 

1889 for node in self.nodes_cache.values() 

1890 if node.server_type == server_type 

1891 ] 

1892 

1893 async def initialize(self) -> None: 

1894 self.read_load_balancer.reset() 

1895 tmp_nodes_cache: Dict[str, "ClusterNode"] = {} 

1896 tmp_slots: Dict[int, List["ClusterNode"]] = {} 

1897 disagreements = [] 

1898 startup_nodes_reachable = False 

1899 fully_covered = False 

1900 exception = None 

1901 epoch = self._epoch 

1902 

1903 async with self._initialize_lock: 

1904 if self._epoch != epoch: 

1905 # another initialize call has already reinitialized the 

1906 # nodes since we started waiting for the lock; 

1907 # we don't need to do it again. 

1908 return 

1909 

1910 # Convert to tuple to prevent RuntimeError if self.startup_nodes 

1911 # is modified during iteration 

1912 for startup_node in tuple(self.startup_nodes.values()): 

1913 try: 

1914 # Make sure cluster mode is enabled on this node 

1915 try: 

1916 self._event_dispatcher.dispatch( 

1917 AfterAsyncClusterInstantiationEvent( 

1918 self.nodes_cache, 

1919 self.connection_kwargs.get("credential_provider", None), 

1920 ) 

1921 ) 

1922 cluster_slots = await startup_node.execute_command( 

1923 "CLUSTER SLOTS" 

1924 ) 

1925 except ResponseError: 

1926 raise RedisClusterException( 

1927 "Cluster mode is not enabled on this node" 

1928 ) 

1929 startup_nodes_reachable = True 

1930 except Exception as e: 

1931 # Try the next startup node. 

1932 # The exception is saved and raised only if we have no more nodes. 

1933 exception = e 

1934 continue 

1935 

1936 # CLUSTER SLOTS command results in the following output: 

1937 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] 

1938 # where each node contains the following list: [IP, port, node_id] 

1939 # Therefore, cluster_slots[0][2][0] will be the IP address of the 

1940 # primary node of the first slot section. 

1941 # If there's only one server in the cluster, its ``host`` is '' 

1942 # Fix it to the host in startup_nodes 

1943 if ( 

1944 len(cluster_slots) == 1 

1945 and not cluster_slots[0][2][0] 

1946 and len(self.startup_nodes) == 1 

1947 ): 

1948 cluster_slots[0][2][0] = startup_node.host 

1949 

1950 for slot in cluster_slots: 

1951 for i in range(2, len(slot)): 

1952 slot[i] = [str_if_bytes(val) for val in slot[i]] 

1953 primary_node = slot[2] 

1954 host = primary_node[0] 

1955 if host == "": 

1956 host = startup_node.host 

1957 port = int(primary_node[1]) 

1958 host, port = self.remap_host_port(host, port) 

1959 

1960 nodes_for_slot = [] 

1961 

1962 target_node = tmp_nodes_cache.get(get_node_name(host, port)) 

1963 if not target_node: 

1964 target_node = ClusterNode( 

1965 host, port, PRIMARY, **self.connection_kwargs 

1966 ) 

1967 # add this node to the nodes cache 

1968 tmp_nodes_cache[target_node.name] = target_node 

1969 nodes_for_slot.append(target_node) 

1970 

1971 replica_nodes = slot[3:] 

1972 for replica_node in replica_nodes: 

1973 host = replica_node[0] 

1974 port = replica_node[1] 

1975 host, port = self.remap_host_port(host, port) 

1976 

1977 target_replica_node = tmp_nodes_cache.get( 

1978 get_node_name(host, port) 

1979 ) 

1980 if not target_replica_node: 

1981 target_replica_node = ClusterNode( 

1982 host, port, REPLICA, **self.connection_kwargs 

1983 ) 

1984 # add this node to the nodes cache 

1985 tmp_nodes_cache[target_replica_node.name] = target_replica_node 

1986 nodes_for_slot.append(target_replica_node) 

1987 

1988 for i in range(int(slot[0]), int(slot[1]) + 1): 

1989 if i not in tmp_slots: 

1990 tmp_slots[i] = nodes_for_slot 

1991 else: 

1992 # Validate that 2 nodes want to use the same slot cache 

1993 # setup 

1994 tmp_slot = tmp_slots[i][0] 

1995 if tmp_slot.name != target_node.name: 

1996 disagreements.append( 

1997 f"{tmp_slot.name} vs {target_node.name} on slot: {i}" 

1998 ) 

1999 

2000 if len(disagreements) > 5: 

2001 raise RedisClusterException( 

2002 f"startup_nodes could not agree on a valid " 

2003 f"slots cache: {', '.join(disagreements)}" 

2004 ) 

2005 

2006 # Validate if all slots are covered or if we should try next startup node 

2007 fully_covered = True 

2008 for i in range(REDIS_CLUSTER_HASH_SLOTS): 

2009 if i not in tmp_slots: 

2010 fully_covered = False 

2011 break 

2012 if fully_covered: 

2013 break 

2014 

2015 if not startup_nodes_reachable: 

2016 raise RedisClusterException( 

2017 f"Redis Cluster cannot be connected. Please provide at least " 

2018 f"one reachable node: {str(exception)}" 

2019 ) from exception 

2020 

2021 # Check if the slots are not fully covered 

2022 if not fully_covered and self.require_full_coverage: 

2023 # Despite the requirement that the slots be covered, there 

2024 # isn't a full coverage 

2025 raise RedisClusterException( 

2026 f"All slots are not covered after query all startup_nodes. " 

2027 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " 

2028 f"covered..." 

2029 ) 

2030 

2031 # Set the tmp variables to the real variables 

2032 self.slots_cache = tmp_slots 

2033 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) 

2034 

2035 if self._dynamic_startup_nodes: 

2036 # Populate the startup nodes with all discovered nodes 

2037 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) 

2038 

2039 # Set the default node 

2040 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] 

2041 self._epoch += 1 

2042 # Dispatch so listeners (e.g. ClusterPubSub) can reconcile per-node 

2043 # state after slot ownership may have changed. A listener must not 

2044 # break slots-cache refresh; log and continue so a single buggy 

2045 # listener cannot starve the rest. 

2046 try: 

2047 await self._event_dispatcher.dispatch_async( 

2048 AsyncAfterSlotsCacheRefreshEvent() 

2049 ) 

2050 except Exception as e: 

2051 logger.exception( 

2052 "listener raised during slots-cache refresh: %s: %s", 

2053 type(e).__name__, 

2054 e, 

2055 ) 

2056 

2057 async def aclose(self, attr: str = "nodes_cache") -> None: 

2058 self.default_node = None 

2059 await asyncio.gather( 

2060 *( 

2061 asyncio.create_task(node.disconnect()) 

2062 for node in getattr(self, attr).values() 

2063 ) 

2064 ) 

2065 

2066 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: 

2067 """ 

2068 Remap the host and port returned from the cluster to a different 

2069 internal value. Useful if the client is not connecting directly 

2070 to the cluster. 

2071 """ 

2072 if self.address_remap: 

2073 return self.address_remap((host, port)) 

2074 return host, port 

2075 

2076 

2077class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

2078 """ 

2079 Create a new ClusterPipeline object. 

2080 

2081 Usage:: 

2082 

2083 result = await ( 

2084 rc.pipeline() 

2085 .set("A", 1) 

2086 .get("A") 

2087 .hset("K", "F", "V") 

2088 .hgetall("K") 

2089 .mset_nonatomic({"A": 2, "B": 3}) 

2090 .get("A") 

2091 .get("B") 

2092 .delete("A", "B", "K") 

2093 .execute() 

2094 ) 

2095 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] 

2096 

2097 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which 

2098 are split across multiple nodes, you'll get multiple results for them in the array. 

2099 

2100 Retryable errors: 

2101 - :class:`~.ClusterDownError` 

2102 - :class:`~.ConnectionError` 

2103 - :class:`~.TimeoutError` 

2104 

2105 Redirection errors: 

2106 - :class:`~.TryAgainError` 

2107 - :class:`~.MovedError` 

2108 - :class:`~.AskError` 

2109 

2110 :param client: 

2111 | Existing :class:`~.RedisCluster` client 

2112 """ 

2113 

2114 __slots__ = ( 

2115 "cluster_client", 

2116 "_transaction", 

2117 "_execution_strategy", 

2118 ) 

2119 

2120 # Type discrimination marker for @overload self-type pattern 

2121 _is_async_client: Literal[True] = True 

2122 

2123 def __init__( 

2124 self, client: RedisCluster, transaction: Optional[bool] = None 

2125 ) -> None: 

2126 self.cluster_client = client 

2127 self._transaction = transaction 

2128 self._execution_strategy: ExecutionStrategy = ( 

2129 PipelineStrategy(self) 

2130 if not self._transaction 

2131 else TransactionStrategy(self) 

2132 ) 

2133 

2134 @property 

2135 def nodes_manager(self) -> "NodesManager": 

2136 """Get the nodes manager from the cluster client.""" 

2137 return self.cluster_client.nodes_manager 

2138 

2139 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

2140 """Set a custom response callback on the cluster client.""" 

2141 self.cluster_client.set_response_callback(command, callback) 

2142 

2143 async def initialize(self) -> "ClusterPipeline": 

2144 await self._execution_strategy.initialize() 

2145 return self 

2146 

2147 async def __aenter__(self) -> "ClusterPipeline": 

2148 return await self.initialize() 

2149 

2150 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: 

2151 await self.reset() 

2152 

2153 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: 

2154 return self.initialize().__await__() 

2155 

2156 def __bool__(self) -> bool: 

2157 "Pipeline instances should always evaluate to True on Python 3+" 

2158 return True 

2159 

2160 def __len__(self) -> int: 

2161 return len(self._execution_strategy) 

2162 

2163 def execute_command( 

2164 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2165 ) -> "ClusterPipeline": 

2166 """ 

2167 Append a raw command to the pipeline. 

2168 

2169 :param args: 

2170 | Raw command args 

2171 :param kwargs: 

2172 

2173 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

2174 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

2175 - Rest of the kwargs are passed to the Redis connection 

2176 """ 

2177 return self._execution_strategy.execute_command(*args, **kwargs) 

2178 

2179 async def execute( 

2180 self, raise_on_error: bool = True, allow_redirections: bool = True 

2181 ) -> List[Any]: 

2182 """ 

2183 Execute the pipeline. 

2184 

2185 It will retry the commands as specified by retries specified in :attr:`retry` 

2186 & then raise an exception. 

2187 

2188 :param raise_on_error: 

2189 | Raise the first error if there are any errors 

2190 :param allow_redirections: 

2191 | Whether to retry each failed command individually in case of redirection 

2192 errors 

2193 

2194 :raises RedisClusterException: if target_nodes is not provided & the command 

2195 can't be mapped to a slot 

2196 """ 

2197 try: 

2198 return await self._execution_strategy.execute( 

2199 raise_on_error, allow_redirections 

2200 ) 

2201 finally: 

2202 await self.reset() 

2203 

2204 def _split_command_across_slots( 

2205 self, command: str, *keys: KeyT 

2206 ) -> "ClusterPipeline": 

2207 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values(): 

2208 self.execute_command(command, *slot_keys) 

2209 

2210 return self 

2211 

2212 async def reset(self): 

2213 """ 

2214 Reset back to empty pipeline. 

2215 """ 

2216 await self._execution_strategy.reset() 

2217 

2218 def multi(self): 

2219 """ 

2220 Start a transactional block of the pipeline after WATCH commands 

2221 are issued. End the transactional block with `execute`. 

2222 """ 

2223 self._execution_strategy.multi() 

2224 

2225 async def discard(self): 

2226 """ """ 

2227 await self._execution_strategy.discard() 

2228 

2229 async def watch(self, *names): 

2230 """Watches the values at keys ``names``""" 

2231 await self._execution_strategy.watch(*names) 

2232 

2233 async def unwatch(self): 

2234 """Unwatches all previously specified keys""" 

2235 await self._execution_strategy.unwatch() 

2236 

2237 async def unlink(self, *names): 

2238 await self._execution_strategy.unlink(*names) 

2239 

2240 def mset_nonatomic( 

2241 self, mapping: Mapping[AnyKeyT, EncodableT] 

2242 ) -> "ClusterPipeline": 

2243 return self._execution_strategy.mset_nonatomic(mapping) 

2244 

2245 

2246for command in PIPELINE_BLOCKED_COMMANDS: 

2247 command = command.replace(" ", "_").lower() 

2248 if command == "mset_nonatomic": 

2249 continue 

2250 

2251 setattr(ClusterPipeline, command, block_pipeline_command(command)) 

2252 

2253 

2254class PipelineCommand: 

2255 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: 

2256 self.args = args 

2257 self.kwargs = kwargs 

2258 self.position = position 

2259 self.result: Union[Any, Exception] = None 

2260 self.command_policies: Optional[CommandPolicies] = None 

2261 

2262 def __repr__(self) -> str: 

2263 return f"[{self.position}] {self.args} ({self.kwargs})" 

2264 

2265 

2266class ExecutionStrategy(ABC): 

2267 @abstractmethod 

2268 async def initialize(self) -> "ClusterPipeline": 

2269 """ 

2270 Initialize the execution strategy. 

2271 

2272 See ClusterPipeline.initialize() 

2273 """ 

2274 pass 

2275 

2276 @abstractmethod 

2277 def execute_command( 

2278 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2279 ) -> "ClusterPipeline": 

2280 """ 

2281 Append a raw command to the pipeline. 

2282 

2283 See ClusterPipeline.execute_command() 

2284 """ 

2285 pass 

2286 

2287 @abstractmethod 

2288 async def execute( 

2289 self, raise_on_error: bool = True, allow_redirections: bool = True 

2290 ) -> List[Any]: 

2291 """ 

2292 Execute the pipeline. 

2293 

2294 It will retry the commands as specified by retries specified in :attr:`retry` 

2295 & then raise an exception. 

2296 

2297 See ClusterPipeline.execute() 

2298 """ 

2299 pass 

2300 

2301 @abstractmethod 

2302 def mset_nonatomic( 

2303 self, mapping: Mapping[AnyKeyT, EncodableT] 

2304 ) -> "ClusterPipeline": 

2305 """ 

2306 Executes multiple MSET commands according to the provided slot/pairs mapping. 

2307 

2308 See ClusterPipeline.mset_nonatomic() 

2309 """ 

2310 pass 

2311 

2312 @abstractmethod 

2313 async def reset(self): 

2314 """ 

2315 Resets current execution strategy. 

2316 

2317 See: ClusterPipeline.reset() 

2318 """ 

2319 pass 

2320 

2321 @abstractmethod 

2322 def multi(self): 

2323 """ 

2324 Starts transactional context. 

2325 

2326 See: ClusterPipeline.multi() 

2327 """ 

2328 pass 

2329 

2330 @abstractmethod 

2331 async def watch(self, *names): 

2332 """ 

2333 Watch given keys. 

2334 

2335 See: ClusterPipeline.watch() 

2336 """ 

2337 pass 

2338 

2339 @abstractmethod 

2340 async def unwatch(self): 

2341 """ 

2342 Unwatches all previously specified keys 

2343 

2344 See: ClusterPipeline.unwatch() 

2345 """ 

2346 pass 

2347 

2348 @abstractmethod 

2349 async def discard(self): 

2350 pass 

2351 

2352 @abstractmethod 

2353 async def unlink(self, *names): 

2354 """ 

2355 "Unlink a key specified by ``names``" 

2356 

2357 See: ClusterPipeline.unlink() 

2358 """ 

2359 pass 

2360 

2361 @abstractmethod 

2362 def __len__(self) -> int: 

2363 pass 

2364 

2365 

2366class AbstractStrategy(ExecutionStrategy): 

2367 def __init__(self, pipe: ClusterPipeline) -> None: 

2368 self._pipe: ClusterPipeline = pipe 

2369 self._command_queue: List["PipelineCommand"] = [] 

2370 

2371 async def initialize(self) -> "ClusterPipeline": 

2372 if self._pipe.cluster_client._initialize: 

2373 await self._pipe.cluster_client.initialize() 

2374 self._command_queue = [] 

2375 return self._pipe 

2376 

2377 def execute_command( 

2378 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2379 ) -> "ClusterPipeline": 

2380 self._command_queue.append( 

2381 PipelineCommand(len(self._command_queue), *args, **kwargs) 

2382 ) 

2383 return self._pipe 

2384 

2385 def _annotate_exception(self, exception, number, command): 

2386 """ 

2387 Provides extra context to the exception prior to it being handled 

2388 """ 

2389 cmd = " ".join(map(safe_str, command)) 

2390 msg = ( 

2391 f"Command # {number} ({truncate_text(cmd)}) of pipeline " 

2392 f"caused error: {exception.args[0]}" 

2393 ) 

2394 exception.args = (msg,) + exception.args[1:] 

2395 

2396 @abstractmethod 

2397 def mset_nonatomic( 

2398 self, mapping: Mapping[AnyKeyT, EncodableT] 

2399 ) -> "ClusterPipeline": 

2400 pass 

2401 

2402 @abstractmethod 

2403 async def execute( 

2404 self, raise_on_error: bool = True, allow_redirections: bool = True 

2405 ) -> List[Any]: 

2406 pass 

2407 

2408 @abstractmethod 

2409 async def reset(self): 

2410 pass 

2411 

2412 @abstractmethod 

2413 def multi(self): 

2414 pass 

2415 

2416 @abstractmethod 

2417 async def watch(self, *names): 

2418 pass 

2419 

2420 @abstractmethod 

2421 async def unwatch(self): 

2422 pass 

2423 

2424 @abstractmethod 

2425 async def discard(self): 

2426 pass 

2427 

2428 @abstractmethod 

2429 async def unlink(self, *names): 

2430 pass 

2431 

2432 def __len__(self) -> int: 

2433 return len(self._command_queue) 

2434 

2435 

2436class PipelineStrategy(AbstractStrategy): 

2437 def __init__(self, pipe: ClusterPipeline) -> None: 

2438 super().__init__(pipe) 

2439 

2440 def mset_nonatomic( 

2441 self, mapping: Mapping[AnyKeyT, EncodableT] 

2442 ) -> "ClusterPipeline": 

2443 encoder = self._pipe.cluster_client.encoder 

2444 

2445 slots_pairs = {} 

2446 for pair in mapping.items(): 

2447 slot = key_slot(encoder.encode(pair[0])) 

2448 slots_pairs.setdefault(slot, []).extend(pair) 

2449 

2450 for pairs in slots_pairs.values(): 

2451 self.execute_command("MSET", *pairs) 

2452 

2453 return self._pipe 

2454 

2455 async def execute( 

2456 self, raise_on_error: bool = True, allow_redirections: bool = True 

2457 ) -> List[Any]: 

2458 if not self._command_queue: 

2459 return [] 

2460 

2461 try: 

2462 retry_attempts = self._pipe.cluster_client.retry.get_retries() 

2463 while True: 

2464 try: 

2465 if self._pipe.cluster_client._initialize: 

2466 await self._pipe.cluster_client.initialize() 

2467 return await self._execute( 

2468 self._pipe.cluster_client, 

2469 self._command_queue, 

2470 raise_on_error=raise_on_error, 

2471 allow_redirections=allow_redirections, 

2472 ) 

2473 

2474 except RedisCluster.ERRORS_ALLOW_RETRY as e: 

2475 if retry_attempts > 0: 

2476 # Try again with the new cluster setup. All other errors 

2477 # should be raised. 

2478 retry_attempts -= 1 

2479 await self._pipe.cluster_client.aclose() 

2480 await asyncio.sleep(0.25) 

2481 else: 

2482 # All other errors should be raised. 

2483 raise e 

2484 finally: 

2485 await self.reset() 

2486 

2487 async def _execute( 

2488 self, 

2489 client: "RedisCluster", 

2490 stack: List["PipelineCommand"], 

2491 raise_on_error: bool = True, 

2492 allow_redirections: bool = True, 

2493 ) -> List[Any]: 

2494 todo = [ 

2495 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception) 

2496 ] 

2497 

2498 nodes = {} 

2499 for cmd in todo: 

2500 passed_targets = cmd.kwargs.pop("target_nodes", None) 

2501 command_policies = await client._policy_resolver.resolve( 

2502 cmd.args[0].lower() 

2503 ) 

2504 

2505 if passed_targets and not client._is_node_flag(passed_targets): 

2506 target_nodes = client._parse_target_nodes(passed_targets) 

2507 

2508 if not command_policies: 

2509 command_policies = CommandPolicies() 

2510 else: 

2511 if not command_policies: 

2512 command_flag = client.command_flags.get(cmd.args[0]) 

2513 if not command_flag: 

2514 # Fallback to default policy 

2515 if not client.get_default_node(): 

2516 slot = None 

2517 else: 

2518 slot = await client._determine_slot(*cmd.args) 

2519 if slot is None: 

2520 command_policies = CommandPolicies() 

2521 else: 

2522 command_policies = CommandPolicies( 

2523 request_policy=RequestPolicy.DEFAULT_KEYED, 

2524 response_policy=ResponsePolicy.DEFAULT_KEYED, 

2525 ) 

2526 else: 

2527 if command_flag in client._command_flags_mapping: 

2528 command_policies = CommandPolicies( 

2529 request_policy=client._command_flags_mapping[ 

2530 command_flag 

2531 ] 

2532 ) 

2533 else: 

2534 command_policies = CommandPolicies() 

2535 

2536 target_nodes = await client._determine_nodes( 

2537 *cmd.args, 

2538 request_policy=command_policies.request_policy, 

2539 node_flag=passed_targets, 

2540 ) 

2541 if not target_nodes: 

2542 raise RedisClusterException( 

2543 f"No targets were found to execute {cmd.args} command on" 

2544 ) 

2545 cmd.command_policies = command_policies 

2546 if len(target_nodes) > 1: 

2547 raise RedisClusterException(f"Too many targets for command {cmd.args}") 

2548 node = target_nodes[0] 

2549 if node.name not in nodes: 

2550 nodes[node.name] = (node, []) 

2551 nodes[node.name][1].append(cmd) 

2552 

2553 # Start timing for observability 

2554 start_time = time.monotonic() 

2555 

2556 errors = await asyncio.gather( 

2557 *( 

2558 asyncio.create_task(node[0].execute_pipeline(node[1])) 

2559 for node in nodes.values() 

2560 ) 

2561 ) 

2562 

2563 # Record operation duration for each node 

2564 for node_name, (node, commands) in nodes.items(): 

2565 # Find the first error in this node's commands, if any 

2566 node_error = None 

2567 for cmd in commands: 

2568 if isinstance(cmd.result, Exception): 

2569 node_error = cmd.result 

2570 break 

2571 

2572 db = node.connection_kwargs.get("db", 0) 

2573 await record_operation_duration( 

2574 command_name="PIPELINE", 

2575 duration_seconds=time.monotonic() - start_time, 

2576 server_address=node.host, 

2577 server_port=node.port, 

2578 db_namespace=str(db) if db is not None else None, 

2579 error=node_error, 

2580 ) 

2581 

2582 if any(errors): 

2583 if allow_redirections: 

2584 # send each errored command individually 

2585 for cmd in todo: 

2586 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): 

2587 try: 

2588 cmd.result = client._policies_callback_mapping[ 

2589 cmd.command_policies.response_policy 

2590 ](await client.execute_command(*cmd.args, **cmd.kwargs)) 

2591 except Exception as e: 

2592 cmd.result = e 

2593 

2594 if raise_on_error: 

2595 for cmd in todo: 

2596 result = cmd.result 

2597 if isinstance(result, Exception): 

2598 command = " ".join(map(safe_str, cmd.args)) 

2599 msg = ( 

2600 f"Command # {cmd.position + 1} " 

2601 f"({truncate_text(command)}) " 

2602 f"of pipeline caused error: {result.args}" 

2603 ) 

2604 result.args = (msg,) + result.args[1:] 

2605 raise result 

2606 

2607 default_cluster_node = client.get_default_node() 

2608 

2609 # Check whether the default node was used. In some cases, 

2610 # 'client.get_default_node()' may return None. The check below 

2611 # prevents a potential AttributeError. 

2612 if default_cluster_node is not None: 

2613 default_node = nodes.get(default_cluster_node.name) 

2614 if default_node is not None: 

2615 # This pipeline execution used the default node, check if we need 

2616 # to replace it. 

2617 # Note: when the error is raised we'll reset the default node in the 

2618 # caller function. 

2619 for cmd in default_node[1]: 

2620 # Check if it has a command that failed with a relevant 

2621 # exception 

2622 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY: 

2623 client.replace_default_node() 

2624 break 

2625 

2626 return [cmd.result for cmd in stack] 

2627 

2628 async def reset(self): 

2629 """ 

2630 Reset back to empty pipeline. 

2631 """ 

2632 self._command_queue = [] 

2633 

2634 def multi(self): 

2635 raise RedisClusterException( 

2636 "method multi() is not supported outside of transactional context" 

2637 ) 

2638 

2639 async def watch(self, *names): 

2640 raise RedisClusterException( 

2641 "method watch() is not supported outside of transactional context" 

2642 ) 

2643 

2644 async def unwatch(self): 

2645 raise RedisClusterException( 

2646 "method unwatch() is not supported outside of transactional context" 

2647 ) 

2648 

2649 async def discard(self): 

2650 raise RedisClusterException( 

2651 "method discard() is not supported outside of transactional context" 

2652 ) 

2653 

2654 async def unlink(self, *names): 

2655 if len(names) != 1: 

2656 raise RedisClusterException( 

2657 "unlinking multiple keys is not implemented in pipeline command" 

2658 ) 

2659 

2660 return self.execute_command("UNLINK", names[0]) 

2661 

2662 

2663class TransactionStrategy(AbstractStrategy): 

2664 NO_SLOTS_COMMANDS = {"UNWATCH"} 

2665 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} 

2666 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} 

2667 SLOT_REDIRECT_ERRORS = (AskError, MovedError) 

2668 CONNECTION_ERRORS = ( 

2669 ConnectionError, 

2670 OSError, 

2671 ClusterDownError, 

2672 SlotNotCoveredError, 

2673 ) 

2674 

2675 def __init__(self, pipe: ClusterPipeline) -> None: 

2676 super().__init__(pipe) 

2677 self._explicit_transaction = False 

2678 self._watching = False 

2679 self._pipeline_slots: Set[int] = set() 

2680 self._transaction_node: Optional[ClusterNode] = None 

2681 self._transaction_connection: Optional[Connection] = None 

2682 self._executing = False 

2683 self._retry = copy(self._pipe.cluster_client.retry) 

2684 self._retry.update_supported_errors( 

2685 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS 

2686 ) 

2687 

2688 def _get_client_and_connection_for_transaction( 

2689 self, 

2690 ) -> Tuple[ClusterNode, Connection]: 

2691 """ 

2692 Find a connection for a pipeline transaction. 

2693 

2694 For running an atomic transaction, watch keys ensure that contents have not been 

2695 altered as long as the watch commands for those keys were sent over the same 

2696 connection. So once we start watching a key, we fetch a connection to the 

2697 node that owns that slot and reuse it. 

2698 """ 

2699 if not self._pipeline_slots: 

2700 raise RedisClusterException( 

2701 "At least a command with a key is needed to identify a node" 

2702 ) 

2703 

2704 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot( 

2705 list(self._pipeline_slots)[0], False 

2706 ) 

2707 self._transaction_node = node 

2708 

2709 if not self._transaction_connection: 

2710 connection: Connection = self._transaction_node.acquire_connection() 

2711 self._transaction_connection = connection 

2712 

2713 return self._transaction_node, self._transaction_connection 

2714 

2715 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any": 

2716 # Given the limitation of ClusterPipeline sync API, we have to run it in thread. 

2717 response = None 

2718 error = None 

2719 

2720 def runner(): 

2721 nonlocal response 

2722 nonlocal error 

2723 try: 

2724 response = asyncio.run(self._execute_command(*args, **kwargs)) 

2725 except Exception as e: 

2726 error = e 

2727 

2728 thread = threading.Thread(target=runner) 

2729 thread.start() 

2730 thread.join() 

2731 

2732 if error: 

2733 raise error 

2734 

2735 return response 

2736 

2737 async def _execute_command( 

2738 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2739 ) -> Any: 

2740 if self._pipe.cluster_client._initialize: 

2741 await self._pipe.cluster_client.initialize() 

2742 

2743 slot_number: Optional[int] = None 

2744 if args[0] not in self.NO_SLOTS_COMMANDS: 

2745 slot_number = await self._pipe.cluster_client._determine_slot(*args) 

2746 

2747 if ( 

2748 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS 

2749 ) and not self._explicit_transaction: 

2750 if args[0] == "WATCH": 

2751 self._validate_watch() 

2752 

2753 if slot_number is not None: 

2754 if self._pipeline_slots and slot_number not in self._pipeline_slots: 

2755 raise CrossSlotTransactionError( 

2756 "Cannot watch or send commands on different slots" 

2757 ) 

2758 

2759 self._pipeline_slots.add(slot_number) 

2760 elif args[0] not in self.NO_SLOTS_COMMANDS: 

2761 raise RedisClusterException( 

2762 f"Cannot identify slot number for command: {args[0]}," 

2763 "it cannot be triggered in a transaction" 

2764 ) 

2765 

2766 return self._immediate_execute_command(*args, **kwargs) 

2767 else: 

2768 if slot_number is not None: 

2769 self._pipeline_slots.add(slot_number) 

2770 

2771 return super().execute_command(*args, **kwargs) 

2772 

2773 def _validate_watch(self): 

2774 if self._explicit_transaction: 

2775 raise RedisError("Cannot issue a WATCH after a MULTI") 

2776 

2777 self._watching = True 

2778 

2779 async def _immediate_execute_command(self, *args, **options): 

2780 return await self._retry.call_with_retry( 

2781 lambda: self._get_connection_and_send_command(*args, **options), 

2782 self._reinitialize_on_error, 

2783 with_failure_count=True, 

2784 ) 

2785 

2786 async def _get_connection_and_send_command(self, *args, **options): 

2787 redis_node, connection = self._get_client_and_connection_for_transaction() 

2788 # Only disconnect if not watching - disconnecting would lose WATCH state 

2789 if not self._watching: 

2790 await redis_node.disconnect_if_needed(connection) 

2791 

2792 # Start timing for observability 

2793 start_time = time.monotonic() 

2794 

2795 try: 

2796 response = await self._send_command_parse_response( 

2797 connection, redis_node, args[0], *args, **options 

2798 ) 

2799 

2800 await record_operation_duration( 

2801 command_name=args[0], 

2802 duration_seconds=time.monotonic() - start_time, 

2803 server_address=connection.host, 

2804 server_port=connection.port, 

2805 db_namespace=str(connection.db), 

2806 ) 

2807 

2808 return response 

2809 except Exception as e: 

2810 e.connection = connection 

2811 await record_operation_duration( 

2812 command_name=args[0], 

2813 duration_seconds=time.monotonic() - start_time, 

2814 server_address=connection.host, 

2815 server_port=connection.port, 

2816 db_namespace=str(connection.db), 

2817 error=e, 

2818 ) 

2819 raise 

2820 

2821 async def _send_command_parse_response( 

2822 self, 

2823 connection: Connection, 

2824 redis_node: ClusterNode, 

2825 command_name, 

2826 *args, 

2827 **options, 

2828 ): 

2829 """ 

2830 Send a command and parse the response 

2831 """ 

2832 

2833 await connection.send_command(*args) 

2834 output = await redis_node.parse_response(connection, command_name, **options) 

2835 

2836 if command_name in self.UNWATCH_COMMANDS: 

2837 self._watching = False 

2838 return output 

2839 

2840 async def _reinitialize_on_error(self, error, failure_count): 

2841 if hasattr(error, "connection"): 

2842 await record_error_count( 

2843 server_address=error.connection.host, 

2844 server_port=error.connection.port, 

2845 network_peer_address=error.connection.host, 

2846 network_peer_port=error.connection.port, 

2847 error_type=error, 

2848 retry_attempts=failure_count, 

2849 is_internal=True, 

2850 ) 

2851 

2852 if self._watching: 

2853 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: 

2854 raise WatchError("Slot rebalancing occurred while watching keys") 

2855 

2856 if ( 

2857 type(error) in self.SLOT_REDIRECT_ERRORS 

2858 or type(error) in self.CONNECTION_ERRORS 

2859 ): 

2860 if self._transaction_connection and self._transaction_node: 

2861 # Disconnect and release back to pool 

2862 await self._transaction_connection.disconnect() 

2863 self._transaction_node.release(self._transaction_connection) 

2864 self._transaction_connection = None 

2865 

2866 self._pipe.cluster_client.reinitialize_counter += 1 

2867 if ( 

2868 self._pipe.cluster_client.reinitialize_steps 

2869 and self._pipe.cluster_client.reinitialize_counter 

2870 % self._pipe.cluster_client.reinitialize_steps 

2871 == 0 

2872 ): 

2873 await self._pipe.cluster_client.nodes_manager.initialize() 

2874 self.reinitialize_counter = 0 

2875 else: 

2876 if isinstance(error, AskError): 

2877 await self._pipe.cluster_client.nodes_manager.move_slot(error) 

2878 

2879 self._executing = False 

2880 

2881 async def _raise_first_error(self, responses, stack, start_time): 

2882 """ 

2883 Raise the first exception on the stack 

2884 """ 

2885 for r, cmd in zip(responses, stack): 

2886 if isinstance(r, Exception): 

2887 self._annotate_exception(r, cmd.position + 1, cmd.args) 

2888 

2889 await record_operation_duration( 

2890 command_name="TRANSACTION", 

2891 duration_seconds=time.monotonic() - start_time, 

2892 server_address=self._transaction_connection.host, 

2893 server_port=self._transaction_connection.port, 

2894 db_namespace=str(self._transaction_connection.db), 

2895 error=r, 

2896 ) 

2897 

2898 raise r 

2899 

2900 def mset_nonatomic( 

2901 self, mapping: Mapping[AnyKeyT, EncodableT] 

2902 ) -> "ClusterPipeline": 

2903 raise NotImplementedError("Method is not supported in transactional context.") 

2904 

2905 async def execute( 

2906 self, raise_on_error: bool = True, allow_redirections: bool = True 

2907 ) -> List[Any]: 

2908 stack = self._command_queue 

2909 if not stack and (not self._watching or not self._pipeline_slots): 

2910 return [] 

2911 

2912 return await self._execute_transaction_with_retries(stack, raise_on_error) 

2913 

2914 async def _execute_transaction_with_retries( 

2915 self, stack: List["PipelineCommand"], raise_on_error: bool 

2916 ): 

2917 return await self._retry.call_with_retry( 

2918 lambda: self._execute_transaction(stack, raise_on_error), 

2919 lambda error, failure_count: self._reinitialize_on_error( 

2920 error, failure_count 

2921 ), 

2922 with_failure_count=True, 

2923 ) 

2924 

2925 async def _execute_transaction( 

2926 self, stack: List["PipelineCommand"], raise_on_error: bool 

2927 ): 

2928 if len(self._pipeline_slots) > 1: 

2929 raise CrossSlotTransactionError( 

2930 "All keys involved in a cluster transaction must map to the same slot" 

2931 ) 

2932 

2933 self._executing = True 

2934 

2935 redis_node, connection = self._get_client_and_connection_for_transaction() 

2936 # Only disconnect if not watching - disconnecting would lose WATCH state 

2937 if not self._watching: 

2938 await redis_node.disconnect_if_needed(connection) 

2939 

2940 stack = chain( 

2941 [PipelineCommand(0, "MULTI")], 

2942 stack, 

2943 [PipelineCommand(0, "EXEC")], 

2944 ) 

2945 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs] 

2946 packed_commands = connection.pack_commands(commands) 

2947 

2948 # Start timing for observability 

2949 start_time = time.monotonic() 

2950 

2951 await connection.send_packed_command(packed_commands) 

2952 errors = [] 

2953 

2954 # parse off the response for MULTI 

2955 # NOTE: we need to handle ResponseErrors here and continue 

2956 # so that we read all the additional command messages from 

2957 # the socket 

2958 try: 

2959 await redis_node.parse_response(connection, "MULTI") 

2960 except ResponseError as e: 

2961 self._annotate_exception(e, 0, "MULTI") 

2962 errors.append(e) 

2963 except self.CONNECTION_ERRORS as cluster_error: 

2964 self._annotate_exception(cluster_error, 0, "MULTI") 

2965 cluster_error.connection = connection 

2966 raise 

2967 

2968 # and all the other commands 

2969 for i, command in enumerate(self._command_queue): 

2970 if EMPTY_RESPONSE in command.kwargs: 

2971 errors.append((i, command.kwargs[EMPTY_RESPONSE])) 

2972 else: 

2973 try: 

2974 _ = await redis_node.parse_response(connection, "_") 

2975 except self.SLOT_REDIRECT_ERRORS as slot_error: 

2976 self._annotate_exception(slot_error, i + 1, command.args) 

2977 errors.append(slot_error) 

2978 except self.CONNECTION_ERRORS as cluster_error: 

2979 self._annotate_exception(cluster_error, i + 1, command.args) 

2980 cluster_error.connection = connection 

2981 raise 

2982 except ResponseError as e: 

2983 self._annotate_exception(e, i + 1, command.args) 

2984 errors.append(e) 

2985 

2986 response = None 

2987 # parse the EXEC. 

2988 try: 

2989 response = await redis_node.parse_response(connection, "EXEC") 

2990 except ExecAbortError: 

2991 if errors: 

2992 raise errors[0] 

2993 raise 

2994 

2995 self._executing = False 

2996 

2997 # EXEC clears any watched keys 

2998 self._watching = False 

2999 

3000 if response is None: 

3001 raise WatchError("Watched variable changed.") 

3002 

3003 # put any parse errors into the response 

3004 for i, e in errors: 

3005 response.insert(i, e) 

3006 

3007 if len(response) != len(self._command_queue): 

3008 raise InvalidPipelineStack( 

3009 "Unexpected response length for cluster pipeline EXEC." 

3010 " Command stack was {} but response had length {}".format( 

3011 [c.args[0] for c in self._command_queue], len(response) 

3012 ) 

3013 ) 

3014 

3015 # find any errors in the response and raise if necessary 

3016 if raise_on_error or len(errors) > 0: 

3017 await self._raise_first_error( 

3018 response, 

3019 self._command_queue, 

3020 start_time, 

3021 ) 

3022 

3023 # We have to run response callbacks manually 

3024 data = [] 

3025 for r, cmd in zip(response, self._command_queue): 

3026 if not isinstance(r, Exception): 

3027 command_name = cmd.args[0] 

3028 if command_name in self._pipe.cluster_client.response_callbacks: 

3029 r = self._pipe.cluster_client.response_callbacks[command_name]( 

3030 r, **cmd.kwargs 

3031 ) 

3032 data.append(r) 

3033 

3034 await record_operation_duration( 

3035 command_name="TRANSACTION", 

3036 duration_seconds=time.monotonic() - start_time, 

3037 server_address=connection.host, 

3038 server_port=connection.port, 

3039 db_namespace=str(connection.db), 

3040 ) 

3041 

3042 return data 

3043 

3044 async def reset(self): 

3045 self._command_queue = [] 

3046 

3047 # make sure to reset the connection state in the event that we were 

3048 # watching something 

3049 if self._transaction_connection: 

3050 try: 

3051 if self._watching: 

3052 # call this manually since our unwatch or 

3053 # immediate_execute_command methods can call reset() 

3054 await self._transaction_connection.send_command("UNWATCH") 

3055 await self._transaction_connection.read_response() 

3056 # we can safely return the connection to the pool here since we're 

3057 # sure we're no longer WATCHing anything 

3058 self._transaction_node.release(self._transaction_connection) 

3059 self._transaction_connection = None 

3060 except self.CONNECTION_ERRORS: 

3061 # disconnect will also remove any previous WATCHes 

3062 if self._transaction_connection and self._transaction_node: 

3063 await self._transaction_connection.disconnect() 

3064 self._transaction_node.release(self._transaction_connection) 

3065 self._transaction_connection = None 

3066 

3067 # clean up the other instance attributes 

3068 self._transaction_node = None 

3069 self._watching = False 

3070 self._explicit_transaction = False 

3071 self._pipeline_slots = set() 

3072 self._executing = False 

3073 

3074 def multi(self): 

3075 if self._explicit_transaction: 

3076 raise RedisError("Cannot issue nested calls to MULTI") 

3077 if self._command_queue: 

3078 raise RedisError( 

3079 "Commands without an initial WATCH have already been issued" 

3080 ) 

3081 self._explicit_transaction = True 

3082 

3083 async def watch(self, *names): 

3084 if self._explicit_transaction: 

3085 raise RedisError("Cannot issue a WATCH after a MULTI") 

3086 

3087 return await self.execute_command("WATCH", *names) 

3088 

3089 async def unwatch(self): 

3090 if self._watching: 

3091 return await self.execute_command("UNWATCH") 

3092 

3093 return True 

3094 

3095 async def discard(self): 

3096 await self.reset() 

3097 

3098 async def unlink(self, *names): 

3099 return self.execute_command("UNLINK", *names) 

3100 

3101 

3102class _ClusterNodePoolAdapter(ConnectionPoolInterface): 

3103 """Thin adapter exposing the :class:`ConnectionPoolInterface` that 

3104 :class:`PubSub` requires, backed by a :class:`ClusterNode`'s own 

3105 connection pool. 

3106 

3107 Connections are acquired from the node via 

3108 :meth:`ClusterNode.acquire_connection` and returned via 

3109 :meth:`ClusterNode.release`. :meth:`PubSub.aclose` already 

3110 disconnects the connection *before* calling :meth:`release`, so the 

3111 connection is returned to the node's free-queue in a disconnected 

3112 state — guaranteeing that a subscribed socket is never silently 

3113 reused for regular commands. 

3114 

3115 Methods that do not apply to this adapter (the underlying node's 

3116 lifecycle is managed by the cluster, not by individual PubSub 

3117 instances) are implemented as no-ops so the adapter remains a valid 

3118 :class:`ConnectionPoolInterface`. 

3119 """ 

3120 

3121 def __init__(self, node: "ClusterNode") -> None: 

3122 self._node = node 

3123 self.connection_kwargs = node.connection_kwargs 

3124 

3125 # -- methods used by PubSub ------------------------------------------------ 

3126 

3127 def get_encoder(self) -> Encoder: 

3128 return self._node.get_encoder() 

3129 

3130 async def get_connection( 

3131 self, command_name: Optional[str] = None, *keys: Any, **options: Any 

3132 ) -> AbstractConnection: 

3133 connection = self._node.acquire_connection() 

3134 try: 

3135 await connection.connect() 

3136 except BaseException: 

3137 # connect() may fail mid-handshake (e.g. after the TCP socket 

3138 # is established but before AUTH/HELLO completes) leaving the 

3139 # connection in a partially-connected state. Disconnect before 

3140 # returning it to the node's free queue so it is not reused. 

3141 await connection.disconnect() 

3142 self._node.release(connection) 

3143 raise 

3144 return connection 

3145 

3146 async def release(self, connection: AbstractConnection) -> None: 

3147 # PubSub.aclose() disconnects the connection before calling 

3148 # release(), so it is safe to put it back in the node's free 

3149 # queue – it will reconnect lazily on next use. 

3150 self._node.release(connection) 

3151 

3152 # -- no-op stubs for the rest of ConnectionPoolInterface ------------------- 

3153 # The node's connections are shared with regular cluster traffic and its 

3154 # lifecycle is managed by RedisCluster / NodesManager, so the adapter must 

3155 # not reset, disconnect, retry-configure or re-auth them on behalf of a 

3156 # single PubSub instance. 

3157 

3158 def get_protocol(self): 

3159 return self.connection_kwargs.get("protocol", None) 

3160 

3161 def reset(self) -> None: 

3162 pass 

3163 

3164 async def disconnect(self, inuse_connections: bool = True) -> None: 

3165 pass 

3166 

3167 async def aclose(self) -> None: 

3168 pass 

3169 

3170 def set_retry(self, retry: "Retry") -> None: 

3171 pass 

3172 

3173 async def re_auth_callback(self, token: TokenInterface) -> None: 

3174 pass 

3175 

3176 def get_connection_count(self) -> List[Tuple[int, dict]]: 

3177 return [] 

3178 

3179 

3180def _unregister_slots_cache_listener( 

3181 dispatcher_ref: "weakref.ref[EventDispatcher]", 

3182 listener: AsyncEventListenerInterface, 

3183 event_type: Type[object], 

3184) -> None: 

3185 # Module-level finalizer callback. Kept free of strong references to the 

3186 # owning ClusterPubSub so attaching it via weakref.finalize does not 

3187 # extend the pubsub's lifetime. 

3188 dispatcher = dispatcher_ref() 

3189 if dispatcher is not None: 

3190 dispatcher.unregister_listeners({event_type: [listener]}) 

3191 

3192 

3193class ClusterPubSubSlotsCacheListener(AsyncEventListenerInterface): 

3194 """ 

3195 Async listener that forwards AsyncAfterSlotsCacheRefreshEvent to a 

3196 ClusterPubSub. 

3197 

3198 Holds a weak reference to the pubsub so it does not keep the instance 

3199 alive. Deterministic cleanup of the dispatcher's strong reference to this 

3200 listener is performed by a ``weakref.finalize`` attached to the owning 

3201 ClusterPubSub in ``ClusterPubSub.__init__``. 

3202 """ 

3203 

3204 def __init__(self, pubsub: "ClusterPubSub") -> None: 

3205 self._pubsub_ref: "weakref.ref[ClusterPubSub]" = weakref.ref(pubsub) 

3206 

3207 async def listen(self, event: object) -> None: 

3208 pubsub = self._pubsub_ref() 

3209 if pubsub is None: 

3210 # Race window between pubsub GC and the finalizer running; safe 

3211 # no-op, finalizer will remove this listener shortly. 

3212 return 

3213 try: 

3214 await pubsub.on_slots_changed() 

3215 except Exception as e: 

3216 # Listeners must not break slots-cache refresh; log and continue so 

3217 # a single buggy pubsub cannot starve the rest. 

3218 logger.exception( 

3219 "pubsub %r raised during slots-cache change: %s: %s", 

3220 pubsub, 

3221 type(e).__name__, 

3222 e, 

3223 ) 

3224 

3225 

3226class ClusterPubSub(PubSub): 

3227 """ 

3228 Async cluster implementation for pub/sub. 

3229 

3230 IMPORTANT: before using ClusterPubSub, read about the known limitations 

3231 with pubsub in Cluster mode and learn how to workaround them: 

3232 https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations 

3233 """ 

3234 

3235 def __init__( 

3236 self, 

3237 redis_cluster: "RedisCluster", 

3238 node: Optional["ClusterNode"] = None, 

3239 host: Optional[str] = None, 

3240 port: Optional[int] = None, 

3241 push_handler_func: Optional[Callable] = None, 

3242 event_dispatcher: Optional[EventDispatcher] = None, 

3243 **kwargs: Any, 

3244 ) -> None: 

3245 """ 

3246 When a pubsub instance is created without specifying a node, a single 

3247 node will be transparently chosen for the pubsub connection on the 

3248 first command execution. The node will be determined by: 

3249 1. Hashing the channel name in the request to find its keyslot 

3250 2. Selecting a node that handles the keyslot: If read_from_replicas is 

3251 set to true or load_balancing_strategy is set, a replica can be selected. 

3252 

3253 :param redis_cluster: RedisCluster instance 

3254 :param node: ClusterNode to connect to 

3255 :param host: Host of the node to connect to 

3256 :param port: Port of the node to connect to 

3257 :param push_handler_func: Optional push handler function 

3258 :param event_dispatcher: Optional event dispatcher 

3259 :param kwargs: Additional keyword arguments 

3260 """ 

3261 self.node = None 

3262 self.set_pubsub_node(redis_cluster, node, host, port) 

3263 

3264 # Borrow the node's own connection pool via an adapter rather than 

3265 # creating a second, detached ConnectionPool for pubsub. 

3266 if self.node is not None: 

3267 connection_pool = _ClusterNodePoolAdapter(self.node) 

3268 else: 

3269 connection_pool = None 

3270 

3271 self.cluster = redis_cluster 

3272 self.node_pubsub_mapping: Dict[str, PubSub] = {} 

3273 # Reverse index: shard channel (normalized) -> owning node.name. Used to 

3274 # route sunsubscribe calls and reconcile subscriptions after slot 

3275 # migration / failover. 

3276 self._shard_channel_to_node: Dict[Any, str] = {} 

3277 # Dedicated lock for shard-subscription bookkeeping. Distinct from 

3278 # PubSub.self._lock (which serializes wire I/O on the cluster-level 

3279 # connection used by aclose / send_command / regular subscribe) so 

3280 # that reconciliation cannot starve those unrelated coroutines 

3281 # during long per-channel migrations. 

3282 self._shard_state_lock: asyncio.Lock = asyncio.Lock() 

3283 # Background tasks created by on_slots_changed; kept to prevent GC. 

3284 self._reconcile_tasks: Set[asyncio.Task] = set() 

3285 self._pubsubs_generator = self._pubsubs_generator() 

3286 if event_dispatcher is None: 

3287 self._event_dispatcher = EventDispatcher() 

3288 else: 

3289 self._event_dispatcher = event_dispatcher 

3290 super().__init__( 

3291 connection_pool=connection_pool, 

3292 encoder=redis_cluster.encoder, 

3293 push_handler_func=push_handler_func, 

3294 event_dispatcher=self._event_dispatcher, 

3295 **kwargs, 

3296 ) 

3297 # Subscribe to slots-cache change notifications so shard subscriptions 

3298 # can be reconciled automatically after topology refreshes. 

3299 nm_dispatcher = redis_cluster.nodes_manager._event_dispatcher 

3300 self._slots_cache_listener = ClusterPubSubSlotsCacheListener(self) 

3301 nm_dispatcher.register_listeners( 

3302 {AsyncAfterSlotsCacheRefreshEvent: [self._slots_cache_listener]} 

3303 ) 

3304 # Deterministic GC-time cleanup so short-lived pubsubs do not leak 

3305 # listeners in the dispatcher when no slots-refresh event ever fires. 

3306 weakref.finalize( 

3307 self, 

3308 _unregister_slots_cache_listener, 

3309 weakref.ref(nm_dispatcher), 

3310 self._slots_cache_listener, 

3311 AsyncAfterSlotsCacheRefreshEvent, 

3312 ) 

3313 

3314 def set_pubsub_node( 

3315 self, 

3316 cluster: "RedisCluster", 

3317 node: Optional["ClusterNode"] = None, 

3318 host: Optional[str] = None, 

3319 port: Optional[int] = None, 

3320 ) -> None: 

3321 """ 

3322 The pubsub node will be set according to the passed node, host and port 

3323 When none of the node, host, or port are specified - the node is set 

3324 to None and will be determined by the keyslot of the channel in the 

3325 first command to be executed. 

3326 RedisClusterException will be thrown if the passed node does not exist 

3327 in the cluster. 

3328 If host is passed without port, or vice versa, a DataError will be 

3329 thrown. 

3330 """ 

3331 if node is not None: 

3332 # node is passed by the user 

3333 self._raise_on_invalid_node(cluster, node, node.host, node.port) 

3334 pubsub_node = node 

3335 elif host is not None and port is not None: 

3336 # host and port passed by the user 

3337 node = cluster.get_node(host=host, port=port) 

3338 self._raise_on_invalid_node(cluster, node, host, port) 

3339 pubsub_node = node 

3340 elif host is not None or port is not None: 

3341 # only one of host and port is specified 

3342 raise DataError("Specify both host and port") 

3343 else: 

3344 # nothing specified by the user 

3345 pubsub_node = None 

3346 self.node = pubsub_node 

3347 

3348 def get_pubsub_node(self) -> Optional["ClusterNode"]: 

3349 """ 

3350 Get the node that is being used as the pubsub connection. 

3351 

3352 :return: The ClusterNode being used for pubsub, or None if not yet determined 

3353 """ 

3354 return self.node 

3355 

3356 async def _resubscribe_shard_channels(self) -> None: 

3357 # A single node can own multiple slot ranges, so a batched 

3358 # ``SSUBSCRIBE`` covering every tracked channel would be rejected by 

3359 # Redis with a ``CROSSSLOT`` error. Group by hash slot and emit one 

3360 # ``SSUBSCRIBE`` per slot. 

3361 by_slot: defaultdict[int, dict] = defaultdict(dict) 

3362 for k, v in self.shard_channels.items(): 

3363 by_slot[key_slot(self.encoder.encode(k))][k] = v 

3364 for subscriptions in by_slot.values(): 

3365 await self._resubscribe(subscriptions, self.ssubscribe) 

3366 

3367 def _get_node_pubsub(self, node: "ClusterNode") -> PubSub: 

3368 """Get or create a PubSub instance for the given node.""" 

3369 try: 

3370 return self.node_pubsub_mapping[node.name] 

3371 except KeyError: 

3372 pubsub = PubSub( 

3373 connection_pool=_ClusterNodePoolAdapter(node), 

3374 encoder=self.cluster.encoder, 

3375 push_handler_func=self.push_handler_func, 

3376 event_dispatcher=self._event_dispatcher, 

3377 ) 

3378 # Replay shard subscriptions on reconnect with slot-aware grouping 

3379 # so that channels spanning multiple slots owned by this node do 

3380 # not trigger a CROSSSLOT error. 

3381 pubsub._resubscribe_shard_channels = MethodType( 

3382 ClusterPubSub._resubscribe_shard_channels, pubsub 

3383 ) 

3384 self.node_pubsub_mapping[node.name] = pubsub 

3385 return pubsub 

3386 

3387 def _find_node_name_for_pubsub(self, pubsub: PubSub) -> Optional[str]: 

3388 for name, candidate in self.node_pubsub_mapping.items(): 

3389 if candidate is pubsub: 

3390 return name 

3391 return None 

3392 

3393 async def _sharded_message_generator( 

3394 self, timeout: float = 0.0 

3395 ) -> Tuple[Optional[PubSub], Optional[Dict[str, Any]]]: 

3396 """Generate messages from shard channels across all nodes.""" 

3397 for _ in range(len(self.node_pubsub_mapping)): 

3398 pubsub = next(self._pubsubs_generator) 

3399 # Don't pass ignore_subscribe_messages here - let get_sharded_message 

3400 # handle the filtering after processing subscription state changes 

3401 message = await pubsub.get_message( 

3402 ignore_subscribe_messages=False, timeout=timeout 

3403 ) 

3404 if message is not None: 

3405 return pubsub, message 

3406 return None, None 

3407 

3408 def _pubsubs_generator(self) -> Generator[PubSub, None, None]: 

3409 """Generator that yields PubSub instances in round-robin fashion.""" 

3410 while True: 

3411 current_nodes = list(self.node_pubsub_mapping.values()) 

3412 if not current_nodes: 

3413 return # Avoid infinite loop when no subscriptions exist 

3414 yield from current_nodes 

3415 

3416 async def get_sharded_message( 

3417 self, 

3418 ignore_subscribe_messages: bool = False, 

3419 timeout: float = 0.0, 

3420 target_node: Optional["ClusterNode"] = None, 

3421 ) -> Optional[Dict[str, Any]]: 

3422 """ 

3423 Get a message from shard channels. 

3424 

3425 :param ignore_subscribe_messages: Whether to ignore subscribe messages 

3426 :param timeout: Timeout for message retrieval 

3427 :param target_node: Specific node to get message from 

3428 :return: Message dictionary or None 

3429 """ 

3430 pubsub: Optional[PubSub] 

3431 if target_node: 

3432 pubsub = self.node_pubsub_mapping.get(target_node.name) 

3433 if pubsub: 

3434 # Don't pass ignore_subscribe_messages here - let get_sharded_message 

3435 # handle the filtering after processing subscription state changes 

3436 message = await pubsub.get_message( 

3437 ignore_subscribe_messages=False, timeout=timeout 

3438 ) 

3439 else: 

3440 message = None 

3441 else: 

3442 pubsub, message = await self._sharded_message_generator(timeout=timeout) 

3443 

3444 if message is None: 

3445 return None 

3446 # Only sunsubscribe mutates cluster-level shard state; bypassing the 

3447 # lock on the data-message hot path keeps smessage delivery from 

3448 # competing with the reconciliation task for _shard_state_lock. 

3449 if str_if_bytes(message["type"]) == "sunsubscribe": 

3450 # Serialize state mutation against reinitialize_shard_subscriptions 

3451 # (background task). The blocking get_message above intentionally 

3452 # runs outside the lock so reconciliation is not stalled by long 

3453 # polls. 

3454 async with self._shard_state_lock: 

3455 if message["channel"] in self.pending_unsubscribe_shard_channels: 

3456 # User-initiated sunsubscribe: drop from cluster-level tracking. 

3457 self.pending_unsubscribe_shard_channels.remove(message["channel"]) 

3458 self.shard_channels.pop(message["channel"], None) 

3459 self._shard_channel_to_node.pop(message["channel"], None) 

3460 # Drop the per-node pubsub that delivered the confirmation once 

3461 # it no longer holds any shard subscriptions, regardless of 

3462 # whether the sunsubscribe was user-initiated or driven by 

3463 # slot-migration reconciliation (_migrate_shard_channel, which 

3464 # intentionally does not add the channel to 

3465 # pending_unsubscribe_shard_channels). This releases the 

3466 # dedicated connection that would otherwise linger. 

3467 # Identifying the receiving pubsub directly (rather than via 

3468 # the cluster's current slot map) is required after slot 

3469 # migration, where the channel's owner is no longer the node 

3470 # that received our original SSUBSCRIBE. 

3471 if pubsub is not None and not pubsub.subscribed: 

3472 name = self._find_node_name_for_pubsub(pubsub) 

3473 if name is not None: 

3474 try: 

3475 await pubsub.aclose() 

3476 except Exception: 

3477 pass 

3478 self.node_pubsub_mapping.pop(name, None) 

3479 

3480 # Only suppress subscribe/unsubscribe messages, not data messages (smessage) 

3481 if str_if_bytes(message["type"]) in ("ssubscribe", "sunsubscribe"): 

3482 if self.ignore_subscribe_messages or ignore_subscribe_messages: 

3483 return None 

3484 return message 

3485 

3486 async def ssubscribe(self, *args: Any, **kwargs: Any) -> None: 

3487 """ 

3488 Subscribe to shard channels. 

3489 

3490 :param args: Channel names 

3491 :param kwargs: Channel names with handlers 

3492 """ 

3493 if args: 

3494 args = list_or_args(args[0], args[1:]) 

3495 s_channels = dict.fromkeys(args) 

3496 s_channels.update(kwargs) 

3497 

3498 # Serialize against reinitialize_shard_subscriptions (background 

3499 # task) so the reverse index, shard_channels, and node_pubsub_mapping 

3500 # are not mutated concurrently. _migrate_shard_channel below does not 

3501 # re-acquire this lock (asyncio.Lock is non-reentrant). 

3502 async with self._shard_state_lock: 

3503 for s_channel, handler in s_channels.items(): 

3504 node = self.cluster.get_node_from_key(s_channel) 

3505 if not node: 

3506 continue 

3507 # Lazy re-route: if this channel is already tracked against a 

3508 # different node (e.g. after a slot migration), migrate it now 

3509 # so the caller's intent is applied on the current owner. 

3510 normalized_key = next(iter(self._normalize_keys({s_channel: None}))) 

3511 old_name = self._shard_channel_to_node.get(normalized_key) 

3512 if old_name and old_name != node.name: 

3513 # Match PubSub.ssubscribe() dict.update() semantics: the 

3514 # caller's newly supplied handler (including None) always 

3515 # overrides any previously registered handler. 

3516 await self._migrate_shard_channel( 

3517 normalized_key, 

3518 handler, 

3519 old_name, 

3520 node, 

3521 ) 

3522 continue 

3523 pubsub = self._get_node_pubsub(node) 

3524 if handler: 

3525 await pubsub.ssubscribe(**{s_channel: handler}) 

3526 else: 

3527 await pubsub.ssubscribe(s_channel) 

3528 self.shard_channels.update(pubsub.shard_channels) 

3529 self._shard_channel_to_node[normalized_key] = node.name 

3530 self.pending_unsubscribe_shard_channels.difference_update( 

3531 self._normalize_keys({s_channel: None}) 

3532 ) 

3533 

3534 async def sunsubscribe(self, *args: Any) -> None: 

3535 """ 

3536 Unsubscribe from shard channels. 

3537 

3538 :param args: Channel names to unsubscribe from. If empty, unsubscribe from all. 

3539 """ 

3540 if args: 

3541 args = list_or_args(args[0], args[1:]) 

3542 else: 

3543 args = list(self.shard_channels.keys()) 

3544 

3545 # Serialize against reinitialize_shard_subscriptions: the reverse 

3546 # index and node_pubsub_mapping must not change between the lookup 

3547 # and the per-node sunsubscribe call below. 

3548 async with self._shard_state_lock: 

3549 for s_channel in args: 

3550 normalized_key = next(iter(self._normalize_keys({s_channel: None}))) 

3551 # Route via the reverse index so we unsubscribe on the node 

3552 # that actually holds the subscription. After a slot migration 

3553 # the cluster's current owner may no longer be that node. 

3554 name = self._shard_channel_to_node.get(normalized_key) 

3555 if name and name in self.node_pubsub_mapping: 

3556 pubsub = self.node_pubsub_mapping[name] 

3557 else: 

3558 node = self.cluster.get_node_from_key(s_channel) 

3559 if not node or node.name not in self.node_pubsub_mapping: 

3560 continue 

3561 pubsub = self.node_pubsub_mapping[node.name] 

3562 await pubsub.sunsubscribe(s_channel) 

3563 self.pending_unsubscribe_shard_channels.update( 

3564 pubsub.pending_unsubscribe_shard_channels 

3565 ) 

3566 

3567 async def reinitialize_shard_subscriptions(self) -> None: 

3568 """ 

3569 Reconcile per-node shard subscriptions against the cluster's current 

3570 slot ownership map. For each tracked shard channel whose owning node 

3571 has changed (e.g. after CLUSTER SETSLOT / failover), sunsubscribe on 

3572 the old node's pubsub and ssubscribe on the new owner's pubsub, 

3573 preserving any registered handler. 

3574 """ 

3575 uncovered: list = [] 

3576 made_progress = False 

3577 first_migrate_error: Optional[BaseException] = None 

3578 async with self._shard_state_lock: 

3579 for channel, handler in list(self.shard_channels.items()): 

3580 try: 

3581 new_node = self.cluster.get_node_from_key(channel) 

3582 except SlotNotCoveredError: 

3583 # Slot is transiently uncovered (mid-migration / partial 

3584 # topology refresh). Defer this channel so coverable 

3585 # siblings still reconcile this pass; we surface the 

3586 # error below so the caller (and logs) know not every 

3587 # channel was reconciled. Retry happens on the next 

3588 # slots-cache change notification. 

3589 uncovered.append(channel) 

3590 continue 

3591 old_name = self._shard_channel_to_node.get(channel) 

3592 if old_name == new_node.name: 

3593 continue 

3594 try: 

3595 await self._migrate_shard_channel( 

3596 channel, handler, old_name, new_node 

3597 ) 

3598 made_progress = True 

3599 except (ConnectionError, TimeoutError, OSError) as e: 

3600 # Transient connectivity error while subscribing on the 

3601 # new owner (or unsubscribing on the old owner if its 

3602 # handler chose to re-raise). Do not abort reconciliation 

3603 # for sibling channels: _shard_channel_to_node was not 

3604 # advanced for this channel, so the next slots-cache 

3605 # change notification will retry it. 

3606 logger.warning( 

3607 "shard channel %r migration deferred: %s: %s", 

3608 channel, 

3609 type(e).__name__, 

3610 e, 

3611 ) 

3612 if first_migrate_error is None: 

3613 first_migrate_error = e 

3614 continue 

3615 # Garbage-collect per-node pubsubs that no longer hold any 

3616 # subscription so their connections are released. 

3617 for name, pubsub in list(self.node_pubsub_mapping.items()): 

3618 if not pubsub.subscribed: 

3619 try: 

3620 await pubsub.aclose() 

3621 except Exception: 

3622 pass 

3623 self.node_pubsub_mapping.pop(name, None) 

3624 if uncovered: 

3625 # Surface the uncovered channels so the caller (and observer 

3626 # notification path) knows reconciliation was incomplete. All 

3627 # coverable siblings have already been migrated above. 

3628 raise SlotNotCoveredError( 

3629 f"{len(uncovered)} shard channel(s) left unreconciled; " 

3630 f"slot(s) not covered by the cluster: {uncovered!r}" 

3631 ) 

3632 if first_migrate_error is not None and not made_progress: 

3633 # Every migration attempted in this pass failed transiently and 

3634 # nothing else made progress. Re-raise the first caught error 

3635 # (typically the root cause; later failures are often downstream 

3636 # symptoms of the same unreachable node) so the task's done- 

3637 # callback surfaces a single representative failure through the 

3638 # same logger channel used for SlotNotCoveredError. Per-channel 

3639 # WARNINGs above preserve the full forensic detail. 

3640 raise first_migrate_error 

3641 

3642 async def _migrate_shard_channel( 

3643 self, 

3644 channel: Any, 

3645 handler: Optional[Callable], 

3646 old_name: Optional[str], 

3647 new_node: "ClusterNode", 

3648 ) -> None: 

3649 # Detach from the old per-node pubsub, best-effort: the old node may 

3650 # already be unreachable during migration / failover. 

3651 if old_name and old_name in self.node_pubsub_mapping: 

3652 old_pubsub = self.node_pubsub_mapping[old_name] 

3653 try: 

3654 await old_pubsub.sunsubscribe(channel) 

3655 except (ConnectionError, TimeoutError, OSError): 

3656 # redis-py's Connection has already called ``disconnect()`` 

3657 # before raising (see Connection.read_response / 

3658 # send_packed_command with ``disconnect_on_error=True``), 

3659 # so ``old_pubsub``'s dedicated socket is gone. Two cases: 

3660 # 

3661 # 1. The old node is no longer in the cluster topology 

3662 # (e.g. removed by failover / topology refresh): no 

3663 # reconnect target exists, so ``old_pubsub.subscribed`` 

3664 # would stay True forever and the end-of-pass GC block 

3665 # would skip it. Drop it eagerly so the round-robin 

3666 # generator does not keep yielding a dead pubsub that 

3667 # produces periodic errors from ``get_sharded_message``. 

3668 # 2. The old node is still known (transiently slow / 

3669 # unreachable): ``PubSub._execute`` auto-reconnects and 

3670 # ``on_connect`` re-subscribes to remaining channels, 

3671 # so other subscriptions on the same pubsub recover 

3672 # naturally. Leave it alone. 

3673 if self.cluster.get_node(node_name=old_name) is None: 

3674 try: 

3675 await old_pubsub.aclose() 

3676 except Exception: 

3677 pass 

3678 self.node_pubsub_mapping.pop(old_name, None) 

3679 # Attach to the new per-node pubsub, preserving the handler. Decode to 

3680 # a text key only when we must pass it as a kwarg (handler present). 

3681 new_pubsub = self._get_node_pubsub(new_node) 

3682 if handler: 

3683 decoded = ( 

3684 self.encoder.decode(channel, force=True) 

3685 if isinstance(channel, (bytes, bytearray)) 

3686 else channel 

3687 ) 

3688 await new_pubsub.ssubscribe(**{decoded: handler}) 

3689 else: 

3690 await new_pubsub.ssubscribe(channel) 

3691 self.shard_channels.update(new_pubsub.shard_channels) 

3692 normalized_key = next(iter(self._normalize_keys({channel: None}))) 

3693 self._shard_channel_to_node[normalized_key] = new_node.name 

3694 self.pending_unsubscribe_shard_channels.difference_update( 

3695 self._normalize_keys({channel: None}) 

3696 ) 

3697 

3698 async def on_slots_changed(self) -> None: 

3699 # Observer hook invoked by NodesManager after a slots-cache refresh. 

3700 # Schedule reconciliation as a separate task so the caller's code 

3701 # path (typically MovedError handling in _execute_command) is not 

3702 # blocked on the network I/O performed by reinitialize_shard_ 

3703 # subscriptions. No-op when there are no shard subscriptions to 

3704 # reconcile. 

3705 if not self.shard_channels: 

3706 return 

3707 task = asyncio.create_task(self.reinitialize_shard_subscriptions()) 

3708 self._reconcile_tasks.add(task) 

3709 task.add_done_callback(self._reconcile_tasks.discard) 

3710 # Consume the task's exception (if any) so Python does not emit a 

3711 # "Task exception was never retrieved" warning. reinitialize_shard_ 

3712 # subscriptions surfaces SlotNotCoveredError when a slot is still 

3713 # transiently uncovered; route it through the same logger channel 

3714 # as sync ClusterPubSubSlotsCacheListener for consistent observability. 

3715 task.add_done_callback(self._log_reconcile_task_exception) 

3716 

3717 @staticmethod 

3718 def _log_reconcile_task_exception(task: "asyncio.Task") -> None: 

3719 if task.cancelled(): 

3720 return 

3721 exc = task.exception() 

3722 if exc is not None: 

3723 logger.error( 

3724 "shard subscription reconciliation failed: %r", exc, exc_info=exc 

3725 ) 

3726 

3727 def get_redis_connection(self) -> Optional["AbstractConnection"]: 

3728 """ 

3729 Get the Redis connection of the pubsub connected node. 

3730 

3731 Returns the pubsub's dedicated connection (acquired from its own 

3732 connection pool), not from the ClusterNode's connection pool. 

3733 This avoids the connection pool resource leak that would occur 

3734 if we called node.acquire_connection() without releasing. 

3735 """ 

3736 # Return the pubsub's own dedicated connection, which is acquired 

3737 # from self.connection_pool when executing pubsub commands. 

3738 # This is safe because it's the connection dedicated to this pubsub 

3739 # instance, not a shared pool connection from the ClusterNode. 

3740 return self.connection 

3741 

3742 async def aclose(self) -> None: 

3743 """ 

3744 Disconnect the pubsub connection. 

3745 """ 

3746 # Cancel and gather in-flight reconciliation tasks BEFORE acquiring 

3747 # _shard_state_lock. The tasks themselves take that lock inside 

3748 # reinitialize_shard_subscriptions; since asyncio.Lock is non- 

3749 # reentrant, gathering while holding it would deadlock. Awaiting 

3750 # each task with suppressed CancelledError also avoids unhandled- 

3751 # exception warnings if the task was created but not yet scheduled. 

3752 if self._reconcile_tasks: 

3753 tasks = list(self._reconcile_tasks) 

3754 for task in tasks: 

3755 task.cancel() 

3756 await asyncio.gather(*tasks, return_exceptions=True) 

3757 # Hold _shard_state_lock across the rest of the teardown so it 

3758 # observes the same mutual-exclusion discipline as ssubscribe / 

3759 # sunsubscribe / get_sharded_message / reinitialize_shard_ 

3760 # subscriptions, which all mutate shard_channels, 

3761 # _shard_channel_to_node, and node_pubsub_mapping under this lock. 

3762 # Without it, super().aclose() rebinds shard_channels and 

3763 # pending_unsubscribe_shard_channels in parallel with a concurrent 

3764 # user-coroutine mutation that resumes during one of the awaits 

3765 # below, silently dropping subscription intent. 

3766 async with self._shard_state_lock: 

3767 self._reconcile_tasks.clear() 

3768 # Close all shard pubsub instances first 

3769 for pubsub in self.node_pubsub_mapping.values(): 

3770 await pubsub.aclose() 

3771 # Drop the now-dead per-node pubsubs from the mapping so the 

3772 # round-robin in _pubsubs_generator / _sharded_message_generator 

3773 # cannot yield them between teardown and re-subscription. 

3774 self.node_pubsub_mapping.clear() 

3775 # _pubsubs_generator captures node_pubsub_mapping.values() into 

3776 # a local list inside ``yield from``; clearing the mapping does 

3777 # not reach references already held by that captured snapshot, 

3778 # so a generator suspended mid-yield-from would still surface 

3779 # the now-aclose()'d per-node pubsubs after re-subscription. 

3780 # Recreate it to drop the captured list. type(self) bypasses 

3781 # the instance-level self-shadow established at __init__ 

3782 # (self._pubsubs_generator = self._pubsubs_generator()). 

3783 self._pubsubs_generator = type(self)._pubsubs_generator( # type: ignore[method-assign] 

3784 self 

3785 ) 

3786 # Let parent handle self.connection disconnect under the lock 

3787 # (includes disconnect, release to pool, and clearing 

3788 # self.connection) 

3789 await super().aclose() 

3790 # Clear the reverse index so a reused instance doesn't route 

3791 # against stale mappings. super().aclose() has already cleared 

3792 # shard_channels. 

3793 self._shard_channel_to_node.clear() 

3794 

3795 def _raise_on_invalid_node( 

3796 self, 

3797 redis_cluster: "RedisCluster", 

3798 node: Optional["ClusterNode"], 

3799 host: Optional[str], 

3800 port: Optional[int], 

3801 ) -> None: 

3802 """ 

3803 Raise a RedisClusterException if the node is None or doesn't exist in 

3804 the cluster. 

3805 """ 

3806 if node is None or redis_cluster.get_node(node_name=node.name) is None: 

3807 raise RedisClusterException( 

3808 f"Node {host}:{port} doesn't exist in the cluster" 

3809 ) 

3810 

3811 async def execute_command(self, *args: Any, **kwargs: Any) -> Any: 

3812 """ 

3813 Execute a command on the appropriate cluster node. 

3814 

3815 Taken code from redis-py and tweaked to make it work within a cluster. 

3816 """ 

3817 # NOTE: don't parse the response in this function -- it could pull a 

3818 # legitimate message off the stack if the connection is already 

3819 # subscribed to one or more channels 

3820 

3821 # For shard commands, route to appropriate node 

3822 command = args[0].upper() if args else "" 

3823 if command in ("SSUBSCRIBE", "SUNSUBSCRIBE", "SPUBLISH"): 

3824 if len(args) > 1: 

3825 channel = args[1] 

3826 node = self.cluster.get_node_from_key(channel) 

3827 if node: 

3828 pubsub = self._get_node_pubsub(node) 

3829 return await pubsub.execute_command(*args, **kwargs) 

3830 

3831 # For other commands, use the set node or lazily discover one 

3832 if self.connection is None: 

3833 if self.connection_pool is None: 

3834 if len(args) > 1: 

3835 # Hash the first channel and get one of the nodes holding 

3836 # this slot 

3837 channel = args[1] 

3838 slot = self.cluster.keyslot(channel) 

3839 node = self.cluster.nodes_manager.get_node_from_slot( 

3840 slot, 

3841 self.cluster.read_from_replicas, 

3842 self.cluster.load_balancing_strategy, 

3843 ) 

3844 else: 

3845 # Get a random node 

3846 node = self.cluster.get_random_node() 

3847 self.node = node 

3848 self.connection_pool = _ClusterNodePoolAdapter(node) 

3849 

3850 # Now we have a connection_pool, use parent's execute_command 

3851 return await super().execute_command(*args, **kwargs)