Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/cluster.py: 18%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1533 statements  

1import asyncio 

2import collections 

3import logging 

4import random 

5import socket 

6import threading 

7import time 

8import warnings 

9import weakref 

10from abc import ABC, abstractmethod 

11from collections import defaultdict 

12from copy import copy 

13from itertools import chain 

14from types import MethodType 

15from typing import ( 

16 TYPE_CHECKING, 

17 Any, 

18 Callable, 

19 Coroutine, 

20 Deque, 

21 Dict, 

22 Generator, 

23 List, 

24 Literal, 

25 Mapping, 

26 Optional, 

27 Set, 

28 Tuple, 

29 Type, 

30 TypeVar, 

31 Union, 

32) 

33 

34if TYPE_CHECKING: 

35 from redis.asyncio.keyspace_notifications import ( 

36 AsyncClusterKeyspaceNotifications, 

37 ) 

38 

39from redis._defaults import ( 

40 DEFAULT_RETRY_BASE, 

41 DEFAULT_RETRY_CAP, 

42 DEFAULT_RETRY_COUNT, 

43 DEFAULT_SOCKET_CONNECT_TIMEOUT, 

44 DEFAULT_SOCKET_READ_SIZE, 

45 DEFAULT_SOCKET_TIMEOUT, 

46) 

47from redis._parsers import AsyncCommandsParser, Encoder 

48from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy 

49from redis._parsers.helpers import get_response_callbacks 

50from redis.asyncio.client import PubSub, ResponseCallbackT 

51from redis.asyncio.connection import ( 

52 AbstractConnection, 

53 Connection, 

54 ConnectionPoolInterface, 

55 SSLConnection, 

56 parse_url, 

57) 

58from redis.asyncio.lock import Lock 

59from redis.asyncio.observability.recorder import ( 

60 record_error_count, 

61 record_operation_duration, 

62) 

63from redis.asyncio.retry import Retry 

64from redis.auth.token import TokenInterface 

65from redis.backoff import ExponentialWithJitterBackoff, NoBackoff 

66from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis 

67from redis.cluster import ( 

68 PIPELINE_BLOCKED_COMMANDS, 

69 PRIMARY, 

70 REPLICA, 

71 SLOT_ID, 

72 AbstractRedisCluster, 

73 LoadBalancer, 

74 LoadBalancingStrategy, 

75 block_pipeline_command, 

76 get_node_name, 

77 parse_cluster_shards, 

78 parse_cluster_shards_unified, 

79 parse_cluster_shards_with_str_keys, 

80 parse_cluster_slots, 

81) 

82from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands 

83from redis.commands.helpers import list_or_args, parse_pubsub_subscriptions 

84from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver 

85from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot 

86from redis.credentials import CredentialProvider 

87from redis.driver_info import DriverInfo, resolve_driver_info 

88from redis.event import ( 

89 AfterAsyncClusterInstantiationEvent, 

90 AsyncAfterSlotsCacheRefreshEvent, 

91 AsyncEventListenerInterface, 

92 EventDispatcher, 

93) 

94from redis.exceptions import ( 

95 AskError, 

96 BusyLoadingError, 

97 ClusterDownError, 

98 ClusterError, 

99 ConnectionError, 

100 CrossSlotTransactionError, 

101 DataError, 

102 ExecAbortError, 

103 InvalidPipelineStack, 

104 MaxConnectionsError, 

105 MovedError, 

106 RedisClusterException, 

107 RedisError, 

108 ResponseError, 

109 SlotNotCoveredError, 

110 TimeoutError, 

111 TryAgainError, 

112 WatchError, 

113) 

114from redis.typing import ( 

115 AnyKeyT, 

116 ChannelT, 

117 EncodableT, 

118 KeyT, 

119 PubSubHandler, 

120 Subscription, 

121) 

122from redis.utils import ( 

123 SENTINEL, 

124 SSL_AVAILABLE, 

125 deprecated_args, 

126 deprecated_function, 

127 safe_str, 

128 str_if_bytes, 

129 truncate_text, 

130) 

131 

132if SSL_AVAILABLE: 

133 from ssl import TLSVersion, VerifyFlags, VerifyMode 

134else: 

135 TLSVersion = None 

136 VerifyMode = None 

137 VerifyFlags = None 

138 

139logger = logging.getLogger(__name__) 

140 

141TargetNodesT = TypeVar( 

142 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] 

143) 

144 

145 

146class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

147 """ 

148 Create a new RedisCluster client. 

149 

150 Pass one of parameters: 

151 

152 - `host` & `port` 

153 - `startup_nodes` 

154 

155 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. 

156 | Use ``await`` :meth:`close` to disconnect connections & close client. 

157 

158 Many commands support the target_nodes kwarg. It can be one of the 

159 :attr:`NODE_FLAGS`: 

160 

161 - :attr:`PRIMARIES` 

162 - :attr:`REPLICAS` 

163 - :attr:`ALL_NODES` 

164 - :attr:`RANDOM` 

165 - :attr:`DEFAULT_NODE` 

166 

167 Note: This client is not thread/process/fork safe. 

168 

169 :param host: 

170 | Can be used to point to a startup node 

171 :param port: 

172 | Port used if **host** is provided 

173 :param startup_nodes: 

174 | :class:`~.ClusterNode` to used as a startup node 

175 :param require_full_coverage: 

176 | When set to ``False``: the client will not require a full coverage of 

177 the slots. However, if not all slots are covered, and at least one node 

178 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw 

179 a :class:`~.ClusterDownError` for some key-based commands. 

180 | When set to ``True``: all slots must be covered to construct the cluster 

181 client. If not all slots are covered, :class:`~.RedisClusterException` will be 

182 thrown. 

183 | See: 

184 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters 

185 :param read_from_replicas: 

186 | @deprecated - please use load_balancing_strategy instead 

187 | Enable read from replicas in READONLY mode. 

188 When set to true, read commands will be assigned between the primary and 

189 its replications in a Round-Robin manner. 

190 The data read from replicas is eventually consistent with the data in primary nodes. 

191 :param load_balancing_strategy: 

192 | Enable read from replicas in READONLY mode and defines the load balancing 

193 strategy that will be used for cluster node selection. 

194 The data read from replicas is eventually consistent with the data in primary nodes. 

195 :param dynamic_startup_nodes: 

196 | Set the RedisCluster's startup nodes to all the discovered nodes. 

197 If true (default value), the cluster's discovered nodes will be used to 

198 determine the cluster nodes-slots mapping in the next topology refresh. 

199 It will remove the initial passed startup nodes if their endpoints aren't 

200 listed in the CLUSTER SLOTS output. 

201 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists 

202 specific IP addresses, it is best to set it to false. 

203 :param reinitialize_steps: 

204 | Specifies the number of MOVED errors that need to occur before reinitializing 

205 the whole cluster topology. If a MOVED error occurs and the cluster does not 

206 need to be reinitialized on this current error handling, only the MOVED slot 

207 will be patched with the redirected node. 

208 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. 

209 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 

210 0. 

211 :param cluster_error_retry_attempts: 

212 | @deprecated - Please configure the 'retry' object instead 

213 In case 'retry' object is set - this argument is ignored! 

214 

215 Number of times to retry before raising an error when :class:`~.TimeoutError`, 

216 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` 

217 or :class:`~.ClusterDownError` are encountered 

218 :param retry: 

219 | A retry object that defines the retry strategy and the number of 

220 retries for the cluster client. 

221 In current implementation for the cluster client (starting form redis-py version 6.0.0) 

222 the retry object is not yet fully utilized, instead it is used just to determine 

223 the number of retries for the cluster client. 

224 In the future releases the retry object will be used to handle the cluster client retries! 

225 :param max_connections: 

226 | Maximum number of connections per node. If there are no free connections & the 

227 maximum number of connections are already created, a 

228 :class:`~.MaxConnectionsError` is raised. 

229 :param socket_keepalive: 

230 | If ``True``, TCP keepalive is enabled for TCP socket connections. 

231 :param socket_keepalive_options: 

232 | Mapping of TCP keepalive socket option constants to values, for 

233 example ``{socket.TCP_KEEPIDLE: 30}``. If left unspecified, redis-py 

234 uses TCP keepalive defaults when ``socket_keepalive`` is enabled: 

235 idle 30 seconds, interval 5 seconds, and 3 probes. 

236 Platform-specific options that are not available are skipped. 

237 Pass ``None`` or ``{}`` to avoid setting additional TCP keepalive 

238 options. 

239 :param address_remap: 

240 | An optional callable which, when provided with an internal network 

241 address of a node, e.g. a `(host, port)` tuple, will return the address 

242 where the node is reachable. This can be used to map the addresses at 

243 which the nodes _think_ they are, to addresses at which a client may 

244 reach them, such as when they sit behind a proxy. 

245 

246 | Rest of the arguments will be passed to the 

247 :class:`~redis.asyncio.connection.Connection` instances when created 

248 

249 :raises RedisClusterException: 

250 if any arguments are invalid or unknown. Eg: 

251 

252 - `db` != 0 or None 

253 - `path` argument for unix socket connection 

254 - none of the `host`/`port` & `startup_nodes` were provided 

255 

256 """ 

257 

258 @classmethod 

259 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": 

260 """ 

261 Return a Redis client object configured from the given URL. 

262 

263 For example:: 

264 

265 redis://[[username]:[password]]@localhost:6379/0 

266 rediss://[[username]:[password]]@localhost:6379/0 

267 

268 Three URL schemes are supported: 

269 

270 - `redis://` creates a TCP socket connection. See more at: 

271 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

272 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

273 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

274 

275 The username, password, hostname, path and all querystring values are passed 

276 through ``urllib.parse.unquote`` in order to replace any percent-encoded values 

277 with their corresponding characters. 

278 

279 All querystring options are cast to their appropriate Python types. Boolean 

280 arguments can be specified with string values "True"/"False" or "Yes"/"No". 

281 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once 

282 parsed, the querystring arguments and keyword arguments are passed to 

283 :class:`~redis.asyncio.connection.Connection` when created. 

284 In the case of conflicting arguments, querystring arguments are used. 

285 """ 

286 kwargs.update(parse_url(url)) 

287 if kwargs.pop("connection_class", None) is SSLConnection: 

288 kwargs["ssl"] = True 

289 return cls(**kwargs) 

290 

291 # Type discrimination marker for @overload self-type pattern 

292 _is_async_client: Literal[True] = True 

293 

294 __slots__ = ( 

295 "_initialize", 

296 "_lock", 

297 "retry", 

298 "command_flags", 

299 "commands_parser", 

300 "connection_kwargs", 

301 "encoder", 

302 "node_flags", 

303 "nodes_manager", 

304 "read_from_replicas", 

305 "reinitialize_counter", 

306 "reinitialize_steps", 

307 "response_callbacks", 

308 "result_callbacks", 

309 ) 

310 

311 @deprecated_args( 

312 args_to_warn=["read_from_replicas"], 

313 reason="Please configure the 'load_balancing_strategy' instead", 

314 version="5.3.0", 

315 ) 

316 @deprecated_args( 

317 args_to_warn=[ 

318 "cluster_error_retry_attempts", 

319 ], 

320 reason="Please configure the 'retry' object instead", 

321 version="6.0.0", 

322 ) 

323 @deprecated_args( 

324 args_to_warn=["lib_name", "lib_version"], 

325 reason="Use 'driver_info' parameter instead. " 

326 "lib_name and lib_version will be removed in a future version.", 

327 ) 

328 def __init__( 

329 self, 

330 host: str | None = None, 

331 port: str | int = 6379, 

332 # Cluster related kwargs 

333 startup_nodes: List["ClusterNode"] | None = None, 

334 require_full_coverage: bool = True, 

335 read_from_replicas: bool = False, 

336 load_balancing_strategy: LoadBalancingStrategy | None = None, 

337 dynamic_startup_nodes: bool = True, 

338 reinitialize_steps: int = 5, 

339 cluster_error_retry_attempts: int = DEFAULT_RETRY_COUNT, 

340 max_connections: int = 100, 

341 retry: Retry | None = None, 

342 retry_on_error: List[Type[Exception]] | None = None, 

343 # Client related kwargs 

344 db: str | int = 0, 

345 path: str | None = None, 

346 credential_provider: CredentialProvider | None = None, 

347 username: str | None = None, 

348 password: str | None = None, 

349 client_name: str | None = None, 

350 lib_name: str | object | None = SENTINEL, 

351 lib_version: str | object | None = SENTINEL, 

352 driver_info: DriverInfo | object | None = SENTINEL, 

353 # Encoding related kwargs 

354 encoding: str = "utf-8", 

355 encoding_errors: str = "strict", 

356 decode_responses: bool = False, 

357 # Connection related kwargs 

358 health_check_interval: float = 0, 

359 socket_timeout: float | None = DEFAULT_SOCKET_TIMEOUT, 

360 socket_connect_timeout: float | None = DEFAULT_SOCKET_CONNECT_TIMEOUT, 

361 socket_read_size: int = DEFAULT_SOCKET_READ_SIZE, 

362 socket_keepalive: bool = True, 

363 socket_keepalive_options: Mapping[int, int | bytes] | object | None = SENTINEL, 

364 # SSL related kwargs 

365 ssl: bool = False, 

366 ssl_ca_certs: str | None = None, 

367 ssl_ca_data: str | None = None, 

368 ssl_cert_reqs: "str | VerifyMode" = "required", 

369 ssl_include_verify_flags: List["VerifyFlags"] | None = None, 

370 ssl_exclude_verify_flags: List["VerifyFlags"] | None = None, 

371 ssl_certfile: str | None = None, 

372 ssl_check_hostname: bool = True, 

373 ssl_keyfile: str | None = None, 

374 ssl_min_version: "TLSVersion | None" = None, 

375 ssl_ciphers: str | None = None, 

376 protocol: int | None = None, 

377 legacy_responses: bool = True, 

378 address_remap: Callable[[Tuple[str, int]], Tuple[str, int]] | None = None, 

379 event_dispatcher: EventDispatcher | None = None, 

380 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(), 

381 ) -> None: 

382 if db: 

383 raise RedisClusterException( 

384 "Argument 'db' must be 0 or None in cluster mode" 

385 ) 

386 

387 if path: 

388 raise RedisClusterException( 

389 "Unix domain socket is not supported in cluster mode" 

390 ) 

391 

392 if (not host or not port) and not startup_nodes: 

393 raise RedisClusterException( 

394 "RedisCluster requires at least one node to discover the cluster.\n" 

395 "Please provide one of the following or use RedisCluster.from_url:\n" 

396 ' - host and port: RedisCluster(host="localhost", port=6379)\n' 

397 " - startup_nodes: RedisCluster(startup_nodes=[" 

398 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' 

399 ) 

400 

401 computed_driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

402 

403 kwargs: Dict[str, Any] = { 

404 "max_connections": max_connections, 

405 "connection_class": Connection, 

406 # Client related kwargs 

407 "credential_provider": credential_provider, 

408 "username": username, 

409 "password": password, 

410 "client_name": client_name, 

411 "driver_info": computed_driver_info, 

412 # Encoding related kwargs 

413 "encoding": encoding, 

414 "encoding_errors": encoding_errors, 

415 "decode_responses": decode_responses, 

416 # Connection related kwargs 

417 "health_check_interval": health_check_interval, 

418 "socket_connect_timeout": socket_connect_timeout, 

419 "socket_keepalive": socket_keepalive, 

420 "socket_keepalive_options": socket_keepalive_options, 

421 "socket_read_size": socket_read_size, 

422 "socket_timeout": socket_timeout, 

423 "protocol": protocol, 

424 "legacy_responses": legacy_responses, 

425 } 

426 

427 if ssl: 

428 # SSL related kwargs 

429 kwargs.update( 

430 { 

431 "connection_class": SSLConnection, 

432 "ssl_ca_certs": ssl_ca_certs, 

433 "ssl_ca_data": ssl_ca_data, 

434 "ssl_cert_reqs": ssl_cert_reqs, 

435 "ssl_include_verify_flags": ssl_include_verify_flags, 

436 "ssl_exclude_verify_flags": ssl_exclude_verify_flags, 

437 "ssl_certfile": ssl_certfile, 

438 "ssl_check_hostname": ssl_check_hostname, 

439 "ssl_keyfile": ssl_keyfile, 

440 "ssl_min_version": ssl_min_version, 

441 "ssl_ciphers": ssl_ciphers, 

442 } 

443 ) 

444 

445 if read_from_replicas or load_balancing_strategy: 

446 # Call our on_connect function to configure READONLY mode 

447 kwargs["redis_connect_func"] = self.on_connect 

448 

449 if retry: 

450 self.retry = retry 

451 else: 

452 self.retry = Retry( 

453 backoff=ExponentialWithJitterBackoff( 

454 base=DEFAULT_RETRY_BASE, cap=DEFAULT_RETRY_CAP 

455 ), 

456 retries=cluster_error_retry_attempts, 

457 ) 

458 if retry_on_error: 

459 self.retry.update_supported_errors(retry_on_error) 

460 

461 kwargs["response_callbacks"] = get_response_callbacks( 

462 user_protocol=kwargs.get("protocol"), 

463 legacy_responses=kwargs.get("legacy_responses", True), 

464 ) 

465 if not kwargs.get("legacy_responses", True): 

466 kwargs["response_callbacks"]["CLUSTER SHARDS"] = ( 

467 parse_cluster_shards_unified 

468 ) 

469 elif kwargs.get("protocol") is None: 

470 kwargs["response_callbacks"]["CLUSTER SHARDS"] = ( 

471 parse_cluster_shards_with_str_keys 

472 ) 

473 else: 

474 kwargs["response_callbacks"]["CLUSTER SHARDS"] = parse_cluster_shards 

475 self.connection_kwargs = kwargs 

476 

477 if startup_nodes: 

478 passed_nodes = [] 

479 for node in startup_nodes: 

480 passed_nodes.append( 

481 ClusterNode(node.host, node.port, **self.connection_kwargs) 

482 ) 

483 startup_nodes = passed_nodes 

484 else: 

485 startup_nodes = [] 

486 if host and port: 

487 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) 

488 

489 if event_dispatcher is None: 

490 self._event_dispatcher = EventDispatcher() 

491 else: 

492 self._event_dispatcher = event_dispatcher 

493 

494 self.startup_nodes = startup_nodes 

495 self.nodes_manager = NodesManager( 

496 startup_nodes, 

497 require_full_coverage, 

498 kwargs, 

499 dynamic_startup_nodes=dynamic_startup_nodes, 

500 address_remap=address_remap, 

501 event_dispatcher=self._event_dispatcher, 

502 ) 

503 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

504 self.read_from_replicas = read_from_replicas 

505 self.load_balancing_strategy = load_balancing_strategy 

