Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/cluster.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1164 statements  

1import asyncio 

2import collections 

3import random 

4import socket 

5import threading 

6import time 

7import warnings 

8from abc import ABC, abstractmethod 

9from copy import copy 

10from itertools import chain 

11from typing import ( 

12 Any, 

13 Callable, 

14 Coroutine, 

15 Deque, 

16 Dict, 

17 Generator, 

18 List, 

19 Mapping, 

20 Optional, 

21 Set, 

22 Tuple, 

23 Type, 

24 TypeVar, 

25 Union, 

26) 

27 

28from redis._parsers import AsyncCommandsParser, Encoder 

29from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy 

30from redis._parsers.helpers import ( 

31 _RedisCallbacks, 

32 _RedisCallbacksRESP2, 

33 _RedisCallbacksRESP3, 

34) 

35from redis.asyncio.client import ResponseCallbackT 

36from redis.asyncio.connection import Connection, SSLConnection, parse_url 

37from redis.asyncio.lock import Lock 

38from redis.asyncio.observability.recorder import ( 

39 record_error_count, 

40 record_operation_duration, 

41) 

42from redis.asyncio.retry import Retry 

43from redis.auth.token import TokenInterface 

44from redis.backoff import ExponentialWithJitterBackoff, NoBackoff 

45from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis 

46from redis.cluster import ( 

47 PIPELINE_BLOCKED_COMMANDS, 

48 PRIMARY, 

49 REPLICA, 

50 SLOT_ID, 

51 AbstractRedisCluster, 

52 LoadBalancer, 

53 LoadBalancingStrategy, 

54 block_pipeline_command, 

55 get_node_name, 

56 parse_cluster_slots, 

57) 

58from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands 

59from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver 

60from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot 

61from redis.credentials import CredentialProvider 

62from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher 

63from redis.exceptions import ( 

64 AskError, 

65 BusyLoadingError, 

66 ClusterDownError, 

67 ClusterError, 

68 ConnectionError, 

69 CrossSlotTransactionError, 

70 DataError, 

71 ExecAbortError, 

72 InvalidPipelineStack, 

73 MaxConnectionsError, 

74 MovedError, 

75 RedisClusterException, 

76 RedisError, 

77 ResponseError, 

78 SlotNotCoveredError, 

79 TimeoutError, 

80 TryAgainError, 

81 WatchError, 

82) 

83from redis.typing import AnyKeyT, EncodableT, KeyT 

84from redis.utils import ( 

85 SSL_AVAILABLE, 

86 deprecated_args, 

87 deprecated_function, 

88 get_lib_version, 

89 safe_str, 

90 str_if_bytes, 

91 truncate_text, 

92) 

93 

94if SSL_AVAILABLE: 

95 from ssl import TLSVersion, VerifyFlags, VerifyMode 

96else: 

97 TLSVersion = None 

98 VerifyMode = None 

99 VerifyFlags = None 

100 

101TargetNodesT = TypeVar( 

102 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] 

103) 

104 

105 

106class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

107 """ 

108 Create a new RedisCluster client. 

109 

110 Pass one of parameters: 

111 

112 - `host` & `port` 

113 - `startup_nodes` 

114 

115 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. 

116 | Use ``await`` :meth:`close` to disconnect connections & close client. 

117 

118 Many commands support the target_nodes kwarg. It can be one of the 

119 :attr:`NODE_FLAGS`: 

120 

121 - :attr:`PRIMARIES` 

122 - :attr:`REPLICAS` 

123 - :attr:`ALL_NODES` 

124 - :attr:`RANDOM` 

125 - :attr:`DEFAULT_NODE` 

126 

127 Note: This client is not thread/process/fork safe. 

128 

129 :param host: 

130 | Can be used to point to a startup node 

131 :param port: 

132 | Port used if **host** is provided 

133 :param startup_nodes: 

134 | :class:`~.ClusterNode` to used as a startup node 

135 :param require_full_coverage: 

136 | When set to ``False``: the client will not require a full coverage of 

137 the slots. However, if not all slots are covered, and at least one node 

138 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw 

139 a :class:`~.ClusterDownError` for some key-based commands. 

140 | When set to ``True``: all slots must be covered to construct the cluster 

141 client. If not all slots are covered, :class:`~.RedisClusterException` will be 

142 thrown. 

143 | See: 

144 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters 

145 :param read_from_replicas: 

146 | @deprecated - please use load_balancing_strategy instead 

147 | Enable read from replicas in READONLY mode. 

148 When set to true, read commands will be assigned between the primary and 

149 its replications in a Round-Robin manner. 

150 The data read from replicas is eventually consistent with the data in primary nodes. 

151 :param load_balancing_strategy: 

152 | Enable read from replicas in READONLY mode and defines the load balancing 

153 strategy that will be used for cluster node selection. 

154 The data read from replicas is eventually consistent with the data in primary nodes. 

155 :param dynamic_startup_nodes: 

156 | Set the RedisCluster's startup nodes to all the discovered nodes. 

157 If true (default value), the cluster's discovered nodes will be used to 

158 determine the cluster nodes-slots mapping in the next topology refresh. 

159 It will remove the initial passed startup nodes if their endpoints aren't 

160 listed in the CLUSTER SLOTS output. 

161 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists 

162 specific IP addresses, it is best to set it to false. 

163 :param reinitialize_steps: 

164 | Specifies the number of MOVED errors that need to occur before reinitializing 

165 the whole cluster topology. If a MOVED error occurs and the cluster does not 

166 need to be reinitialized on this current error handling, only the MOVED slot 

167 will be patched with the redirected node. 

168 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. 

169 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 

170 0. 

171 :param cluster_error_retry_attempts: 

172 | @deprecated - Please configure the 'retry' object instead 

173 In case 'retry' object is set - this argument is ignored! 

174 

175 Number of times to retry before raising an error when :class:`~.TimeoutError`, 

176 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` 

177 or :class:`~.ClusterDownError` are encountered 

178 :param retry: 

179 | A retry object that defines the retry strategy and the number of 

180 retries for the cluster client. 

181 In current implementation for the cluster client (starting form redis-py version 6.0.0) 

182 the retry object is not yet fully utilized, instead it is used just to determine 

183 the number of retries for the cluster client. 

184 In the future releases the retry object will be used to handle the cluster client retries! 

185 :param max_connections: 

186 | Maximum number of connections per node. If there are no free connections & the 

187 maximum number of connections are already created, a 

188 :class:`~.MaxConnectionsError` is raised. 

189 :param address_remap: 

190 | An optional callable which, when provided with an internal network 

191 address of a node, e.g. a `(host, port)` tuple, will return the address 

192 where the node is reachable. This can be used to map the addresses at 

193 which the nodes _think_ they are, to addresses at which a client may 

194 reach them, such as when they sit behind a proxy. 

195 

196 | Rest of the arguments will be passed to the 

197 :class:`~redis.asyncio.connection.Connection` instances when created 

198 

199 :raises RedisClusterException: 

200 if any arguments are invalid or unknown. Eg: 

201 

202 - `db` != 0 or None 

203 - `path` argument for unix socket connection 

204 - none of the `host`/`port` & `startup_nodes` were provided 

205 

206 """ 

207 

208 @classmethod 

209 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": 

210 """ 

211 Return a Redis client object configured from the given URL. 

212 

213 For example:: 

214 

215 redis://[[username]:[password]]@localhost:6379/0 

216 rediss://[[username]:[password]]@localhost:6379/0 

217 

218 Three URL schemes are supported: 

219 

220 - `redis://` creates a TCP socket connection. See more at: 

221 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

222 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

223 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

224 

225 The username, password, hostname, path and all querystring values are passed 

226 through ``urllib.parse.unquote`` in order to replace any percent-encoded values 

227 with their corresponding characters. 

228 

229 All querystring options are cast to their appropriate Python types. Boolean 

230 arguments can be specified with string values "True"/"False" or "Yes"/"No". 

231 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once 

232 parsed, the querystring arguments and keyword arguments are passed to 

233 :class:`~redis.asyncio.connection.Connection` when created. 

234 In the case of conflicting arguments, querystring arguments are used. 

235 """ 

236 kwargs.update(parse_url(url)) 

237 if kwargs.pop("connection_class", None) is SSLConnection: 

238 kwargs["ssl"] = True 

239 return cls(**kwargs) 

240 

241 __slots__ = ( 

242 "_initialize", 

243 "_lock", 

244 "retry", 

245 "command_flags", 

246 "commands_parser", 

247 "connection_kwargs", 

248 "encoder", 

249 "node_flags", 

250 "nodes_manager", 

251 "read_from_replicas", 

252 "reinitialize_counter", 

253 "reinitialize_steps", 

254 "response_callbacks", 

255 "result_callbacks", 

256 ) 

257 

258 @deprecated_args( 

259 args_to_warn=["read_from_replicas"], 

260 reason="Please configure the 'load_balancing_strategy' instead", 

261 version="5.3.0", 

262 ) 

263 @deprecated_args( 

264 args_to_warn=[ 

265 "cluster_error_retry_attempts", 

266 ], 

267 reason="Please configure the 'retry' object instead", 

268 version="6.0.0", 

269 ) 

270 def __init__( 

271 self, 

272 host: Optional[str] = None, 

273 port: Union[str, int] = 6379, 

274 # Cluster related kwargs 

275 startup_nodes: Optional[List["ClusterNode"]] = None, 

276 require_full_coverage: bool = True, 

277 read_from_replicas: bool = False, 

278 load_balancing_strategy: Optional[LoadBalancingStrategy] = None, 

279 dynamic_startup_nodes: bool = True, 

280 reinitialize_steps: int = 5, 

281 cluster_error_retry_attempts: int = 3, 

282 max_connections: int = 2**31, 

283 retry: Optional["Retry"] = None, 

284 retry_on_error: Optional[List[Type[Exception]]] = None, 

285 # Client related kwargs 

286 db: Union[str, int] = 0, 

287 path: Optional[str] = None, 

288 credential_provider: Optional[CredentialProvider] = None, 

289 username: Optional[str] = None, 

290 password: Optional[str] = None, 

291 client_name: Optional[str] = None, 

292 lib_name: Optional[str] = "redis-py", 

293 lib_version: Optional[str] = get_lib_version(), 

294 # Encoding related kwargs 

295 encoding: str = "utf-8", 

296 encoding_errors: str = "strict", 

297 decode_responses: bool = False, 

298 # Connection related kwargs 

299 health_check_interval: float = 0, 

300 socket_connect_timeout: Optional[float] = None, 

301 socket_keepalive: bool = False, 

302 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

303 socket_timeout: Optional[float] = None, 

304 # SSL related kwargs 

305 ssl: bool = False, 

306 ssl_ca_certs: Optional[str] = None, 

307 ssl_ca_data: Optional[str] = None, 

308 ssl_cert_reqs: Union[str, VerifyMode] = "required", 

309 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, 

310 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, 

311 ssl_certfile: Optional[str] = None, 

312 ssl_check_hostname: bool = True, 

313 ssl_keyfile: Optional[str] = None, 

314 ssl_min_version: Optional[TLSVersion] = None, 

315 ssl_ciphers: Optional[str] = None, 

316 protocol: Optional[int] = 2, 

317 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

318 event_dispatcher: Optional[EventDispatcher] = None, 

319 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(), 

320 ) -> None: 

321 if db: 

322 raise RedisClusterException( 

323 "Argument 'db' must be 0 or None in cluster mode" 

324 ) 

325 

326 if path: 

327 raise RedisClusterException( 

328 "Unix domain socket is not supported in cluster mode" 

329 ) 

330 

331 if (not host or not port) and not startup_nodes: 

332 raise RedisClusterException( 

333 "RedisCluster requires at least one node to discover the cluster.\n" 

334 "Please provide one of the following or use RedisCluster.from_url:\n" 

335 ' - host and port: RedisCluster(host="localhost", port=6379)\n' 

336 " - startup_nodes: RedisCluster(startup_nodes=[" 

337 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' 

338 ) 

339 

340 kwargs: Dict[str, Any] = { 

341 "max_connections": max_connections, 

342 "connection_class": Connection, 

343 # Client related kwargs 

344 "credential_provider": credential_provider, 

345 "username": username, 

346 "password": password, 

347 "client_name": client_name, 

348 "lib_name": lib_name, 

349 "lib_version": lib_version, 

350 # Encoding related kwargs 

351 "encoding": encoding, 

352 "encoding_errors": encoding_errors, 

353 "decode_responses": decode_responses, 

354 # Connection related kwargs 

355 "health_check_interval": health_check_interval, 

356 "socket_connect_timeout": socket_connect_timeout, 

357 "socket_keepalive": socket_keepalive, 

358 "socket_keepalive_options": socket_keepalive_options, 

359 "socket_timeout": socket_timeout, 

360 "protocol": protocol, 

361 } 

362 

363 if ssl: 

364 # SSL related kwargs 