506 self.reinitialize_steps = reinitialize_steps 

507 self.reinitialize_counter = 0 

508 

509 # For backward compatibility, mapping from existing policies to new one 

510 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { 

511 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, 

512 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, 

513 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, 

514 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, 

515 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, 

516 SLOT_ID: RequestPolicy.DEFAULT_KEYED, 

517 } 

518 

519 self._policies_callback_mapping: dict[ 

520 Union[RequestPolicy, ResponsePolicy], Callable 

521 ] = { 

522 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ 

523 self.get_random_primary_or_all_nodes(command_name) 

524 ], 

525 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, 

526 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], 

527 RequestPolicy.ALL_SHARDS: self.get_primaries, 

528 RequestPolicy.ALL_NODES: self.get_nodes, 

529 RequestPolicy.ALL_REPLICAS: self.get_replicas, 

530 RequestPolicy.SPECIAL: self.get_special_nodes, 

531 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, 

532 ResponsePolicy.DEFAULT_KEYED: lambda res: res, 

533 } 

534 

535 self._policy_resolver = policy_resolver 

536 self.commands_parser = AsyncCommandsParser() 

537 self._aggregate_nodes = None 

538 self.node_flags = self.__class__.NODE_FLAGS.copy() 

539 self.command_flags = self.__class__.COMMAND_FLAGS.copy() 

540 self.response_callbacks = kwargs["response_callbacks"] 

541 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() 

542 self.result_callbacks["CLUSTER SLOTS"] = ( 

543 lambda cmd, res, **kwargs: parse_cluster_slots( 

544 list(res.values())[0], **kwargs 

545 ) 

546 ) 

547 

548 self._initialize = True 

549 self._lock: Optional[asyncio.Lock] = None 

550 

551 # When used as an async context manager, we need to increment and decrement 

552 # a usage counter so that we can close the connection pool when no one is 

553 # using the client. 

554 self._usage_counter = 0 

555 self._usage_lock = asyncio.Lock() 

556 

557 async def initialize( 

558 self, 

559 additional_startup_nodes_info: Optional[List[Tuple[str, int]]] = None, 

560 last_failed_node_name: Optional[str] = None, 

561 ) -> "RedisCluster": 

562 """Get all nodes from startup nodes & creates connections if not initialized.""" 

563 if self._initialize: 

564 if not self._lock: 

565 self._lock = asyncio.Lock() 

566 async with self._lock: 

567 if self._initialize: 

568 try: 

569 await self.nodes_manager.initialize( 

570 additional_startup_nodes_info=additional_startup_nodes_info, 

571 last_failed_node_name=last_failed_node_name, 

572 ) 

573 await self.commands_parser.initialize( 

574 self.nodes_manager.default_node 

575 ) 

576 self._initialize = False 

577 except BaseException: 

578 await self.nodes_manager.aclose() 

579 await self.nodes_manager.aclose("startup_nodes") 

580 raise 

581 return self 

582 

583 async def aclose(self) -> None: 

584 """Close all connections & client if initialized.""" 

585 if not self._initialize: 

586 if not self._lock: 

587 self._lock = asyncio.Lock() 

588 async with self._lock: 

589 if not self._initialize: 

590 self._initialize = True 

591 await self.nodes_manager.aclose() 

592 await self.nodes_manager.aclose("startup_nodes") 

593 

594 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") 

595 async def close(self) -> None: 

596 """alias for aclose() for backwards compatibility""" 

597 await self.aclose() 

598 

599 async def __aenter__(self) -> "RedisCluster": 

600 """ 

601 Async context manager entry. Increments a usage counter so that the 

602 connection pool is only closed (via aclose()) when no context is using 

603 the client. 

604 """ 

605 await self._increment_usage() 

606 try: 

607 # Initialize the client (i.e. establish connection, etc.) 

608 return await self.initialize() 

609 except Exception: 

610 # If initialization fails, decrement the counter to keep it in sync 

611 await self._decrement_usage() 

612 raise 

613 

614 async def _increment_usage(self) -> int: 

615 """ 

616 Helper coroutine to increment the usage counter while holding the lock. 

617 Returns the new value of the usage counter. 

618 """ 

619 async with self._usage_lock: 

620 self._usage_counter += 1 

621 return self._usage_counter 

622 

623 async def _decrement_usage(self) -> int: 

624 """ 

625 Helper coroutine to decrement the usage counter while holding the lock. 

626 Returns the new value of the usage counter. 

627 """ 

628 async with self._usage_lock: 

629 self._usage_counter -= 1 

630 return self._usage_counter 

631 

632 async def __aexit__(self, exc_type, exc_value, traceback): 

633 """ 

634 Async context manager exit. Decrements a usage counter. If this is the 

635 last exit (counter becomes zero), the client closes its connection pool. 

636 """ 

637 current_usage = await asyncio.shield(self._decrement_usage()) 

638 if current_usage == 0: 

639 # This was the last active context, so disconnect the pool. 

640 await asyncio.shield(self.aclose()) 

641 

642 def __await__(self) -> Generator[Any, None, "RedisCluster"]: 

643 return self.initialize().__await__() 

644 

645 _DEL_MESSAGE = "Unclosed RedisCluster client" 

646 

647 def __del__( 

648 self, 

649 _warn: Any = warnings.warn, 

650 _grl: Any = asyncio.get_running_loop, 

651 ) -> None: 

652 if hasattr(self, "_initialize") and not self._initialize: 

653 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

654 try: 

655 context = {"client": self, "message": self._DEL_MESSAGE} 

656 _grl().call_exception_handler(context) 

657 except RuntimeError: 

658 pass 

659 

660 async def on_connect(self, connection: Connection) -> None: 

661 await connection.on_connect() 

662 

663 # Sending READONLY command to server to configure connection as 

664 # readonly. Since each cluster node may change its server type due 

665 # to a failover, we should establish a READONLY connection 

666 # regardless of the server type. If this is a primary connection, 

667 # READONLY would not affect executing write commands. 

668 await connection.send_command("READONLY") 

669 if str_if_bytes(await connection.read_response()) != "OK": 

670 raise ConnectionError("READONLY command failed") 

671 

672 def get_nodes(self) -> List["ClusterNode"]: 

673 """Get all nodes of the cluster.""" 

674 return list(self.nodes_manager.nodes_cache.values()) 

675 

676 def get_primaries(self) -> List["ClusterNode"]: 

677 """Get the primary nodes of the cluster.""" 

678 return self.nodes_manager.get_nodes_by_server_type(PRIMARY) 

679 

680 def get_replicas(self) -> List["ClusterNode"]: 

681 """Get the replica nodes of the cluster.""" 

682 return self.nodes_manager.get_nodes_by_server_type(REPLICA) 

683 

684 def get_random_node(self) -> "ClusterNode": 

685 """Get a random node of the cluster.""" 

686 return random.choice(list(self.nodes_manager.nodes_cache.values())) 

687 

688 def get_default_node(self) -> "ClusterNode": 

689 """Get the default node of the client.""" 

690 return self.nodes_manager.default_node 

691 

692 def set_default_node(self, node: "ClusterNode") -> None: 

693 """ 

694 Set the default node of the client. 

695 

696 :raises DataError: if None is passed or node does not exist in cluster. 

697 """ 

698 if not node or not self.get_node(node_name=node.name): 

699 raise DataError("The requested node does not exist in the cluster.") 

700 

701 self.nodes_manager.default_node = node 

702 

703 def get_node( 

704 self, 

705 host: Optional[str] = None, 

706 port: Optional[int] = None, 

707 node_name: Optional[str] = None, 

708 ) -> Optional["ClusterNode"]: 

709 """Get node by (host, port) or node_name.""" 

710 return self.nodes_manager.get_node(host, port, node_name) 

711 

712 def get_node_from_key( 

713 self, key: str, replica: bool = False 

714 ) -> Optional["ClusterNode"]: 

715 """ 

716 Get the cluster node corresponding to the provided key. 

717 

718 :param key: 

719 :param replica: 

720 | Indicates if a replica should be returned 

721 | 

722 None will returned if no replica holds this key 

723 

724 :raises SlotNotCoveredError: if the key is not covered by any slot. 

725 """ 

726 slot = self.keyslot(key) 

727 slot_cache = self.nodes_manager.slots_cache.get(slot) 

728 if not slot_cache: 

729 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') 

730 

731 if replica: 

732 if len(self.nodes_manager.slots_cache[slot]) < 2: 

733 return None 

734 node_idx = 1 

735 else: 

736 node_idx = 0 

737 

738 return slot_cache[node_idx] 

739 

740 def get_random_primary_or_all_nodes(self, command_name): 

741 """ 

742 Returns random primary or all nodes depends on READONLY mode. 

743 """ 

744 if self.read_from_replicas and command_name in READ_COMMANDS: 

745 return self.get_random_node() 

746 

747 return self.get_random_primary_node() 

748 

749 def get_random_primary_node(self) -> "ClusterNode": 

750 """ 

751 Returns a random primary node 

752 """ 

753 return random.choice(self.get_primaries()) 

754 

755 async def get_nodes_from_slot(self, command: str, *args): 

756 """ 

757 Returns a list of nodes that hold the specified keys' slots. 

758 """ 

759 # get the node that holds the key's slot 

760 return [ 

761 self.nodes_manager.get_node_from_slot( 

762 await self._determine_slot(command, *args), 

763 self.read_from_replicas and command in READ_COMMANDS, 

764 self.load_balancing_strategy if command in READ_COMMANDS else None, 

765 ) 

766 ] 

767 

768 def get_special_nodes(self) -> Optional[list["ClusterNode"]]: 

769 """ 

770 Returns a list of nodes for commands with a special policy. 

771 """ 

772 if not self._aggregate_nodes: 

773 raise RedisClusterException( 

774 "Cannot execute FT.CURSOR commands without FT.AGGREGATE" 

775 ) 

776 

777 return self._aggregate_nodes 

778 

779 def keyslot(self, key: EncodableT) -> int: 

780 """ 

781 Find the keyslot for a given key. 

782 

783 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding 

784 """ 

785 return key_slot(self.encoder.encode(key)) 

786 

787 def get_encoder(self) -> Encoder: 

788 """Get the encoder object of the client.""" 

789 return self.encoder 

790 

791 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: 

792 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" 

793 return self.connection_kwargs 

794 

795 def set_retry(self, retry: Retry) -> None: 

796 self.retry = retry 

797 

798 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

799 """Set a custom response callback.""" 

800 self.response_callbacks[command] = callback 

801 

802 async def _determine_nodes( 

803 self, 

804 command: str, 

805 *args: Any, 

806 request_policy: RequestPolicy, 

807 node_flag: Optional[str] = None, 

808 ) -> List["ClusterNode"]: 

809 # Determine which nodes should be executed the command on. 

810 # Returns a list of target nodes. 

811 if not node_flag: 

812 # get the nodes group for this command if it was predefined 

813 node_flag = self.command_flags.get(command) 

814 

815 if node_flag in self._command_flags_mapping: 

816 request_policy = self._command_flags_mapping[node_flag] 

817 

818 policy_callback = self._policies_callback_mapping[request_policy] 

819 

820 if request_policy == RequestPolicy.DEFAULT_KEYED: 

821 nodes = await policy_callback(command, *args) 

822 elif request_policy == RequestPolicy.DEFAULT_KEYLESS: 

823 nodes = policy_callback(command) 

824 else: 

825 nodes = policy_callback() 

826 

827 if command.lower() == "ft.aggregate": 

828 self._aggregate_nodes = nodes 

829 

830 return nodes 

831 

832 async def _determine_slot(self, command: str, *args: Any) -> int: 

833 if self.command_flags.get(command) == SLOT_ID: 

834 # The command contains the slot ID 

835 return int(args[0]) 

836 

837 # Get the keys in the command 

838 

839 # EVAL and EVALSHA are common enough that it's wasteful to go to the 

840 # redis server to parse the keys. Besides, there is a bug in redis<7.0 

841 # where `self._get_command_keys()` fails anyway. So, we special case 

842 # EVAL/EVALSHA. 

843 # - issue: https://github.com/redis/redis/issues/9493 

844 # - fix: https://github.com/redis/redis/pull/9733 

845 if command.upper() in ("EVAL", "EVALSHA"): 

846 # command syntax: EVAL "script body" num_keys ... 

847 if len(args) < 2: 

848 raise RedisClusterException( 

849 f"Invalid args in command: {command, *args}" 

850 ) 

851 keys = args[2 : 2 + int(args[1])] 

852 # if there are 0 keys, that means the script can be run on any node 

853 # so we can just return a random slot 

854 if not keys: 

855 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

856 else: 

857 keys = await self.commands_parser.get_keys(command, *args) 

858 if not keys: 

859 # FCALL can call a function with 0 keys, that means the function 

860 # can be run on any node so we can just return a random slot 

861 if command.upper() in ("FCALL", "FCALL_RO"): 

862 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

863 raise RedisClusterException( 

864 "No way to dispatch this command to Redis Cluster. " 

865 "Missing key.\nYou can execute the command by specifying " 

866 f"target nodes.\nCommand: {args}" 

867 ) 

868 

869 # single key command 

870 if len(keys) == 1: 

871 return self.keyslot(keys[0]) 

872 

873 # multi-key command; we need to make sure all keys are mapped to 

874 # the same slot 

875 slots = {self.keyslot(key) for key in keys} 

876 if len(slots) != 1: 

877 raise RedisClusterException( 

878 f"{command} - all keys must map to the same key slot" 

879 ) 

880 

881 return slots.pop() 

882 

883 def _is_node_flag(self, target_nodes: Any) -> bool: 

884 return isinstance(target_nodes, str) and target_nodes in self.node_flags 

885 

886 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: 

887 if isinstance(target_nodes, list): 

888 nodes = target_nodes 

889 elif isinstance(target_nodes, ClusterNode): 

890 # Supports passing a single ClusterNode as a variable 

891 nodes = [target_nodes] 

892 elif isinstance(target_nodes, dict): 

893 # Supports dictionaries of the format {node_name: node}. 

894 # It enables to execute commands with multi nodes as follows: 

895 # rc.cluster_save_config(rc.get_primaries()) 

896 nodes = list(target_nodes.values()) 

897 else: 

898 raise TypeError( 

899 "target_nodes type can be one of the following: " 

900 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," 

901 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. " 

902 f"The passed type is {type(target_nodes)}" 

903 ) 

904 return nodes 

905 

906 async def _record_error_metric( 

907 self, 

908 error: Exception, 

909 connection: Union[Connection, "ClusterNode"], 

910 is_internal: bool = True, 

911 retry_attempts: Optional[int] = None, 

912 ): 

913 """ 

914 Records error count metric directly. 

915 Accepts either a Connection or ClusterNode object. 

916 """ 

917 await record_error_count( 

918 server_address=connection.host, 

919 server_port=connection.port, 

920 network_peer_address=connection.host, 

921 network_peer_port=connection.port, 

922 error_type=error, 

923 retry_attempts=retry_attempts if retry_attempts is not None else 0, 

924 is_internal=is_internal, 

925 ) 

926 

927 async def _record_command_metric( 

928 self, 

929 command_name: str, 

930 duration_seconds: float, 

931 connection: Union[Connection, "ClusterNode"], 

932 error: Optional[Exception] = None, 

933 ): 

934 """ 

935 Records operation duration metric directly. 

936 Accepts either a Connection or ClusterNode object. 

937 """ 

938 # Connection has db attribute, ClusterNode has connection_kwargs 

939 if hasattr(connection, "db"): 

940 db = connection.db 

941 else: 

942 db = connection.connection_kwargs.get("db", 0) 

943 await record_operation_duration( 

944 command_name=command_name, 

945 duration_seconds=duration_seconds, 

946 server_address=connection.host, 

947 server_port=connection.port, 

948 db_namespace=str(db) if db is not None else None, 

949 error=error, 

950 ) 

951 

952 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: 

953 """ 

954 Execute a raw command on the appropriate cluster node or target_nodes. 

955 

956 It will retry the command as specified by the retries property of 

957 the :attr:`retry` & then raise an exception. 

958 

959 :param args: 

960 | Raw command args 

961 :param kwargs: 

962 

963 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

964 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

965 - Rest of the kwargs are passed to the Redis connection 

966 

967 :raises RedisClusterException: if target_nodes is not provided & the command 

968 can't be mapped to a slot 

969 """ 

970 command = args[0] 

971 target_nodes = [] 

972 target_nodes_specified = False 

973 retry_attempts = self.retry.get_retries() 

974 

975 passed_targets = kwargs.pop("target_nodes", None) 

976 if passed_targets and not self._is_node_flag(passed_targets): 

977 target_nodes = self._parse_target_nodes(passed_targets) 

978 target_nodes_specified = True 

979 retry_attempts = 0 

980 

981 command_policies = await self._policy_resolver.resolve(args[0].lower()) 

982 

983 if not command_policies and not target_nodes_specified: 

984 command_flag = self.command_flags.get(command) 

985 if not command_flag: 

986 # Fallback to default policy 

987 if not self.get_default_node(): 

988 slot = None 

989 else: 

990 slot = await self._determine_slot(*args) 

991 if slot is None: 

992 command_policies = CommandPolicies() 

993 else: 

994 command_policies = CommandPolicies( 

995 request_policy=RequestPolicy.DEFAULT_KEYED, 

996 response_policy=ResponsePolicy.DEFAULT_KEYED, 

997 ) 

998 else: 

999 if command_flag in self._command_flags_mapping: 

1000 command_policies = CommandPolicies( 

1001 request_policy=self._command_flags_mapping[command_flag] 

1002 ) 

1003 else: 

1004 command_policies = CommandPolicies() 

1005 elif not command_policies and target_nodes_specified: 

1006 command_policies = CommandPolicies() 

1007 

1008 # Add one for the first execution 

1009 execute_attempts = 1 + retry_attempts 

1010 failure_count = 0 

1011 

1012 # Start timing for observability 

1013 start_time = time.monotonic() 

1014 last_failed_node_name = None 

1015 

1016 for _ in range(execute_attempts): 

1017 if self._initialize: 

1018 await self.initialize(last_failed_node_name=last_failed_node_name) 

1019 last_failed_node_name = None 

1020 if ( 

1021 len(target_nodes) == 1 

1022 and target_nodes[0] == self.get_default_node() 

1023 ): 

1024 # Replace the default cluster node 

1025 self.replace_default_node() 

1026 try: 

1027 if not target_nodes_specified: 

1028 # Determine the nodes to execute the command on 

1029 target_nodes = await self._determine_nodes( 

1030 *args, 

1031 request_policy=command_policies.request_policy, 

1032 node_flag=passed_targets, 

1033 ) 

1034 if not target_nodes: 

1035 raise RedisClusterException( 

1036 f"No targets were found to execute {args} command on" 

1037 ) 