365 kwargs.update( 

366 { 

367 "connection_class": SSLConnection, 

368 "ssl_ca_certs": ssl_ca_certs, 

369 "ssl_ca_data": ssl_ca_data, 

370 "ssl_cert_reqs": ssl_cert_reqs, 

371 "ssl_include_verify_flags": ssl_include_verify_flags, 

372 "ssl_exclude_verify_flags": ssl_exclude_verify_flags, 

373 "ssl_certfile": ssl_certfile, 

374 "ssl_check_hostname": ssl_check_hostname, 

375 "ssl_keyfile": ssl_keyfile, 

376 "ssl_min_version": ssl_min_version, 

377 "ssl_ciphers": ssl_ciphers, 

378 } 

379 ) 

380 

381 if read_from_replicas or load_balancing_strategy: 

382 # Call our on_connect function to configure READONLY mode 

383 kwargs["redis_connect_func"] = self.on_connect 

384 

385 if retry: 

386 self.retry = retry 

387 else: 

388 self.retry = Retry( 

389 backoff=ExponentialWithJitterBackoff(base=1, cap=10), 

390 retries=cluster_error_retry_attempts, 

391 ) 

392 if retry_on_error: 

393 self.retry.update_supported_errors(retry_on_error) 

394 

395 kwargs["response_callbacks"] = _RedisCallbacks.copy() 

396 if kwargs.get("protocol") in ["3", 3]: 

397 kwargs["response_callbacks"].update(_RedisCallbacksRESP3) 

398 else: 

399 kwargs["response_callbacks"].update(_RedisCallbacksRESP2) 

400 self.connection_kwargs = kwargs 

401 

402 if startup_nodes: 

403 passed_nodes = [] 

404 for node in startup_nodes: 

405 passed_nodes.append( 

406 ClusterNode(node.host, node.port, **self.connection_kwargs) 

407 ) 

408 startup_nodes = passed_nodes 

409 else: 

410 startup_nodes = [] 

411 if host and port: 

412 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) 

413 

414 if event_dispatcher is None: 

415 self._event_dispatcher = EventDispatcher() 

416 else: 

417 self._event_dispatcher = event_dispatcher 

418 

419 self.startup_nodes = startup_nodes 

420 self.nodes_manager = NodesManager( 

421 startup_nodes, 

422 require_full_coverage, 

423 kwargs, 

424 dynamic_startup_nodes=dynamic_startup_nodes, 

425 address_remap=address_remap, 

426 event_dispatcher=self._event_dispatcher, 

427 ) 

428 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

429 self.read_from_replicas = read_from_replicas 

430 self.load_balancing_strategy = load_balancing_strategy 

431 self.reinitialize_steps = reinitialize_steps 

432 self.reinitialize_counter = 0 

433 

434 # For backward compatibility, mapping from existing policies to new one 

435 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { 

436 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, 

437 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, 

438 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, 

439 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, 

440 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, 

441 SLOT_ID: RequestPolicy.DEFAULT_KEYED, 

442 } 

443 

444 self._policies_callback_mapping: dict[ 

445 Union[RequestPolicy, ResponsePolicy], Callable 

446 ] = { 

447 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ 

448 self.get_random_primary_or_all_nodes(command_name) 

449 ], 

450 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, 

451 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], 

452 RequestPolicy.ALL_SHARDS: self.get_primaries, 

453 RequestPolicy.ALL_NODES: self.get_nodes, 

454 RequestPolicy.ALL_REPLICAS: self.get_replicas, 

455 RequestPolicy.SPECIAL: self.get_special_nodes, 

456 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, 

457 ResponsePolicy.DEFAULT_KEYED: lambda res: res, 

458 } 

459 

460 self._policy_resolver = policy_resolver 

461 self.commands_parser = AsyncCommandsParser() 

462 self._aggregate_nodes = None 

463 self.node_flags = self.__class__.NODE_FLAGS.copy() 

464 self.command_flags = self.__class__.COMMAND_FLAGS.copy() 

465 self.response_callbacks = kwargs["response_callbacks"] 

466 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() 

467 self.result_callbacks["CLUSTER SLOTS"] = ( 

468 lambda cmd, res, **kwargs: parse_cluster_slots( 

469 list(res.values())[0], **kwargs 

470 ) 

471 ) 

472 

473 self._initialize = True 

474 self._lock: Optional[asyncio.Lock] = None 

475 

476 # When used as an async context manager, we need to increment and decrement 

477 # a usage counter so that we can close the connection pool when no one is 

478 # using the client. 

479 self._usage_counter = 0 

480 self._usage_lock = asyncio.Lock() 

481 

482 async def initialize(self) -> "RedisCluster": 

483 """Get all nodes from startup nodes & creates connections if not initialized.""" 

484 if self._initialize: 

485 if not self._lock: 

486 self._lock = asyncio.Lock() 

487 async with self._lock: 

488 if self._initialize: 

489 try: 

490 await self.nodes_manager.initialize() 

491 await self.commands_parser.initialize( 

492 self.nodes_manager.default_node 

493 ) 

494 self._initialize = False 

495 except BaseException: 

496 await self.nodes_manager.aclose() 

497 await self.nodes_manager.aclose("startup_nodes") 

498 raise 

499 return self 

500 

501 async def aclose(self) -> None: 

502 """Close all connections & client if initialized.""" 

503 if not self._initialize: 

504 if not self._lock: 

505 self._lock = asyncio.Lock() 

506 async with self._lock: 

507 if not self._initialize: 

508 self._initialize = True 

509 await self.nodes_manager.aclose() 

510 await self.nodes_manager.aclose("startup_nodes") 

511 

512 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") 

513 async def close(self) -> None: 

514 """alias for aclose() for backwards compatibility""" 

515 await self.aclose() 

516 

517 async def __aenter__(self) -> "RedisCluster": 

518 """ 

519 Async context manager entry. Increments a usage counter so that the 

520 connection pool is only closed (via aclose()) when no context is using 

521 the client. 

522 """ 

523 await self._increment_usage() 

524 try: 

525 # Initialize the client (i.e. establish connection, etc.) 

526 return await self.initialize() 

527 except Exception: 

528 # If initialization fails, decrement the counter to keep it in sync 

529 await self._decrement_usage() 

530 raise 

531 

532 async def _increment_usage(self) -> int: 

533 """ 

534 Helper coroutine to increment the usage counter while holding the lock. 

535 Returns the new value of the usage counter. 

536 """ 

537 async with self._usage_lock: 

538 self._usage_counter += 1 

539 return self._usage_counter 

540 

541 async def _decrement_usage(self) -> int: 

542 """ 

543 Helper coroutine to decrement the usage counter while holding the lock. 

544 Returns the new value of the usage counter. 

545 """ 

546 async with self._usage_lock: 

547 self._usage_counter -= 1 

548 return self._usage_counter 

549 

550 async def __aexit__(self, exc_type, exc_value, traceback): 

551 """ 

552 Async context manager exit. Decrements a usage counter. If this is the 

553 last exit (counter becomes zero), the client closes its connection pool. 

554 """ 

555 current_usage = await asyncio.shield(self._decrement_usage()) 

556 if current_usage == 0: 

557 # This was the last active context, so disconnect the pool. 

558 await asyncio.shield(self.aclose()) 

559 

560 def __await__(self) -> Generator[Any, None, "RedisCluster"]: 

561 return self.initialize().__await__() 

562 

563 _DEL_MESSAGE = "Unclosed RedisCluster client" 

564 

565 def __del__( 

566 self, 

567 _warn: Any = warnings.warn, 

568 _grl: Any = asyncio.get_running_loop, 

569 ) -> None: 

570 if hasattr(self, "_initialize") and not self._initialize: 

571 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

572 try: 

573 context = {"client": self, "message": self._DEL_MESSAGE} 

574 _grl().call_exception_handler(context) 

575 except RuntimeError: 

576 pass 

577 

578 async def on_connect(self, connection: Connection) -> None: 

579 await connection.on_connect() 

580 

581 # Sending READONLY command to server to configure connection as 

582 # readonly. Since each cluster node may change its server type due 

583 # to a failover, we should establish a READONLY connection 

584 # regardless of the server type. If this is a primary connection, 

585 # READONLY would not affect executing write commands. 

586 await connection.send_command("READONLY") 

587 if str_if_bytes(await connection.read_response()) != "OK": 

588 raise ConnectionError("READONLY command failed") 

589 

590 def get_nodes(self) -> List["ClusterNode"]: 

591 """Get all nodes of the cluster.""" 

592 return list(self.nodes_manager.nodes_cache.values()) 

593 

594 def get_primaries(self) -> List["ClusterNode"]: 

595 """Get the primary nodes of the cluster.""" 

596 return self.nodes_manager.get_nodes_by_server_type(PRIMARY) 

597 

598 def get_replicas(self) -> List["ClusterNode"]: 

599 """Get the replica nodes of the cluster.""" 

600 return self.nodes_manager.get_nodes_by_server_type(REPLICA) 

601 

602 def get_random_node(self) -> "ClusterNode": 

603 """Get a random node of the cluster.""" 

604 return random.choice(list(self.nodes_manager.nodes_cache.values())) 

605 

606 def get_default_node(self) -> "ClusterNode": 

607 """Get the default node of the client.""" 

608 return self.nodes_manager.default_node 

609 

610 def set_default_node(self, node: "ClusterNode") -> None: 

611 """ 

612 Set the default node of the client. 

613 

614 :raises DataError: if None is passed or node does not exist in cluster. 

615 """ 

616 if not node or not self.get_node(node_name=node.name): 

617 raise DataError("The requested node does not exist in the cluster.") 

618 

619 self.nodes_manager.default_node = node 

620 

621 def get_node( 

622 self, 

623 host: Optional[str] = None, 

624 port: Optional[int] = None, 

625 node_name: Optional[str] = None, 

626 ) -> Optional["ClusterNode"]: 

627 """Get node by (host, port) or node_name.""" 

628 return self.nodes_manager.get_node(host, port, node_name) 

629 

630 def get_node_from_key( 

631 self, key: str, replica: bool = False 

632 ) -> Optional["ClusterNode"]: 

633 """ 

634 Get the cluster node corresponding to the provided key. 

635 

636 :param key: 

637 :param replica: 

638 | Indicates if a replica should be returned 

639 | 

640 None will returned if no replica holds this key 

641 

642 :raises SlotNotCoveredError: if the key is not covered by any slot. 

643 """ 

644 slot = self.keyslot(key) 

645 slot_cache = self.nodes_manager.slots_cache.get(slot) 

646 if not slot_cache: 

647 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') 

648 

649 if replica: 

650 if len(self.nodes_manager.slots_cache[slot]) < 2: 

651 return None 

652 node_idx = 1 

653 else: 

654 node_idx = 0 

655 

656 return slot_cache[node_idx] 

657 

658 def get_random_primary_or_all_nodes(self, command_name): 

659 """ 

660 Returns random primary or all nodes depends on READONLY mode. 

661 """ 

662 if self.read_from_replicas and command_name in READ_COMMANDS: 

663 return self.get_random_node() 

664 

665 return self.get_random_primary_node() 

666 

667 def get_random_primary_node(self) -> "ClusterNode": 

668 """ 

669 Returns a random primary node 

670 """ 

671 return random.choice(self.get_primaries()) 

672 

673 async def get_nodes_from_slot(self, command: str, *args): 

674 """ 

675 Returns a list of nodes that hold the specified keys' slots. 

676 """ 

677 # get the node that holds the key's slot 

678 return [ 

679 self.nodes_manager.get_node_from_slot( 

680 await self._determine_slot(command, *args), 

681 self.read_from_replicas and command in READ_COMMANDS, 

682 self.load_balancing_strategy if command in READ_COMMANDS else None, 

683 ) 

684 ] 

685 

686 def get_special_nodes(self) -> Optional[list["ClusterNode"]]: 

687 """ 

688 Returns a list of nodes for commands with a special policy. 

689 """ 

690 if not self._aggregate_nodes: 

691 raise RedisClusterException( 

692 "Cannot execute FT.CURSOR commands without FT.AGGREGATE" 

693 ) 

694 

695 return self._aggregate_nodes 

696 

697 def keyslot(self, key: EncodableT) -> int: 

698 """ 

699 Find the keyslot for a given key. 

700 

701 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding 

702 """ 

703 return key_slot(self.encoder.encode(key)) 

704 

705 def get_encoder(self) -> Encoder: 

706 """Get the encoder object of the client.""" 

707 return self.encoder 

708 

709 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: 

710 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" 

711 return self.connection_kwargs 

712 

713 def set_retry(self, retry: Retry) -> None: 

714 self.retry = retry 

715 

716 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

717 """Set a custom response callback.""" 

718 self.response_callbacks[command] = callback 

719 

720 async def _determine_nodes( 

721 self, 

722 command: str, 

723 *args: Any, 

724 request_policy: RequestPolicy, 

725 node_flag: Optional[str] = None, 

726 ) -> List["ClusterNode"]: 

727 # Determine which nodes should be executed the command on. 

728 # Returns a list of target nodes. 

729 if not node_flag: 

730 # get the nodes group for this command if it was predefined 

731 node_flag = self.command_flags.get(command) 

732 

733 if node_flag in self._command_flags_mapping: 

734 request_policy = self._command_flags_mapping[node_flag] 