1038 

1039 if len(target_nodes) == 1: 

1040 # Return the processed result 

1041 ret = await self._execute_command(target_nodes[0], *args, **kwargs) 

1042 if command in self.result_callbacks: 

1043 ret = self.result_callbacks[command]( 

1044 command, {target_nodes[0].name: ret}, **kwargs 

1045 ) 

1046 return self._policies_callback_mapping[ 

1047 command_policies.response_policy 

1048 ](ret) 

1049 else: 

1050 keys = [node.name for node in target_nodes] 

1051 values = await asyncio.gather( 

1052 *( 

1053 asyncio.create_task( 

1054 self._execute_command(node, *args, **kwargs) 

1055 ) 

1056 for node in target_nodes 

1057 ) 

1058 ) 

1059 if command in self.result_callbacks: 

1060 return self.result_callbacks[command]( 

1061 command, dict(zip(keys, values)), **kwargs 

1062 ) 

1063 return self._policies_callback_mapping[ 

1064 command_policies.response_policy 

1065 ](dict(zip(keys, values))) 

1066 except Exception as e: 

1067 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: 

1068 # The nodes and slots cache were should be reinitialized. 

1069 # Try again with the new cluster setup. 

1070 retry_attempts -= 1 

1071 failure_count += 1 

1072 last_failed_node_name = getattr(e, "last_failed_node_name", None) 

1073 

1074 if hasattr(e, "connection"): 

1075 await self._record_command_metric( 

1076 command_name=command, 

1077 duration_seconds=time.monotonic() - start_time, 

1078 connection=e.connection, 

1079 error=e, 

1080 ) 

1081 await self._record_error_metric( 

1082 error=e, 

1083 connection=e.connection, 

1084 retry_attempts=failure_count, 

1085 ) 

1086 continue 

1087 else: 

1088 # raise the exception 

1089 if hasattr(e, "connection"): 

1090 await self._record_error_metric( 

1091 error=e, 

1092 connection=e.connection, 

1093 retry_attempts=failure_count, 

1094 is_internal=False, 

1095 ) 

1096 raise e 

1097 

1098 async def _execute_command( 

1099 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any 

1100 ) -> Any: 

1101 asking = moved = False 

1102 redirect_addr = None 

1103 ttl = self.RedisClusterRequestTTL 

1104 command = args[0] 

1105 start_time = time.monotonic() 

1106 

1107 while ttl > 0: 

1108 ttl -= 1 

1109 try: 

1110 if asking: 

1111 target_node = self.get_node(node_name=redirect_addr) 

1112 await target_node.execute_command("ASKING") 

1113 asking = False 

1114 elif moved: 

1115 # MOVED occurred and the slots cache was updated, 

1116 # refresh the target node 

1117 slot = await self._determine_slot(*args) 

1118 target_node = self.nodes_manager.get_node_from_slot( 

1119 slot, 

1120 self.read_from_replicas and args[0] in READ_COMMANDS, 

1121 self.load_balancing_strategy 

1122 if args[0] in READ_COMMANDS 

1123 else None, 

1124 ) 

1125 moved = False 

1126 

1127 response = await target_node.execute_command(*args, **kwargs) 

1128 await self._record_command_metric( 

1129 command_name=command, 

1130 duration_seconds=time.monotonic() - start_time, 

1131 connection=target_node, 

1132 ) 

1133 return response 

1134 except BusyLoadingError as e: 

1135 e.connection = target_node 

1136 await self._record_command_metric( 

1137 command_name=command, 

1138 duration_seconds=time.monotonic() - start_time, 

1139 connection=target_node, 

1140 error=e, 

1141 ) 

1142 raise 

1143 except MaxConnectionsError as e: 

1144 # MaxConnectionsError indicates client-side resource exhaustion 

1145 # (too many connections in the pool), not a node failure. 

1146 # Don't treat this as a node failure - just re-raise the error 

1147 # without reinitializing the cluster. 

1148 e.connection = target_node 

1149 await self._record_command_metric( 

1150 command_name=command, 

1151 duration_seconds=time.monotonic() - start_time, 

1152 connection=target_node, 

1153 error=e, 

1154 ) 

1155 raise 

1156 except (ConnectionError, TimeoutError) as e: 

1157 # Connection retries are being handled in the node's 

1158 # Retry object. 

1159 # Mark active connections for reconnect and disconnect free ones 

1160 # This handles connection state (like READONLY) that may be stale 

1161 target_node.update_active_connections_for_reconnect() 

1162 await target_node.disconnect_free_connections() 

1163 

1164 # Move the failed node to the end of the cached nodes list 

1165 # so it's tried last during reinitialization 

1166 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name) 

1167 e.last_failed_node_name = target_node.name 

1168 

1169 # Signal that reinitialization is needed 

1170 # The retry loop will handle initialize() AND replace_default_node() 

1171 self._initialize = True 

1172 e.connection = target_node 

1173 await self._record_command_metric( 

1174 command_name=command, 

1175 duration_seconds=time.monotonic() - start_time, 

1176 connection=target_node, 

1177 error=e, 

1178 ) 

1179 raise 

1180 except (ClusterDownError, SlotNotCoveredError) as e: 

1181 # ClusterDownError can occur during a failover and to get 

1182 # self-healed, we will try to reinitialize the cluster layout 

1183 # and retry executing the command 

1184 

1185 # SlotNotCoveredError can occur when the cluster is not fully 

1186 # initialized or can be temporary issue. 

1187 # We will try to reinitialize the cluster topology 

1188 # and retry executing the command 

1189 

1190 await self.aclose() 

1191 await asyncio.sleep(0.25) 

1192 e.connection = target_node 

1193 await self._record_command_metric( 

1194 command_name=command, 

1195 duration_seconds=time.monotonic() - start_time, 

1196 connection=target_node, 

1197 error=e, 

1198 ) 

1199 raise 

1200 except MovedError as e: 

1201 # First, we will try to patch the slots/nodes cache with the 

1202 # redirected node output and try again. If MovedError exceeds 

1203 # 'reinitialize_steps' number of times, we will force 

1204 # reinitializing the tables, and then try again. 

1205 # 'reinitialize_steps' counter will increase faster when 

1206 # the same client object is shared between multiple threads. To 

1207 # reduce the frequency you can set this variable in the 

1208 # RedisCluster constructor. 

1209 self.reinitialize_counter += 1 

1210 if ( 

1211 self.reinitialize_steps 

1212 and self.reinitialize_counter % self.reinitialize_steps == 0 

1213 ): 

1214 await self.aclose() 

1215 # Reset the counter 

1216 self.reinitialize_counter = 0 

1217 else: 

1218 await self.nodes_manager.move_slot(e) 

1219 moved = True 

1220 await self._record_command_metric( 

1221 command_name=command, 

1222 duration_seconds=time.monotonic() - start_time, 

1223 connection=target_node, 

1224 error=e, 

1225 ) 

1226 await self._record_error_metric( 

1227 error=e, 

1228 connection=target_node, 

1229 ) 

1230 except AskError as e: 

1231 redirect_addr = get_node_name(host=e.host, port=e.port) 

1232 asking = True 

1233 await self._record_command_metric( 

1234 command_name=command, 

1235 duration_seconds=time.monotonic() - start_time, 

1236 connection=target_node, 

1237 error=e, 

1238 ) 

1239 await self._record_error_metric( 

1240 error=e, 

1241 connection=target_node, 

1242 ) 

1243 except TryAgainError as e: 

1244 if ttl < self.RedisClusterRequestTTL / 2: 

1245 await asyncio.sleep(0.05) 

1246 await self._record_command_metric( 

1247 command_name=command, 

1248 duration_seconds=time.monotonic() - start_time, 

1249 connection=target_node, 

1250 error=e, 

1251 ) 

1252 await self._record_error_metric( 

1253 error=e, 

1254 connection=target_node, 

1255 ) 

1256 except ResponseError as e: 

1257 e.connection = target_node 

1258 await self._record_command_metric( 

1259 command_name=command, 

1260 duration_seconds=time.monotonic() - start_time, 

1261 connection=target_node, 

1262 error=e, 

1263 ) 

1264 raise 

1265 except Exception as e: 

1266 e.connection = target_node 

1267 await self._record_command_metric( 

1268 command_name=command, 

1269 duration_seconds=time.monotonic() - start_time, 

1270 connection=target_node, 

1271 error=e, 

1272 ) 

1273 raise 

1274 

1275 e = ClusterError("TTL exhausted.") 

1276 e.connection = target_node 

1277 await self._record_command_metric( 

1278 command_name=command, 

1279 duration_seconds=time.monotonic() - start_time, 

1280 connection=target_node, 

1281 error=e, 

1282 ) 

1283 raise e 

1284 

1285 def pipeline( 

1286 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None 

1287 ) -> "ClusterPipeline": 

1288 """ 

1289 Create & return a new :class:`~.ClusterPipeline` object. 

1290 

1291 Cluster implementation of pipeline does not support transaction or shard_hint. 

1292 

1293 :raises RedisClusterException: if transaction or shard_hint are truthy values 

1294 """ 

1295 if shard_hint: 

1296 raise RedisClusterException("shard_hint is deprecated in cluster mode") 

1297 

1298 return ClusterPipeline(self, transaction) 

1299 

1300 def pubsub( 

1301 self, 

1302 node: Optional["ClusterNode"] = None, 

1303 host: Optional[str] = None, 

1304 port: Optional[int] = None, 

1305 **kwargs: Any, 

1306 ) -> "ClusterPubSub": 

1307 """ 

1308 Create and return a ClusterPubSub instance. 

1309 

1310 Allows passing a ClusterNode, or host&port, to get a pubsub instance 

1311 connected to the specified node 

1312 

1313 :param node: ClusterNode to connect to 

1314 :param host: Host of the node to connect to 

1315 :param port: Port of the node to connect to 

1316 :param kwargs: Additional keyword arguments 

1317 :return: ClusterPubSub instance 

1318 """ 

1319 return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) 

1320 

1321 def keyspace_notifications( 

1322 self, 

1323 key_prefix: Union[str, bytes, None] = None, 

1324 ignore_subscribe_messages: bool = True, 

1325 ) -> "AsyncClusterKeyspaceNotifications": 

1326 """ 

1327 Return an 

1328 :class:`~redis.asyncio.keyspace_notifications.AsyncClusterKeyspaceNotifications` 

1329 object for subscribing to keyspace and keyevent notifications across 

1330 all primary nodes in the cluster. 

1331 

1332 Note: Keyspace notifications must be enabled on all Redis cluster nodes 

1333 via the ``notify-keyspace-events`` configuration option. 

1334 

1335 Args: 

1336 key_prefix: Optional prefix to filter and strip from keys in 

1337 notifications. 

1338 ignore_subscribe_messages: If True, subscribe/unsubscribe 

1339 confirmations are not returned by 

1340 get_message/listen. 

1341 """ 

1342 from redis.asyncio.keyspace_notifications import ( 

1343 AsyncClusterKeyspaceNotifications, 

1344 ) 

1345 

1346 return AsyncClusterKeyspaceNotifications( 

1347 self, 

1348 key_prefix=key_prefix, 

1349 ignore_subscribe_messages=ignore_subscribe_messages, 

1350 ) 

1351 

1352 def lock( 

1353 self, 

1354 name: KeyT, 

1355 timeout: Optional[float] = None, 

1356 sleep: float = 0.1, 

1357 blocking: bool = True, 

1358 blocking_timeout: Optional[float] = None, 

1359 lock_class: Optional[Type[Lock]] = None, 

1360 thread_local: bool = True, 

1361 raise_on_release_error: bool = True, 

1362 ) -> Lock: 

1363 """ 

1364 Return a new Lock object using key ``name`` that mimics 

1365 the behavior of threading.Lock. 

1366 

1367 If specified, ``timeout`` indicates a maximum life for the lock. 

1368 By default, it will remain locked until release() is called. 

1369 

1370 ``sleep`` indicates the amount of time to sleep per loop iteration 

1371 when the lock is in blocking mode and another client is currently 

1372 holding the lock. 

1373 

1374 ``blocking`` indicates whether calling ``acquire`` should block until 

1375 the lock has been acquired or to fail immediately, causing ``acquire`` 

1376 to return False and the lock not being acquired. Defaults to True. 

1377 Note this value can be overridden by passing a ``blocking`` 

1378 argument to ``acquire``. 

1379 

1380 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

1381 spend trying to acquire the lock. A value of ``None`` indicates 

1382 continue trying forever. ``blocking_timeout`` can be specified as a 

1383 float or integer, both representing the number of seconds to wait. 

1384 

1385 ``lock_class`` forces the specified lock implementation. Note that as 

1386 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is 

1387 a Lua-based lock). So, it's unlikely you'll need this parameter, unless 

1388 you have created your own custom lock class. 

1389 

1390 ``thread_local`` indicates whether the lock token is placed in 

1391 thread-local storage. By default, the token is placed in thread local 

1392 storage so that a thread only sees its token, not a token set by 

1393 another thread. Consider the following timeline: 

1394 

1395 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

1396 thread-1 sets the token to "abc" 

1397 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

1398 Lock instance. 

1399 time: 5, thread-1 has not yet completed. redis expires the lock 

1400 key. 

1401 time: 5, thread-2 acquired `my-lock` now that it's available. 

1402 thread-2 sets the token to "xyz" 

1403 time: 6, thread-1 finishes its work and calls release(). if the 

1404 token is *not* stored in thread local storage, then 

1405 thread-1 would see the token value as "xyz" and would be 

1406 able to successfully release the thread-2's lock. 

1407 

1408 ``raise_on_release_error`` indicates whether to raise an exception when 

1409 the lock is no longer owned when exiting the context manager. By default, 

1410 this is True, meaning an exception will be raised. If False, the warning 

1411 will be logged and the exception will be suppressed. 

1412 

1413 In some use cases it's necessary to disable thread local storage. For 

1414 example, if you have code where one thread acquires a lock and passes 

1415 that lock instance to a worker thread to release later. If thread 

1416 local storage isn't disabled in this case, the worker thread won't see 

1417 the token set by the thread that acquired the lock. Our assumption 

1418 is that these cases aren't common and as such default to using 

1419 thread local storage.""" 

1420 if lock_class is None: 

1421 lock_class = Lock 

1422 return lock_class( 

1423 self, 

1424 name, 

1425 timeout=timeout, 

1426 sleep=sleep, 

1427 blocking=blocking, 

1428 blocking_timeout=blocking_timeout, 

1429 thread_local=thread_local, 

1430 raise_on_release_error=raise_on_release_error, 

1431 ) 

1432 

1433 async def transaction( 

1434 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs 

1435 ): 

1436 """ 

1437 Convenience method for executing the callable `func` as a transaction 

1438 while watching all keys specified in `watches`. The 'func' callable 

1439 should expect a single argument which is a Pipeline object. 

1440 """ 

1441 shard_hint = kwargs.pop("shard_hint", None) 

1442 value_from_callable = kwargs.pop("value_from_callable", False) 

1443 watch_delay = kwargs.pop("watch_delay", None) 

1444 async with self.pipeline(True, shard_hint) as pipe: 

1445 while True: 

1446 try: 

1447 if watches: 

1448 await pipe.watch(*watches) 

1449 func_value = await func(pipe) 

1450 exec_value = await pipe.execute() 

1451 return func_value if value_from_callable else exec_value 

1452 except WatchError: 

1453 if watch_delay is not None and watch_delay > 0: 

1454 time.sleep(watch_delay) 

1455 continue 

1456 

1457 

1458class ClusterNode: 

1459 """ 

1460 Create a new ClusterNode. 

1461 

1462 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` 

1463 objects for the (host, port). 

1464 """ 

1465 

1466 __slots__ = ( 

1467 "_background_tasks", 

1468 "_connections", 

1469 "_free", 

1470 "_lock", 

1471 "_event_dispatcher", 

1472 "connection_class", 

1473 "connection_kwargs", 

1474 "host", 

1475 "max_connections", 

1476 "name", 

1477 "port", 

1478 "response_callbacks", 

1479 "server_type", 

1480 ) 

1481 

1482 def __init__( 

1483 self, 

1484 host: str, 

1485 port: Union[str, int], 

1486 server_type: Optional[str] = None, 

1487 *, 

1488 max_connections: int = 100, 

1489 connection_class: Type[Connection] = Connection, 

1490 **connection_kwargs: Any, 

1491 ) -> None: 

1492 if host == "localhost": 

1493 host = socket.gethostbyname(host) 

1494 

1495 connection_kwargs["host"] = host 

1496 connection_kwargs["port"] = port 

1497 self.host = host 

1498 self.port = port 

1499 self.name = get_node_name(host, port) 

1500 self.server_type = server_type 

1501 

1502 self.max_connections = max_connections 

1503 self.connection_class = connection_class 

1504 self.connection_kwargs = connection_kwargs 

1505 self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) 

1506 

1507 self._connections: List[Connection] = [] 

1508 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) 

1509 self._background_tasks: Set[asyncio.Task] = set() 

1510 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1511 if self._event_dispatcher is None: 

1512 self._event_dispatcher = EventDispatcher() 

1513 

1514 def __repr__(self) -> str: 

1515 return ( 

1516 f"[host={self.host}, port={self.port}, " 

1517 f"name={self.name}, server_type={self.server_type}]" 

1518 ) 

1519 

1520 def __eq__(self, obj: Any) -> bool: 

1521 return isinstance(obj, ClusterNode) and obj.name == self.name 

1522 

1523 def __hash__(self) -> int: 

1524 return hash(self.name) 

1525 

1526 _DEL_MESSAGE = "Unclosed ClusterNode object" 

1527 

1528 def __del__( 

1529 self, 

1530 _warn: Any = warnings.warn, 

1531 _grl: Any = asyncio.get_running_loop, 

1532 ) -> None: 

1533 for connection in self._connections: 

1534 if connection.is_connected: 

1535 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

1536 

1537 try: 

1538 context = {"client": self, "message": self._DEL_MESSAGE} 

1539 _grl().call_exception_handler(context) 

1540 except RuntimeError: 

1541 pass 

1542 break 

1543 

1544 async def disconnect(self) -> None: 

1545 ret = await asyncio.gather( 

1546 *( 

1547 asyncio.create_task(connection.disconnect()) 

1548 for connection in self._connections 

1549 ), 

1550 return_exceptions=True, 

1551 ) 

1552 exc = next((res for res in ret if isinstance(res, Exception)), None) 

1553 if exc: 

1554 raise exc 

1555 

1556 def acquire_connection(self) -> Connection: 

1557 try: 

1558 return self._free.popleft() 

1559 except IndexError: 

1560 if len(self._connections) < self.max_connections: 

1561 # We are configuring the connection pool not to retry 

1562 # connections on lower level clients to avoid retrying 

1563 # connections to nodes that are not reachable 

1564 # and to avoid blocking the connection pool. 

1565 # The only error that will have some handling in the lower 

1566 # level clients is ConnectionError which will trigger disconnection 

1567 # of the socket. 

1568 # The retries will be handled on cluster client level 

1569 # where we will have proper handling of the cluster topology 

1570 retry = Retry( 

1571 backoff=NoBackoff(), 

1572 retries=0, 

1573 supported_errors=(ConnectionError,), 

1574 ) 

1575 connection_kwargs = self.connection_kwargs.copy() 

1576 connection_kwargs["retry"] = retry 

1577 connection = self.connection_class(**connection_kwargs) 

1578 self._connections.append(connection) 

1579 return connection 

1580 

1581 raise MaxConnectionsError() 

1582 

1583 async def disconnect_if_needed(self, connection: Connection) -> None: 

1584 """ 

1585 Disconnect a connection if it's marked for reconnect. 

1586 This implements lazy disconnection to avoid race conditions. 

1587 The connection will auto-reconnect on next use. 

1588 """ 

1589 if connection.should_reconnect(): 

1590 await connection.disconnect() 

1591 

1592 def release(self, connection: Connection) -> None: 

1593 """ 

1594 Release connection back to free queue. 

1595 If the connection is marked for reconnect, disconnect it before 

1596 returning it to the free queue. 

1597 """ 

1598 if connection.should_reconnect(): 

1599 task = asyncio.create_task(self._disconnect_and_release(connection)) 

1600 self._background_tasks.add(task) 

1601 task.add_done_callback(self._background_tasks.discard) 

1602 return 

1603 self._free.append(connection) 

1604 

1605 async def _disconnect_and_release(self, connection: Connection) -> None: 

1606 try: 

1607 await connection.disconnect() 

1608 except Exception as exc: 

1609 logger.debug( 

1610 "disconnecting released cluster connection failed: %r", 

1611 exc, 

1612 exc_info=True, 

1613 ) 

1614 try: 

1615 self._connections.remove(connection) 

1616 except ValueError: 

1617 pass 

1618 return 

1619 

1620 self._free.append(connection) 

1621 

1622 def get_encoder(self) -> Encoder: 

1623 """Return an :class:`Encoder` derived from this node's connection kwargs.""" 

1624 kwargs = self.connection_kwargs 

1625 encoder_class = kwargs.get("encoder_class", Encoder) 

1626 return encoder_class( 

1627 encoding=kwargs.get("encoding", "utf-8"), 

1628 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1629 decode_responses=kwargs.get("decode_responses", False), 

1630 ) 

1631 

1632 def update_active_connections_for_reconnect(self) -> None: 

1633 """ 

1634 Mark all in-use (active) connections for reconnect. 

1635 In-use connections are those in _connections but not currently in _free. 

1636 They will be disconnected after their current operation completes. 

1637 """ 

1638 free_set = set(self._free) 

1639 for connection in self._connections: 

1640 if connection not in free_set: 

1641 connection.mark_for_reconnect() 

1642 

1643 async def disconnect_free_connections(self) -> None: 

1644 """ 

1645 Disconnect all free/idle connections in the pool. 

1646 This is useful after topology changes (e.g., failover) to clear 

1647 stale connection state like READONLY mode. 

1648 The connections remain in the pool and will reconnect on next use. 

1649 """ 

1650 if self._free: 

1651 # Take a snapshot to avoid issues if _free changes during await 

1652 await asyncio.gather( 

1653 *(connection.disconnect() for connection in tuple(self._free)), 

1654 return_exceptions=True, 

1655 ) 

1656 

1657 async def parse_response( 

1658 self, connection: Connection, command: str, **kwargs: Any 

1659 ) -> Any: 

1660 try: 

1661 if NEVER_DECODE in kwargs: 

1662 response = await connection.read_response(disable_decoding=True) 

1663 kwargs.pop(NEVER_DECODE) 

1664 else: 

1665 response = await connection.read_response() 

1666 except ResponseError: 

1667 if EMPTY_RESPONSE in kwargs: 

1668 return kwargs[EMPTY_RESPONSE] 

1669 raise 

1670 

1671 if EMPTY_RESPONSE in kwargs: 

1672 kwargs.pop(EMPTY_RESPONSE) 

1673 

1674 # Remove keys entry, it needs only for cache. 

1675 kwargs.pop("keys", None) 

1676 

1677 # Return response 

1678 if command in self.response_callbacks: 

1679 return self.response_callbacks[command](response, **kwargs) 

1680 

1681 return response 

1682 

1683 async def execute_command(self, *args: Any, **kwargs: Any) -> Any: 

1684 # Acquire connection 

1685 connection = self.acquire_connection() 

1686 # Handle lazy disconnect for connections marked for reconnect 

1687 await self.disconnect_if_needed(connection) 

1688 

1689 # Execute command 

1690 await connection.send_packed_command(connection.pack_command(*args)) 

1691 

1692 # Read response 

1693 try: 

1694 return await self.parse_response(connection, args[0], **kwargs) 

1695 finally: 

1696 await self.disconnect_if_needed(connection) 

1697 # Release connection 

1698 self.release(connection) 

1699 

1700 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: 

1701 # Acquire connection 

1702 connection = self.acquire_connection() 

1703 # Handle lazy disconnect for connections marked for reconnect 

1704 await self.disconnect_if_needed(connection) 

1705 

1706 # Execute command 

1707 await connection.send_packed_command( 

1708 connection.pack_commands(cmd.args for cmd in commands) 

1709 ) 

1710 

1711 # Read responses 

1712 ret = False 

1713 for cmd in commands: 

1714 try: 

1715 cmd.result = await self.parse_response( 

1716 connection, cmd.args[0], **cmd.kwargs 

1717 ) 

1718 except Exception as e: 

1719 cmd.result = e 

1720 ret = True 

1721 

1722 # Release connection 

1723 await self.disconnect_if_needed(connection) 

1724 self.release(connection) 

1725 

1726 return ret 

1727 

1728 async def re_auth_callback(self, token: TokenInterface): 

1729 tmp_queue = collections.deque() 

1730 while self._free: 

1731 conn = self._free.popleft() 

1732 await conn.retry.call_with_retry( 

1733 lambda: conn.send_command( 

1734 "AUTH", token.try_get("oid"), token.get_value() 

1735 ), 

1736 lambda error: self._mock(error), 

1737 ) 

1738 await conn.retry.call_with_retry( 

1739 lambda: conn.read_response(), lambda error: self._mock(error) 

1740 ) 

1741 tmp_queue.append(conn) 

1742 

1743 while tmp_queue: 

1744 conn = tmp_queue.popleft() 

1745 self._free.append(conn) 

1746 

1747 async def _mock(self, error: RedisError): 

1748 """ 

1749 Dummy functions, needs to be passed as error callback to retry object. 

1750 :param error: 

1751 :return: 

1752 """ 

1753 pass 

1754 

1755 

1756class NodesManager: 

1757 __slots__ = ( 

1758 "_dynamic_startup_nodes", 

1759 "_event_dispatcher", 

1760 "_background_tasks", 

1761 "connection_kwargs", 

1762 "default_node", 

1763 "nodes_cache", 

1764 "_epoch", 

1765 "read_load_balancer", 

1766 "_initialize_lock", 

1767 "require_full_coverage", 

1768 "slots_cache", 

1769 "startup_nodes", 

1770 "address_remap", 

1771 ) 

1772 

1773 def __init__( 

1774 self, 

1775 startup_nodes: List["ClusterNode"], 

1776 require_full_coverage: bool, 

1777 connection_kwargs: Dict[str, Any], 

1778 dynamic_startup_nodes: bool = True, 

1779 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

1780 event_dispatcher: Optional[EventDispatcher] = None, 

1781 ) -> None: 

1782 self.startup_nodes = {node.name: node for node in startup_nodes} 

1783 self.require_full_coverage = require_full_coverage 

1784 self.connection_kwargs = connection_kwargs 

1785 self.address_remap = address_remap 

1786 

1787 self.default_node: "ClusterNode" = None 

1788 self.nodes_cache: Dict[str, "ClusterNode"] = {} 

1789 self.slots_cache: Dict[int, List["ClusterNode"]] = {} 

1790 self._epoch: int = 0 

1791 self.read_load_balancer = LoadBalancer() 

1792 self._initialize_lock: asyncio.Lock = asyncio.Lock() 

1793 

1794 self._background_tasks: Set[asyncio.Task] = set() 

1795 self._dynamic_startup_nodes: bool = dynamic_startup_nodes 

1796 if event_dispatcher is None: 

1797 self._event_dispatcher = EventDispatcher() 

1798 else: 

1799 self._event_dispatcher = event_dispatcher 

1800 

1801 def get_node( 

1802 self, 

1803 host: Optional[str] = None, 

1804 port: Optional[int] = None, 

1805 node_name: Optional[str] = None, 

1806 ) -> Optional["ClusterNode"]: 

1807 if host and port: 

1808 # the user passed host and port 

1809 if host == "localhost": 

1810 host = socket.gethostbyname(host) 

1811 return self.nodes_cache.get(get_node_name(host=host, port=port)) 

1812 elif node_name: 

1813 return self.nodes_cache.get(node_name) 

1814 else: 

1815 raise DataError( 

1816 "get_node requires one of the following: 1. node name 2. host and port" 

1817 ) 

1818 

1819 def set_nodes( 

1820 self, 

1821 old: Dict[str, "ClusterNode"], 

1822 new: Dict[str, "ClusterNode"], 

1823 remove_old: bool = False, 

1824 ) -> None: 

1825 if remove_old: 

1826 for name in list(old.keys()): 

1827 if name not in new: 

1828 # Node is removed from cache before disconnect starts, 

1829 # so it won't be found in lookups during disconnect 

1830 # Mark active connections so in-flight commands can 

1831 # finish, then disconnect them when their current 

1832 # operation completes. Free connections can be 

1833 # disconnected immediately. 

1834 removed_node = old.pop(name) 

1835 removed_node.update_active_connections_for_reconnect() 

1836 task = asyncio.create_task( 

1837 removed_node.disconnect_free_connections() 

1838 ) 

1839 self._background_tasks.add(task) 

1840 task.add_done_callback(self._background_tasks.discard) 

1841 

1842 for name, node in new.items(): 

1843 if name in old: 

1844 # Preserve the existing node but mark connections for reconnect. 

1845 # This method is sync so we can't call disconnect_free_connections() 

1846 # which is async. Instead, we mark free connections for reconnect 

1847 # and they will be lazily disconnected when acquired via 

1848 # disconnect_if_needed() to avoid race conditions. 

1849 # TODO: Make this method async in the next major release to allow 

1850 # immediate disconnection of free connections. 

1851 existing_node = old[name] 

1852 existing_node.server_type = node.server_type 

1853 existing_node.update_active_connections_for_reconnect() 

1854 for conn in existing_node._free: 

1855 conn.mark_for_reconnect() 

1856 continue 

1857 # New node is detected and should be added to the pool 

1858 old[name] = node 

1859 

1860 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None: 

1861 """ 

1862 Move a failing node to the end of startup_nodes and nodes_cache so it's 

1863 tried last during reinitialization and when selecting the default node. 

1864 If the node is not in the respective list, nothing is done. 

1865 """ 

1866 # Move in startup_nodes 

1867 if node_name in self.startup_nodes and len(self.startup_nodes) > 1: 

1868 node = self.startup_nodes.pop(node_name) 

1869 self.startup_nodes[node_name] = node # Re-insert at end 

1870 

1871 # Move in nodes_cache - this affects get_nodes_by_server_type ordering 

1872 # which is used to select the default_node during initialize() 

1873 if node_name in self.nodes_cache and len(self.nodes_cache) > 1: 

1874 node = self.nodes_cache.pop(node_name) 

1875 self.nodes_cache[node_name] = node # Re-insert at end 

1876 

1877 async def move_slot(self, e: AskError | MovedError): 

1878 node_changed = False 

1879 redirected_node = self.get_node(host=e.host, port=e.port) 

1880 if redirected_node: 

1881 # The node already exists 

1882 if redirected_node.server_type != PRIMARY: 

1883 # Update the node's server type 

1884 redirected_node.server_type = PRIMARY 

1885 else: 

1886 # This is a new node, we will add it to the nodes cache 

1887 redirected_node = ClusterNode( 

1888 e.host, e.port, PRIMARY, **self.connection_kwargs 

1889 ) 

1890 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) 

1891 slot_nodes = self.slots_cache[e.slot_id] 

1892 if redirected_node not in slot_nodes: 

1893 # The new slot owner is a new server, or a server from a different 

1894 # shard. We need to remove all current nodes from the slot's list 

1895 # (including replications) and add just the new node. 

1896 self.slots_cache[e.slot_id] = [redirected_node] 

1897 node_changed = True 

1898 elif redirected_node is not slot_nodes[0]: 

1899 # The MOVED error resulted from a failover, and the new slot owner 

1900 # had previously been a replica. 

1901 old_primary = slot_nodes[0] 

1902 # Update the old primary to be a replica and add it to the end of 

1903 # the slot's node list 

1904 old_primary.server_type = REPLICA 

1905 slot_nodes.append(old_primary) 

1906 # Remove the old replica, which is now a primary, from the slot's 

1907 # node list 

1908 slot_nodes.remove(redirected_node) 

1909 # Override the old primary with the new one 

1910 slot_nodes[0] = redirected_node 

1911 if self.default_node == old_primary: 

1912 # Update the default node with the new primary 

1913 self.default_node = redirected_node 

1914 node_changed = True 

1915 # else: circular MOVED to current primary -> no-op 

1916 # Dispatch so listeners can run shard-pubsub reconciliation; skipped on 

1917 # the no-op branch to avoid needless walks under MOVED storms. A 

1918 # listener must not break slots-cache refresh; log and continue so a 

1919 # single buggy listener cannot starve the rest. 

1920 if node_changed: 

1921 try: 

1922 await self._event_dispatcher.dispatch_async( 

1923 AsyncAfterSlotsCacheRefreshEvent() 

1924 ) 

1925 except Exception as exc: 

1926 # Don't shadow the method parameter ``e``: ``except as`` binds 

1927 # the listener exception in the function scope and ``del``s 

1928 # the name on block exit (PEP 3134), which would also wipe 

1929 # out the original AskError/MovedError parameter. 

1930 logger.exception( 

1931 "listener raised during slots-cache refresh: %s: %s", 

1932 type(exc).__name__, 

1933 exc, 

1934 ) 

1935 

1936 def get_node_from_slot( 

1937 self, 

1938 slot: int, 

1939 read_from_replicas: bool = False, 

1940 load_balancing_strategy=None, 

1941 ) -> "ClusterNode": 

1942 if read_from_replicas is True and load_balancing_strategy is None: 

1943 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN 

1944 

1945 try: 

1946 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: 

1947 # get the server index using the strategy defined in load_balancing_strategy 

1948 primary_name = self.slots_cache[slot][0].name 

1949 node_idx = self.read_load_balancer.get_server_index( 

1950 primary_name, len(self.slots_cache[slot]), load_balancing_strategy 

1951 ) 

1952 return self.slots_cache[slot][node_idx] 

1953 return self.slots_cache[slot][0] 

1954 except (IndexError, TypeError): 

1955 raise SlotNotCoveredError( 

1956 f'Slot "{slot}" not covered by the cluster. ' 

1957 f'"require_full_coverage={self.require_full_coverage}"' 

1958 ) 

1959 

1960 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: 

1961 return [ 

1962 node 

1963 for node in self.nodes_cache.values() 

1964 if node.server_type == server_type 

1965 ] 

1966 

1967 async def initialize( 

1968 self, 

1969 additional_startup_nodes_info: Optional[List[Tuple[str, int]]] = None, 

1970 last_failed_node_name: Optional[str] = None, 

1971 ) -> None: 

1972 self.read_load_balancer.reset() 

1973 tmp_nodes_cache: Dict[str, "ClusterNode"] = {} 

1974 tmp_slots: Dict[int, List["ClusterNode"]] = {} 

1975 disagreements = [] 

1976 startup_nodes_reachable = False 

1977 fully_covered = False 

1978 exception = None 

1979 epoch = self._epoch 

1980 if additional_startup_nodes_info is None: 

1981 additional_startup_nodes_info = [] 

1982 

1983 async with self._initialize_lock: 

1984 if self._epoch != epoch: 

1985 # another initialize call has already reinitialized the 

1986 # nodes since we started waiting for the lock; 

1987 # we don't need to do it again. 

1988 return 

1989 

1990 # Copy to a list to prevent RuntimeError if self.startup_nodes 

1991 # is modified during iteration, then shuffle the iteration order. 

1992 startup_nodes = list(self.startup_nodes.values()) 

1993 deferred_failed_nodes = [] 

1994 if last_failed_node_name is not None: 

1995 for index, node in enumerate(startup_nodes): 

1996 if node.name == last_failed_node_name: 

1997 deferred_failed_nodes.append(startup_nodes.pop(index)) 

1998 break 

1999 if len(startup_nodes) > 1: 

2000 # Vary which startup node is queried first so clients do not 

2001 # all reinitialize through the same node. 

2002 random.shuffle(startup_nodes) 

2003 additional_startup_nodes = [ 

2004 ClusterNode(host, port, **self.connection_kwargs) 

2005 for host, port in additional_startup_nodes_info 

2006 ] 

2007 if last_failed_node_name is not None: 

2008 for index, node in enumerate(additional_startup_nodes): 

2009 if node.name == last_failed_node_name: 

2010 if not deferred_failed_nodes: 

2011 deferred_failed_nodes.append(node) 

2012 additional_startup_nodes.pop(index) 

2013 break 

2014 for startup_node in chain( 

2015 startup_nodes, 

2016 additional_startup_nodes, 

2017 deferred_failed_nodes, 

2018 ): 

2019 try: 

2020 # Make sure cluster mode is enabled on this node 

2021 try: 

2022 self._event_dispatcher.dispatch( 

2023 AfterAsyncClusterInstantiationEvent( 

2024 self.nodes_cache, 

2025 self.connection_kwargs.get("credential_provider", None), 

2026 ) 

2027 ) 

2028 cluster_slots = await startup_node.execute_command( 

2029 "CLUSTER SLOTS" 

2030 ) 

2031 except ResponseError: 

2032 raise RedisClusterException( 

2033 "Cluster mode is not enabled on this node" 

2034 ) 

2035 startup_nodes_reachable = True 

2036 except Exception as e: 

2037 # Try the next startup node. 

2038 # The exception is saved and raised only if we have no more nodes. 

2039 exception = e 

2040 continue 

2041 

2042 # CLUSTER SLOTS command results in the following output: 

2043 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] 

2044 # where each node contains the following list: [IP, port, node_id] 