735 

736 policy_callback = self._policies_callback_mapping[request_policy] 

737 

738 if request_policy == RequestPolicy.DEFAULT_KEYED: 

739 nodes = await policy_callback(command, *args) 

740 elif request_policy == RequestPolicy.DEFAULT_KEYLESS: 

741 nodes = policy_callback(command) 

742 else: 

743 nodes = policy_callback() 

744 

745 if command.lower() == "ft.aggregate": 

746 self._aggregate_nodes = nodes 

747 

748 return nodes 

749 

750 async def _determine_slot(self, command: str, *args: Any) -> int: 

751 if self.command_flags.get(command) == SLOT_ID: 

752 # The command contains the slot ID 

753 return int(args[0]) 

754 

755 # Get the keys in the command 

756 

757 # EVAL and EVALSHA are common enough that it's wasteful to go to the 

758 # redis server to parse the keys. Besides, there is a bug in redis<7.0 

759 # where `self._get_command_keys()` fails anyway. So, we special case 

760 # EVAL/EVALSHA. 

761 # - issue: https://github.com/redis/redis/issues/9493 

762 # - fix: https://github.com/redis/redis/pull/9733 

763 if command.upper() in ("EVAL", "EVALSHA"): 

764 # command syntax: EVAL "script body" num_keys ... 

765 if len(args) < 2: 

766 raise RedisClusterException( 

767 f"Invalid args in command: {command, *args}" 

768 ) 

769 keys = args[2 : 2 + int(args[1])] 

770 # if there are 0 keys, that means the script can be run on any node 

771 # so we can just return a random slot 

772 if not keys: 

773 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

774 else: 

775 keys = await self.commands_parser.get_keys(command, *args) 

776 if not keys: 

777 # FCALL can call a function with 0 keys, that means the function 

778 # can be run on any node so we can just return a random slot 

779 if command.upper() in ("FCALL", "FCALL_RO"): 

780 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

781 raise RedisClusterException( 

782 "No way to dispatch this command to Redis Cluster. " 

783 "Missing key.\nYou can execute the command by specifying " 

784 f"target nodes.\nCommand: {args}" 

785 ) 

786 

787 # single key command 

788 if len(keys) == 1: 

789 return self.keyslot(keys[0]) 

790 

791 # multi-key command; we need to make sure all keys are mapped to 

792 # the same slot 

793 slots = {self.keyslot(key) for key in keys} 

794 if len(slots) != 1: 

795 raise RedisClusterException( 

796 f"{command} - all keys must map to the same key slot" 

797 ) 

798 

799 return slots.pop() 

800 

801 def _is_node_flag(self, target_nodes: Any) -> bool: 

802 return isinstance(target_nodes, str) and target_nodes in self.node_flags 

803 

804 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: 

805 if isinstance(target_nodes, list): 

806 nodes = target_nodes 

807 elif isinstance(target_nodes, ClusterNode): 

808 # Supports passing a single ClusterNode as a variable 

809 nodes = [target_nodes] 

810 elif isinstance(target_nodes, dict): 

811 # Supports dictionaries of the format {node_name: node}. 

812 # It enables to execute commands with multi nodes as follows: 

813 # rc.cluster_save_config(rc.get_primaries()) 

814 nodes = list(target_nodes.values()) 

815 else: 

816 raise TypeError( 

817 "target_nodes type can be one of the following: " 

818 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," 

819 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. " 

820 f"The passed type is {type(target_nodes)}" 

821 ) 

822 return nodes 

823 

824 async def _record_error_metric( 

825 self, 

826 error: Exception, 

827 connection: Union[Connection, "ClusterNode"], 

828 is_internal: bool = True, 

829 retry_attempts: Optional[int] = None, 

830 ): 

831 """ 

832 Records error count metric directly. 

833 Accepts either a Connection or ClusterNode object. 

834 """ 

835 await record_error_count( 

836 server_address=connection.host, 

837 server_port=connection.port, 

838 network_peer_address=connection.host, 

839 network_peer_port=connection.port, 

840 error_type=error, 

841 retry_attempts=retry_attempts if retry_attempts is not None else 0, 

842 is_internal=is_internal, 

843 ) 

844 

845 async def _record_command_metric( 

846 self, 

847 command_name: str, 

848 duration_seconds: float, 

849 connection: Union[Connection, "ClusterNode"], 

850 error: Optional[Exception] = None, 

851 ): 

852 """ 

853 Records operation duration metric directly. 

854 Accepts either a Connection or ClusterNode object. 

855 """ 

856 # Connection has db attribute, ClusterNode has connection_kwargs 

857 if hasattr(connection, "db"): 

858 db = connection.db 

859 else: 

860 db = connection.connection_kwargs.get("db", 0) 

861 await record_operation_duration( 

862 command_name=command_name, 

863 duration_seconds=duration_seconds, 

864 server_address=connection.host, 

865 server_port=connection.port, 

866 db_namespace=str(db) if db is not None else None, 

867 error=error, 

868 ) 

869 

870 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: 

871 """ 

872 Execute a raw command on the appropriate cluster node or target_nodes. 

873 

874 It will retry the command as specified by the retries property of 

875 the :attr:`retry` & then raise an exception. 

876 

877 :param args: 

878 | Raw command args 

879 :param kwargs: 

880 

881 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

882 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

883 - Rest of the kwargs are passed to the Redis connection 

884 

885 :raises RedisClusterException: if target_nodes is not provided & the command 

886 can't be mapped to a slot 

887 """ 

888 command = args[0] 

889 target_nodes = [] 

890 target_nodes_specified = False 

891 retry_attempts = self.retry.get_retries() 

892 

893 passed_targets = kwargs.pop("target_nodes", None) 

894 if passed_targets and not self._is_node_flag(passed_targets): 

895 target_nodes = self._parse_target_nodes(passed_targets) 

896 target_nodes_specified = True 

897 retry_attempts = 0 

898 

899 command_policies = await self._policy_resolver.resolve(args[0].lower()) 

900 

901 if not command_policies and not target_nodes_specified: 

902 command_flag = self.command_flags.get(command) 

903 if not command_flag: 

904 # Fallback to default policy 

905 if not self.get_default_node(): 

906 slot = None 

907 else: 

908 slot = await self._determine_slot(*args) 

909 if slot is None: 

910 command_policies = CommandPolicies() 

911 else: 

912 command_policies = CommandPolicies( 

913 request_policy=RequestPolicy.DEFAULT_KEYED, 

914 response_policy=ResponsePolicy.DEFAULT_KEYED, 

915 ) 

916 else: 

917 if command_flag in self._command_flags_mapping: 

918 command_policies = CommandPolicies( 

919 request_policy=self._command_flags_mapping[command_flag] 

920 ) 

921 else: 

922 command_policies = CommandPolicies() 

923 elif not command_policies and target_nodes_specified: 

924 command_policies = CommandPolicies() 

925 

926 # Add one for the first execution 

927 execute_attempts = 1 + retry_attempts 

928 failure_count = 0 

929 

930 # Start timing for observability 

931 start_time = time.monotonic() 

932 

933 for _ in range(execute_attempts): 

934 if self._initialize: 

935 await self.initialize() 

936 if ( 

937 len(target_nodes) == 1 

938 and target_nodes[0] == self.get_default_node() 

939 ): 

940 # Replace the default cluster node 

941 self.replace_default_node() 

942 try: 

943 if not target_nodes_specified: 

944 # Determine the nodes to execute the command on 

945 target_nodes = await self._determine_nodes( 

946 *args, 

947 request_policy=command_policies.request_policy, 

948 node_flag=passed_targets, 

949 ) 

950 if not target_nodes: 

951 raise RedisClusterException( 

952 f"No targets were found to execute {args} command on" 

953 ) 

954 

955 if len(target_nodes) == 1: 

956 # Return the processed result 

957 ret = await self._execute_command(target_nodes[0], *args, **kwargs) 

958 if command in self.result_callbacks: 

959 ret = self.result_callbacks[command]( 

960 command, {target_nodes[0].name: ret}, **kwargs 

961 ) 

962 return self._policies_callback_mapping[ 

963 command_policies.response_policy 

964 ](ret) 

965 else: 

966 keys = [node.name for node in target_nodes] 

967 values = await asyncio.gather( 

968 *( 

969 asyncio.create_task( 

970 self._execute_command(node, *args, **kwargs) 

971 ) 

972 for node in target_nodes 

973 ) 

974 ) 

975 if command in self.result_callbacks: 

976 return self.result_callbacks[command]( 

977 command, dict(zip(keys, values)), **kwargs 

978 ) 

979 return self._policies_callback_mapping[ 

980 command_policies.response_policy 

981 ](dict(zip(keys, values))) 

982 except Exception as e: 

983 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: 

984 # The nodes and slots cache were should be reinitialized. 

985 # Try again with the new cluster setup. 

986 retry_attempts -= 1 

987 failure_count += 1 

988 

989 if hasattr(e, "connection"): 

990 await self._record_command_metric( 

991 command_name=command, 

992 duration_seconds=time.monotonic() - start_time, 

993 connection=e.connection, 

994 error=e, 

995 ) 

996 await self._record_error_metric( 

997 error=e, 

998 connection=e.connection, 

999 retry_attempts=failure_count, 

1000 ) 

1001 continue 

1002 else: 

1003 # raise the exception 

1004 if hasattr(e, "connection"): 

1005 await self._record_error_metric( 

1006 error=e, 

1007 connection=e.connection, 

1008 retry_attempts=failure_count, 

1009 is_internal=False, 

1010 ) 

1011 raise e 

1012 

1013 async def _execute_command( 

1014 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any 

1015 ) -> Any: 

1016 asking = moved = False 

1017 redirect_addr = None 

1018 ttl = self.RedisClusterRequestTTL 

1019 command = args[0] 

1020 start_time = time.monotonic() 

1021 

1022 while ttl > 0: 

1023 ttl -= 1 

1024 try: 

1025 if asking: 

1026 target_node = self.get_node(node_name=redirect_addr) 

1027 await target_node.execute_command("ASKING") 

1028 asking = False 

1029 elif moved: 

1030 # MOVED occurred and the slots cache was updated, 

1031 # refresh the target node 

1032 slot = await self._determine_slot(*args) 

1033 target_node = self.nodes_manager.get_node_from_slot( 

1034 slot, 

1035 self.read_from_replicas and args[0] in READ_COMMANDS, 

1036 self.load_balancing_strategy 

1037 if args[0] in READ_COMMANDS 

1038 else None, 

1039 ) 

1040 moved = False 

1041 

1042 response = await target_node.execute_command(*args, **kwargs) 

1043 await self._record_command_metric( 

1044 command_name=command, 

1045 duration_seconds=time.monotonic() - start_time, 

1046 connection=target_node, 

1047 ) 

1048 return response 

1049 except BusyLoadingError as e: 

1050 e.connection = target_node 

1051 await self._record_command_metric( 

1052 command_name=command, 

1053 duration_seconds=time.monotonic() - start_time, 

1054 connection=target_node, 

1055 error=e, 

1056 ) 

1057 raise 

1058 except MaxConnectionsError as e: 

1059 # MaxConnectionsError indicates client-side resource exhaustion 

1060 # (too many connections in the pool), not a node failure. 

1061 # Don't treat this as a node failure - just re-raise the error 

1062 # without reinitializing the cluster. 

1063 e.connection = target_node 

1064 await self._record_command_metric( 

1065 command_name=command, 

1066 duration_seconds=time.monotonic() - start_time, 

1067 connection=target_node, 

1068 error=e, 

1069 ) 

1070 raise 

1071 except (ConnectionError, TimeoutError) as e: 

1072 # Connection retries are being handled in the node's 

1073 # Retry object. 

1074 # Mark active connections for reconnect and disconnect free ones 

1075 # This handles connection state (like READONLY) that may be stale 

1076 target_node.update_active_connections_for_reconnect() 

1077 await target_node.disconnect_free_connections() 

1078 

1079 # Move the failed node to the end of the cached nodes list 

1080 # so it's tried last during reinitialization 

1081 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name) 

1082 

1083 # Signal that reinitialization is needed 

1084 # The retry loop will handle initialize() AND replace_default_node() 

1085 self._initialize = True 

1086 e.connection = target_node 

1087 await self._record_command_metric( 

1088 command_name=command, 

1089 duration_seconds=time.monotonic() - start_time, 

1090 connection=target_node, 

1091 error=e, 

1092 ) 

1093 raise 

1094 except (ClusterDownError, SlotNotCoveredError) as e: 

1095 # ClusterDownError can occur during a failover and to get 

1096 # self-healed, we will try to reinitialize the cluster layout 

1097 # and retry executing the command 

1098 

1099 # SlotNotCoveredError can occur when the cluster is not fully 

1100 # initialized or can be temporary issue. 

1101 # We will try to reinitialize the cluster topology 

1102 # and retry executing the command 

1103 

1104 await self.aclose() 

1105 await asyncio.sleep(0.25) 

1106 e.connection = target_node 

1107 await self._record_command_metric( 

1108 command_name=command, 

1109 duration_seconds=time.monotonic() - start_time, 

1110 connection=target_node, 

1111 error=e, 

1112 ) 