2045 # Therefore, cluster_slots[0][2][0] will be the IP address of the 

2046 # primary node of the first slot section. 

2047 # If there's only one server in the cluster, its ``host`` is '' 

2048 # Fix it to the host in startup_nodes 

2049 if ( 

2050 len(cluster_slots) == 1 

2051 and not cluster_slots[0][2][0] 

2052 and len(self.startup_nodes) == 1 

2053 ): 

2054 cluster_slots[0][2][0] = startup_node.host 

2055 

2056 for slot in cluster_slots: 

2057 for i in range(2, len(slot)): 

2058 slot[i] = [str_if_bytes(val) for val in slot[i]] 

2059 primary_node = slot[2] 

2060 host = primary_node[0] 

2061 if host == "": 

2062 host = startup_node.host 

2063 port = int(primary_node[1]) 

2064 host, port = self.remap_host_port(host, port) 

2065 

2066 nodes_for_slot = [] 

2067 

2068 target_node = tmp_nodes_cache.get(get_node_name(host, port)) 

2069 if not target_node: 

2070 target_node = ClusterNode( 

2071 host, port, PRIMARY, **self.connection_kwargs 

2072 ) 

2073 # add this node to the nodes cache 

2074 tmp_nodes_cache[target_node.name] = target_node 

2075 nodes_for_slot.append(target_node) 

2076 

2077 replica_nodes = slot[3:] 

2078 for replica_node in replica_nodes: 

2079 host = replica_node[0] 

2080 port = replica_node[1] 

2081 host, port = self.remap_host_port(host, port) 

2082 

2083 target_replica_node = tmp_nodes_cache.get( 

2084 get_node_name(host, port) 

2085 ) 

2086 if not target_replica_node: 

2087 target_replica_node = ClusterNode( 

2088 host, port, REPLICA, **self.connection_kwargs 

2089 ) 

2090 # add this node to the nodes cache 

2091 tmp_nodes_cache[target_replica_node.name] = target_replica_node 

2092 nodes_for_slot.append(target_replica_node) 

2093 

2094 for i in range(int(slot[0]), int(slot[1]) + 1): 

2095 if i not in tmp_slots: 

2096 tmp_slots[i] = nodes_for_slot 

2097 else: 

2098 # Validate that 2 nodes want to use the same slot cache 

2099 # setup 

2100 tmp_slot = tmp_slots[i][0] 

2101 if tmp_slot.name != target_node.name: 

2102 disagreements.append( 

2103 f"{tmp_slot.name} vs {target_node.name} on slot: {i}" 

2104 ) 

2105 

2106 if len(disagreements) > 5: 

2107 raise RedisClusterException( 

2108 f"startup_nodes could not agree on a valid " 

2109 f"slots cache: {', '.join(disagreements)}" 

2110 ) 

2111 

2112 # Validate if all slots are covered or if we should try next startup node 

2113 fully_covered = True 

2114 for i in range(REDIS_CLUSTER_HASH_SLOTS): 

2115 if i not in tmp_slots: 

2116 fully_covered = False 

2117 break 

2118 if fully_covered: 

2119 break 

2120 

2121 if not startup_nodes_reachable: 

2122 raise RedisClusterException( 

2123 f"Redis Cluster cannot be connected. Please provide at least " 

2124 f"one reachable node: {str(exception)}" 

2125 ) from exception 

2126 

2127 # Check if the slots are not fully covered 

2128 if not fully_covered and self.require_full_coverage: 

2129 # Despite the requirement that the slots be covered, there 

2130 # isn't a full coverage 

2131 raise RedisClusterException( 

2132 f"All slots are not covered after query all startup_nodes. " 

2133 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " 

2134 f"covered..." 

2135 ) 

2136 

2137 # Set the tmp variables to the real variables 

2138 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) 

2139 # tmp_slots was built from CLUSTER SLOTS responses and can contain 

2140 # newly-created ClusterNode objects for nodes we already know about. 

2141 # Rebuild the slots cache with the preserved nodes_cache instances 

2142 # so existing per-node connection pools stay in use after refresh. 

2143 # Keep the shared node-list-per-slot-range shape from tmp_slots to 

2144 # avoid allocating a separate list for every slot. 

2145 node_lists_by_id: Dict[int, List["ClusterNode"]] = {} 

2146 new_slots_cache: Dict[int, List["ClusterNode"]] = {} 

2147 for slot, nodes in tmp_slots.items(): 

2148 node_list_id = id(nodes) 

2149 slot_nodes = node_lists_by_id.get(node_list_id) 

2150 if slot_nodes is None: 

2151 slot_nodes = [self.nodes_cache[node.name] for node in nodes] 

2152 node_lists_by_id[node_list_id] = slot_nodes 

2153 new_slots_cache[slot] = slot_nodes 

2154 self.slots_cache = new_slots_cache 

2155 

2156 if self._dynamic_startup_nodes: 

2157 # Populate the startup nodes with all discovered nodes 

2158 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) 

2159 

2160 # Set the default node 

2161 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] 

2162 self._epoch += 1 

2163 # Dispatch so listeners (e.g. ClusterPubSub) can reconcile per-node 

2164 # state after slot ownership may have changed. A listener must not 

2165 # break slots-cache refresh; log and continue so a single buggy 

2166 # listener cannot starve the rest. 

2167 try: 

2168 await self._event_dispatcher.dispatch_async( 

2169 AsyncAfterSlotsCacheRefreshEvent() 

2170 ) 

2171 except Exception as e: 

2172 logger.exception( 

2173 "listener raised during slots-cache refresh: %s: %s", 

2174 type(e).__name__, 

2175 e, 

2176 ) 

2177 

2178 async def aclose(self, attr: str = "nodes_cache") -> None: 

2179 self.default_node = None 

2180 await asyncio.gather( 

2181 *( 

2182 asyncio.create_task(node.disconnect()) 

2183 for node in getattr(self, attr).values() 

2184 ) 

2185 ) 

2186 

2187 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: 

2188 """ 

2189 Remap the host and port returned from the cluster to a different 

2190 internal value. Useful if the client is not connecting directly 

2191 to the cluster. 

2192 """ 

2193 if self.address_remap: 

2194 return self.address_remap((host, port)) 

2195 return host, port 

2196 

2197 

2198class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

2199 """ 

2200 Create a new ClusterPipeline object. 

2201 

2202 Usage:: 

2203 

2204 result = await ( 

2205 rc.pipeline() 

2206 .set("A", 1) 

2207 .get("A") 

2208 .hset("K", "F", "V") 

2209 .hgetall("K") 

2210 .mset_nonatomic({"A": 2, "B": 3}) 

2211 .get("A") 

2212 .get("B") 

2213 .delete("A", "B", "K") 

2214 .execute() 

2215 ) 

2216 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] 

2217 

2218 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which 

2219 are split across multiple nodes, you'll get multiple results for them in the array. 

2220 

2221 Retryable errors: 

2222 - :class:`~.ClusterDownError` 

2223 - :class:`~.ConnectionError` 

2224 - :class:`~.TimeoutError` 

2225 

2226 Redirection errors: 

2227 - :class:`~.TryAgainError` 

2228 - :class:`~.MovedError` 

2229 - :class:`~.AskError` 

2230 

2231 :param client: 

2232 | Existing :class:`~.RedisCluster` client 

2233 """ 

2234 

2235 __slots__ = ( 

2236 "cluster_client", 

2237 "_transaction", 

2238 "_execution_strategy", 

2239 ) 

2240 

2241 # Type discrimination marker for @overload self-type pattern 

2242 _is_async_client: Literal[True] = True 

2243 

2244 def __init__( 

2245 self, client: RedisCluster, transaction: Optional[bool] = None 

2246 ) -> None: 

2247 self.cluster_client = client 

2248 self._transaction = transaction 

2249 self._execution_strategy: ExecutionStrategy = ( 

2250 PipelineStrategy(self) 

2251 if not self._transaction 

2252 else TransactionStrategy(self) 

2253 ) 

2254 

2255 @property 

2256 def nodes_manager(self) -> "NodesManager": 

2257 """Get the nodes manager from the cluster client.""" 

2258 return self.cluster_client.nodes_manager 

2259 

2260 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

2261 """Set a custom response callback on the cluster client.""" 

2262 self.cluster_client.set_response_callback(command, callback) 

2263 

2264 async def initialize(self) -> "ClusterPipeline": 

2265 await self._execution_strategy.initialize() 

2266 return self 

2267 

2268 async def __aenter__(self) -> "ClusterPipeline": 

2269 return await self.initialize() 

2270 

2271 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: 

2272 await self.reset() 

2273 

2274 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: 

2275 return self.initialize().__await__() 

2276 

2277 def __bool__(self) -> bool: 

2278 "Pipeline instances should always evaluate to True on Python 3+" 

2279 return True 

2280 

2281 def __len__(self) -> int: 

2282 return len(self._execution_strategy) 

2283 

2284 def execute_command( 

2285 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2286 ) -> "ClusterPipeline": 

2287 """ 

2288 Append a raw command to the pipeline. 

2289 

2290 :param args: 

2291 | Raw command args 

2292 :param kwargs: 

2293 

2294 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

2295 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

2296 - Rest of the kwargs are passed to the Redis connection 

2297 """ 

2298 return self._execution_strategy.execute_command(*args, **kwargs) 

2299 

2300 async def execute( 

2301 self, raise_on_error: bool = True, allow_redirections: bool = True 

2302 ) -> List[Any]: 

2303 """ 

2304 Execute the pipeline. 

2305 

2306 It will retry the commands as specified by retries specified in :attr:`retry` 

2307 & then raise an exception. 

2308 

2309 :param raise_on_error: 

2310 | Raise the first error if there are any errors 

2311 :param allow_redirections: 

2312 | Whether to retry each failed command individually in case of redirection 

2313 errors 

2314 

2315 :raises RedisClusterException: if target_nodes is not provided & the command 

2316 can't be mapped to a slot 

2317 """ 

2318 try: 

2319 return await self._execution_strategy.execute( 

2320 raise_on_error, allow_redirections 

2321 ) 

2322 finally: 

2323 await self.reset() 

2324 

2325 def _split_command_across_slots( 

2326 self, command: str, *keys: KeyT 

2327 ) -> "ClusterPipeline": 

2328 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values(): 

2329 self.execute_command(command, *slot_keys) 

2330 

2331 return self 

2332 

2333 async def reset(self): 

2334 """ 

2335 Reset back to empty pipeline. 

2336 """ 

2337 await self._execution_strategy.reset() 

2338 

2339 def multi(self): 

2340 """ 

2341 Start a transactional block of the pipeline after WATCH commands 

2342 are issued. End the transactional block with `execute`. 

2343 """ 

2344 self._execution_strategy.multi() 

2345 

2346 async def discard(self): 

2347 """ """ 

2348 await self._execution_strategy.discard() 

2349 

2350 async def watch(self, *names): 

2351 """Watches the values at keys ``names``""" 

2352 await self._execution_strategy.watch(*names) 

2353 

2354 async def unwatch(self): 

2355 """Unwatches all previously specified keys""" 

2356 await self._execution_strategy.unwatch() 

2357 

2358 async def unlink(self, *names): 

2359 await self._execution_strategy.unlink(*names) 

2360 

2361 def mset_nonatomic( 

2362 self, mapping: Mapping[AnyKeyT, EncodableT] 

2363 ) -> "ClusterPipeline": 

2364 return self._execution_strategy.mset_nonatomic(mapping) 

2365 

2366 

2367for command in PIPELINE_BLOCKED_COMMANDS: 

2368 command = command.replace(" ", "_").lower() 

2369 if command == "mset_nonatomic": 

2370 continue 

2371 

2372 setattr(ClusterPipeline, command, block_pipeline_command(command)) 

2373 

2374 

2375class PipelineCommand: 

2376 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: 

2377 self.args = args 

2378 self.kwargs = kwargs 

2379 self.position = position 

2380 self.result: Union[Any, Exception] = None 

2381 self.command_policies: Optional[CommandPolicies] = None 

2382 

2383 def __repr__(self) -> str: 

2384 return f"[{self.position}] {self.args} ({self.kwargs})" 

2385 

2386 

2387class ExecutionStrategy(ABC): 

2388 @abstractmethod 

2389 async def initialize(self) -> "ClusterPipeline": 

2390 """ 

2391 Initialize the execution strategy. 

2392 

2393 See ClusterPipeline.initialize() 

2394 """ 

2395 pass 

2396 

2397 @abstractmethod 

2398 def execute_command( 

2399 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2400 ) -> "ClusterPipeline": 

2401 """ 

2402 Append a raw command to the pipeline. 

2403 

2404 See ClusterPipeline.execute_command() 

2405 """ 

2406 pass 

2407 

2408 @abstractmethod 

2409 async def execute( 

2410 self, raise_on_error: bool = True, allow_redirections: bool = True 

2411 ) -> List[Any]: 

2412 """ 

2413 Execute the pipeline. 

2414 

2415 It will retry the commands as specified by retries specified in :attr:`retry` 

2416 & then raise an exception. 

2417 

2418 See ClusterPipeline.execute() 

2419 """ 

2420 pass 

2421 

2422 @abstractmethod 

2423 def mset_nonatomic( 

2424 self, mapping: Mapping[AnyKeyT, EncodableT] 

2425 ) -> "ClusterPipeline": 

2426 """ 

2427 Executes multiple MSET commands according to the provided slot/pairs mapping. 

2428 

2429 See ClusterPipeline.mset_nonatomic() 

2430 """ 

2431 pass 

2432 

2433 @abstractmethod 

2434 async def reset(self): 

2435 """ 

2436 Resets current execution strategy. 

2437 

2438 See: ClusterPipeline.reset() 

2439 """ 

2440 pass 

2441 

2442 @abstractmethod 

2443 def multi(self): 

2444 """ 

2445 Starts transactional context. 

2446 

2447 See: ClusterPipeline.multi() 

2448 """ 

2449 pass 

2450 

2451 @abstractmethod 

2452 async def watch(self, *names): 

2453 """ 

2454 Watch given keys. 

2455 

2456 See: ClusterPipeline.watch() 

2457 """ 

2458 pass 

2459 

2460 @abstractmethod 

2461 async def unwatch(self): 

2462 """ 

2463 Unwatches all previously specified keys 

2464 

2465 See: ClusterPipeline.unwatch() 

2466 """ 

2467 pass 

2468 

2469 @abstractmethod 

2470 async def discard(self): 

2471 pass 

2472 

2473 @abstractmethod 

2474 async def unlink(self, *names): 

2475 """ 

2476 "Unlink a key specified by ``names``" 

2477 

2478 See: ClusterPipeline.unlink() 

2479 """ 

2480 pass 

2481 

2482 @abstractmethod 

2483 def __len__(self) -> int: 

2484 pass 

2485 

2486 

2487class AbstractStrategy(ExecutionStrategy): 

2488 def __init__(self, pipe: ClusterPipeline) -> None: 

2489 self._pipe: ClusterPipeline = pipe 

2490 self._command_queue: List["PipelineCommand"] = [] 

2491 

2492 async def initialize(self) -> "ClusterPipeline": 

2493 if self._pipe.cluster_client._initialize: 

2494 await self._pipe.cluster_client.initialize() 

2495 self._command_queue = [] 

2496 return self._pipe 

2497 

2498 def execute_command( 

2499 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2500 ) -> "ClusterPipeline": 

2501 self._command_queue.append( 

2502 PipelineCommand(len(self._command_queue), *args, **kwargs) 

2503 ) 

2504 return self._pipe 

2505 

2506 def _annotate_exception(self, exception, number, command): 

2507 """ 

2508 Provides extra context to the exception prior to it being handled 

2509 """ 

2510 cmd = " ".join(map(safe_str, command)) 

2511 msg = ( 

2512 f"Command # {number} ({truncate_text(cmd)}) of pipeline " 

2513 f"caused error: {exception.args[0]}" 

2514 ) 

2515 exception.args = (msg,) + exception.args[1:] 

2516 

2517 @abstractmethod 

2518 def mset_nonatomic( 

2519 self, mapping: Mapping[AnyKeyT, EncodableT] 

2520 ) -> "ClusterPipeline": 

2521 pass 

2522 

2523 @abstractmethod 

2524 async def execute( 

2525 self, raise_on_error: bool = True, allow_redirections: bool = True 

2526 ) -> List[Any]: 

2527 pass 

2528 

2529 @abstractmethod 

2530 async def reset(self): 

2531 pass 

2532 

2533 @abstractmethod 

2534 def multi(self): 

2535 pass 

2536 

2537 @abstractmethod 

2538 async def watch(self, *names): 

2539 pass 

2540 

2541 @abstractmethod 

2542 async def unwatch(self): 

2543 pass 

2544 

2545 @abstractmethod 

2546 async def discard(self): 

2547 pass 

2548 

2549 @abstractmethod 

2550 async def unlink(self, *names): 

2551 pass 

2552 

2553 def __len__(self) -> int: 

2554 return len(self._command_queue) 

2555 

2556 

2557class PipelineStrategy(AbstractStrategy): 

2558 def __init__(self, pipe: ClusterPipeline) -> None: 

2559 super().__init__(pipe) 

2560 

2561 def mset_nonatomic( 

2562 self, mapping: Mapping[AnyKeyT, EncodableT] 

2563 ) -> "ClusterPipeline": 

2564 encoder = self._pipe.cluster_client.encoder 

2565 

2566 slots_pairs = {} 

2567 for pair in mapping.items(): 

2568 slot = key_slot(encoder.encode(pair[0])) 

2569 slots_pairs.setdefault(slot, []).extend(pair) 

2570 

2571 for pairs in slots_pairs.values(): 

2572 self.execute_command("MSET", *pairs) 

2573 

2574 return self._pipe 

2575 

2576 async def execute( 

2577 self, raise_on_error: bool = True, allow_redirections: bool = True 

2578 ) -> List[Any]: 

2579 if not self._command_queue: 

2580 return [] 

2581 

2582 try: 

2583 retry_attempts = self._pipe.cluster_client.retry.get_retries() 

2584 while True: 

2585 try: 

2586 if self._pipe.cluster_client._initialize: 

2587 await self._pipe.cluster_client.initialize() 

2588 return await self._execute( 

2589 self._pipe.cluster_client, 

2590 self._command_queue, 

2591 raise_on_error=raise_on_error, 

2592 allow_redirections=allow_redirections, 

2593 ) 

2594 

2595 except RedisCluster.ERRORS_ALLOW_RETRY as e: 

2596 if retry_attempts > 0: 

2597 # Try again with the new cluster setup. All other errors 

2598 # should be raised. 

2599 retry_attempts -= 1 

2600 await self._pipe.cluster_client.aclose() 

2601 await asyncio.sleep(0.25) 

2602 else: 

2603 # All other errors should be raised. 

2604 raise e 

2605 finally: 

2606 await self.reset() 

2607 

2608 async def _execute( 

2609 self, 

2610 client: "RedisCluster", 

2611 stack: List["PipelineCommand"], 

2612 raise_on_error: bool = True, 

2613 allow_redirections: bool = True, 

2614 ) -> List[Any]: 

2615 todo = [ 

2616 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception) 

2617 ] 

2618 

2619 nodes = {} 

2620 for cmd in todo: 

2621 passed_targets = cmd.kwargs.pop("target_nodes", None) 

2622 command_policies = await client._policy_resolver.resolve( 

2623 cmd.args[0].lower() 

2624 ) 

2625 

2626 if passed_targets and not client._is_node_flag(passed_targets): 

2627 target_nodes = client._parse_target_nodes(passed_targets) 

2628 

2629 if not command_policies: 

2630 command_policies = CommandPolicies() 

2631 else: 

2632 if not command_policies: 

2633 command_flag = client.command_flags.get(cmd.args[0]) 

2634 if not command_flag: 

2635 # Fallback to default policy 

2636 if not client.get_default_node(): 

2637 slot = None 

2638 else: 

2639 slot = await client._determine_slot(*cmd.args) 

2640 if slot is None: 

2641 command_policies = CommandPolicies() 

2642 else: 

2643 command_policies = CommandPolicies( 

2644 request_policy=RequestPolicy.DEFAULT_KEYED, 

2645 response_policy=ResponsePolicy.DEFAULT_KEYED, 

2646 ) 

2647 else: 

2648 if command_flag in client._command_flags_mapping: 

2649 command_policies = CommandPolicies( 

2650 request_policy=client._command_flags_mapping[ 

2651 command_flag 

2652 ] 

2653 ) 

2654 else: 

2655 command_policies = CommandPolicies() 

2656 

2657 target_nodes = await client._determine_nodes( 

2658 *cmd.args, 

2659 request_policy=command_policies.request_policy, 

2660 node_flag=passed_targets, 

2661 ) 

2662 if not target_nodes: 

2663 raise RedisClusterException( 

2664 f"No targets were found to execute {cmd.args} command on" 

2665 ) 

2666 cmd.command_policies = command_policies 

2667 if len(target_nodes) > 1: 

2668 raise RedisClusterException(f"Too many targets for command {cmd.args}") 

2669 node = target_nodes[0] 

2670 if node.name not in nodes: 

2671 nodes[node.name] = (node, []) 

2672 nodes[node.name][1].append(cmd) 

2673 

2674 # Start timing for observability 

2675 start_time = time.monotonic() 

2676 

2677 errors = await asyncio.gather( 

2678 *( 

2679 asyncio.create_task(node[0].execute_pipeline(node[1])) 

2680 for node in nodes.values() 

2681 ) 

2682 ) 

2683 

2684 # Record operation duration for each node 

2685 for node_name, (node, commands) in nodes.items(): 

2686 # Find the first error in this node's commands, if any 

2687 node_error = None 

2688 for cmd in commands: 

2689 if isinstance(cmd.result, Exception): 

2690 node_error = cmd.result 

2691 break 

2692 

2693 db = node.connection_kwargs.get("db", 0) 

2694 await record_operation_duration( 

2695 command_name="PIPELINE", 

2696 duration_seconds=time.monotonic() - start_time, 

2697 server_address=node.host, 

2698 server_port=node.port, 

2699 db_namespace=str(db) if db is not None else None, 

2700 error=node_error, 

2701 ) 

2702 

2703 if any(errors): 

2704 if allow_redirections: 

2705 # send each errored command individually 

2706 for cmd in todo: 

2707 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): 

2708 try: 

2709 cmd.result = client._policies_callback_mapping[ 

2710 cmd.command_policies.response_policy 

2711 ](await client.execute_command(*cmd.args, **cmd.kwargs)) 

2712 except Exception as e: 

2713 cmd.result = e 

2714 

2715 if raise_on_error: 

2716 for cmd in todo: 

2717 result = cmd.result 

2718 if isinstance(result, Exception): 

2719 command = " ".join(map(safe_str, cmd.args)) 

2720 msg = ( 

2721 f"Command # {cmd.position + 1} " 

2722 f"({truncate_text(command)}) " 

2723 f"of pipeline caused error: {result.args}" 

2724 ) 

2725 result.args = (msg,) + result.args[1:] 

2726 raise result 

2727 

2728 default_cluster_node = client.get_default_node() 

2729 

2730 # Check whether the default node was used. In some cases, 

2731 # 'client.get_default_node()' may return None. The check below 

2732 # prevents a potential AttributeError. 

2733 if default_cluster_node is not None: 

2734 default_node = nodes.get(default_cluster_node.name) 

2735 if default_node is not None: 

2736 # This pipeline execution used the default node, check if we need 

2737 # to replace it. 

2738 # Note: when the error is raised we'll reset the default node in the 

2739 # caller function. 

2740 for cmd in default_node[1]: 

2741 # Check if it has a command that failed with a relevant 

2742 # exception 

2743 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY: 

2744 client.replace_default_node() 

2745 break 

2746 

2747 return [cmd.result for cmd in stack] 

2748 

2749 async def reset(self): 

2750 """ 

2751 Reset back to empty pipeline. 

2752 """ 

2753 self._command_queue = [] 

2754 

2755 def multi(self): 

2756 raise RedisClusterException( 

2757 "method multi() is not supported outside of transactional context" 

2758 ) 

2759 

2760 async def watch(self, *names): 

2761 raise RedisClusterException( 

2762 "method watch() is not supported outside of transactional context" 

2763 ) 

2764 

2765 async def unwatch(self): 

2766 raise RedisClusterException( 

2767 "method unwatch() is not supported outside of transactional context" 

2768 ) 

2769 

2770 async def discard(self): 

2771 raise RedisClusterException( 

2772 "method discard() is not supported outside of transactional context" 

2773 ) 

2774 

2775 async def unlink(self, *names): 

2776 if len(names) != 1: 

2777 raise RedisClusterException( 

2778 "unlinking multiple keys is not implemented in pipeline command" 

2779 ) 

2780 

2781 return self.execute_command("UNLINK", names[0]) 

2782 

2783 

2784class TransactionStrategy(AbstractStrategy): 

2785 NO_SLOTS_COMMANDS = {"UNWATCH"} 

2786 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} 

2787 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} 

2788 SLOT_REDIRECT_ERRORS = (AskError, MovedError) 

2789 CONNECTION_ERRORS = ( 

2790 ConnectionError, 

2791 OSError, 

2792 ClusterDownError, 

2793 SlotNotCoveredError, 

2794 ) 

2795 

2796 def __init__(self, pipe: ClusterPipeline) -> None: 

2797 super().__init__(pipe) 

2798 self._explicit_transaction = False 

2799 self._watching = False 

2800 self._pipeline_slots: Set[int] = set() 

2801 self._transaction_node: Optional[ClusterNode] = None 

2802 self._transaction_connection: Optional[Connection] = None 

2803 self._executing = False 

2804 self._retry = copy(self._pipe.cluster_client.retry) 

2805 self._retry.update_supported_errors( 

2806 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS 

2807 ) 

2808 

2809 def _get_client_and_connection_for_transaction( 

2810 self, 

2811 ) -> Tuple[ClusterNode, Connection]: 

2812 """ 

2813 Find a connection for a pipeline transaction. 

2814 

2815 For running an atomic transaction, watch keys ensure that contents have not been 

2816 altered as long as the watch commands for those keys were sent over the same 

2817 connection. So once we start watching a key, we fetch a connection to the 

2818 node that owns that slot and reuse it. 

2819 """ 

2820 if not self._pipeline_slots: 

2821 raise RedisClusterException( 

2822 "At least a command with a key is needed to identify a node" 

2823 ) 

2824 

2825 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot( 

2826 list(self._pipeline_slots)[0], False 

2827 ) 

2828 self._transaction_node = node 

2829 

2830 if not self._transaction_connection: 

2831 connection: Connection = self._transaction_node.acquire_connection() 

2832 self._transaction_connection = connection 

2833 

2834 return self._transaction_node, self._transaction_connection 

2835 

2836 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any": 

2837 # Given the limitation of ClusterPipeline sync API, we have to run it in thread. 

2838 response = None 

2839 error = None 

2840 

2841 def runner(): 

2842 nonlocal response 

2843 nonlocal error 

2844 try: 

2845 response = asyncio.run(self._execute_command(*args, **kwargs)) 

2846 except Exception as e: 

2847 error = e 

2848 

2849 thread = threading.Thread(target=runner) 

2850 thread.start() 

2851 thread.join() 

2852 

2853 if error: 

2854 raise error 

2855 

2856 return response 

2857 

2858 async def _execute_command( 

2859 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2860 ) -> Any: 

2861 if self._pipe.cluster_client._initialize: 

2862 await self._pipe.cluster_client.initialize() 

2863 

2864 slot_number: Optional[int] = None 

2865 if args[0] not in self.NO_SLOTS_COMMANDS: 

2866 slot_number = await self._pipe.cluster_client._determine_slot(*args) 

2867 

2868 if ( 

2869 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS 

2870 ) and not self._explicit_transaction: 

2871 if args[0] == "WATCH": 

2872 self._validate_watch() 

2873 

2874 if slot_number is not None: 

2875 if self._pipeline_slots and slot_number not in self._pipeline_slots: 

2876 raise CrossSlotTransactionError( 

2877 "Cannot watch or send commands on different slots" 

2878 ) 

2879 

2880 self._pipeline_slots.add(slot_number) 

2881 elif args[0] not in self.NO_SLOTS_COMMANDS: 

2882 raise RedisClusterException( 

2883 f"Cannot identify slot number for command: {args[0]}," 

2884 "it cannot be triggered in a transaction" 

2885 ) 

2886 

2887 return self._immediate_execute_command(*args, **kwargs) 

2888 else: 

2889 if slot_number is not None: 

2890 self._pipeline_slots.add(slot_number) 

2891 

2892 return super().execute_command(*args, **kwargs) 

2893 

2894 def _validate_watch(self): 

2895 if self._explicit_transaction: 

2896 raise RedisError("Cannot issue a WATCH after a MULTI") 

2897 

2898 self._watching = True 

2899 

2900 async def _immediate_execute_command(self, *args, **options): 

2901 return await self._retry.call_with_retry( 

2902 lambda: self._get_connection_and_send_command(*args, **options), 

2903 self._reinitialize_on_error, 

2904 with_failure_count=True, 

2905 ) 

2906 

2907 async def _get_connection_and_send_command(self, *args, **options): 

2908 redis_node, connection = self._get_client_and_connection_for_transaction() 

2909 # Only disconnect if not watching - disconnecting would lose WATCH state 

2910 if not self._watching: 

2911 await redis_node.disconnect_if_needed(connection) 

2912 

2913 # Start timing for observability 

2914 start_time = time.monotonic() 

2915 

2916 try: 

2917 response = await self._send_command_parse_response( 

2918 connection, redis_node, args[0], *args, **options 

2919 ) 

2920 

2921 await record_operation_duration( 

2922 command_name=args[0], 

2923 duration_seconds=time.monotonic() - start_time, 

2924 server_address=connection.host, 

2925 server_port=connection.port, 

2926 db_namespace=str(connection.db), 

2927 ) 

2928 

2929 return response 

2930 except Exception as e: 

2931 e.connection = connection 

2932 await record_operation_duration( 

2933 command_name=args[0], 

2934 duration_seconds=time.monotonic() - start_time, 

2935 server_address=connection.host, 

2936 server_port=connection.port, 

2937 db_namespace=str(connection.db), 

2938 error=e, 

2939 ) 

2940 raise 

2941 

2942 async def _send_command_parse_response( 

2943 self, 

2944 connection: Connection, 

2945 redis_node: ClusterNode, 

2946 command_name, 

2947 *args, 

2948 **options, 

2949 ): 

2950 """ 

2951 Send a command and parse the response 

2952 """ 

2953 

2954 await connection.send_command(*args) 

2955 output = await redis_node.parse_response(connection, command_name, **options) 

2956 

2957 if command_name in self.UNWATCH_COMMANDS: 

2958 self._watching = False 

2959 return output 

2960 

2961 async def _reinitialize_on_error(self, error, failure_count): 

2962 if hasattr(error, "connection"): 

2963 await record_error_count( 

2964 server_address=error.connection.host, 

2965 server_port=error.connection.port, 

2966 network_peer_address=error.connection.host, 

2967 network_peer_port=error.connection.port, 

2968 error_type=error, 

2969 retry_attempts=failure_count, 

2970 is_internal=True, 

2971 ) 

2972 

2973 if self._watching: 

2974 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: 

2975 raise WatchError("Slot rebalancing occurred while watching keys") 

2976 

2977 if ( 

2978 type(error) in self.SLOT_REDIRECT_ERRORS 

2979 or type(error) in self.CONNECTION_ERRORS 

2980 ): 

2981 if self._transaction_connection and self._transaction_node: 

2982 # Disconnect and release back to pool 

2983 await self._transaction_connection.disconnect() 

2984 self._transaction_node.release(self._transaction_connection) 

2985 self._transaction_connection = None 

2986 

2987 self._pipe.cluster_client.reinitialize_counter += 1 

2988 if ( 

2989 self._pipe.cluster_client.reinitialize_steps 

2990 and self._pipe.cluster_client.reinitialize_counter 

2991 % self._pipe.cluster_client.reinitialize_steps 

2992 == 0 

2993 ): 

2994 await self._pipe.cluster_client.nodes_manager.initialize() 

2995 self.reinitialize_counter = 0 

2996 else: 

2997 if isinstance(error, AskError): 

2998 await self._pipe.cluster_client.nodes_manager.move_slot(error) 

2999 

3000 self._executing = False 

3001 

3002 async def _raise_first_error(self, responses, stack, start_time): 

3003 """ 

3004 Raise the first exception on the stack 

3005 """ 

3006 for r, cmd in zip(responses, stack): 

3007 if isinstance(r, Exception): 

3008 self._annotate_exception(r, cmd.position + 1, cmd.args) 

3009 

3010 await record_operation_duration( 

3011 command_name="TRANSACTION", 

3012 duration_seconds=time.monotonic() - start_time, 

3013 server_address=self._transaction_connection.host, 

3014 server_port=self._transaction_connection.port, 

3015 db_namespace=str(self._transaction_connection.db), 

3016 error=r, 

3017 ) 

3018 

3019 raise r 

3020 

3021 def mset_nonatomic( 

3022 self, mapping: Mapping[AnyKeyT, EncodableT] 

3023 ) -> "ClusterPipeline": 

3024 raise NotImplementedError("Method is not supported in transactional context.") 

3025 

3026 async def execute( 

3027 self, raise_on_error: bool = True, allow_redirections: bool = True 

3028 ) -> List[Any]: 

3029 stack = self._command_queue 

3030 if not stack and (not self._watching or not self._pipeline_slots): 

3031 return [] 

3032 

3033 return await self._execute_transaction_with_retries(stack, raise_on_error) 

3034 

3035 async def _execute_transaction_with_retries( 

3036 self, stack: List["PipelineCommand"], raise_on_error: bool 

3037 ): 

3038 return await self._retry.call_with_retry( 

3039 lambda: self._execute_transaction(stack, raise_on_error), 

3040 lambda error, failure_count: self._reinitialize_on_error( 

3041 error, failure_count 

3042 ), 

3043 with_failure_count=True, 

3044 ) 

3045 

3046 async def _execute_transaction( 

3047 self, stack: List["PipelineCommand"], raise_on_error: bool 

3048 ): 

3049 if len(self._pipeline_slots) > 1: 

3050 raise CrossSlotTransactionError( 

3051 "All keys involved in a cluster transaction must map to the same slot" 

3052 ) 

3053 

3054 self._executing = True 

3055 

3056 redis_node, connection = self._get_client_and_connection_for_transaction() 

3057 # Only disconnect if not watching - disconnecting would lose WATCH state 

3058 if not self._watching: 

3059 await redis_node.disconnect_if_needed(connection) 

3060 

3061 stack = chain( 

3062 [PipelineCommand(0, "MULTI")], 

3063 stack, 

3064 [PipelineCommand(0, "EXEC")], 

3065 ) 

3066 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs] 

3067 packed_commands = connection.pack_commands(commands) 

3068 

3069 # Start timing for observability 

3070 start_time = time.monotonic() 

3071 

3072 await connection.send_packed_command(packed_commands) 

3073 errors = [] 

3074 

3075 # parse off the response for MULTI 

3076 # NOTE: we need to handle ResponseErrors here and continue 

3077 # so that we read all the additional command messages from 

3078 # the socket 

3079 try: 

3080 await redis_node.parse_response(connection, "MULTI") 

3081 except ResponseError as e: 

3082 self._annotate_exception(e, 0, "MULTI") 

3083 errors.append(e) 

3084 except self.CONNECTION_ERRORS as cluster_error: 

3085 self._annotate_exception(cluster_error, 0, "MULTI") 

3086 cluster_error.connection = connection 

3087 raise 

3088 

3089 # and all the other commands 

3090 for i, command in enumerate(self._command_queue): 

3091 if EMPTY_RESPONSE in command.kwargs: 

3092 errors.append((i, command.kwargs[EMPTY_RESPONSE])) 

3093 else: 

3094 try: 

3095 _ = await redis_node.parse_response(connection, "_") 

3096 except self.SLOT_REDIRECT_ERRORS as slot_error: 

3097 self._annotate_exception(slot_error, i + 1, command.args) 

3098 errors.append(slot_error) 

3099 except self.CONNECTION_ERRORS as cluster_error: 

3100 self._annotate_exception(cluster_error, i + 1, command.args) 

3101 cluster_error.connection = connection 

3102 raise 

3103 except ResponseError as e: 

3104 self._annotate_exception(e, i + 1, command.args) 

3105 errors.append(e) 

3106 

3107 response = None 

3108 # parse the EXEC. 

3109 try: 

3110 response = await redis_node.parse_response(connection, "EXEC") 

3111 except ExecAbortError: 

3112 if errors: 

3113 raise errors[0] 

3114 raise 

3115 

3116 self._executing = False 

3117 

3118 # EXEC clears any watched keys 

3119 self._watching = False 

3120 

3121 if response is None: 

3122 raise WatchError("Watched variable changed.") 

3123 

3124 # put any parse errors into the response 

3125 for i, e in errors: 

3126 response.insert(i, e) 

3127 

3128 if len(response) != len(self._command_queue): 

3129 raise InvalidPipelineStack( 

3130 "Unexpected response length for cluster pipeline EXEC." 

3131 " Command stack was {} but response had length {}".format( 

3132 [c.args[0] for c in self._command_queue], len(response) 

3133 ) 

3134 ) 