1113 raise 

1114 except MovedError as e: 

1115 # First, we will try to patch the slots/nodes cache with the 

1116 # redirected node output and try again. If MovedError exceeds 

1117 # 'reinitialize_steps' number of times, we will force 

1118 # reinitializing the tables, and then try again. 

1119 # 'reinitialize_steps' counter will increase faster when 

1120 # the same client object is shared between multiple threads. To 

1121 # reduce the frequency you can set this variable in the 

1122 # RedisCluster constructor. 

1123 self.reinitialize_counter += 1 

1124 if ( 

1125 self.reinitialize_steps 

1126 and self.reinitialize_counter % self.reinitialize_steps == 0 

1127 ): 

1128 await self.aclose() 

1129 # Reset the counter 

1130 self.reinitialize_counter = 0 

1131 else: 

1132 self.nodes_manager.move_slot(e) 

1133 moved = True 

1134 await self._record_command_metric( 

1135 command_name=command, 

1136 duration_seconds=time.monotonic() - start_time, 

1137 connection=target_node, 

1138 error=e, 

1139 ) 

1140 await self._record_error_metric( 

1141 error=e, 

1142 connection=target_node, 

1143 ) 

1144 except AskError as e: 

1145 redirect_addr = get_node_name(host=e.host, port=e.port) 

1146 asking = True 

1147 await self._record_command_metric( 

1148 command_name=command, 

1149 duration_seconds=time.monotonic() - start_time, 

1150 connection=target_node, 

1151 error=e, 

1152 ) 

1153 await self._record_error_metric( 

1154 error=e, 

1155 connection=target_node, 

1156 ) 

1157 except TryAgainError as e: 

1158 if ttl < self.RedisClusterRequestTTL / 2: 

1159 await asyncio.sleep(0.05) 

1160 await self._record_command_metric( 

1161 command_name=command, 

1162 duration_seconds=time.monotonic() - start_time, 

1163 connection=target_node, 

1164 error=e, 

1165 ) 

1166 await self._record_error_metric( 

1167 error=e, 

1168 connection=target_node, 

1169 ) 

1170 except ResponseError as e: 

1171 e.connection = target_node 

1172 await self._record_command_metric( 

1173 command_name=command, 

1174 duration_seconds=time.monotonic() - start_time, 

1175 connection=target_node, 

1176 error=e, 

1177 ) 

1178 raise 

1179 except Exception as e: 

1180 e.connection = target_node 

1181 await self._record_command_metric( 

1182 command_name=command, 

1183 duration_seconds=time.monotonic() - start_time, 

1184 connection=target_node, 

1185 error=e, 

1186 ) 

1187 raise 

1188 

1189 e = ClusterError("TTL exhausted.") 

1190 e.connection = target_node 

1191 await self._record_command_metric( 

1192 command_name=command, 

1193 duration_seconds=time.monotonic() - start_time, 

1194 connection=target_node, 

1195 error=e, 

1196 ) 

1197 raise e 

1198 

1199 def pipeline( 

1200 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None 

1201 ) -> "ClusterPipeline": 

1202 """ 

1203 Create & return a new :class:`~.ClusterPipeline` object. 

1204 

1205 Cluster implementation of pipeline does not support transaction or shard_hint. 

1206 

1207 :raises RedisClusterException: if transaction or shard_hint are truthy values 

1208 """ 

1209 if shard_hint: 

1210 raise RedisClusterException("shard_hint is deprecated in cluster mode") 

1211 

1212 return ClusterPipeline(self, transaction) 

1213 

1214 def lock( 

1215 self, 

1216 name: KeyT, 

1217 timeout: Optional[float] = None, 

1218 sleep: float = 0.1, 

1219 blocking: bool = True, 

1220 blocking_timeout: Optional[float] = None, 

1221 lock_class: Optional[Type[Lock]] = None, 

1222 thread_local: bool = True, 

1223 raise_on_release_error: bool = True, 

1224 ) -> Lock: 

1225 """ 

1226 Return a new Lock object using key ``name`` that mimics 

1227 the behavior of threading.Lock. 

1228 

1229 If specified, ``timeout`` indicates a maximum life for the lock. 

1230 By default, it will remain locked until release() is called. 

1231 

1232 ``sleep`` indicates the amount of time to sleep per loop iteration 

1233 when the lock is in blocking mode and another client is currently 

1234 holding the lock. 

1235 

1236 ``blocking`` indicates whether calling ``acquire`` should block until 

1237 the lock has been acquired or to fail immediately, causing ``acquire`` 

1238 to return False and the lock not being acquired. Defaults to True. 

1239 Note this value can be overridden by passing a ``blocking`` 

1240 argument to ``acquire``. 

1241 

1242 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

1243 spend trying to acquire the lock. A value of ``None`` indicates 

1244 continue trying forever. ``blocking_timeout`` can be specified as a 

1245 float or integer, both representing the number of seconds to wait. 

1246 

1247 ``lock_class`` forces the specified lock implementation. Note that as 

1248 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is 

1249 a Lua-based lock). So, it's unlikely you'll need this parameter, unless 

1250 you have created your own custom lock class. 

1251 

1252 ``thread_local`` indicates whether the lock token is placed in 

1253 thread-local storage. By default, the token is placed in thread local 

1254 storage so that a thread only sees its token, not a token set by 

1255 another thread. Consider the following timeline: 

1256 

1257 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

1258 thread-1 sets the token to "abc" 

1259 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

1260 Lock instance. 

1261 time: 5, thread-1 has not yet completed. redis expires the lock 

1262 key. 

1263 time: 5, thread-2 acquired `my-lock` now that it's available. 

1264 thread-2 sets the token to "xyz" 

1265 time: 6, thread-1 finishes its work and calls release(). if the 

1266 token is *not* stored in thread local storage, then 

1267 thread-1 would see the token value as "xyz" and would be 

1268 able to successfully release the thread-2's lock. 

1269 

1270 ``raise_on_release_error`` indicates whether to raise an exception when 

1271 the lock is no longer owned when exiting the context manager. By default, 

1272 this is True, meaning an exception will be raised. If False, the warning 

1273 will be logged and the exception will be suppressed. 

1274 

1275 In some use cases it's necessary to disable thread local storage. For 

1276 example, if you have code where one thread acquires a lock and passes 

1277 that lock instance to a worker thread to release later. If thread 

1278 local storage isn't disabled in this case, the worker thread won't see 

1279 the token set by the thread that acquired the lock. Our assumption 

1280 is that these cases aren't common and as such default to using 

1281 thread local storage.""" 

1282 if lock_class is None: 

1283 lock_class = Lock 

1284 return lock_class( 

1285 self, 

1286 name, 

1287 timeout=timeout, 

1288 sleep=sleep, 

1289 blocking=blocking, 

1290 blocking_timeout=blocking_timeout, 

1291 thread_local=thread_local, 

1292 raise_on_release_error=raise_on_release_error, 

1293 ) 

1294 

1295 async def transaction( 

1296 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs 

1297 ): 

1298 """ 

1299 Convenience method for executing the callable `func` as a transaction 

1300 while watching all keys specified in `watches`. The 'func' callable 

1301 should expect a single argument which is a Pipeline object. 

1302 """ 

1303 shard_hint = kwargs.pop("shard_hint", None) 

1304 value_from_callable = kwargs.pop("value_from_callable", False) 

1305 watch_delay = kwargs.pop("watch_delay", None) 

1306 async with self.pipeline(True, shard_hint) as pipe: 

1307 while True: 

1308 try: 

1309 if watches: 

1310 await pipe.watch(*watches) 

1311 func_value = await func(pipe) 

1312 exec_value = await pipe.execute() 

1313 return func_value if value_from_callable else exec_value 

1314 except WatchError: 

1315 if watch_delay is not None and watch_delay > 0: 

1316 time.sleep(watch_delay) 

1317 continue 

1318 

1319 

1320class ClusterNode: 

1321 """ 

1322 Create a new ClusterNode. 

1323 

1324 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` 

1325 objects for the (host, port). 

1326 """ 

1327 

1328 __slots__ = ( 

1329 "_connections", 

1330 "_free", 

1331 "_lock", 

1332 "_event_dispatcher", 

1333 "connection_class", 

1334 "connection_kwargs", 

1335 "host", 

1336 "max_connections", 

1337 "name", 

1338 "port", 

1339 "response_callbacks", 

1340 "server_type", 

1341 ) 

1342 

1343 def __init__( 

1344 self, 

1345 host: str, 

1346 port: Union[str, int], 

1347 server_type: Optional[str] = None, 

1348 *, 

1349 max_connections: int = 2**31, 

1350 connection_class: Type[Connection] = Connection, 

1351 **connection_kwargs: Any, 

1352 ) -> None: 

1353 if host == "localhost": 

1354 host = socket.gethostbyname(host) 

1355 

1356 connection_kwargs["host"] = host 

1357 connection_kwargs["port"] = port 

1358 self.host = host 

1359 self.port = port 

1360 self.name = get_node_name(host, port) 

1361 self.server_type = server_type 

1362 

1363 self.max_connections = max_connections 

1364 self.connection_class = connection_class 

1365 self.connection_kwargs = connection_kwargs 

1366 self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) 

1367 

1368 self._connections: List[Connection] = [] 

1369 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) 

1370 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1371 if self._event_dispatcher is None: 

1372 self._event_dispatcher = EventDispatcher() 

1373 

1374 def __repr__(self) -> str: 

1375 return ( 

1376 f"[host={self.host}, port={self.port}, " 

1377 f"name={self.name}, server_type={self.server_type}]" 

1378 ) 

1379 

1380 def __eq__(self, obj: Any) -> bool: 

1381 return isinstance(obj, ClusterNode) and obj.name == self.name 

1382 

1383 def __hash__(self) -> int: 

1384 return hash(self.name) 

1385 

1386 _DEL_MESSAGE = "Unclosed ClusterNode object" 

1387 

1388 def __del__( 

1389 self, 

1390 _warn: Any = warnings.warn, 

1391 _grl: Any = asyncio.get_running_loop, 

1392 ) -> None: 

1393 for connection in self._connections: 

1394 if connection.is_connected: 

1395 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

1396 

1397 try: 

1398 context = {"client": self, "message": self._DEL_MESSAGE} 

1399 _grl().call_exception_handler(context) 

1400 except RuntimeError: 

1401 pass 

1402 break 

1403 

1404 async def disconnect(self) -> None: 

1405 ret = await asyncio.gather( 

1406 *( 

1407 asyncio.create_task(connection.disconnect()) 

1408 for connection in self._connections 

1409 ), 

1410 return_exceptions=True, 

1411 ) 

1412 exc = next((res for res in ret if isinstance(res, Exception)), None) 

1413 if exc: 

1414 raise exc 

1415 

1416 def acquire_connection(self) -> Connection: 

1417 try: 

1418 return self._free.popleft() 

1419 except IndexError: 

1420 if len(self._connections) < self.max_connections: 

1421 # We are configuring the connection pool not to retry 

1422 # connections on lower level clients to avoid retrying 

1423 # connections to nodes that are not reachable 

1424 # and to avoid blocking the connection pool. 

1425 # The only error that will have some handling in the lower 

1426 # level clients is ConnectionError which will trigger disconnection 

1427 # of the socket. 

1428 # The retries will be handled on cluster client level 

1429 # where we will have proper handling of the cluster topology 

1430 retry = Retry( 

1431 backoff=NoBackoff(), 

1432 retries=0, 

1433 supported_errors=(ConnectionError,), 

1434 ) 

1435 connection_kwargs = self.connection_kwargs.copy() 

1436 connection_kwargs["retry"] = retry 

1437 connection = self.connection_class(**connection_kwargs) 

1438 self._connections.append(connection) 

1439 return connection 

1440 

1441 raise MaxConnectionsError() 

1442 

1443 async def disconnect_if_needed(self, connection: Connection) -> None: 

1444 """ 

1445 Disconnect a connection if it's marked for reconnect. 

1446 This implements lazy disconnection to avoid race conditions. 

1447 The connection will auto-reconnect on next use. 

1448 """ 

1449 if connection.should_reconnect(): 

1450 await connection.disconnect() 

1451 

1452 def release(self, connection: Connection) -> None: 

1453 """ 

1454 Release connection back to free queue. 

1455 If the connection is marked for reconnect, it will be disconnected 

1456 lazily when next acquired via disconnect_if_needed(). 

1457 """ 

1458 self._free.append(connection) 

1459 

1460 def update_active_connections_for_reconnect(self) -> None: 

1461 """ 

1462 Mark all in-use (active) connections for reconnect. 

1463 In-use connections are those in _connections but not currently in _free. 

1464 They will be disconnected when released back to the pool. 

1465 """ 

1466 free_set = set(self._free) 

1467 for connection in self._connections: 

1468 if connection not in free_set: 

1469 connection.mark_for_reconnect() 

1470 

1471 async def disconnect_free_connections(self) -> None: 

1472 """ 

1473 Disconnect all free/idle connections in the pool. 

1474 This is useful after topology changes (e.g., failover) to clear 

1475 stale connection state like READONLY mode. 

1476 The connections remain in the pool and will reconnect on next use. 

1477 """ 

1478 if self._free: 

1479 # Take a snapshot to avoid issues if _free changes during await 

1480 await asyncio.gather( 

1481 *(connection.disconnect() for connection in tuple(self._free)), 

1482 return_exceptions=True, 

1483 ) 

1484 

1485 async def parse_response( 

1486 self, connection: Connection, command: str, **kwargs: Any 

1487 ) -> Any: 

1488 try: 

1489 if NEVER_DECODE in kwargs: 

1490 response = await connection.read_response(disable_decoding=True) 

1491 kwargs.pop(NEVER_DECODE) 

1492 else: 

1493 response = await connection.read_response() 

1494 except ResponseError: 

1495 if EMPTY_RESPONSE in kwargs: 

1496 return kwargs[EMPTY_RESPONSE] 

1497 raise 

1498 

1499 if EMPTY_RESPONSE in kwargs: 

1500 kwargs.pop(EMPTY_RESPONSE) 

1501 

1502 # Remove keys entry, it needs only for cache. 

1503 kwargs.pop("keys", None) 

1504 

1505 # Return response 

1506 if command in self.response_callbacks: 

1507 return self.response_callbacks[command](response, **kwargs) 

1508 

1509 return response 

1510 

1511 async def execute_command(self, *args: Any, **kwargs: Any) -> Any: 

1512 # Acquire connection 

1513 connection = self.acquire_connection() 

1514 # Handle lazy disconnect for connections marked for reconnect 

1515 await self.disconnect_if_needed(connection) 

1516 

1517 # Execute command 

1518 await connection.send_packed_command(connection.pack_command(*args), False) 

1519 

1520 # Read response 

1521 try: 

1522 return await self.parse_response(connection, args[0], **kwargs) 

1523 finally: 

1524 await self.disconnect_if_needed(connection) 

1525 # Release connection 

1526 self._free.append(connection) 

1527 

1528 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: 

1529 # Acquire connection 

1530 connection = self.acquire_connection() 

1531 # Handle lazy disconnect for connections marked for reconnect 

1532 await self.disconnect_if_needed(connection) 

1533 

1534 # Execute command 

1535 await connection.send_packed_command( 

1536 connection.pack_commands(cmd.args for cmd in commands), False 

1537 ) 

1538 

1539 # Read responses 

1540 ret = False 

1541 for cmd in commands: 

1542 try: 

1543 cmd.result = await self.parse_response( 

1544 connection, cmd.args[0], **cmd.kwargs 

1545 ) 

1546 except Exception as e: 

1547 cmd.result = e 

1548 ret = True 

1549 

1550 # Release connection 

1551 await self.disconnect_if_needed(connection) 

1552 self._free.append(connection) 

1553 

1554 return ret 

1555 

1556 async def re_auth_callback(self, token: TokenInterface): 

1557 tmp_queue = collections.deque() 

1558 while self._free: 

1559 conn = self._free.popleft() 

1560 await conn.retry.call_with_retry( 

1561 lambda: conn.send_command( 

1562 "AUTH", token.try_get("oid"), token.get_value() 

1563 ), 

1564 lambda error: self._mock(error), 

1565 ) 

1566 await conn.retry.call_with_retry( 

1567 lambda: conn.read_response(), lambda error: self._mock(error) 

1568 ) 

1569 tmp_queue.append(conn) 

1570 

1571 while tmp_queue: 

1572 conn = tmp_queue.popleft() 

1573 self._free.append(conn) 

1574 

1575 async def _mock(self, error: RedisError): 

1576 """ 

1577 Dummy functions, needs to be passed as error callback to retry object. 

1578 :param error: 

1579 :return: 

1580 """ 

1581 pass 

1582 

1583 

1584class NodesManager: 

1585 __slots__ = ( 

1586 "_dynamic_startup_nodes", 

1587 "_event_dispatcher", 

1588 "_background_tasks", 

1589 "connection_kwargs", 

1590 "default_node", 

1591 "nodes_cache", 

1592 "_epoch", 

1593 "read_load_balancer", 

1594 "_initialize_lock", 

1595 "require_full_coverage", 

1596 "slots_cache", 

1597 "startup_nodes", 

1598 "address_remap", 

1599 ) 

1600 

1601 def __init__( 

1602 self, 

1603 startup_nodes: List["ClusterNode"], 

1604 require_full_coverage: bool, 

1605 connection_kwargs: Dict[str, Any], 

1606 dynamic_startup_nodes: bool = True, 

1607 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

1608 event_dispatcher: Optional[EventDispatcher] = None, 

1609 ) -> None: 

1610 self.startup_nodes = {node.name: node for node in startup_nodes} 

1611 self.require_full_coverage = require_full_coverage 

1612 self.connection_kwargs = connection_kwargs 

1613 self.address_remap = address_remap 

1614 

1615 self.default_node: "ClusterNode" = None 

1616 self.nodes_cache: Dict[str, "ClusterNode"] = {} 

1617 self.slots_cache: Dict[int, List["ClusterNode"]] = {} 

1618 self._epoch: int = 0 

1619 self.read_load_balancer = LoadBalancer() 

1620 self._initialize_lock: asyncio.Lock = asyncio.Lock() 

1621 

1622 self._background_tasks: Set[asyncio.Task] = set() 

1623 self._dynamic_startup_nodes: bool = dynamic_startup_nodes 

1624 if event_dispatcher is None: 

1625 self._event_dispatcher = EventDispatcher() 

1626 else: 

1627 self._event_dispatcher = event_dispatcher 

1628 

1629 def get_node( 

1630 self, 

1631 host: Optional[str] = None, 

1632 port: Optional[int] = None, 

1633 node_name: Optional[str] = None, 

1634 ) -> Optional["ClusterNode"]: 

1635 if host and port: 

1636 # the user passed host and port 

1637 if host == "localhost": 

1638 host = socket.gethostbyname(host) 

1639 return self.nodes_cache.get(get_node_name(host=host, port=port)) 

1640 elif node_name: 

1641 return self.nodes_cache.get(node_name) 

1642 else: 

1643 raise DataError( 

1644 "get_node requires one of the following: 1. node name 2. host and port" 

1645 ) 

1646 

1647 def set_nodes( 

1648 self, 

1649 old: Dict[str, "ClusterNode"], 

1650 new: Dict[str, "ClusterNode"], 

1651 remove_old: bool = False, 

1652 ) -> None: 

1653 if remove_old: 

1654 for name in list(old.keys()): 

1655 if name not in new: 

1656 # Node is removed from cache before disconnect starts, 

1657 # so it won't be found in lookups during disconnect 

1658 # Mark active connections for reconnect so they get disconnected after current command completes 

1659 # and disconnect free connections immediately 

1660 # the node is removed from the cache before the connections changes so it won't be used and should be safe 

1661 # not to wait for the disconnects 

1662 removed_node = old.pop(name) 

1663 removed_node.update_active_connections_for_reconnect() 

1664 task = asyncio.create_task( 

1665 removed_node.disconnect_free_connections() 

1666 ) 

1667 self._background_tasks.add(task) 

1668 task.add_done_callback(self._background_tasks.discard) 

1669 

1670 for name, node in new.items(): 

1671 if name in old: 

1672 # Preserve the existing node but mark connections for reconnect. 

1673 # This method is sync so we can't call disconnect_free_connections() 

1674 # which is async. Instead, we mark free connections for reconnect 

1675 # and they will be lazily disconnected when acquired via 

1676 # disconnect_if_needed() to avoid race conditions. 

1677 # TODO: Make this method async in the next major release to allow 

1678 # immediate disconnection of free connections. 

1679 existing_node = old[name] 

1680 existing_node.update_active_connections_for_reconnect() 

1681 for conn in existing_node._free: 

1682 conn.mark_for_reconnect() 

1683 continue 

1684 # New node is detected and should be added to the pool 

1685 old[name] = node 

1686 

1687 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None: 

1688 """ 

1689 Move a failing node to the end of startup_nodes and nodes_cache so it's 

1690 tried last during reinitialization and when selecting the default node. 

1691 If the node is not in the respective list, nothing is done. 

1692 """ 

1693 # Move in startup_nodes 

1694 if node_name in self.startup_nodes and len(self.startup_nodes) > 1: 

1695 node = self.startup_nodes.pop(node_name) 

1696 self.startup_nodes[node_name] = node # Re-insert at end 

1697 

1698 # Move in nodes_cache - this affects get_nodes_by_server_type ordering 

1699 # which is used to select the default_node during initialize() 

1700 if node_name in self.nodes_cache and len(self.nodes_cache) > 1: 

1701 node = self.nodes_cache.pop(node_name) 

1702 self.nodes_cache[node_name] = node # Re-insert at end 

1703 

1704 def move_slot(self, e: AskError | MovedError): 

1705 redirected_node = self.get_node(host=e.host, port=e.port) 

1706 if redirected_node: 

1707 # The node already exists 

1708 if redirected_node.server_type != PRIMARY: 

1709 # Update the node's server type 

1710 redirected_node.server_type = PRIMARY 

1711 else: 

1712 # This is a new node, we will add it to the nodes cache 

1713 redirected_node = ClusterNode( 

1714 e.host, e.port, PRIMARY, **self.connection_kwargs 

1715 ) 

1716 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) 

1717 slot_nodes = self.slots_cache[e.slot_id] 

1718 if redirected_node not in slot_nodes: 

1719 # The new slot owner is a new server, or a server from a different 

1720 # shard. We need to remove all current nodes from the slot's list 

1721 # (including replications) and add just the new node. 

1722 self.slots_cache[e.slot_id] = [redirected_node] 

1723 elif redirected_node is not slot_nodes[0]: 

1724 # The MOVED error resulted from a failover, and the new slot owner 

1725 # had previously been a replica. 

1726 old_primary = slot_nodes[0] 

1727 # Update the old primary to be a replica and add it to the end of 

1728 # the slot's node list 

1729 old_primary.server_type = REPLICA 

1730 slot_nodes.append(old_primary) 

1731 # Remove the old replica, which is now a primary, from the slot's 

1732 # node list 

1733 slot_nodes.remove(redirected_node) 

1734 # Override the old primary with the new one 

1735 slot_nodes[0] = redirected_node 

1736 if self.default_node == old_primary: 

1737 # Update the default node with the new primary 

1738 self.default_node = redirected_node 

1739 # else: circular MOVED to current primary -> no-op 

1740 

1741 def get_node_from_slot( 

1742 self, 

1743 slot: int, 

1744 read_from_replicas: bool = False, 

1745 load_balancing_strategy=None, 

1746 ) -> "ClusterNode": 

1747 if read_from_replicas is True and load_balancing_strategy is None: 

1748 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN 

1749 

1750 try: 

1751 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: 

1752 # get the server index using the strategy defined in load_balancing_strategy 

1753 primary_name = self.slots_cache[slot][0].name 

1754 node_idx = self.read_load_balancer.get_server_index( 

1755 primary_name, len(self.slots_cache[slot]), load_balancing_strategy 

1756 ) 

1757 return self.slots_cache[slot][node_idx] 

1758 return self.slots_cache[slot][0] 

1759 except (IndexError, TypeError): 

1760 raise SlotNotCoveredError( 

1761 f'Slot "{slot}" not covered by the cluster. ' 

1762 f'"require_full_coverage={self.require_full_coverage}"' 

1763 ) 

1764 

1765 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: 

1766 return [ 

1767 node 

1768 for node in self.nodes_cache.values() 

1769 if node.server_type == server_type 

1770 ] 

1771 

1772 async def initialize(self) -> None: 

1773 self.read_load_balancer.reset() 

1774 tmp_nodes_cache: Dict[str, "ClusterNode"] = {} 

1775 tmp_slots: Dict[int, List["ClusterNode"]] = {} 

1776 disagreements = [] 

1777 startup_nodes_reachable = False 

1778 fully_covered = False 

1779 exception = None 

1780 epoch = self._epoch 

1781 

1782 async with self._initialize_lock: 

1783 if self._epoch != epoch: 

1784 # another initialize call has already reinitialized the 

1785 # nodes since we started waiting for the lock; 

1786 # we don't need to do it again. 

1787 return 

1788 

1789 # Convert to tuple to prevent RuntimeError if self.startup_nodes 

1790 # is modified during iteration 

1791 for startup_node in tuple(self.startup_nodes.values()): 

1792 try: 

1793 # Make sure cluster mode is enabled on this node 

1794 try: 

1795 self._event_dispatcher.dispatch( 

1796 AfterAsyncClusterInstantiationEvent( 

1797 self.nodes_cache, 

1798 self.connection_kwargs.get("credential_provider", None), 

1799 ) 

1800 ) 

1801 cluster_slots = await startup_node.execute_command( 

1802 "CLUSTER SLOTS" 

1803 ) 

1804 except ResponseError: 

1805 raise RedisClusterException( 

1806 "Cluster mode is not enabled on this node" 

1807 ) 

1808 startup_nodes_reachable = True 

1809 except Exception as e: 