3135 

3136 # find any errors in the response and raise if necessary 

3137 if raise_on_error or len(errors) > 0: 

3138 await self._raise_first_error( 

3139 response, 

3140 self._command_queue, 

3141 start_time, 

3142 ) 

3143 

3144 # We have to run response callbacks manually 

3145 data = [] 

3146 for r, cmd in zip(response, self._command_queue): 

3147 if not isinstance(r, Exception): 

3148 command_name = cmd.args[0] 

3149 if command_name in self._pipe.cluster_client.response_callbacks: 

3150 r = self._pipe.cluster_client.response_callbacks[command_name]( 

3151 r, **cmd.kwargs 

3152 ) 

3153 data.append(r) 

3154 

3155 await record_operation_duration( 

3156 command_name="TRANSACTION", 

3157 duration_seconds=time.monotonic() - start_time, 

3158 server_address=connection.host, 

3159 server_port=connection.port, 

3160 db_namespace=str(connection.db), 

3161 ) 

3162 

3163 return data 

3164 

3165 async def reset(self): 

3166 self._command_queue = [] 

3167 

3168 # make sure to reset the connection state in the event that we were 

3169 # watching something 

3170 if self._transaction_connection: 

3171 try: 

3172 if self._watching: 

3173 # call this manually since our unwatch or 

3174 # immediate_execute_command methods can call reset() 

3175 await self._transaction_connection.send_command("UNWATCH") 

3176 await self._transaction_connection.read_response() 

3177 # we can safely return the connection to the pool here since we're 

3178 # sure we're no longer WATCHing anything 

3179 await self._transaction_node.disconnect_if_needed( 

3180 self._transaction_connection 

3181 ) 

3182 self._transaction_node.release(self._transaction_connection) 

3183 self._transaction_connection = None 

3184 except self.CONNECTION_ERRORS: 

3185 # disconnect will also remove any previous WATCHes 

3186 if self._transaction_connection and self._transaction_node: 

3187 await self._transaction_connection.disconnect() 

3188 self._transaction_node.release(self._transaction_connection) 

3189 self._transaction_connection = None 

3190 

3191 # clean up the other instance attributes 

3192 self._transaction_node = None 

3193 self._watching = False 

3194 self._explicit_transaction = False 

3195 self._pipeline_slots = set() 

3196 self._executing = False 

3197 

3198 def multi(self): 

3199 if self._explicit_transaction: 

3200 raise RedisError("Cannot issue nested calls to MULTI") 

3201 if self._command_queue: 

3202 raise RedisError( 

3203 "Commands without an initial WATCH have already been issued" 

3204 ) 

3205 self._explicit_transaction = True 

3206 

3207 async def watch(self, *names): 

3208 if self._explicit_transaction: 

3209 raise RedisError("Cannot issue a WATCH after a MULTI") 

3210 

3211 return await self.execute_command("WATCH", *names) 

3212 

3213 async def unwatch(self): 

3214 if self._watching: 

3215 return await self.execute_command("UNWATCH") 

3216 

3217 return True 

3218 

3219 async def discard(self): 

3220 await self.reset() 

3221 

3222 async def unlink(self, *names): 

3223 return self.execute_command("UNLINK", *names) 

3224 

3225 

3226class _ClusterNodePoolAdapter(ConnectionPoolInterface): 

3227 """Thin adapter exposing the :class:`ConnectionPoolInterface` that 

3228 :class:`PubSub` requires, backed by a :class:`ClusterNode`'s own 

3229 connection pool. 

3230 

3231 Connections are acquired from the node via 

3232 :meth:`ClusterNode.acquire_connection` and returned via 

3233 :meth:`ClusterNode.release`. :meth:`PubSub.aclose` already 

3234 disconnects the connection *before* calling :meth:`release`, so the 

3235 connection is returned to the node's free-queue in a disconnected 

3236 state — guaranteeing that a subscribed socket is never silently 

3237 reused for regular commands. 

3238 

3239 Methods that do not apply to this adapter (the underlying node's 

3240 lifecycle is managed by the cluster, not by individual PubSub 

3241 instances) are implemented as no-ops so the adapter remains a valid 

3242 :class:`ConnectionPoolInterface`. 

3243 """ 

3244 

3245 def __init__(self, node: "ClusterNode") -> None: 

3246 self._node = node 

3247 self.connection_kwargs = node.connection_kwargs 

3248 

3249 # -- methods used by PubSub ------------------------------------------------ 

3250 

3251 def get_encoder(self) -> Encoder: 

3252 return self._node.get_encoder() 

3253 

3254 async def get_connection( 

3255 self, command_name: Optional[str] = None, *keys: Any, **options: Any 

3256 ) -> AbstractConnection: 

3257 connection = self._node.acquire_connection() 

3258 try: 

3259 await connection.connect() 

3260 except BaseException: 

3261 # connect() may fail mid-handshake (e.g. after the TCP socket 

3262 # is established but before AUTH/HELLO completes) leaving the 

3263 # connection in a partially-connected state. Disconnect before 

3264 # returning it to the node's free queue so it is not reused. 

3265 await connection.disconnect() 

3266 self._node.release(connection) 

3267 raise 

3268 return connection 

3269 

3270 async def release(self, connection: AbstractConnection) -> None: 

3271 # PubSub.aclose() disconnects the connection before calling 

3272 # release(), so it is safe to put it back in the node's free 

3273 # queue – it will reconnect lazily on next use. 

3274 await self._node.disconnect_if_needed(connection) 

3275 self._node.release(connection) 

3276 

3277 # -- no-op stubs for the rest of ConnectionPoolInterface ------------------- 

3278 # The node's connections are shared with regular cluster traffic and its 

3279 # lifecycle is managed by RedisCluster / NodesManager, so the adapter must 

3280 # not reset, disconnect, retry-configure or re-auth them on behalf of a 

3281 # single PubSub instance. 

3282 

3283 def get_protocol(self): 

3284 return self.connection_kwargs.get("protocol", None) 

3285 

3286 def reset(self) -> None: 

3287 pass 

3288 

3289 async def disconnect(self, inuse_connections: bool = True) -> None: 

3290 pass 

3291 

3292 async def aclose(self) -> None: 

3293 pass 

3294 

3295 def set_retry(self, retry: "Retry") -> None: 

3296 pass 

3297 

3298 async def re_auth_callback(self, token: TokenInterface) -> None: 

3299 pass 

3300 

3301 def get_connection_count(self) -> List[Tuple[int, dict]]: 

3302 return [] 

3303 

3304 

3305def _unregister_slots_cache_listener( 

3306 dispatcher_ref: "weakref.ref[EventDispatcher]", 

3307 listener: AsyncEventListenerInterface, 

3308 event_type: Type[object], 

3309) -> None: 

3310 # Module-level finalizer callback. Kept free of strong references to the 

3311 # owning ClusterPubSub so attaching it via weakref.finalize does not 

3312 # extend the pubsub's lifetime. 

3313 dispatcher = dispatcher_ref() 

3314 if dispatcher is not None: 

3315 dispatcher.unregister_listeners({event_type: [listener]}) 

3316 

3317 

3318class ClusterPubSubSlotsCacheListener(AsyncEventListenerInterface): 

3319 """ 

3320 Async listener that forwards AsyncAfterSlotsCacheRefreshEvent to a 

3321 ClusterPubSub. 

3322 

3323 Holds a weak reference to the pubsub so it does not keep the instance 

3324 alive. Deterministic cleanup of the dispatcher's strong reference to this 

3325 listener is performed by a ``weakref.finalize`` attached to the owning 

3326 ClusterPubSub in ``ClusterPubSub.__init__``. 

3327 """ 

3328 

3329 def __init__(self, pubsub: "ClusterPubSub") -> None: 

3330 self._pubsub_ref: "weakref.ref[ClusterPubSub]" = weakref.ref(pubsub) 

3331 

3332 async def listen(self, event: object) -> None: 

3333 pubsub = self._pubsub_ref() 

3334 if pubsub is None: 

3335 # Race window between pubsub GC and the finalizer running; safe 

3336 # no-op, finalizer will remove this listener shortly. 

3337 return 

3338 try: 

3339 await pubsub.on_slots_changed() 

3340 except Exception as e: 

3341 # Listeners must not break slots-cache refresh; log and continue so 

3342 # a single buggy pubsub cannot starve the rest. 

3343 logger.exception( 

3344 "pubsub %r raised during slots-cache change: %s: %s", 

3345 pubsub, 

3346 type(e).__name__, 

3347 e, 

3348 ) 

3349 

3350 

3351class ClusterPubSub(PubSub): 

3352 """ 

3353 Async cluster implementation for pub/sub. 

3354 

3355 IMPORTANT: before using ClusterPubSub, read about the known limitations 

3356 with pubsub in Cluster mode and learn how to workaround them: 

3357 https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations 

3358 """ 

3359 

3360 def __init__( 

3361 self, 

3362 redis_cluster: "RedisCluster", 

3363 node: Optional["ClusterNode"] = None, 

3364 host: Optional[str] = None, 

3365 port: Optional[int] = None, 

3366 push_handler_func: Optional[Callable] = None, 

3367 event_dispatcher: Optional[EventDispatcher] = None, 

3368 **kwargs: Any, 

3369 ) -> None: 

3370 """ 

3371 When a pubsub instance is created without specifying a node, a single 

3372 node will be transparently chosen for the pubsub connection on the 

3373 first command execution. The node will be determined by: 

3374 1. Hashing the channel name in the request to find its keyslot 

3375 2. Selecting a node that handles the keyslot: If read_from_replicas is 

3376 set to true or load_balancing_strategy is set, a replica can be selected. 

3377 

3378 :param redis_cluster: RedisCluster instance 

3379 :param node: ClusterNode to connect to 

3380 :param host: Host of the node to connect to 

3381 :param port: Port of the node to connect to 

3382 :param push_handler_func: Optional push handler function 

3383 :param event_dispatcher: Optional event dispatcher 

3384 :param kwargs: Additional keyword arguments 

3385 """ 

3386 self.node = None 

3387 self.set_pubsub_node(redis_cluster, node, host, port) 

3388 

3389 # Borrow the node's own connection pool via an adapter rather than 

3390 # creating a second, detached ConnectionPool for pubsub. 

3391 if self.node is not None: 

3392 connection_pool = _ClusterNodePoolAdapter(self.node) 

3393 else: 

3394 connection_pool = None 

3395 

3396 self.cluster = redis_cluster 

3397 self.node_pubsub_mapping: Dict[str, PubSub] = {} 

3398 # Reverse index: shard channel (normalized) -> owning node.name. Used to 

3399 # route sunsubscribe calls and reconcile subscriptions after slot 

3400 # migration / failover. 

3401 self._shard_channel_to_node: Dict[Any, str] = {} 

3402 # Dedicated lock for shard-subscription bookkeeping. Distinct from 

3403 # PubSub.self._lock (which serializes wire I/O on the cluster-level 

3404 # connection used by aclose / send_command / regular subscribe) so 

3405 # that reconciliation cannot starve those unrelated coroutines 

3406 # during long per-channel migrations. 

3407 self._shard_state_lock: asyncio.Lock = asyncio.Lock() 

3408 # Background tasks created by on_slots_changed; kept to prevent GC. 

3409 self._reconcile_tasks: Set[asyncio.Task] = set() 

3410 self._pubsubs_generator = self._pubsubs_generator() 

3411 if event_dispatcher is None: 

3412 self._event_dispatcher = EventDispatcher() 

3413 else: 

3414 self._event_dispatcher = event_dispatcher 

3415 super().__init__( 

3416 connection_pool=connection_pool, 

3417 encoder=redis_cluster.encoder, 

3418 push_handler_func=push_handler_func, 

3419 event_dispatcher=self._event_dispatcher, 

3420 **kwargs, 

3421 ) 

3422 # Subscribe to slots-cache change notifications so shard subscriptions 

3423 # can be reconciled automatically after topology refreshes. 

3424 nm_dispatcher = redis_cluster.nodes_manager._event_dispatcher 

3425 self._slots_cache_listener = ClusterPubSubSlotsCacheListener(self) 

3426 nm_dispatcher.register_listeners( 

3427 {AsyncAfterSlotsCacheRefreshEvent: [self._slots_cache_listener]} 

3428 ) 

3429 # Deterministic GC-time cleanup so short-lived pubsubs do not leak 

3430 # listeners in the dispatcher when no slots-refresh event ever fires. 

3431 weakref.finalize( 

3432 self, 

3433 _unregister_slots_cache_listener, 

3434 weakref.ref(nm_dispatcher), 

3435 self._slots_cache_listener, 

3436 AsyncAfterSlotsCacheRefreshEvent, 

3437 ) 

3438 

3439 def set_pubsub_node( 

3440 self, 

3441 cluster: "RedisCluster", 

3442 node: Optional["ClusterNode"] = None, 

3443 host: Optional[str] = None, 

3444 port: Optional[int] = None, 

3445 ) -> None: 

3446 """ 

3447 The pubsub node will be set according to the passed node, host and port 

3448 When none of the node, host, or port are specified - the node is set 

3449 to None and will be determined by the keyslot of the channel in the 

3450 first command to be executed. 

3451 RedisClusterException will be thrown if the passed node does not exist 

3452 in the cluster. 

3453 If host is passed without port, or vice versa, a DataError will be 

3454 thrown. 

3455 """ 

3456 if node is not None: 

3457 # node is passed by the user 

3458 self._raise_on_invalid_node(cluster, node, node.host, node.port) 

3459 pubsub_node = node 

3460 elif host is not None and port is not None: 

3461 # host and port passed by the user 

3462 node = cluster.get_node(host=host, port=port) 

3463 self._raise_on_invalid_node(cluster, node, host, port) 

3464 pubsub_node = node 

3465 elif host is not None or port is not None: 

3466 # only one of host and port is specified 

3467 raise DataError("Specify both host and port") 

3468 else: 

3469 # nothing specified by the user 

3470 pubsub_node = None 

3471 self.node = pubsub_node 

3472 

3473 def get_pubsub_node(self) -> Optional["ClusterNode"]: 

3474 """ 

3475 Get the node that is being used as the pubsub connection. 

3476 

3477 :return: The ClusterNode being used for pubsub, or None if not yet determined 

3478 """ 

3479 return self.node 

3480 

3481 async def _resubscribe_shard_channels(self) -> None: 

3482 # A single node can own multiple slot ranges, so a batched 

3483 # ``SSUBSCRIBE`` covering every tracked channel would be rejected by 

3484 # Redis with a ``CROSSSLOT`` error. Group by hash slot and emit one 

3485 # ``SSUBSCRIBE`` per slot. 

3486 by_slot: defaultdict[int, dict] = defaultdict(dict) 

3487 for k, v in self.shard_channels.items(): 

3488 by_slot[key_slot(self.encoder.encode(k))][k] = v 

3489 for subscriptions in by_slot.values(): 

3490 await self._resubscribe(subscriptions, self.ssubscribe) 

3491 

3492 def _get_node_pubsub(self, node: "ClusterNode") -> PubSub: 

3493 """Get or create a PubSub instance for the given node.""" 

3494 try: 

3495 return self.node_pubsub_mapping[node.name] 

3496 except KeyError: 

3497 pubsub = PubSub( 

3498 connection_pool=_ClusterNodePoolAdapter(node), 

3499 encoder=self.cluster.encoder, 

3500 push_handler_func=self.push_handler_func, 

3501 event_dispatcher=self._event_dispatcher, 

3502 ) 

3503 # Replay shard subscriptions on reconnect with slot-aware grouping 

3504 # so that channels spanning multiple slots owned by this node do 

3505 # not trigger a CROSSSLOT error. 

3506 pubsub._resubscribe_shard_channels = MethodType( 

3507 ClusterPubSub._resubscribe_shard_channels, pubsub 

3508 ) 

3509 self.node_pubsub_mapping[node.name] = pubsub 

3510 return pubsub 

3511 

3512 def _find_node_name_for_pubsub(self, pubsub: PubSub) -> Optional[str]: 

3513 for name, candidate in self.node_pubsub_mapping.items(): 

3514 if candidate is pubsub: 

3515 return name 

3516 return None 

3517 

3518 async def _sharded_message_generator( 

3519 self, timeout: float = 0.0 

3520 ) -> Tuple[Optional[PubSub], Optional[Dict[str, Any]]]: 

3521 """Generate messages from shard channels across all nodes.""" 

3522 for _ in range(len(self.node_pubsub_mapping)): 

3523 pubsub = next(self._pubsubs_generator) 

3524 # Don't pass ignore_subscribe_messages here - let get_sharded_message 

3525 # handle the filtering after processing subscription state changes 

3526 message = await pubsub.get_message( 

3527 ignore_subscribe_messages=False, timeout=timeout 

3528 ) 

3529 if message is not None: 

3530 return pubsub, message 

3531 return None, None 

3532 

3533 def _pubsubs_generator(self) -> Generator[PubSub, None, None]: 

3534 """Generator that yields PubSub instances in round-robin fashion.""" 

3535 while True: 

3536 current_nodes = list(self.node_pubsub_mapping.values()) 

3537 if not current_nodes: 

3538 return # Avoid infinite loop when no subscriptions exist 

3539 yield from current_nodes 

3540 

3541 async def get_sharded_message( 

3542 self, 

3543 ignore_subscribe_messages: bool = False, 

3544 timeout: float = 0.0, 

3545 target_node: Optional["ClusterNode"] = None, 

3546 ) -> Optional[Dict[str, Any]]: 

3547 """ 

3548 Get a message from shard channels. 

3549 

3550 :param ignore_subscribe_messages: Whether to ignore subscribe messages 

3551 :param timeout: Timeout for message retrieval 

3552 :param target_node: Specific node to get message from 

3553 :return: Message dictionary or None 

3554 """ 

3555 pubsub: Optional[PubSub] 

3556 if target_node: 

3557 pubsub = self.node_pubsub_mapping.get(target_node.name) 

3558 if pubsub: 

3559 # Don't pass ignore_subscribe_messages here - let get_sharded_message 

3560 # handle the filtering after processing subscription state changes 

3561 message = await pubsub.get_message( 

3562 ignore_subscribe_messages=False, timeout=timeout 

3563 ) 

3564 else: 

3565 message = None 

3566 else: 

3567 pubsub, message = await self._sharded_message_generator(timeout=timeout) 

3568 

3569 if message is None: 

3570 return None 

3571 # Only sunsubscribe mutates cluster-level shard state; bypassing the 

3572 # lock on the data-message hot path keeps smessage delivery from 

3573 # competing with the reconciliation task for _shard_state_lock. 

3574 if str_if_bytes(message["type"]) == "sunsubscribe": 

3575 # Serialize state mutation against reinitialize_shard_subscriptions 