1810 # Try the next startup node. 

1811 # The exception is saved and raised only if we have no more nodes. 

1812 exception = e 

1813 continue 

1814 

1815 # CLUSTER SLOTS command results in the following output: 

1816 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] 

1817 # where each node contains the following list: [IP, port, node_id] 

1818 # Therefore, cluster_slots[0][2][0] will be the IP address of the 

1819 # primary node of the first slot section. 

1820 # If there's only one server in the cluster, its ``host`` is '' 

1821 # Fix it to the host in startup_nodes 

1822 if ( 

1823 len(cluster_slots) == 1 

1824 and not cluster_slots[0][2][0] 

1825 and len(self.startup_nodes) == 1 

1826 ): 

1827 cluster_slots[0][2][0] = startup_node.host 

1828 

1829 for slot in cluster_slots: 

1830 for i in range(2, len(slot)): 

1831 slot[i] = [str_if_bytes(val) for val in slot[i]] 

1832 primary_node = slot[2] 

1833 host = primary_node[0] 

1834 if host == "": 

1835 host = startup_node.host 

1836 port = int(primary_node[1]) 

1837 host, port = self.remap_host_port(host, port) 

1838 

1839 nodes_for_slot = [] 

1840 

1841 target_node = tmp_nodes_cache.get(get_node_name(host, port)) 

1842 if not target_node: 

1843 target_node = ClusterNode( 

1844 host, port, PRIMARY, **self.connection_kwargs 

1845 ) 

1846 # add this node to the nodes cache 

1847 tmp_nodes_cache[target_node.name] = target_node 

1848 nodes_for_slot.append(target_node) 

1849 

1850 replica_nodes = slot[3:] 

1851 for replica_node in replica_nodes: 

1852 host = replica_node[0] 

1853 port = replica_node[1] 

1854 host, port = self.remap_host_port(host, port) 

1855 

1856 target_replica_node = tmp_nodes_cache.get( 

1857 get_node_name(host, port) 

1858 ) 

1859 if not target_replica_node: 

1860 target_replica_node = ClusterNode( 

1861 host, port, REPLICA, **self.connection_kwargs 

1862 ) 

1863 # add this node to the nodes cache 

1864 tmp_nodes_cache[target_replica_node.name] = target_replica_node 

1865 nodes_for_slot.append(target_replica_node) 

1866 

1867 for i in range(int(slot[0]), int(slot[1]) + 1): 

1868 if i not in tmp_slots: 

1869 tmp_slots[i] = nodes_for_slot 

1870 else: 

1871 # Validate that 2 nodes want to use the same slot cache 

1872 # setup 

1873 tmp_slot = tmp_slots[i][0] 

1874 if tmp_slot.name != target_node.name: 

1875 disagreements.append( 

1876 f"{tmp_slot.name} vs {target_node.name} on slot: {i}" 

1877 ) 

1878 

1879 if len(disagreements) > 5: 

1880 raise RedisClusterException( 

1881 f"startup_nodes could not agree on a valid " 

1882 f"slots cache: {', '.join(disagreements)}" 

1883 ) 

1884 

1885 # Validate if all slots are covered or if we should try next startup node 

1886 fully_covered = True 

1887 for i in range(REDIS_CLUSTER_HASH_SLOTS): 

1888 if i not in tmp_slots: 

1889 fully_covered = False 

1890 break 

1891 if fully_covered: 

1892 break 

1893 

1894 if not startup_nodes_reachable: 

1895 raise RedisClusterException( 

1896 f"Redis Cluster cannot be connected. Please provide at least " 

1897 f"one reachable node: {str(exception)}" 

1898 ) from exception 

1899 

1900 # Check if the slots are not fully covered 

1901 if not fully_covered and self.require_full_coverage: 

1902 # Despite the requirement that the slots be covered, there 

1903 # isn't a full coverage 

1904 raise RedisClusterException( 

1905 f"All slots are not covered after query all startup_nodes. " 

1906 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " 

1907 f"covered..." 

1908 ) 

1909 

1910 # Set the tmp variables to the real variables 

1911 self.slots_cache = tmp_slots 

1912 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) 

1913 

1914 if self._dynamic_startup_nodes: 

1915 # Populate the startup nodes with all discovered nodes 

1916 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) 

1917 

1918 # Set the default node 

1919 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] 

1920 self._epoch += 1 

1921 

1922 async def aclose(self, attr: str = "nodes_cache") -> None: 

1923 self.default_node = None 

1924 await asyncio.gather( 

1925 *( 

1926 asyncio.create_task(node.disconnect()) 

1927 for node in getattr(self, attr).values() 

1928 ) 

1929 ) 

1930 

1931 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: 

1932 """ 

1933 Remap the host and port returned from the cluster to a different 

1934 internal value. Useful if the client is not connecting directly 

1935 to the cluster. 

1936 """ 

1937 if self.address_remap: 

1938 return self.address_remap((host, port)) 

1939 return host, port 

1940 

1941 

1942class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

1943 """ 

1944 Create a new ClusterPipeline object. 

1945 

1946 Usage:: 

1947 

1948 result = await ( 

1949 rc.pipeline() 

1950 .set("A", 1) 

1951 .get("A") 

1952 .hset("K", "F", "V") 

1953 .hgetall("K") 

1954 .mset_nonatomic({"A": 2, "B": 3}) 

1955 .get("A") 

1956 .get("B") 

1957 .delete("A", "B", "K") 

1958 .execute() 

1959 ) 

1960 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] 

1961 

1962 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which 

1963 are split across multiple nodes, you'll get multiple results for them in the array. 

1964 

1965 Retryable errors: 

1966 - :class:`~.ClusterDownError` 

1967 - :class:`~.ConnectionError` 

1968 - :class:`~.TimeoutError` 

1969 

1970 Redirection errors: 

1971 - :class:`~.TryAgainError` 

1972 - :class:`~.MovedError` 

1973 - :class:`~.AskError` 

1974 

1975 :param client: 

1976 | Existing :class:`~.RedisCluster` client 

1977 """ 

1978 

1979 __slots__ = ("cluster_client", "_transaction", "_execution_strategy") 

1980 

1981 def __init__( 

1982 self, client: RedisCluster, transaction: Optional[bool] = None 

1983 ) -> None: 

1984 self.cluster_client = client 

1985 self._transaction = transaction 

1986 self._execution_strategy: ExecutionStrategy = ( 

1987 PipelineStrategy(self) 

1988 if not self._transaction 

1989 else TransactionStrategy(self) 

1990 ) 

1991 

1992 @property 

1993 def nodes_manager(self) -> "NodesManager": 

1994 """Get the nodes manager from the cluster client.""" 

1995 return self.cluster_client.nodes_manager 

1996 

1997 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

1998 """Set a custom response callback on the cluster client.""" 

1999 self.cluster_client.set_response_callback(command, callback) 

2000 

2001 async def initialize(self) -> "ClusterPipeline": 

2002 await self._execution_strategy.initialize() 

2003 return self 

2004 

2005 async def __aenter__(self) -> "ClusterPipeline": 

2006 return await self.initialize() 

2007 

2008 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: 

2009 await self.reset() 

2010 

2011 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: 

2012 return self.initialize().__await__() 

2013 

2014 def __bool__(self) -> bool: 

2015 "Pipeline instances should always evaluate to True on Python 3+" 

2016 return True 

2017 

2018 def __len__(self) -> int: 

2019 return len(self._execution_strategy) 

2020 

2021 def execute_command( 

2022 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2023 ) -> "ClusterPipeline": 

2024 """ 

2025 Append a raw command to the pipeline. 

2026 

2027 :param args: 

2028 | Raw command args 

2029 :param kwargs: 

2030 

2031 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

2032 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

2033 - Rest of the kwargs are passed to the Redis connection 

2034 """ 

2035 return self._execution_strategy.execute_command(*args, **kwargs) 

2036 

2037 async def execute( 

2038 self, raise_on_error: bool = True, allow_redirections: bool = True 

2039 ) -> List[Any]: 

2040 """ 

2041 Execute the pipeline. 

2042 

2043 It will retry the commands as specified by retries specified in :attr:`retry` 

2044 & then raise an exception. 

2045 

2046 :param raise_on_error: 

2047 | Raise the first error if there are any errors 

2048 :param allow_redirections: 

2049 | Whether to retry each failed command individually in case of redirection 

2050 errors 

2051 

2052 :raises RedisClusterException: if target_nodes is not provided & the command 

2053 can't be mapped to a slot 

2054 """ 

2055 try: 

2056 return await self._execution_strategy.execute( 

2057 raise_on_error, allow_redirections 

2058 ) 

2059 finally: 

2060 await self.reset() 

2061 

2062 def _split_command_across_slots( 

2063 self, command: str, *keys: KeyT 

2064 ) -> "ClusterPipeline": 

2065 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values(): 

2066 self.execute_command(command, *slot_keys) 

2067 

2068 return self 

2069 

2070 async def reset(self): 

2071 """ 

2072 Reset back to empty pipeline. 

2073 """ 

2074 await self._execution_strategy.reset() 

2075 

2076 def multi(self): 

2077 """ 

2078 Start a transactional block of the pipeline after WATCH commands 

2079 are issued. End the transactional block with `execute`. 

2080 """ 

2081 self._execution_strategy.multi() 

2082 

2083 async def discard(self): 

2084 """ """ 

2085 await self._execution_strategy.discard() 

2086 

2087 async def watch(self, *names): 

2088 """Watches the values at keys ``names``""" 

2089 await self._execution_strategy.watch(*names) 

2090 

2091 async def unwatch(self): 

2092 """Unwatches all previously specified keys""" 

2093 await self._execution_strategy.unwatch() 

2094 

2095 async def unlink(self, *names): 

2096 await self._execution_strategy.unlink(*names) 

2097 

2098 def mset_nonatomic( 

2099 self, mapping: Mapping[AnyKeyT, EncodableT] 

2100 ) -> "ClusterPipeline": 

2101 return self._execution_strategy.mset_nonatomic(mapping) 

2102 

2103 

2104for command in PIPELINE_BLOCKED_COMMANDS: 

2105 command = command.replace(" ", "_").lower() 

2106 if command == "mset_nonatomic": 

2107 continue 

2108 

2109 setattr(ClusterPipeline, command, block_pipeline_command(command)) 

2110 

2111 

2112class PipelineCommand: 

2113 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: 

2114 self.args = args 

2115 self.kwargs = kwargs 

2116 self.position = position 

2117 self.result: Union[Any, Exception] = None 

2118 self.command_policies: Optional[CommandPolicies] = None 

2119 

2120 def __repr__(self) -> str: 

2121 return f"[{self.position}] {self.args} ({self.kwargs})" 

2122 

2123 

2124class ExecutionStrategy(ABC): 

2125 @abstractmethod 

2126 async def initialize(self) -> "ClusterPipeline": 

2127 """ 

2128 Initialize the execution strategy. 

2129 

2130 See ClusterPipeline.initialize() 

2131 """ 

2132 pass 

2133 

2134 @abstractmethod 

2135 def execute_command( 

2136 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2137 ) -> "ClusterPipeline": 

2138 """ 

2139 Append a raw command to the pipeline. 

2140 

2141 See ClusterPipeline.execute_command() 

2142 """ 

2143 pass 

2144 

2145 @abstractmethod 

2146 async def execute( 

2147 self, raise_on_error: bool = True, allow_redirections: bool = True 

2148 ) -> List[Any]: 

2149 """ 

2150 Execute the pipeline. 

2151 

2152 It will retry the commands as specified by retries specified in :attr:`retry` 

2153 & then raise an exception. 

2154 

2155 See ClusterPipeline.execute() 

2156 """ 

2157 pass 

2158 

2159 @abstractmethod 

2160 def mset_nonatomic( 

2161 self, mapping: Mapping[AnyKeyT, EncodableT] 

2162 ) -> "ClusterPipeline": 

2163 """ 

2164 Executes multiple MSET commands according to the provided slot/pairs mapping. 

2165 

2166 See ClusterPipeline.mset_nonatomic() 

2167 """ 

2168 pass 

2169 

2170 @abstractmethod 

2171 async def reset(self): 

2172 """ 

2173 Resets current execution strategy. 

2174 

2175 See: ClusterPipeline.reset() 

2176 """ 

2177 pass 

2178 

2179 @abstractmethod 

2180 def multi(self): 

2181 """ 

2182 Starts transactional context. 

2183 

2184 See: ClusterPipeline.multi() 

2185 """ 

2186 pass 

2187 

2188 @abstractmethod 

2189 async def watch(self, *names): 

2190 """ 

2191 Watch given keys. 

2192 

2193 See: ClusterPipeline.watch() 

2194 """ 

2195 pass 

2196 

2197 @abstractmethod 

2198 async def unwatch(self): 

2199 """ 

2200 Unwatches all previously specified keys 

2201 

2202 See: ClusterPipeline.unwatch() 

2203 """ 

2204 pass 

2205 

2206 @abstractmethod 

2207 async def discard(self): 

2208 pass 

2209 

2210 @abstractmethod 

2211 async def unlink(self, *names): 

2212 """ 

2213 "Unlink a key specified by ``names``" 

2214 

2215 See: ClusterPipeline.unlink() 

2216 """ 

2217 pass 

2218 

2219 @abstractmethod 

2220 def __len__(self) -> int: 

2221 pass 

2222 

2223 

2224class AbstractStrategy(ExecutionStrategy): 

2225 def __init__(self, pipe: ClusterPipeline) -> None: 

2226 self._pipe: ClusterPipeline = pipe 

2227 self._command_queue: List["PipelineCommand"] = [] 

2228 

2229 async def initialize(self) -> "ClusterPipeline": 

2230 if self._pipe.cluster_client._initialize: 

2231 await self._pipe.cluster_client.initialize() 

2232 self._command_queue = [] 

2233 return self._pipe 

2234 

2235 def execute_command( 

2236 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2237 ) -> "ClusterPipeline": 

2238 self._command_queue.append( 

2239 PipelineCommand(len(self._command_queue), *args, **kwargs) 

2240 ) 

2241 return self._pipe 

2242 

2243 def _annotate_exception(self, exception, number, command): 

2244 """ 

2245 Provides extra context to the exception prior to it being handled 

2246 """ 

2247 cmd = " ".join(map(safe_str, command)) 

2248 msg = ( 

2249 f"Command # {number} ({truncate_text(cmd)}) of pipeline " 

2250 f"caused error: {exception.args[0]}" 

2251 ) 

2252 exception.args = (msg,) + exception.args[1:] 

2253 

2254 @abstractmethod 

2255 def mset_nonatomic( 

2256 self, mapping: Mapping[AnyKeyT, EncodableT] 

2257 ) -> "ClusterPipeline": 

2258 pass 

2259 

2260 @abstractmethod 

2261 async def execute( 

2262 self, raise_on_error: bool = True, allow_redirections: bool = True 

2263 ) -> List[Any]: 

2264 pass 

2265 

2266 @abstractmethod 

2267 async def reset(self): 

2268 pass 

2269 

2270 @abstractmethod 

2271 def multi(self): 

2272 pass 

2273 

2274 @abstractmethod 

2275 async def watch(self, *names): 

2276 pass 

2277 

2278 @abstractmethod 

2279 async def unwatch(self): 

2280 pass 

2281 

2282 @abstractmethod 

2283 async def discard(self): 

2284 pass 

2285 

2286 @abstractmethod 

2287 async def unlink(self, *names): 

2288 pass 

2289 

2290 def __len__(self) -> int: 

2291 return len(self._command_queue) 

2292 

2293 

2294class PipelineStrategy(AbstractStrategy): 

2295 def __init__(self, pipe: ClusterPipeline) -> None: 

2296 super().__init__(pipe) 

2297 

2298 def mset_nonatomic( 

2299 self, mapping: Mapping[AnyKeyT, EncodableT] 

2300 ) -> "ClusterPipeline": 

2301 encoder = self._pipe.cluster_client.encoder 

2302 

2303 slots_pairs = {} 

2304 for pair in mapping.items(): 

2305 slot = key_slot(encoder.encode(pair[0])) 

2306 slots_pairs.setdefault(slot, []).extend(pair) 

2307 

2308 for pairs in slots_pairs.values(): 

2309 self.execute_command("MSET", *pairs) 

2310 

2311 return self._pipe 

2312 

2313 async def execute( 

2314 self, raise_on_error: bool = True, allow_redirections: bool = True 

2315 ) -> List[Any]: 

2316 if not self._command_queue: 

2317 return [] 

2318 

2319 try: 

2320 retry_attempts = self._pipe.cluster_client.retry.get_retries() 

2321 while True: 

2322 try: 

2323 if self._pipe.cluster_client._initialize: 

2324 await self._pipe.cluster_client.initialize() 

2325 return await self._execute( 

2326 self._pipe.cluster_client, 

2327 self._command_queue, 

2328 raise_on_error=raise_on_error, 

2329 allow_redirections=allow_redirections, 

2330 ) 

2331 

2332 except RedisCluster.ERRORS_ALLOW_RETRY as e: 

2333 if retry_attempts > 0: 

2334 # Try again with the new cluster setup. All other errors 

2335 # should be raised. 

2336 retry_attempts -= 1 

2337 await self._pipe.cluster_client.aclose() 

2338 await asyncio.sleep(0.25) 

2339 else: 

2340 # All other errors should be raised. 

2341 raise e 

2342 finally: 

2343 await self.reset() 

2344 

2345 async def _execute( 

2346 self, 

2347 client: "RedisCluster", 

2348 stack: List["PipelineCommand"], 

2349 raise_on_error: bool = True, 

2350 allow_redirections: bool = True, 

2351 ) -> List[Any]: 

2352 todo = [ 

2353 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception) 

2354 ] 

2355 

2356 nodes = {} 

2357 for cmd in todo: 

2358 passed_targets = cmd.kwargs.pop("target_nodes", None) 

2359 command_policies = await client._policy_resolver.resolve( 

2360 cmd.args[0].lower() 

2361 ) 

2362 

2363 if passed_targets and not client._is_node_flag(passed_targets): 

2364 target_nodes = client._parse_target_nodes(passed_targets) 

2365 

2366 if not command_policies: 

2367 command_policies = CommandPolicies() 

2368 else: 

2369 if not command_policies: 

2370 command_flag = client.command_flags.get(cmd.args[0]) 

2371 if not command_flag: 

2372 # Fallback to default policy 

2373 if not client.get_default_node(): 

2374 slot = None 

2375 else: 

2376 slot = await client._determine_slot(*cmd.args) 

2377 if slot is None: 

2378 command_policies = CommandPolicies() 

2379 else: 

2380 command_policies = CommandPolicies( 

2381 request_policy=RequestPolicy.DEFAULT_KEYED, 

2382 response_policy=ResponsePolicy.DEFAULT_KEYED, 

2383 ) 

2384 else: 

2385 if command_flag in client._command_flags_mapping: 

2386 command_policies = CommandPolicies( 

2387 request_policy=client._command_flags_mapping[ 

2388 command_flag 

2389 ] 

2390 ) 

2391 else: 

2392 command_policies = CommandPolicies() 

2393 

2394 target_nodes = await client._determine_nodes( 

2395 *cmd.args, 

2396 request_policy=command_policies.request_policy, 

2397 node_flag=passed_targets, 

2398 ) 

2399 if not target_nodes: 

2400 raise RedisClusterException( 

2401 f"No targets were found to execute {cmd.args} command on" 

2402 ) 

2403 cmd.command_policies = command_policies 

2404 if len(target_nodes) > 1: 

2405 raise RedisClusterException(f"Too many targets for command {cmd.args}") 

2406 node = target_nodes[0] 

2407 if node.name not in nodes: 

2408 nodes[node.name] = (node, []) 

2409 nodes[node.name][1].append(cmd) 

2410 

2411 # Start timing for observability 

2412 start_time = time.monotonic() 

2413 

2414 errors = await asyncio.gather( 

2415 *( 

2416 asyncio.create_task(node[0].execute_pipeline(node[1])) 

2417 for node in nodes.values() 

2418 ) 

2419 ) 

2420 

2421 # Record operation duration for each node 

2422 for node_name, (node, commands) in nodes.items(): 

2423 # Find the first error in this node's commands, if any 

2424 node_error = None 

2425 for cmd in commands: 

2426 if isinstance(cmd.result, Exception): 

2427 node_error = cmd.result 

2428 break 

2429 

2430 db = node.connection_kwargs.get("db", 0) 

2431 await record_operation_duration( 

2432 command_name="PIPELINE", 

2433 duration_seconds=time.monotonic() - start_time, 

2434 server_address=node.host, 

2435 server_port=node.port, 

2436 db_namespace=str(db) if db is not None else None, 

2437 error=node_error, 

2438 ) 

2439 

2440 if any(errors): 

2441 if allow_redirections: 

2442 # send each errored command individually 

2443 for cmd in todo: 

2444 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): 

2445 try: 

2446 cmd.result = client._policies_callback_mapping[ 

2447 cmd.command_policies.response_policy 

2448 ](await client.execute_command(*cmd.args, **cmd.kwargs)) 

2449 except Exception as e: 

2450 cmd.result = e 

2451 

2452 if raise_on_error: 

2453 for cmd in todo: 

2454 result = cmd.result 

2455 if isinstance(result, Exception): 

2456 command = " ".join(map(safe_str, cmd.args)) 

2457 msg = ( 

2458 f"Command # {cmd.position + 1} " 

2459 f"({truncate_text(command)}) " 

2460 f"of pipeline caused error: {result.args}" 

2461 ) 

2462 result.args = (msg,) + result.args[1:] 

2463 raise result 

2464 

2465 default_cluster_node = client.get_default_node() 

2466 

2467 # Check whether the default node was used. In some cases, 

2468 # 'client.get_default_node()' may return None. The check below 

2469 # prevents a potential AttributeError. 

2470 if default_cluster_node is not None: 

2471 default_node = nodes.get(default_cluster_node.name) 

2472 if default_node is not None: 

2473 # This pipeline execution used the default node, check if we need 

2474 # to replace it. 

2475 # Note: when the error is raised we'll reset the default node in the 

2476 # caller function. 

2477 for cmd in default_node[1]: 

2478 # Check if it has a command that failed with a relevant 

2479 # exception 

2480 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY: 

2481 client.replace_default_node() 

2482 break 

2483 

2484 return [cmd.result for cmd in stack] 

2485 

2486 async def reset(self): 

2487 """ 

2488 Reset back to empty pipeline. 

2489 """ 

2490 self._command_queue = [] 

2491 

2492 def multi(self): 

2493 raise RedisClusterException( 

2494 "method multi() is not supported outside of transactional context" 

2495 ) 

2496 

2497 async def watch(self, *names): 

2498 raise RedisClusterException( 

2499 "method watch() is not supported outside of transactional context" 

2500 ) 

2501 

2502 async def unwatch(self): 

2503 raise RedisClusterException( 

2504 "method unwatch() is not supported outside of transactional context" 

2505 ) 

2506 

2507 async def discard(self): 

2508 raise RedisClusterException( 

2509 "method discard() is not supported outside of transactional context" 

2510 ) 

2511 

2512 async def unlink(self, *names): 

2513 if len(names) != 1: 

2514 raise RedisClusterException( 

2515 "unlinking multiple keys is not implemented in pipeline command" 

2516 ) 

2517 

2518 return self.execute_command("UNLINK", names[0]) 

2519 

2520 

2521class TransactionStrategy(AbstractStrategy): 

2522 NO_SLOTS_COMMANDS = {"UNWATCH"} 

2523 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} 

2524 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} 

2525 SLOT_REDIRECT_ERRORS = (AskError, MovedError) 

2526 CONNECTION_ERRORS = ( 

2527 ConnectionError, 

2528 OSError, 

2529 ClusterDownError, 

2530 SlotNotCoveredError, 

2531 ) 

2532 

2533 def __init__(self, pipe: ClusterPipeline) -> None: 

2534 super().__init__(pipe) 

2535 self._explicit_transaction = False 

2536 self._watching = False 

2537 self._pipeline_slots: Set[int] = set() 

2538 self._transaction_node: Optional[ClusterNode] = None 

2539 self._transaction_connection: Optional[Connection] = None 

2540 self._executing = False 

2541 self._retry = copy(self._pipe.cluster_client.retry) 

2542 self._retry.update_supported_errors( 

2543 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS 

2544 ) 

2545 

2546 def _get_client_and_connection_for_transaction( 

2547 self, 

2548 ) -> Tuple[ClusterNode, Connection]: 

2549 """ 

2550 Find a connection for a pipeline transaction. 

2551 

2552 For running an atomic transaction, watch keys ensure that contents have not been 

2553 altered as long as the watch commands for those keys were sent over the same 

2554 connection. So once we start watching a key, we fetch a connection to the 

2555 node that owns that slot and reuse it. 

2556 """ 

2557 if not self._pipeline_slots: 

2558 raise RedisClusterException( 

2559 "At least a command with a key is needed to identify a node" 

2560 ) 

2561 

2562 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot( 

2563 list(self._pipeline_slots)[0], False 

2564 ) 

2565 self._transaction_node = node 

2566 

2567 if not self._transaction_connection: 

2568 connection: Connection = self._transaction_node.acquire_connection() 

2569 self._transaction_connection = connection 

2570 

2571 return self._transaction_node, self._transaction_connection 

2572 

2573 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any": 

2574 # Given the limitation of ClusterPipeline sync API, we have to run it in thread. 

2575 response = None 

2576 error = None 

2577 

2578 def runner(): 

2579 nonlocal response 

2580 nonlocal error 

2581 try: 

2582 response = asyncio.run(self._execute_command(*args, **kwargs)) 

2583 except Exception as e: 

2584 error = e 

2585 

2586 thread = threading.Thread(target=runner) 

2587 thread.start() 

2588 thread.join() 

2589 

2590 if error: 

2591 raise error 

2592 

2593 return response 

2594 

2595 async def _execute_command( 

2596 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2597 ) -> Any: 

2598 if self._pipe.cluster_client._initialize: 