3576 # (background task). The blocking get_message above intentionally 

3577 # runs outside the lock so reconciliation is not stalled by long 

3578 # polls. 

3579 async with self._shard_state_lock: 

3580 if message["channel"] in self.pending_unsubscribe_shard_channels: 

3581 # User-initiated sunsubscribe: drop from cluster-level tracking. 

3582 self.pending_unsubscribe_shard_channels.remove(message["channel"]) 

3583 self.shard_channels.pop(message["channel"], None) 

3584 self._shard_channel_to_node.pop(message["channel"], None) 

3585 # Drop the per-node pubsub that delivered the confirmation once 

3586 # it no longer holds any shard subscriptions, regardless of 

3587 # whether the sunsubscribe was user-initiated or driven by 

3588 # slot-migration reconciliation (_migrate_shard_channel, which 

3589 # intentionally does not add the channel to 

3590 # pending_unsubscribe_shard_channels). This releases the 

3591 # dedicated connection that would otherwise linger. 

3592 # Identifying the receiving pubsub directly (rather than via 

3593 # the cluster's current slot map) is required after slot 

3594 # migration, where the channel's owner is no longer the node 

3595 # that received our original SSUBSCRIBE. 

3596 if pubsub is not None and not pubsub.subscribed: 

3597 name = self._find_node_name_for_pubsub(pubsub) 

3598 if name is not None: 

3599 try: 

3600 await pubsub.aclose() 

3601 except Exception: 

3602 pass 

3603 self.node_pubsub_mapping.pop(name, None) 

3604 

3605 # Only suppress subscribe/unsubscribe messages, not data messages (smessage) 

3606 if str_if_bytes(message["type"]) in ("ssubscribe", "sunsubscribe"): 

3607 if self.ignore_subscribe_messages or ignore_subscribe_messages: 

3608 return None 

3609 return message 

3610 

3611 async def ssubscribe( 

3612 self, *args: ChannelT | Subscription, **kwargs: PubSubHandler 

3613 ) -> None: 

3614 """ 

3615 Subscribe to shard channels. 

3616 

3617 :param args: Channel names or ``Subscription`` objects 

3618 :param kwargs: Channel names with handlers 

3619 """ 

3620 s_channels = parse_pubsub_subscriptions(args, kwargs) 

3621 

3622 # Serialize against reinitialize_shard_subscriptions (background 

3623 # task) so the reverse index, shard_channels, and node_pubsub_mapping 

3624 # are not mutated concurrently. _migrate_shard_channel below does not 

3625 # re-acquire this lock (asyncio.Lock is non-reentrant). 

3626 async with self._shard_state_lock: 

3627 for s_channel, handler in s_channels.items(): 

3628 node = self.cluster.get_node_from_key(s_channel) 

3629 if not node: 

3630 continue 

3631 # Lazy re-route: if this channel is already tracked against a 

3632 # different node (e.g. after a slot migration), migrate it now 

3633 # so the caller's intent is applied on the current owner. 

3634 normalized_key = next(iter(self._normalize_keys({s_channel: None}))) 

3635 old_name = self._shard_channel_to_node.get(normalized_key) 

3636 if old_name and old_name != node.name: 

3637 # Match PubSub.ssubscribe() dict.update() semantics: the 

3638 # caller's newly supplied handler (including None) always 

3639 # overrides any previously registered handler. 

3640 await self._migrate_shard_channel( 

3641 normalized_key, 

3642 handler, 

3643 old_name, 

3644 node, 

3645 ) 

3646 continue 

3647 pubsub = self._get_node_pubsub(node) 

3648 if handler: 

3649 await pubsub.ssubscribe(Subscription(s_channel, handler)) 

3650 else: 

3651 await pubsub.ssubscribe(s_channel) 

3652 self.shard_channels.update(pubsub.shard_channels) 

3653 self._shard_channel_to_node[normalized_key] = node.name 

3654 self.pending_unsubscribe_shard_channels.difference_update( 

3655 self._normalize_keys({s_channel: None}) 

3656 ) 

3657 

3658 async def sunsubscribe(self, *args: Any) -> None: 

3659 """ 

3660 Unsubscribe from shard channels. 

3661 

3662 :param args: Channel names to unsubscribe from. If empty, unsubscribe from all. 

3663 """ 

3664 if args: 

3665 args = list_or_args(args[0], args[1:]) 

3666 else: 

3667 args = list(self.shard_channels.keys()) 

3668 

3669 # Serialize against reinitialize_shard_subscriptions: the reverse 

3670 # index and node_pubsub_mapping must not change between the lookup 

3671 # and the per-node sunsubscribe call below. 

3672 async with self._shard_state_lock: 

3673 for s_channel in args: 

3674 normalized_key = next(iter(self._normalize_keys({s_channel: None}))) 

3675 # Route via the reverse index so we unsubscribe on the node 

3676 # that actually holds the subscription. After a slot migration 

3677 # the cluster's current owner may no longer be that node. 

3678 name = self._shard_channel_to_node.get(normalized_key) 

3679 if name and name in self.node_pubsub_mapping: 

3680 pubsub = self.node_pubsub_mapping[name] 

3681 else: 

3682 node = self.cluster.get_node_from_key(s_channel) 

3683 if not node or node.name not in self.node_pubsub_mapping: 

3684 continue 

3685 pubsub = self.node_pubsub_mapping[node.name] 

3686 await pubsub.sunsubscribe(s_channel) 

3687 self.pending_unsubscribe_shard_channels.update( 

3688 pubsub.pending_unsubscribe_shard_channels 

3689 ) 

3690 

3691 async def reinitialize_shard_subscriptions(self) -> None: 

3692 """ 

3693 Reconcile per-node shard subscriptions against the cluster's current 

3694 slot ownership map. For each tracked shard channel whose owning node 

3695 has changed (e.g. after CLUSTER SETSLOT / failover), sunsubscribe on 

3696 the old node's pubsub and ssubscribe on the new owner's pubsub, 

3697 preserving any registered handler. 

3698 """ 

3699 uncovered: list = [] 

3700 made_progress = False 

3701 first_migrate_error: Optional[BaseException] = None 

3702 async with self._shard_state_lock: 

3703 for channel, handler in list(self.shard_channels.items()): 

3704 try: 

3705 new_node = self.cluster.get_node_from_key(channel) 

3706 except SlotNotCoveredError: 

3707 # Slot is transiently uncovered (mid-migration / partial 

3708 # topology refresh). Defer this channel so coverable 

3709 # siblings still reconcile this pass; we surface the 

3710 # error below so the caller (and logs) know not every 

3711 # channel was reconciled. Retry happens on the next 

3712 # slots-cache change notification. 

3713 uncovered.append(channel) 

3714 continue 

3715 old_name = self._shard_channel_to_node.get(channel) 

3716 if old_name == new_node.name: 

3717 continue 

3718 try: 

3719 await self._migrate_shard_channel( 

3720 channel, handler, old_name, new_node 

3721 ) 

3722 made_progress = True 

3723 except (ConnectionError, TimeoutError, OSError) as e: 

3724 # Transient connectivity error while subscribing on the 

3725 # new owner (or unsubscribing on the old owner if its 

3726 # handler chose to re-raise). Do not abort reconciliation 

3727 # for sibling channels: _shard_channel_to_node was not 

3728 # advanced for this channel, so the next slots-cache 

3729 # change notification will retry it. 

3730 logger.warning( 

3731 "shard channel %r migration deferred: %s: %s", 

3732 channel, 

3733 type(e).__name__, 

3734 e, 

3735 ) 

3736 if first_migrate_error is None: 

3737 first_migrate_error = e 

3738 continue 

3739 # Garbage-collect per-node pubsubs that no longer hold any 

3740 # subscription so their connections are released. 

3741 for name, pubsub in list(self.node_pubsub_mapping.items()): 

3742 if not pubsub.subscribed: 

3743 try: 

3744 await pubsub.aclose() 

3745 except Exception: 

3746 pass 

3747 self.node_pubsub_mapping.pop(name, None) 

3748 if uncovered: 

3749 # Surface the uncovered channels so the caller (and observer 

3750 # notification path) knows reconciliation was incomplete. All 

3751 # coverable siblings have already been migrated above. 

3752 raise SlotNotCoveredError( 

3753 f"{len(uncovered)} shard channel(s) left unreconciled; " 

3754 f"slot(s) not covered by the cluster: {uncovered!r}" 

3755 ) 

3756 if first_migrate_error is not None and not made_progress: 

3757 # Every migration attempted in this pass failed transiently and 

3758 # nothing else made progress. Re-raise the first caught error 

3759 # (typically the root cause; later failures are often downstream 

3760 # symptoms of the same unreachable node) so the task's done- 

3761 # callback surfaces a single representative failure through the 

3762 # same logger channel used for SlotNotCoveredError. Per-channel 

3763 # WARNINGs above preserve the full forensic detail. 

3764 raise first_migrate_error 

3765 

3766 async def _migrate_shard_channel( 

3767 self, 

3768 channel: Any, 

3769 handler: Optional[Callable], 

3770 old_name: Optional[str], 

3771 new_node: "ClusterNode", 

3772 ) -> None: 

3773 # Detach from the old per-node pubsub, best-effort: the old node may 

3774 # already be unreachable during migration / failover. 

3775 if old_name and old_name in self.node_pubsub_mapping: 

3776 old_pubsub = self.node_pubsub_mapping[old_name] 

3777 try: 

3778 await old_pubsub.sunsubscribe(channel) 

3779 except (ConnectionError, TimeoutError, OSError): 

3780 # redis-py's Connection has already called ``disconnect()`` 

3781 # before raising (see Connection.read_response / 

3782 # send_packed_command with ``disconnect_on_error=True``), 

3783 # so ``old_pubsub``'s dedicated socket is gone. Two cases: 

3784 # 

3785 # 1. The old node is no longer in the cluster topology 

3786 # (e.g. removed by failover / topology refresh): no 

3787 # reconnect target exists, so ``old_pubsub.subscribed`` 

3788 # would stay True forever and the end-of-pass GC block 

3789 # would skip it. Drop it eagerly so the round-robin 

3790 # generator does not keep yielding a dead pubsub that 

3791 # produces periodic errors from ``get_sharded_message``. 

3792 # 2. The old node is still known (transiently slow / 

3793 # unreachable): ``PubSub._execute`` auto-reconnects and 

3794 # ``on_connect`` re-subscribes to remaining channels, 

3795 # so other subscriptions on the same pubsub recover 

3796 # naturally. Leave it alone. 

3797 if self.cluster.get_node(node_name=old_name) is None: 

3798 try: 

3799 await old_pubsub.aclose() 

3800 except Exception: 

3801 pass 

3802 self.node_pubsub_mapping.pop(old_name, None) 

3803 # Attach to the new per-node pubsub, preserving the handler. Decode to 

3804 # a text key only when we must pass it as a kwarg (handler present). 

3805 new_pubsub = self._get_node_pubsub(new_node) 

3806 if handler: 

3807 await new_pubsub.ssubscribe(Subscription(channel, handler)) 

3808 else: 

3809 await new_pubsub.ssubscribe(channel) 

3810 self.shard_channels.update(new_pubsub.shard_channels) 

3811 normalized_key = next(iter(self._normalize_keys({channel: None}))) 

3812 self._shard_channel_to_node[normalized_key] = new_node.name 

3813 self.pending_unsubscribe_shard_channels.difference_update( 

3814 self._normalize_keys({channel: None}) 

3815 ) 

3816 

3817 async def on_slots_changed(self) -> None: 

3818 # Observer hook invoked by NodesManager after a slots-cache refresh. 

3819 # Schedule reconciliation as a separate task so the caller's code 

3820 # path (typically MovedError handling in _execute_command) is not 

3821 # blocked on the network I/O performed by reinitialize_shard_ 

3822 # subscriptions. No-op when there are no shard subscriptions to 

3823 # reconcile. 

3824 if not self.shard_channels: 

3825 return 

3826 task = asyncio.create_task(self.reinitialize_shard_subscriptions()) 

3827 self._reconcile_tasks.add(task) 

3828 task.add_done_callback(self._reconcile_tasks.discard) 

3829 # Consume the task's exception (if any) so Python does not emit a 

3830 # "Task exception was never retrieved" warning. reinitialize_shard_ 

3831 # subscriptions surfaces SlotNotCoveredError when a slot is still 

3832 # transiently uncovered; route it through the same logger channel 

3833 # as sync ClusterPubSubSlotsCacheListener for consistent observability. 

3834 task.add_done_callback(self._log_reconcile_task_exception) 

3835 

3836 @staticmethod 

3837 def _log_reconcile_task_exception(task: "asyncio.Task") -> None: 

3838 if task.cancelled(): 

3839 return 

3840 exc = task.exception() 

3841 if exc is not None: 

3842 logger.error( 

3843 "shard subscription reconciliation failed: %r", exc, exc_info=exc 

3844 ) 

3845 

3846 def get_redis_connection(self) -> Optional["AbstractConnection"]: 

3847 """ 

3848 Get the Redis connection of the pubsub connected node. 

3849 

3850 Returns the pubsub's dedicated connection (acquired from its own 

3851 connection pool), not from the ClusterNode's connection pool. 

3852 This avoids the connection pool resource leak that would occur 

3853 if we called node.acquire_connection() without releasing. 

3854 """ 

3855 # Return the pubsub's own dedicated connection, which is acquired 

3856 # from self.connection_pool when executing pubsub commands. 

3857 # This is safe because it's the connection dedicated to this pubsub 

3858 # instance, not a shared pool connection from the ClusterNode. 

3859 return self.connection 

3860 

3861 async def aclose(self) -> None: 

3862 """ 

3863 Disconnect the pubsub connection. 

3864 """ 

3865 # Cancel and gather in-flight reconciliation tasks BEFORE acquiring 

3866 # _shard_state_lock. The tasks themselves take that lock inside 

3867 # reinitialize_shard_subscriptions; since asyncio.Lock is non- 

3868 # reentrant, gathering while holding it would deadlock. Awaiting 

3869 # each task with suppressed CancelledError also avoids unhandled- 

3870 # exception warnings if the task was created but not yet scheduled. 

3871 if self._reconcile_tasks: 

3872 tasks = list(self._reconcile_tasks) 

3873 for task in tasks: 

3874 task.cancel() 

3875 await asyncio.gather(*tasks, return_exceptions=True) 

3876 # Hold _shard_state_lock across the rest of the teardown so it 

3877 # observes the same mutual-exclusion discipline as ssubscribe / 

3878 # sunsubscribe / get_sharded_message / reinitialize_shard_ 

3879 # subscriptions, which all mutate shard_channels, 

3880 # _shard_channel_to_node, and node_pubsub_mapping under this lock. 

3881 # Without it, super().aclose() rebinds shard_channels and 

3882 # pending_unsubscribe_shard_channels in parallel with a concurrent 

3883 # user-coroutine mutation that resumes during one of the awaits 

3884 # below, silently dropping subscription intent. 

3885 async with self._shard_state_lock: 

3886 self._reconcile_tasks.clear() 

3887 # Close all shard pubsub instances first 

3888 for pubsub in self.node_pubsub_mapping.values(): 

3889 await pubsub.aclose() 

3890 # Drop the now-dead per-node pubsubs from the mapping so the 

3891 # round-robin in _pubsubs_generator / _sharded_message_generator 

3892 # cannot yield them between teardown and re-subscription. 

3893 self.node_pubsub_mapping.clear() 

3894 # _pubsubs_generator captures node_pubsub_mapping.values() into 

3895 # a local list inside ``yield from``; clearing the mapping does 

3896 # not reach references already held by that captured snapshot, 

3897 # so a generator suspended mid-yield-from would still surface 

3898 # the now-aclose()'d per-node pubsubs after re-subscription. 

3899 # Recreate it to drop the captured list. type(self) bypasses 

3900 # the instance-level self-shadow established at __init__ 

3901 # (self._pubsubs_generator = self._pubsubs_generator()). 

3902 self._pubsubs_generator = type(self)._pubsubs_generator( # type: ignore[method-assign] 

3903 self 

3904 ) 

3905 # Let parent handle self.connection disconnect under the lock 

3906 # (includes disconnect, release to pool, and clearing 

3907 # self.connection) 

3908 await super().aclose() 

3909 # Clear the reverse index so a reused instance doesn't route 

3910 # against stale mappings. super().aclose() has already cleared 

3911 # shard_channels. 

3912 self._shard_channel_to_node.clear() 

3913 

3914 def _raise_on_invalid_node( 

3915 self, 

3916 redis_cluster: "RedisCluster", 

3917 node: Optional["ClusterNode"], 

3918 host: Optional[str], 

3919 port: Optional[int], 

3920 ) -> None: 

3921 """ 

3922 Raise a RedisClusterException if the node is None or doesn't exist in 

3923 the cluster. 

3924 """ 

3925 if node is None or redis_cluster.get_node(node_name=node.name) is None: 

3926 raise RedisClusterException( 

3927 f"Node {host}:{port} doesn't exist in the cluster" 

3928 ) 

3929 

3930 async def execute_command(self, *args: Any, **kwargs: Any) -> Any: 

3931 """ 

3932 Execute a command on the appropriate cluster node. 

3933 

3934 Taken code from redis-py and tweaked to make it work within a cluster. 

3935 """ 

3936 # NOTE: don't parse the response in this function -- it could pull a 

3937 # legitimate message off the stack if the connection is already 

3938 # subscribed to one or more channels 

3939 

3940 # For shard commands, route to appropriate node 

3941 command = args[0].upper() if args else "" 

3942 if command in ("SSUBSCRIBE", "SUNSUBSCRIBE", "SPUBLISH"): 

3943 if len(args) > 1: 

3944 channel = args[1] 

3945 node = self.cluster.get_node_from_key(channel) 

3946 if node: 

3947 pubsub = self._get_node_pubsub(node) 

3948 return await pubsub.execute_command(*args, **kwargs) 

3949 

3950 # For other commands, use the set node or lazily discover one 

3951 if self.connection is None: 

3952 if self.connection_pool is None: 

3953 if len(args) > 1: 

3954 # Hash the first channel and get one of the nodes holding 

3955 # this slot 

3956 channel = args[1] 

3957 slot = self.cluster.keyslot(channel) 

3958 node = self.cluster.nodes_manager.get_node_from_slot( 

3959 slot, 

3960 self.cluster.read_from_replicas, 

3961 self.cluster.load_balancing_strategy, 

3962 ) 

3963 else: 

3964 # Get a random node 

3965 node = self.cluster.get_random_node() 

3966 self.node = node 

3967 self.connection_pool = _ClusterNodePoolAdapter(node) 

3968 

3969 # Now we have a connection_pool, use parent's execute_command 

3970 return await super().execute_command(*args, **kwargs)