2599 await self._pipe.cluster_client.initialize() 

2600 

2601 slot_number: Optional[int] = None 

2602 if args[0] not in self.NO_SLOTS_COMMANDS: 

2603 slot_number = await self._pipe.cluster_client._determine_slot(*args) 

2604 

2605 if ( 

2606 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS 

2607 ) and not self._explicit_transaction: 

2608 if args[0] == "WATCH": 

2609 self._validate_watch() 

2610 

2611 if slot_number is not None: 

2612 if self._pipeline_slots and slot_number not in self._pipeline_slots: 

2613 raise CrossSlotTransactionError( 

2614 "Cannot watch or send commands on different slots" 

2615 ) 

2616 

2617 self._pipeline_slots.add(slot_number) 

2618 elif args[0] not in self.NO_SLOTS_COMMANDS: 

2619 raise RedisClusterException( 

2620 f"Cannot identify slot number for command: {args[0]}," 

2621 "it cannot be triggered in a transaction" 

2622 ) 

2623 

2624 return self._immediate_execute_command(*args, **kwargs) 

2625 else: 

2626 if slot_number is not None: 

2627 self._pipeline_slots.add(slot_number) 

2628 

2629 return super().execute_command(*args, **kwargs) 

2630 

2631 def _validate_watch(self): 

2632 if self._explicit_transaction: 

2633 raise RedisError("Cannot issue a WATCH after a MULTI") 

2634 

2635 self._watching = True 

2636 

2637 async def _immediate_execute_command(self, *args, **options): 

2638 return await self._retry.call_with_retry( 

2639 lambda: self._get_connection_and_send_command(*args, **options), 

2640 self._reinitialize_on_error, 

2641 with_failure_count=True, 

2642 ) 

2643 

2644 async def _get_connection_and_send_command(self, *args, **options): 

2645 redis_node, connection = self._get_client_and_connection_for_transaction() 

2646 # Only disconnect if not watching - disconnecting would lose WATCH state 

2647 if not self._watching: 

2648 await redis_node.disconnect_if_needed(connection) 

2649 

2650 # Start timing for observability 

2651 start_time = time.monotonic() 

2652 

2653 try: 

2654 response = await self._send_command_parse_response( 

2655 connection, redis_node, args[0], *args, **options 

2656 ) 

2657 

2658 await record_operation_duration( 

2659 command_name=args[0], 

2660 duration_seconds=time.monotonic() - start_time, 

2661 server_address=connection.host, 

2662 server_port=connection.port, 

2663 db_namespace=str(connection.db), 

2664 ) 

2665 

2666 return response 

2667 except Exception as e: 

2668 e.connection = connection 

2669 await record_operation_duration( 

2670 command_name=args[0], 

2671 duration_seconds=time.monotonic() - start_time, 

2672 server_address=connection.host, 

2673 server_port=connection.port, 

2674 db_namespace=str(connection.db), 

2675 error=e, 

2676 ) 

2677 raise 

2678 

2679 async def _send_command_parse_response( 

2680 self, 

2681 connection: Connection, 

2682 redis_node: ClusterNode, 

2683 command_name, 

2684 *args, 

2685 **options, 

2686 ): 

2687 """ 

2688 Send a command and parse the response 

2689 """ 

2690 

2691 await connection.send_command(*args) 

2692 output = await redis_node.parse_response(connection, command_name, **options) 

2693 

2694 if command_name in self.UNWATCH_COMMANDS: 

2695 self._watching = False 

2696 return output 

2697 

2698 async def _reinitialize_on_error(self, error, failure_count): 

2699 if hasattr(error, "connection"): 

2700 await record_error_count( 

2701 server_address=error.connection.host, 

2702 server_port=error.connection.port, 

2703 network_peer_address=error.connection.host, 

2704 network_peer_port=error.connection.port, 

2705 error_type=error, 

2706 retry_attempts=failure_count, 

2707 is_internal=True, 

2708 ) 

2709 

2710 if self._watching: 

2711 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: 

2712 raise WatchError("Slot rebalancing occurred while watching keys") 

2713 

2714 if ( 

2715 type(error) in self.SLOT_REDIRECT_ERRORS 

2716 or type(error) in self.CONNECTION_ERRORS 

2717 ): 

2718 if self._transaction_connection and self._transaction_node: 

2719 # Disconnect and release back to pool 

2720 await self._transaction_connection.disconnect() 

2721 self._transaction_node.release(self._transaction_connection) 

2722 self._transaction_connection = None 

2723 

2724 self._pipe.cluster_client.reinitialize_counter += 1 

2725 if ( 

2726 self._pipe.cluster_client.reinitialize_steps 

2727 and self._pipe.cluster_client.reinitialize_counter 

2728 % self._pipe.cluster_client.reinitialize_steps 

2729 == 0 

2730 ): 

2731 await self._pipe.cluster_client.nodes_manager.initialize() 

2732 self.reinitialize_counter = 0 

2733 else: 

2734 if isinstance(error, AskError): 

2735 self._pipe.cluster_client.nodes_manager.move_slot(error) 

2736 

2737 self._executing = False 

2738 

2739 async def _raise_first_error(self, responses, stack, start_time): 

2740 """ 

2741 Raise the first exception on the stack 

2742 """ 

2743 for r, cmd in zip(responses, stack): 

2744 if isinstance(r, Exception): 

2745 self._annotate_exception(r, cmd.position + 1, cmd.args) 

2746 

2747 await record_operation_duration( 

2748 command_name="TRANSACTION", 

2749 duration_seconds=time.monotonic() - start_time, 

2750 server_address=self._transaction_connection.host, 

2751 server_port=self._transaction_connection.port, 

2752 db_namespace=str(self._transaction_connection.db), 

2753 error=r, 

2754 ) 

2755 

2756 raise r 

2757 

2758 def mset_nonatomic( 

2759 self, mapping: Mapping[AnyKeyT, EncodableT] 

2760 ) -> "ClusterPipeline": 

2761 raise NotImplementedError("Method is not supported in transactional context.") 

2762 

2763 async def execute( 

2764 self, raise_on_error: bool = True, allow_redirections: bool = True 

2765 ) -> List[Any]: 

2766 stack = self._command_queue 

2767 if not stack and (not self._watching or not self._pipeline_slots): 

2768 return [] 

2769 

2770 return await self._execute_transaction_with_retries(stack, raise_on_error) 

2771 

2772 async def _execute_transaction_with_retries( 

2773 self, stack: List["PipelineCommand"], raise_on_error: bool 

2774 ): 

2775 return await self._retry.call_with_retry( 

2776 lambda: self._execute_transaction(stack, raise_on_error), 

2777 lambda error, failure_count: self._reinitialize_on_error( 

2778 error, failure_count 

2779 ), 

2780 with_failure_count=True, 

2781 ) 

2782 

2783 async def _execute_transaction( 

2784 self, stack: List["PipelineCommand"], raise_on_error: bool 

2785 ): 

2786 if len(self._pipeline_slots) > 1: 

2787 raise CrossSlotTransactionError( 

2788 "All keys involved in a cluster transaction must map to the same slot" 

2789 ) 

2790 

2791 self._executing = True 

2792 

2793 redis_node, connection = self._get_client_and_connection_for_transaction() 

2794 # Only disconnect if not watching - disconnecting would lose WATCH state 

2795 if not self._watching: 

2796 await redis_node.disconnect_if_needed(connection) 

2797 

2798 stack = chain( 

2799 [PipelineCommand(0, "MULTI")], 

2800 stack, 

2801 [PipelineCommand(0, "EXEC")], 

2802 ) 

2803 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs] 

2804 packed_commands = connection.pack_commands(commands) 

2805 

2806 # Start timing for observability 

2807 start_time = time.monotonic() 

2808 

2809 await connection.send_packed_command(packed_commands) 

2810 errors = [] 

2811 

2812 # parse off the response for MULTI 

2813 # NOTE: we need to handle ResponseErrors here and continue 

2814 # so that we read all the additional command messages from 

2815 # the socket 

2816 try: 

2817 await redis_node.parse_response(connection, "MULTI") 

2818 except ResponseError as e: 

2819 self._annotate_exception(e, 0, "MULTI") 

2820 errors.append(e) 

2821 except self.CONNECTION_ERRORS as cluster_error: 

2822 self._annotate_exception(cluster_error, 0, "MULTI") 

2823 cluster_error.connection = connection 

2824 raise 

2825 

2826 # and all the other commands 

2827 for i, command in enumerate(self._command_queue): 

2828 if EMPTY_RESPONSE in command.kwargs: 

2829 errors.append((i, command.kwargs[EMPTY_RESPONSE])) 

2830 else: 

2831 try: 

2832 _ = await redis_node.parse_response(connection, "_") 

2833 except self.SLOT_REDIRECT_ERRORS as slot_error: 

2834 self._annotate_exception(slot_error, i + 1, command.args) 

2835 errors.append(slot_error) 

2836 except self.CONNECTION_ERRORS as cluster_error: 

2837 self._annotate_exception(cluster_error, i + 1, command.args) 

2838 cluster_error.connection = connection 

2839 raise 

2840 except ResponseError as e: 

2841 self._annotate_exception(e, i + 1, command.args) 

2842 errors.append(e) 

2843 

2844 response = None 

2845 # parse the EXEC. 

2846 try: 

2847 response = await redis_node.parse_response(connection, "EXEC") 

2848 except ExecAbortError: 

2849 if errors: 

2850 raise errors[0] 

2851 raise 

2852 

2853 self._executing = False 

2854 

2855 # EXEC clears any watched keys 

2856 self._watching = False 

2857 

2858 if response is None: 

2859 raise WatchError("Watched variable changed.") 

2860 

2861 # put any parse errors into the response 

2862 for i, e in errors: 

2863 response.insert(i, e) 

2864 

2865 if len(response) != len(self._command_queue): 

2866 raise InvalidPipelineStack( 

2867 "Unexpected response length for cluster pipeline EXEC." 

2868 " Command stack was {} but response had length {}".format( 

2869 [c.args[0] for c in self._command_queue], len(response) 

2870 ) 

2871 ) 

2872 

2873 # find any errors in the response and raise if necessary 

2874 if raise_on_error or len(errors) > 0: 

2875 await self._raise_first_error( 

2876 response, 

2877 self._command_queue, 

2878 start_time, 

2879 ) 

2880 

2881 # We have to run response callbacks manually 

2882 data = [] 

2883 for r, cmd in zip(response, self._command_queue): 

2884 if not isinstance(r, Exception): 

2885 command_name = cmd.args[0] 

2886 if command_name in self._pipe.cluster_client.response_callbacks: 

2887 r = self._pipe.cluster_client.response_callbacks[command_name]( 

2888 r, **cmd.kwargs 

2889 ) 

2890 data.append(r) 

2891 

2892 await record_operation_duration( 

2893 command_name="TRANSACTION", 

2894 duration_seconds=time.monotonic() - start_time, 

2895 server_address=connection.host, 

2896 server_port=connection.port, 

2897 db_namespace=str(connection.db), 

2898 ) 

2899 

2900 return data 

2901 

2902 async def reset(self): 

2903 self._command_queue = [] 

2904 

2905 # make sure to reset the connection state in the event that we were 

2906 # watching something 

2907 if self._transaction_connection: 

2908 try: 

2909 if self._watching: 

2910 # call this manually since our unwatch or 

2911 # immediate_execute_command methods can call reset() 

2912 await self._transaction_connection.send_command("UNWATCH") 

2913 await self._transaction_connection.read_response() 

2914 # we can safely return the connection to the pool here since we're 

2915 # sure we're no longer WATCHing anything 

2916 self._transaction_node.release(self._transaction_connection) 

2917 self._transaction_connection = None 

2918 except self.CONNECTION_ERRORS: 

2919 # disconnect will also remove any previous WATCHes 

2920 if self._transaction_connection and self._transaction_node: 

2921 await self._transaction_connection.disconnect() 

2922 self._transaction_node.release(self._transaction_connection) 

2923 self._transaction_connection = None 

2924 

2925 # clean up the other instance attributes 

2926 self._transaction_node = None 

2927 self._watching = False 

2928 self._explicit_transaction = False 

2929 self._pipeline_slots = set() 

2930 self._executing = False 

2931 

2932 def multi(self): 

2933 if self._explicit_transaction: 

2934 raise RedisError("Cannot issue nested calls to MULTI") 

2935 if self._command_queue: 

2936 raise RedisError( 

2937 "Commands without an initial WATCH have already been issued" 

2938 ) 

2939 self._explicit_transaction = True 

2940 

2941 async def watch(self, *names): 

2942 if self._explicit_transaction: 

2943 raise RedisError("Cannot issue a WATCH after a MULTI") 

2944 

2945 return await self.execute_command("WATCH", *names) 

2946 

2947 async def unwatch(self): 

2948 if self._watching: 

2949 return await self.execute_command("UNWATCH") 

2950 

2951 return True 

2952 

2953 async def discard(self): 

2954 await self.reset() 

2955 

2956 async def unlink(self, *names): 

2957 return self.execute_command("UNLINK", *names)