Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/cluster.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1050 statements  

1import asyncio 

2import collections 

3import random 

4import socket 

5import threading 

6import time 

7import warnings 

8from abc import ABC, abstractmethod 

9from copy import copy 

10from itertools import chain 

11from typing import ( 

12 Any, 

13 Callable, 

14 Coroutine, 

15 Deque, 

16 Dict, 

17 Generator, 

18 List, 

19 Mapping, 

20 Optional, 

21 Set, 

22 Tuple, 

23 Type, 

24 TypeVar, 

25 Union, 

26) 

27 

28from redis._parsers import AsyncCommandsParser, Encoder 

29from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy 

30from redis._parsers.helpers import ( 

31 _RedisCallbacks, 

32 _RedisCallbacksRESP2, 

33 _RedisCallbacksRESP3, 

34) 

35from redis.asyncio.client import ResponseCallbackT 

36from redis.asyncio.connection import Connection, SSLConnection, parse_url 

37from redis.asyncio.lock import Lock 

38from redis.asyncio.retry import Retry 

39from redis.auth.token import TokenInterface 

40from redis.backoff import ExponentialWithJitterBackoff, NoBackoff 

41from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis 

42from redis.cluster import ( 

43 PIPELINE_BLOCKED_COMMANDS, 

44 PRIMARY, 

45 REPLICA, 

46 SLOT_ID, 

47 AbstractRedisCluster, 

48 LoadBalancer, 

49 LoadBalancingStrategy, 

50 block_pipeline_command, 

51 get_node_name, 

52 parse_cluster_slots, 

53) 

54from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands 

55from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver 

56from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot 

57from redis.credentials import CredentialProvider 

58from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher 

59from redis.exceptions import ( 

60 AskError, 

61 BusyLoadingError, 

62 ClusterDownError, 

63 ClusterError, 

64 ConnectionError, 

65 CrossSlotTransactionError, 

66 DataError, 

67 ExecAbortError, 

68 InvalidPipelineStack, 

69 MaxConnectionsError, 

70 MovedError, 

71 RedisClusterException, 

72 RedisError, 

73 ResponseError, 

74 SlotNotCoveredError, 

75 TimeoutError, 

76 TryAgainError, 

77 WatchError, 

78) 

79from redis.typing import AnyKeyT, EncodableT, KeyT 

80from redis.utils import ( 

81 SSL_AVAILABLE, 

82 deprecated_args, 

83 deprecated_function, 

84 get_lib_version, 

85 safe_str, 

86 str_if_bytes, 

87 truncate_text, 

88) 

89 

90if SSL_AVAILABLE: 

91 from ssl import TLSVersion, VerifyFlags, VerifyMode 

92else: 

93 TLSVersion = None 

94 VerifyMode = None 

95 VerifyFlags = None 

96 

97TargetNodesT = TypeVar( 

98 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] 

99) 

100 

101 

102class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

103 """ 

104 Create a new RedisCluster client. 

105 

106 Pass one of parameters: 

107 

108 - `host` & `port` 

109 - `startup_nodes` 

110 

111 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. 

112 | Use ``await`` :meth:`close` to disconnect connections & close client. 

113 

114 Many commands support the target_nodes kwarg. It can be one of the 

115 :attr:`NODE_FLAGS`: 

116 

117 - :attr:`PRIMARIES` 

118 - :attr:`REPLICAS` 

119 - :attr:`ALL_NODES` 

120 - :attr:`RANDOM` 

121 - :attr:`DEFAULT_NODE` 

122 

123 Note: This client is not thread/process/fork safe. 

124 

125 :param host: 

126 | Can be used to point to a startup node 

127 :param port: 

128 | Port used if **host** is provided 

129 :param startup_nodes: 

130 | :class:`~.ClusterNode` to used as a startup node 

131 :param require_full_coverage: 

132 | When set to ``False``: the client will not require a full coverage of 

133 the slots. However, if not all slots are covered, and at least one node 

134 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw 

135 a :class:`~.ClusterDownError` for some key-based commands. 

136 | When set to ``True``: all slots must be covered to construct the cluster 

137 client. If not all slots are covered, :class:`~.RedisClusterException` will be 

138 thrown. 

139 | See: 

140 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters 

141 :param read_from_replicas: 

142 | @deprecated - please use load_balancing_strategy instead 

143 | Enable read from replicas in READONLY mode. 

144 When set to true, read commands will be assigned between the primary and 

145 its replications in a Round-Robin manner. 

146 The data read from replicas is eventually consistent with the data in primary nodes. 

147 :param load_balancing_strategy: 

148 | Enable read from replicas in READONLY mode and defines the load balancing 

149 strategy that will be used for cluster node selection. 

150 The data read from replicas is eventually consistent with the data in primary nodes. 

151 :param dynamic_startup_nodes: 

152 | Set the RedisCluster's startup nodes to all the discovered nodes. 

153 If true (default value), the cluster's discovered nodes will be used to 

154 determine the cluster nodes-slots mapping in the next topology refresh. 

155 It will remove the initial passed startup nodes if their endpoints aren't 

156 listed in the CLUSTER SLOTS output. 

157 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists 

158 specific IP addresses, it is best to set it to false. 

159 :param reinitialize_steps: 

160 | Specifies the number of MOVED errors that need to occur before reinitializing 

161 the whole cluster topology. If a MOVED error occurs and the cluster does not 

162 need to be reinitialized on this current error handling, only the MOVED slot 

163 will be patched with the redirected node. 

164 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. 

165 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 

166 0. 

167 :param cluster_error_retry_attempts: 

168 | @deprecated - Please configure the 'retry' object instead 

169 In case 'retry' object is set - this argument is ignored! 

170 

171 Number of times to retry before raising an error when :class:`~.TimeoutError`, 

172 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` 

173 or :class:`~.ClusterDownError` are encountered 

174 :param retry: 

175 | A retry object that defines the retry strategy and the number of 

176 retries for the cluster client. 

177 In current implementation for the cluster client (starting form redis-py version 6.0.0) 

178 the retry object is not yet fully utilized, instead it is used just to determine 

179 the number of retries for the cluster client. 

180 In the future releases the retry object will be used to handle the cluster client retries! 

181 :param max_connections: 

182 | Maximum number of connections per node. If there are no free connections & the 

183 maximum number of connections are already created, a 

184 :class:`~.MaxConnectionsError` is raised. 

185 :param address_remap: 

186 | An optional callable which, when provided with an internal network 

187 address of a node, e.g. a `(host, port)` tuple, will return the address 

188 where the node is reachable. This can be used to map the addresses at 

189 which the nodes _think_ they are, to addresses at which a client may 

190 reach them, such as when they sit behind a proxy. 

191 

192 | Rest of the arguments will be passed to the 

193 :class:`~redis.asyncio.connection.Connection` instances when created 

194 

195 :raises RedisClusterException: 

196 if any arguments are invalid or unknown. Eg: 

197 

198 - `db` != 0 or None 

199 - `path` argument for unix socket connection 

200 - none of the `host`/`port` & `startup_nodes` were provided 

201 

202 """ 

203 

204 @classmethod 

205 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": 

206 """ 

207 Return a Redis client object configured from the given URL. 

208 

209 For example:: 

210 

211 redis://[[username]:[password]]@localhost:6379/0 

212 rediss://[[username]:[password]]@localhost:6379/0 

213 

214 Three URL schemes are supported: 

215 

216 - `redis://` creates a TCP socket connection. See more at: 

217 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

218 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

219 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

220 

221 The username, password, hostname, path and all querystring values are passed 

222 through ``urllib.parse.unquote`` in order to replace any percent-encoded values 

223 with their corresponding characters. 

224 

225 All querystring options are cast to their appropriate Python types. Boolean 

226 arguments can be specified with string values "True"/"False" or "Yes"/"No". 

227 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once 

228 parsed, the querystring arguments and keyword arguments are passed to 

229 :class:`~redis.asyncio.connection.Connection` when created. 

230 In the case of conflicting arguments, querystring arguments are used. 

231 """ 

232 kwargs.update(parse_url(url)) 

233 if kwargs.pop("connection_class", None) is SSLConnection: 

234 kwargs["ssl"] = True 

235 return cls(**kwargs) 

236 

237 __slots__ = ( 

238 "_initialize", 

239 "_lock", 

240 "retry", 

241 "command_flags", 

242 "commands_parser", 

243 "connection_kwargs", 

244 "encoder", 

245 "node_flags", 

246 "nodes_manager", 

247 "read_from_replicas", 

248 "reinitialize_counter", 

249 "reinitialize_steps", 

250 "response_callbacks", 

251 "result_callbacks", 

252 ) 

253 

254 @deprecated_args( 

255 args_to_warn=["read_from_replicas"], 

256 reason="Please configure the 'load_balancing_strategy' instead", 

257 version="5.3.0", 

258 ) 

259 @deprecated_args( 

260 args_to_warn=[ 

261 "cluster_error_retry_attempts", 

262 ], 

263 reason="Please configure the 'retry' object instead", 

264 version="6.0.0", 

265 ) 

266 def __init__( 

267 self, 

268 host: Optional[str] = None, 

269 port: Union[str, int] = 6379, 

270 # Cluster related kwargs 

271 startup_nodes: Optional[List["ClusterNode"]] = None, 

272 require_full_coverage: bool = True, 

273 read_from_replicas: bool = False, 

274 load_balancing_strategy: Optional[LoadBalancingStrategy] = None, 

275 dynamic_startup_nodes: bool = True, 

276 reinitialize_steps: int = 5, 

277 cluster_error_retry_attempts: int = 3, 

278 max_connections: int = 2**31, 

279 retry: Optional["Retry"] = None, 

280 retry_on_error: Optional[List[Type[Exception]]] = None, 

281 # Client related kwargs 

282 db: Union[str, int] = 0, 

283 path: Optional[str] = None, 

284 credential_provider: Optional[CredentialProvider] = None, 

285 username: Optional[str] = None, 

286 password: Optional[str] = None, 

287 client_name: Optional[str] = None, 

288 lib_name: Optional[str] = "redis-py", 

289 lib_version: Optional[str] = get_lib_version(), 

290 # Encoding related kwargs 

291 encoding: str = "utf-8", 

292 encoding_errors: str = "strict", 

293 decode_responses: bool = False, 

294 # Connection related kwargs 

295 health_check_interval: float = 0, 

296 socket_connect_timeout: Optional[float] = None, 

297 socket_keepalive: bool = False, 

298 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

299 socket_timeout: Optional[float] = None, 

300 # SSL related kwargs 

301 ssl: bool = False, 

302 ssl_ca_certs: Optional[str] = None, 

303 ssl_ca_data: Optional[str] = None, 

304 ssl_cert_reqs: Union[str, VerifyMode] = "required", 

305 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, 

306 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, 

307 ssl_certfile: Optional[str] = None, 

308 ssl_check_hostname: bool = True, 

309 ssl_keyfile: Optional[str] = None, 

310 ssl_min_version: Optional[TLSVersion] = None, 

311 ssl_ciphers: Optional[str] = None, 

312 protocol: Optional[int] = 2, 

313 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

314 event_dispatcher: Optional[EventDispatcher] = None, 

315 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(), 

316 ) -> None: 

317 if db: 

318 raise RedisClusterException( 

319 "Argument 'db' must be 0 or None in cluster mode" 

320 ) 

321 

322 if path: 

323 raise RedisClusterException( 

324 "Unix domain socket is not supported in cluster mode" 

325 ) 

326 

327 if (not host or not port) and not startup_nodes: 

328 raise RedisClusterException( 

329 "RedisCluster requires at least one node to discover the cluster.\n" 

330 "Please provide one of the following or use RedisCluster.from_url:\n" 

331 ' - host and port: RedisCluster(host="localhost", port=6379)\n' 

332 " - startup_nodes: RedisCluster(startup_nodes=[" 

333 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' 

334 ) 

335 

336 kwargs: Dict[str, Any] = { 

337 "max_connections": max_connections, 

338 "connection_class": Connection, 

339 # Client related kwargs 

340 "credential_provider": credential_provider, 

341 "username": username, 

342 "password": password, 

343 "client_name": client_name, 

344 "lib_name": lib_name, 

345 "lib_version": lib_version, 

346 # Encoding related kwargs 

347 "encoding": encoding, 

348 "encoding_errors": encoding_errors, 

349 "decode_responses": decode_responses, 

350 # Connection related kwargs 

351 "health_check_interval": health_check_interval, 

352 "socket_connect_timeout": socket_connect_timeout, 

353 "socket_keepalive": socket_keepalive, 

354 "socket_keepalive_options": socket_keepalive_options, 

355 "socket_timeout": socket_timeout, 

356 "protocol": protocol, 

357 } 

358 

359 if ssl: 

360 # SSL related kwargs 

361 kwargs.update( 

362 { 

363 "connection_class": SSLConnection, 

364 "ssl_ca_certs": ssl_ca_certs, 

365 "ssl_ca_data": ssl_ca_data, 

366 "ssl_cert_reqs": ssl_cert_reqs, 

367 "ssl_include_verify_flags": ssl_include_verify_flags, 

368 "ssl_exclude_verify_flags": ssl_exclude_verify_flags, 

369 "ssl_certfile": ssl_certfile, 

370 "ssl_check_hostname": ssl_check_hostname, 

371 "ssl_keyfile": ssl_keyfile, 

372 "ssl_min_version": ssl_min_version, 

373 "ssl_ciphers": ssl_ciphers, 

374 } 

375 ) 

376 

377 if read_from_replicas or load_balancing_strategy: 

378 # Call our on_connect function to configure READONLY mode 

379 kwargs["redis_connect_func"] = self.on_connect 

380 

381 if retry: 

382 self.retry = retry 

383 else: 

384 self.retry = Retry( 

385 backoff=ExponentialWithJitterBackoff(base=1, cap=10), 

386 retries=cluster_error_retry_attempts, 

387 ) 

388 if retry_on_error: 

389 self.retry.update_supported_errors(retry_on_error) 

390 

391 kwargs["response_callbacks"] = _RedisCallbacks.copy() 

392 if kwargs.get("protocol") in ["3", 3]: 

393 kwargs["response_callbacks"].update(_RedisCallbacksRESP3) 

394 else: 

395 kwargs["response_callbacks"].update(_RedisCallbacksRESP2) 

396 self.connection_kwargs = kwargs 

397 

398 if startup_nodes: 

399 passed_nodes = [] 

400 for node in startup_nodes: 

401 passed_nodes.append( 

402 ClusterNode(node.host, node.port, **self.connection_kwargs) 

403 ) 

404 startup_nodes = passed_nodes 

405 else: 

406 startup_nodes = [] 

407 if host and port: 

408 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) 

409 

410 if event_dispatcher is None: 

411 self._event_dispatcher = EventDispatcher() 

412 else: 

413 self._event_dispatcher = event_dispatcher 

414 

415 self.startup_nodes = startup_nodes 

416 self.nodes_manager = NodesManager( 

417 startup_nodes, 

418 require_full_coverage, 

419 kwargs, 

420 dynamic_startup_nodes=dynamic_startup_nodes, 

421 address_remap=address_remap, 

422 event_dispatcher=self._event_dispatcher, 

423 ) 

424 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

425 self.read_from_replicas = read_from_replicas 

426 self.load_balancing_strategy = load_balancing_strategy 

427 self.reinitialize_steps = reinitialize_steps 

428 self.reinitialize_counter = 0 

429 

430 # For backward compatibility, mapping from existing policies to new one 

431 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { 

432 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, 

433 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, 

434 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, 

435 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, 

436 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, 

437 SLOT_ID: RequestPolicy.DEFAULT_KEYED, 

438 } 

439 

440 self._policies_callback_mapping: dict[ 

441 Union[RequestPolicy, ResponsePolicy], Callable 

442 ] = { 

443 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ 

444 self.get_random_primary_or_all_nodes(command_name) 

445 ], 

446 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, 

447 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], 

448 RequestPolicy.ALL_SHARDS: self.get_primaries, 

449 RequestPolicy.ALL_NODES: self.get_nodes, 

450 RequestPolicy.ALL_REPLICAS: self.get_replicas, 

451 RequestPolicy.SPECIAL: self.get_special_nodes, 

452 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, 

453 ResponsePolicy.DEFAULT_KEYED: lambda res: res, 

454 } 

455 

456 self._policy_resolver = policy_resolver 

457 self.commands_parser = AsyncCommandsParser() 

458 self._aggregate_nodes = None 

459 self.node_flags = self.__class__.NODE_FLAGS.copy() 

460 self.command_flags = self.__class__.COMMAND_FLAGS.copy() 

461 self.response_callbacks = kwargs["response_callbacks"] 

462 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() 

463 self.result_callbacks["CLUSTER SLOTS"] = ( 

464 lambda cmd, res, **kwargs: parse_cluster_slots( 

465 list(res.values())[0], **kwargs 

466 ) 

467 ) 

468 

469 self._initialize = True 

470 self._lock: Optional[asyncio.Lock] = None 

471 

472 # When used as an async context manager, we need to increment and decrement 

473 # a usage counter so that we can close the connection pool when no one is 

474 # using the client. 

475 self._usage_counter = 0 

476 self._usage_lock = asyncio.Lock() 

477 

478 async def initialize(self) -> "RedisCluster": 

479 """Get all nodes from startup nodes & creates connections if not initialized.""" 

480 if self._initialize: 

481 if not self._lock: 

482 self._lock = asyncio.Lock() 

483 async with self._lock: 

484 if self._initialize: 

485 try: 

486 await self.nodes_manager.initialize() 

487 await self.commands_parser.initialize( 

488 self.nodes_manager.default_node 

489 ) 

490 self._initialize = False 

491 except BaseException: 

492 await self.nodes_manager.aclose() 

493 await self.nodes_manager.aclose("startup_nodes") 

494 raise 

495 return self 

496 

497 async def aclose(self) -> None: 

498 """Close all connections & client if initialized.""" 

499 if not self._initialize: 

500 if not self._lock: 

501 self._lock = asyncio.Lock() 

502 async with self._lock: 

503 if not self._initialize: 

504 self._initialize = True 

505 await self.nodes_manager.aclose() 

506 await self.nodes_manager.aclose("startup_nodes") 

507 

508 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") 

509 async def close(self) -> None: 

510 """alias for aclose() for backwards compatibility""" 

511 await self.aclose() 

512 

513 async def __aenter__(self) -> "RedisCluster": 

514 """ 

515 Async context manager entry. Increments a usage counter so that the 

516 connection pool is only closed (via aclose()) when no context is using 

517 the client. 

518 """ 

519 await self._increment_usage() 

520 try: 

521 # Initialize the client (i.e. establish connection, etc.) 

522 return await self.initialize() 

523 except Exception: 

524 # If initialization fails, decrement the counter to keep it in sync 

525 await self._decrement_usage() 

526 raise 

527 

528 async def _increment_usage(self) -> int: 

529 """ 

530 Helper coroutine to increment the usage counter while holding the lock. 

531 Returns the new value of the usage counter. 

532 """ 

533 async with self._usage_lock: 

534 self._usage_counter += 1 

535 return self._usage_counter 

536 

537 async def _decrement_usage(self) -> int: 

538 """ 

539 Helper coroutine to decrement the usage counter while holding the lock. 

540 Returns the new value of the usage counter. 

541 """ 

542 async with self._usage_lock: 

543 self._usage_counter -= 1 

544 return self._usage_counter 

545 

546 async def __aexit__(self, exc_type, exc_value, traceback): 

547 """ 

548 Async context manager exit. Decrements a usage counter. If this is the 

549 last exit (counter becomes zero), the client closes its connection pool. 

550 """ 

551 current_usage = await asyncio.shield(self._decrement_usage()) 

552 if current_usage == 0: 

553 # This was the last active context, so disconnect the pool. 

554 await asyncio.shield(self.aclose()) 

555 

556 def __await__(self) -> Generator[Any, None, "RedisCluster"]: 

557 return self.initialize().__await__() 

558 

559 _DEL_MESSAGE = "Unclosed RedisCluster client" 

560 

561 def __del__( 

562 self, 

563 _warn: Any = warnings.warn, 

564 _grl: Any = asyncio.get_running_loop, 

565 ) -> None: 

566 if hasattr(self, "_initialize") and not self._initialize: 

567 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

568 try: 

569 context = {"client": self, "message": self._DEL_MESSAGE} 

570 _grl().call_exception_handler(context) 

571 except RuntimeError: 

572 pass 

573 

574 async def on_connect(self, connection: Connection) -> None: 

575 await connection.on_connect() 

576 

577 # Sending READONLY command to server to configure connection as 

578 # readonly. Since each cluster node may change its server type due 

579 # to a failover, we should establish a READONLY connection 

580 # regardless of the server type. If this is a primary connection, 

581 # READONLY would not affect executing write commands. 

582 await connection.send_command("READONLY") 

583 if str_if_bytes(await connection.read_response()) != "OK": 

584 raise ConnectionError("READONLY command failed") 

585 

586 def get_nodes(self) -> List["ClusterNode"]: 

587 """Get all nodes of the cluster.""" 

588 return list(self.nodes_manager.nodes_cache.values()) 

589 

590 def get_primaries(self) -> List["ClusterNode"]: 

591 """Get the primary nodes of the cluster.""" 

592 return self.nodes_manager.get_nodes_by_server_type(PRIMARY) 

593 

594 def get_replicas(self) -> List["ClusterNode"]: 

595 """Get the replica nodes of the cluster.""" 

596 return self.nodes_manager.get_nodes_by_server_type(REPLICA) 

597 

598 def get_random_node(self) -> "ClusterNode": 

599 """Get a random node of the cluster.""" 

600 return random.choice(list(self.nodes_manager.nodes_cache.values())) 

601 

602 def get_default_node(self) -> "ClusterNode": 

603 """Get the default node of the client.""" 

604 return self.nodes_manager.default_node 

605 

606 def set_default_node(self, node: "ClusterNode") -> None: 

607 """ 

608 Set the default node of the client. 

609 

610 :raises DataError: if None is passed or node does not exist in cluster. 

611 """ 

612 if not node or not self.get_node(node_name=node.name): 

613 raise DataError("The requested node does not exist in the cluster.") 

614 

615 self.nodes_manager.default_node = node 

616 

617 def get_node( 

618 self, 

619 host: Optional[str] = None, 

620 port: Optional[int] = None, 

621 node_name: Optional[str] = None, 

622 ) -> Optional["ClusterNode"]: 

623 """Get node by (host, port) or node_name.""" 

624 return self.nodes_manager.get_node(host, port, node_name) 

625 

626 def get_node_from_key( 

627 self, key: str, replica: bool = False 

628 ) -> Optional["ClusterNode"]: 

629 """ 

630 Get the cluster node corresponding to the provided key. 

631 

632 :param key: 

633 :param replica: 

634 | Indicates if a replica should be returned 

635 | 

636 None will returned if no replica holds this key 

637 

638 :raises SlotNotCoveredError: if the key is not covered by any slot. 

639 """ 

640 slot = self.keyslot(key) 

641 slot_cache = self.nodes_manager.slots_cache.get(slot) 

642 if not slot_cache: 

643 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') 

644 

645 if replica: 

646 if len(self.nodes_manager.slots_cache[slot]) < 2: 

647 return None 

648 node_idx = 1 

649 else: 

650 node_idx = 0 

651 

652 return slot_cache[node_idx] 

653 

654 def get_random_primary_or_all_nodes(self, command_name): 

655 """ 

656 Returns random primary or all nodes depends on READONLY mode. 

657 """ 

658 if self.read_from_replicas and command_name in READ_COMMANDS: 

659 return self.get_random_node() 

660 

661 return self.get_random_primary_node() 

662 

663 def get_random_primary_node(self) -> "ClusterNode": 

664 """ 

665 Returns a random primary node 

666 """ 

667 return random.choice(self.get_primaries()) 

668 

669 async def get_nodes_from_slot(self, command: str, *args): 

670 """ 

671 Returns a list of nodes that hold the specified keys' slots. 

672 """ 

673 # get the node that holds the key's slot 

674 return [ 

675 self.nodes_manager.get_node_from_slot( 

676 await self._determine_slot(command, *args), 

677 self.read_from_replicas and command in READ_COMMANDS, 

678 self.load_balancing_strategy if command in READ_COMMANDS else None, 

679 ) 

680 ] 

681 

682 def get_special_nodes(self) -> Optional[list["ClusterNode"]]: 

683 """ 

684 Returns a list of nodes for commands with a special policy. 

685 """ 

686 if not self._aggregate_nodes: 

687 raise RedisClusterException( 

688 "Cannot execute FT.CURSOR commands without FT.AGGREGATE" 

689 ) 

690 

691 return self._aggregate_nodes 

692 

693 def keyslot(self, key: EncodableT) -> int: 

694 """ 

695 Find the keyslot for a given key. 

696 

697 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding 

698 """ 

699 return key_slot(self.encoder.encode(key)) 

700 

701 def get_encoder(self) -> Encoder: 

702 """Get the encoder object of the client.""" 

703 return self.encoder 

704 

705 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: 

706 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" 

707 return self.connection_kwargs 

708 

709 def set_retry(self, retry: Retry) -> None: 

710 self.retry = retry 

711 

712 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

713 """Set a custom response callback.""" 

714 self.response_callbacks[command] = callback 

715 

716 async def _determine_nodes( 

717 self, 

718 command: str, 

719 *args: Any, 

720 request_policy: RequestPolicy, 

721 node_flag: Optional[str] = None, 

722 ) -> List["ClusterNode"]: 

723 # Determine which nodes should be executed the command on. 

724 # Returns a list of target nodes. 

725 if not node_flag: 

726 # get the nodes group for this command if it was predefined 

727 node_flag = self.command_flags.get(command) 

728 

729 if node_flag in self._command_flags_mapping: 

730 request_policy = self._command_flags_mapping[node_flag] 

731 

732 policy_callback = self._policies_callback_mapping[request_policy] 

733 

734 if request_policy == RequestPolicy.DEFAULT_KEYED: 

735 nodes = await policy_callback(command, *args) 

736 elif request_policy == RequestPolicy.DEFAULT_KEYLESS: 

737 nodes = policy_callback(command) 

738 else: 

739 nodes = policy_callback() 

740 

741 if command.lower() == "ft.aggregate": 

742 self._aggregate_nodes = nodes 

743 

744 return nodes 

745 

746 async def _determine_slot(self, command: str, *args: Any) -> int: 

747 if self.command_flags.get(command) == SLOT_ID: 

748 # The command contains the slot ID 

749 return int(args[0]) 

750 

751 # Get the keys in the command 

752 

753 # EVAL and EVALSHA are common enough that it's wasteful to go to the 

754 # redis server to parse the keys. Besides, there is a bug in redis<7.0 

755 # where `self._get_command_keys()` fails anyway. So, we special case 

756 # EVAL/EVALSHA. 

757 # - issue: https://github.com/redis/redis/issues/9493 

758 # - fix: https://github.com/redis/redis/pull/9733 

759 if command.upper() in ("EVAL", "EVALSHA"): 

760 # command syntax: EVAL "script body" num_keys ... 

761 if len(args) < 2: 

762 raise RedisClusterException( 

763 f"Invalid args in command: {command, *args}" 

764 ) 

765 keys = args[2 : 2 + int(args[1])] 

766 # if there are 0 keys, that means the script can be run on any node 

767 # so we can just return a random slot 

768 if not keys: 

769 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

770 else: 

771 keys = await self.commands_parser.get_keys(command, *args) 

772 if not keys: 

773 # FCALL can call a function with 0 keys, that means the function 

774 # can be run on any node so we can just return a random slot 

775 if command.upper() in ("FCALL", "FCALL_RO"): 

776 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

777 raise RedisClusterException( 

778 "No way to dispatch this command to Redis Cluster. " 

779 "Missing key.\nYou can execute the command by specifying " 

780 f"target nodes.\nCommand: {args}" 

781 ) 

782 

783 # single key command 

784 if len(keys) == 1: 

785 return self.keyslot(keys[0]) 

786 

787 # multi-key command; we need to make sure all keys are mapped to 

788 # the same slot 

789 slots = {self.keyslot(key) for key in keys} 

790 if len(slots) != 1: 

791 raise RedisClusterException( 

792 f"{command} - all keys must map to the same key slot" 

793 ) 

794 

795 return slots.pop() 

796 

797 def _is_node_flag(self, target_nodes: Any) -> bool: 

798 return isinstance(target_nodes, str) and target_nodes in self.node_flags 

799 

800 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: 

801 if isinstance(target_nodes, list): 

802 nodes = target_nodes 

803 elif isinstance(target_nodes, ClusterNode): 

804 # Supports passing a single ClusterNode as a variable 

805 nodes = [target_nodes] 

806 elif isinstance(target_nodes, dict): 

807 # Supports dictionaries of the format {node_name: node}. 

808 # It enables to execute commands with multi nodes as follows: 

809 # rc.cluster_save_config(rc.get_primaries()) 

810 nodes = list(target_nodes.values()) 

811 else: 

812 raise TypeError( 

813 "target_nodes type can be one of the following: " 

814 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," 

815 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. " 

816 f"The passed type is {type(target_nodes)}" 

817 ) 

818 return nodes 

819 

820 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: 

821 """ 

822 Execute a raw command on the appropriate cluster node or target_nodes. 

823 

824 It will retry the command as specified by the retries property of 

825 the :attr:`retry` & then raise an exception. 

826 

827 :param args: 

828 | Raw command args 

829 :param kwargs: 

830 

831 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

832 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

833 - Rest of the kwargs are passed to the Redis connection 

834 

835 :raises RedisClusterException: if target_nodes is not provided & the command 

836 can't be mapped to a slot 

837 """ 

838 command = args[0] 

839 target_nodes = [] 

840 target_nodes_specified = False 

841 retry_attempts = self.retry.get_retries() 

842 

843 passed_targets = kwargs.pop("target_nodes", None) 

844 if passed_targets and not self._is_node_flag(passed_targets): 

845 target_nodes = self._parse_target_nodes(passed_targets) 

846 target_nodes_specified = True 

847 retry_attempts = 0 

848 

849 command_policies = await self._policy_resolver.resolve(args[0].lower()) 

850 

851 if not command_policies and not target_nodes_specified: 

852 command_flag = self.command_flags.get(command) 

853 if not command_flag: 

854 # Fallback to default policy 

855 if not self.get_default_node(): 

856 slot = None 

857 else: 

858 slot = await self._determine_slot(*args) 

859 if not slot: 

860 command_policies = CommandPolicies() 

861 else: 

862 command_policies = CommandPolicies( 

863 request_policy=RequestPolicy.DEFAULT_KEYED, 

864 response_policy=ResponsePolicy.DEFAULT_KEYED, 

865 ) 

866 else: 

867 if command_flag in self._command_flags_mapping: 

868 command_policies = CommandPolicies( 

869 request_policy=self._command_flags_mapping[command_flag] 

870 ) 

871 else: 

872 command_policies = CommandPolicies() 

873 elif not command_policies and target_nodes_specified: 

874 command_policies = CommandPolicies() 

875 

876 # Add one for the first execution 

877 execute_attempts = 1 + retry_attempts 

878 for _ in range(execute_attempts): 

879 if self._initialize: 

880 await self.initialize() 

881 if ( 

882 len(target_nodes) == 1 

883 and target_nodes[0] == self.get_default_node() 

884 ): 

885 # Replace the default cluster node 

886 self.replace_default_node() 

887 try: 

888 if not target_nodes_specified: 

889 # Determine the nodes to execute the command on 

890 target_nodes = await self._determine_nodes( 

891 *args, 

892 request_policy=command_policies.request_policy, 

893 node_flag=passed_targets, 

894 ) 

895 if not target_nodes: 

896 raise RedisClusterException( 

897 f"No targets were found to execute {args} command on" 

898 ) 

899 

900 if len(target_nodes) == 1: 

901 # Return the processed result 

902 ret = await self._execute_command(target_nodes[0], *args, **kwargs) 

903 if command in self.result_callbacks: 

904 ret = self.result_callbacks[command]( 

905 command, {target_nodes[0].name: ret}, **kwargs 

906 ) 

907 return self._policies_callback_mapping[ 

908 command_policies.response_policy 

909 ](ret) 

910 else: 

911 keys = [node.name for node in target_nodes] 

912 values = await asyncio.gather( 

913 *( 

914 asyncio.create_task( 

915 self._execute_command(node, *args, **kwargs) 

916 ) 

917 for node in target_nodes 

918 ) 

919 ) 

920 if command in self.result_callbacks: 

921 return self.result_callbacks[command]( 

922 command, dict(zip(keys, values)), **kwargs 

923 ) 

924 return self._policies_callback_mapping[ 

925 command_policies.response_policy 

926 ](dict(zip(keys, values))) 

927 except Exception as e: 

928 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: 

929 # The nodes and slots cache were should be reinitialized. 

930 # Try again with the new cluster setup. 

931 retry_attempts -= 1 

932 continue 

933 else: 

934 # raise the exception 

935 raise e 

936 

937 async def _execute_command( 

938 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any 

939 ) -> Any: 

940 asking = moved = False 

941 redirect_addr = None 

942 ttl = self.RedisClusterRequestTTL 

943 

944 while ttl > 0: 

945 ttl -= 1 

946 try: 

947 if asking: 

948 target_node = self.get_node(node_name=redirect_addr) 

949 await target_node.execute_command("ASKING") 

950 asking = False 

951 elif moved: 

952 # MOVED occurred and the slots cache was updated, 

953 # refresh the target node 

954 slot = await self._determine_slot(*args) 

955 target_node = self.nodes_manager.get_node_from_slot( 

956 slot, 

957 self.read_from_replicas and args[0] in READ_COMMANDS, 

958 self.load_balancing_strategy 

959 if args[0] in READ_COMMANDS 

960 else None, 

961 ) 

962 moved = False 

963 

964 return await target_node.execute_command(*args, **kwargs) 

965 except BusyLoadingError: 

966 raise 

967 except MaxConnectionsError: 

968 # MaxConnectionsError indicates client-side resource exhaustion 

969 # (too many connections in the pool), not a node failure. 

970 # Don't treat this as a node failure - just re-raise the error 

971 # without reinitializing the cluster. 

972 raise 

973 except (ConnectionError, TimeoutError): 

974 # Connection retries are being handled in the node's 

975 # Retry object. 

976 # Remove the failed node from the startup nodes before we try 

977 # to reinitialize the cluster 

978 self.nodes_manager.startup_nodes.pop(target_node.name, None) 

979 # Hard force of reinitialize of the node/slots setup 

980 # and try again with the new setup 

981 await self.aclose() 

982 raise 

983 except (ClusterDownError, SlotNotCoveredError): 

984 # ClusterDownError can occur during a failover and to get 

985 # self-healed, we will try to reinitialize the cluster layout 

986 # and retry executing the command 

987 

988 # SlotNotCoveredError can occur when the cluster is not fully 

989 # initialized or can be temporary issue. 

990 # We will try to reinitialize the cluster topology 

991 # and retry executing the command 

992 

993 await self.aclose() 

994 await asyncio.sleep(0.25) 

995 raise 

996 except MovedError as e: 

997 # First, we will try to patch the slots/nodes cache with the 

998 # redirected node output and try again. If MovedError exceeds 

999 # 'reinitialize_steps' number of times, we will force 

1000 # reinitializing the tables, and then try again. 

1001 # 'reinitialize_steps' counter will increase faster when 

1002 # the same client object is shared between multiple threads. To 

1003 # reduce the frequency you can set this variable in the 

1004 # RedisCluster constructor. 

1005 self.reinitialize_counter += 1 

1006 if ( 

1007 self.reinitialize_steps 

1008 and self.reinitialize_counter % self.reinitialize_steps == 0 

1009 ): 

1010 await self.aclose() 

1011 # Reset the counter 

1012 self.reinitialize_counter = 0 

1013 else: 

1014 self.nodes_manager._moved_exception = e 

1015 moved = True 

1016 except AskError as e: 

1017 redirect_addr = get_node_name(host=e.host, port=e.port) 

1018 asking = True 

1019 except TryAgainError: 

1020 if ttl < self.RedisClusterRequestTTL / 2: 

1021 await asyncio.sleep(0.05) 

1022 

1023 raise ClusterError("TTL exhausted.") 

1024 

1025 def pipeline( 

1026 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None 

1027 ) -> "ClusterPipeline": 

1028 """ 

1029 Create & return a new :class:`~.ClusterPipeline` object. 

1030 

1031 Cluster implementation of pipeline does not support transaction or shard_hint. 

1032 

1033 :raises RedisClusterException: if transaction or shard_hint are truthy values 

1034 """ 

1035 if shard_hint: 

1036 raise RedisClusterException("shard_hint is deprecated in cluster mode") 

1037 

1038 return ClusterPipeline(self, transaction) 

1039 

1040 def lock( 

1041 self, 

1042 name: KeyT, 

1043 timeout: Optional[float] = None, 

1044 sleep: float = 0.1, 

1045 blocking: bool = True, 

1046 blocking_timeout: Optional[float] = None, 

1047 lock_class: Optional[Type[Lock]] = None, 

1048 thread_local: bool = True, 

1049 raise_on_release_error: bool = True, 

1050 ) -> Lock: 

1051 """ 

1052 Return a new Lock object using key ``name`` that mimics 

1053 the behavior of threading.Lock. 

1054 

1055 If specified, ``timeout`` indicates a maximum life for the lock. 

1056 By default, it will remain locked until release() is called. 

1057 

1058 ``sleep`` indicates the amount of time to sleep per loop iteration 

1059 when the lock is in blocking mode and another client is currently 

1060 holding the lock. 

1061 

1062 ``blocking`` indicates whether calling ``acquire`` should block until 

1063 the lock has been acquired or to fail immediately, causing ``acquire`` 

1064 to return False and the lock not being acquired. Defaults to True. 

1065 Note this value can be overridden by passing a ``blocking`` 

1066 argument to ``acquire``. 

1067 

1068 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

1069 spend trying to acquire the lock. A value of ``None`` indicates 

1070 continue trying forever. ``blocking_timeout`` can be specified as a 

1071 float or integer, both representing the number of seconds to wait. 

1072 

1073 ``lock_class`` forces the specified lock implementation. Note that as 

1074 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is 

1075 a Lua-based lock). So, it's unlikely you'll need this parameter, unless 

1076 you have created your own custom lock class. 

1077 

1078 ``thread_local`` indicates whether the lock token is placed in 

1079 thread-local storage. By default, the token is placed in thread local 

1080 storage so that a thread only sees its token, not a token set by 

1081 another thread. Consider the following timeline: 

1082 

1083 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

1084 thread-1 sets the token to "abc" 

1085 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

1086 Lock instance. 

1087 time: 5, thread-1 has not yet completed. redis expires the lock 

1088 key. 

1089 time: 5, thread-2 acquired `my-lock` now that it's available. 

1090 thread-2 sets the token to "xyz" 

1091 time: 6, thread-1 finishes its work and calls release(). if the 

1092 token is *not* stored in thread local storage, then 

1093 thread-1 would see the token value as "xyz" and would be 

1094 able to successfully release the thread-2's lock. 

1095 

1096 ``raise_on_release_error`` indicates whether to raise an exception when 

1097 the lock is no longer owned when exiting the context manager. By default, 

1098 this is True, meaning an exception will be raised. If False, the warning 

1099 will be logged and the exception will be suppressed. 

1100 

1101 In some use cases it's necessary to disable thread local storage. For 

1102 example, if you have code where one thread acquires a lock and passes 

1103 that lock instance to a worker thread to release later. If thread 

1104 local storage isn't disabled in this case, the worker thread won't see 

1105 the token set by the thread that acquired the lock. Our assumption 

1106 is that these cases aren't common and as such default to using 

1107 thread local storage.""" 

1108 if lock_class is None: 

1109 lock_class = Lock 

1110 return lock_class( 

1111 self, 

1112 name, 

1113 timeout=timeout, 

1114 sleep=sleep, 

1115 blocking=blocking, 

1116 blocking_timeout=blocking_timeout, 

1117 thread_local=thread_local, 

1118 raise_on_release_error=raise_on_release_error, 

1119 ) 

1120 

1121 async def transaction( 

1122 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs 

1123 ): 

1124 """ 

1125 Convenience method for executing the callable `func` as a transaction 

1126 while watching all keys specified in `watches`. The 'func' callable 

1127 should expect a single argument which is a Pipeline object. 

1128 """ 

1129 shard_hint = kwargs.pop("shard_hint", None) 

1130 value_from_callable = kwargs.pop("value_from_callable", False) 

1131 watch_delay = kwargs.pop("watch_delay", None) 

1132 async with self.pipeline(True, shard_hint) as pipe: 

1133 while True: 

1134 try: 

1135 if watches: 

1136 await pipe.watch(*watches) 

1137 func_value = await func(pipe) 

1138 exec_value = await pipe.execute() 

1139 return func_value if value_from_callable else exec_value 

1140 except WatchError: 

1141 if watch_delay is not None and watch_delay > 0: 

1142 time.sleep(watch_delay) 

1143 continue 

1144 

1145 

1146class ClusterNode: 

1147 """ 

1148 Create a new ClusterNode. 

1149 

1150 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` 

1151 objects for the (host, port). 

1152 """ 

1153 

1154 __slots__ = ( 

1155 "_connections", 

1156 "_free", 

1157 "_lock", 

1158 "_event_dispatcher", 

1159 "connection_class", 

1160 "connection_kwargs", 

1161 "host", 

1162 "max_connections", 

1163 "name", 

1164 "port", 

1165 "response_callbacks", 

1166 "server_type", 

1167 ) 

1168 

1169 def __init__( 

1170 self, 

1171 host: str, 

1172 port: Union[str, int], 

1173 server_type: Optional[str] = None, 

1174 *, 

1175 max_connections: int = 2**31, 

1176 connection_class: Type[Connection] = Connection, 

1177 **connection_kwargs: Any, 

1178 ) -> None: 

1179 if host == "localhost": 

1180 host = socket.gethostbyname(host) 

1181 

1182 connection_kwargs["host"] = host 

1183 connection_kwargs["port"] = port 

1184 self.host = host 

1185 self.port = port 

1186 self.name = get_node_name(host, port) 

1187 self.server_type = server_type 

1188 

1189 self.max_connections = max_connections 

1190 self.connection_class = connection_class 

1191 self.connection_kwargs = connection_kwargs 

1192 self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) 

1193 

1194 self._connections: List[Connection] = [] 

1195 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) 

1196 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1197 if self._event_dispatcher is None: 

1198 self._event_dispatcher = EventDispatcher() 

1199 

1200 def __repr__(self) -> str: 

1201 return ( 

1202 f"[host={self.host}, port={self.port}, " 

1203 f"name={self.name}, server_type={self.server_type}]" 

1204 ) 

1205 

1206 def __eq__(self, obj: Any) -> bool: 

1207 return isinstance(obj, ClusterNode) and obj.name == self.name 

1208 

1209 _DEL_MESSAGE = "Unclosed ClusterNode object" 

1210 

1211 def __del__( 

1212 self, 

1213 _warn: Any = warnings.warn, 

1214 _grl: Any = asyncio.get_running_loop, 

1215 ) -> None: 

1216 for connection in self._connections: 

1217 if connection.is_connected: 

1218 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

1219 

1220 try: 

1221 context = {"client": self, "message": self._DEL_MESSAGE} 

1222 _grl().call_exception_handler(context) 

1223 except RuntimeError: 

1224 pass 

1225 break 

1226 

1227 async def disconnect(self) -> None: 

1228 ret = await asyncio.gather( 

1229 *( 

1230 asyncio.create_task(connection.disconnect()) 

1231 for connection in self._connections 

1232 ), 

1233 return_exceptions=True, 

1234 ) 

1235 exc = next((res for res in ret if isinstance(res, Exception)), None) 

1236 if exc: 

1237 raise exc 

1238 

1239 def acquire_connection(self) -> Connection: 

1240 try: 

1241 return self._free.popleft() 

1242 except IndexError: 

1243 if len(self._connections) < self.max_connections: 

1244 # We are configuring the connection pool not to retry 

1245 # connections on lower level clients to avoid retrying 

1246 # connections to nodes that are not reachable 

1247 # and to avoid blocking the connection pool. 

1248 # The only error that will have some handling in the lower 

1249 # level clients is ConnectionError which will trigger disconnection 

1250 # of the socket. 

1251 # The retries will be handled on cluster client level 

1252 # where we will have proper handling of the cluster topology 

1253 retry = Retry( 

1254 backoff=NoBackoff(), 

1255 retries=0, 

1256 supported_errors=(ConnectionError,), 

1257 ) 

1258 connection_kwargs = self.connection_kwargs.copy() 

1259 connection_kwargs["retry"] = retry 

1260 connection = self.connection_class(**connection_kwargs) 

1261 self._connections.append(connection) 

1262 return connection 

1263 

1264 raise MaxConnectionsError() 

1265 

1266 def release(self, connection: Connection) -> None: 

1267 """ 

1268 Release connection back to free queue. 

1269 """ 

1270 self._free.append(connection) 

1271 

1272 async def parse_response( 

1273 self, connection: Connection, command: str, **kwargs: Any 

1274 ) -> Any: 

1275 try: 

1276 if NEVER_DECODE in kwargs: 

1277 response = await connection.read_response(disable_decoding=True) 

1278 kwargs.pop(NEVER_DECODE) 

1279 else: 

1280 response = await connection.read_response() 

1281 except ResponseError: 

1282 if EMPTY_RESPONSE in kwargs: 

1283 return kwargs[EMPTY_RESPONSE] 

1284 raise 

1285 

1286 if EMPTY_RESPONSE in kwargs: 

1287 kwargs.pop(EMPTY_RESPONSE) 

1288 

1289 # Remove keys entry, it needs only for cache. 

1290 kwargs.pop("keys", None) 

1291 

1292 # Return response 

1293 if command in self.response_callbacks: 

1294 return self.response_callbacks[command](response, **kwargs) 

1295 

1296 return response 

1297 

1298 async def execute_command(self, *args: Any, **kwargs: Any) -> Any: 

1299 # Acquire connection 

1300 connection = self.acquire_connection() 

1301 

1302 # Execute command 

1303 await connection.send_packed_command(connection.pack_command(*args), False) 

1304 

1305 # Read response 

1306 try: 

1307 return await self.parse_response(connection, args[0], **kwargs) 

1308 finally: 

1309 # Release connection 

1310 self._free.append(connection) 

1311 

1312 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: 

1313 # Acquire connection 

1314 connection = self.acquire_connection() 

1315 

1316 # Execute command 

1317 await connection.send_packed_command( 

1318 connection.pack_commands(cmd.args for cmd in commands), False 

1319 ) 

1320 

1321 # Read responses 

1322 ret = False 

1323 for cmd in commands: 

1324 try: 

1325 cmd.result = await self.parse_response( 

1326 connection, cmd.args[0], **cmd.kwargs 

1327 ) 

1328 except Exception as e: 

1329 cmd.result = e 

1330 ret = True 

1331 

1332 # Release connection 

1333 self._free.append(connection) 

1334 

1335 return ret 

1336 

1337 async def re_auth_callback(self, token: TokenInterface): 

1338 tmp_queue = collections.deque() 

1339 while self._free: 

1340 conn = self._free.popleft() 

1341 await conn.retry.call_with_retry( 

1342 lambda: conn.send_command( 

1343 "AUTH", token.try_get("oid"), token.get_value() 

1344 ), 

1345 lambda error: self._mock(error), 

1346 ) 

1347 await conn.retry.call_with_retry( 

1348 lambda: conn.read_response(), lambda error: self._mock(error) 

1349 ) 

1350 tmp_queue.append(conn) 

1351 

1352 while tmp_queue: 

1353 conn = tmp_queue.popleft() 

1354 self._free.append(conn) 

1355 

1356 async def _mock(self, error: RedisError): 

1357 """ 

1358 Dummy functions, needs to be passed as error callback to retry object. 

1359 :param error: 

1360 :return: 

1361 """ 

1362 pass 

1363 

1364 

1365class NodesManager: 

1366 __slots__ = ( 

1367 "_dynamic_startup_nodes", 

1368 "_moved_exception", 

1369 "_event_dispatcher", 

1370 "connection_kwargs", 

1371 "default_node", 

1372 "nodes_cache", 

1373 "read_load_balancer", 

1374 "require_full_coverage", 

1375 "slots_cache", 

1376 "startup_nodes", 

1377 "address_remap", 

1378 ) 

1379 

1380 def __init__( 

1381 self, 

1382 startup_nodes: List["ClusterNode"], 

1383 require_full_coverage: bool, 

1384 connection_kwargs: Dict[str, Any], 

1385 dynamic_startup_nodes: bool = True, 

1386 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

1387 event_dispatcher: Optional[EventDispatcher] = None, 

1388 ) -> None: 

1389 self.startup_nodes = {node.name: node for node in startup_nodes} 

1390 self.require_full_coverage = require_full_coverage 

1391 self.connection_kwargs = connection_kwargs 

1392 self.address_remap = address_remap 

1393 

1394 self.default_node: "ClusterNode" = None 

1395 self.nodes_cache: Dict[str, "ClusterNode"] = {} 

1396 self.slots_cache: Dict[int, List["ClusterNode"]] = {} 

1397 self.read_load_balancer = LoadBalancer() 

1398 

1399 self._dynamic_startup_nodes: bool = dynamic_startup_nodes 

1400 self._moved_exception: MovedError = None 

1401 if event_dispatcher is None: 

1402 self._event_dispatcher = EventDispatcher() 

1403 else: 

1404 self._event_dispatcher = event_dispatcher 

1405 

1406 def get_node( 

1407 self, 

1408 host: Optional[str] = None, 

1409 port: Optional[int] = None, 

1410 node_name: Optional[str] = None, 

1411 ) -> Optional["ClusterNode"]: 

1412 if host and port: 

1413 # the user passed host and port 

1414 if host == "localhost": 

1415 host = socket.gethostbyname(host) 

1416 return self.nodes_cache.get(get_node_name(host=host, port=port)) 

1417 elif node_name: 

1418 return self.nodes_cache.get(node_name) 

1419 else: 

1420 raise DataError( 

1421 "get_node requires one of the following: 1. node name 2. host and port" 

1422 ) 

1423 

1424 def set_nodes( 

1425 self, 

1426 old: Dict[str, "ClusterNode"], 

1427 new: Dict[str, "ClusterNode"], 

1428 remove_old: bool = False, 

1429 ) -> None: 

1430 if remove_old: 

1431 for name in list(old.keys()): 

1432 if name not in new: 

1433 task = asyncio.create_task(old.pop(name).disconnect()) # noqa 

1434 

1435 for name, node in new.items(): 

1436 if name in old: 

1437 if old[name] is node: 

1438 continue 

1439 task = asyncio.create_task(old[name].disconnect()) # noqa 

1440 old[name] = node 

1441 

1442 def update_moved_exception(self, exception): 

1443 self._moved_exception = exception 

1444 

1445 def _update_moved_slots(self) -> None: 

1446 e = self._moved_exception 

1447 redirected_node = self.get_node(host=e.host, port=e.port) 

1448 if redirected_node: 

1449 # The node already exists 

1450 if redirected_node.server_type != PRIMARY: 

1451 # Update the node's server type 

1452 redirected_node.server_type = PRIMARY 

1453 else: 

1454 # This is a new node, we will add it to the nodes cache 

1455 redirected_node = ClusterNode( 

1456 e.host, e.port, PRIMARY, **self.connection_kwargs 

1457 ) 

1458 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) 

1459 slot_nodes = self.slots_cache[e.slot_id] 

1460 if redirected_node not in slot_nodes: 

1461 # The new slot owner is a new server, or a server from a different 

1462 # shard. We need to remove all current nodes from the slot's list 

1463 # (including replications) and add just the new node. 

1464 self.slots_cache[e.slot_id] = [redirected_node] 

1465 elif redirected_node is not slot_nodes[0]: 

1466 # The MOVED error resulted from a failover, and the new slot owner 

1467 # had previously been a replica. 

1468 old_primary = slot_nodes[0] 

1469 # Update the old primary to be a replica and add it to the end of 

1470 # the slot's node list 

1471 old_primary.server_type = REPLICA 

1472 slot_nodes.append(old_primary) 

1473 # Remove the old replica, which is now a primary, from the slot's 

1474 # node list 

1475 slot_nodes.remove(redirected_node) 

1476 # Override the old primary with the new one 

1477 slot_nodes[0] = redirected_node 

1478 if self.default_node == old_primary: 

1479 # Update the default node with the new primary 

1480 self.default_node = redirected_node 

1481 # else: circular MOVED to current primary -> no-op 

1482 

1483 # Reset moved_exception 

1484 self._moved_exception = None 

1485 

1486 def get_node_from_slot( 

1487 self, 

1488 slot: int, 

1489 read_from_replicas: bool = False, 

1490 load_balancing_strategy=None, 

1491 ) -> "ClusterNode": 

1492 if self._moved_exception: 

1493 self._update_moved_slots() 

1494 

1495 if read_from_replicas is True and load_balancing_strategy is None: 

1496 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN 

1497 

1498 try: 

1499 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: 

1500 # get the server index using the strategy defined in load_balancing_strategy 

1501 primary_name = self.slots_cache[slot][0].name 

1502 node_idx = self.read_load_balancer.get_server_index( 

1503 primary_name, len(self.slots_cache[slot]), load_balancing_strategy 

1504 ) 

1505 return self.slots_cache[slot][node_idx] 

1506 return self.slots_cache[slot][0] 

1507 except (IndexError, TypeError): 

1508 raise SlotNotCoveredError( 

1509 f'Slot "{slot}" not covered by the cluster. ' 

1510 f'"require_full_coverage={self.require_full_coverage}"' 

1511 ) 

1512 

1513 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: 

1514 return [ 

1515 node 

1516 for node in self.nodes_cache.values() 

1517 if node.server_type == server_type 

1518 ] 

1519 

1520 async def initialize(self) -> None: 

1521 self.read_load_balancer.reset() 

1522 tmp_nodes_cache: Dict[str, "ClusterNode"] = {} 

1523 tmp_slots: Dict[int, List["ClusterNode"]] = {} 

1524 disagreements = [] 

1525 startup_nodes_reachable = False 

1526 fully_covered = False 

1527 exception = None 

1528 # Convert to tuple to prevent RuntimeError if self.startup_nodes 

1529 # is modified during iteration 

1530 for startup_node in tuple(self.startup_nodes.values()): 

1531 try: 

1532 # Make sure cluster mode is enabled on this node 

1533 try: 

1534 self._event_dispatcher.dispatch( 

1535 AfterAsyncClusterInstantiationEvent( 

1536 self.nodes_cache, 

1537 self.connection_kwargs.get("credential_provider", None), 

1538 ) 

1539 ) 

1540 cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") 

1541 except ResponseError: 

1542 raise RedisClusterException( 

1543 "Cluster mode is not enabled on this node" 

1544 ) 

1545 startup_nodes_reachable = True 

1546 except Exception as e: 

1547 # Try the next startup node. 

1548 # The exception is saved and raised only if we have no more nodes. 

1549 exception = e 

1550 continue 

1551 

1552 # CLUSTER SLOTS command results in the following output: 

1553 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] 

1554 # where each node contains the following list: [IP, port, node_id] 

1555 # Therefore, cluster_slots[0][2][0] will be the IP address of the 

1556 # primary node of the first slot section. 

1557 # If there's only one server in the cluster, its ``host`` is '' 

1558 # Fix it to the host in startup_nodes 

1559 if ( 

1560 len(cluster_slots) == 1 

1561 and not cluster_slots[0][2][0] 

1562 and len(self.startup_nodes) == 1 

1563 ): 

1564 cluster_slots[0][2][0] = startup_node.host 

1565 

1566 for slot in cluster_slots: 

1567 for i in range(2, len(slot)): 

1568 slot[i] = [str_if_bytes(val) for val in slot[i]] 

1569 primary_node = slot[2] 

1570 host = primary_node[0] 

1571 if host == "": 

1572 host = startup_node.host 

1573 port = int(primary_node[1]) 

1574 host, port = self.remap_host_port(host, port) 

1575 

1576 nodes_for_slot = [] 

1577 

1578 target_node = tmp_nodes_cache.get(get_node_name(host, port)) 

1579 if not target_node: 

1580 target_node = ClusterNode( 

1581 host, port, PRIMARY, **self.connection_kwargs 

1582 ) 

1583 # add this node to the nodes cache 

1584 tmp_nodes_cache[target_node.name] = target_node 

1585 nodes_for_slot.append(target_node) 

1586 

1587 replica_nodes = slot[3:] 

1588 for replica_node in replica_nodes: 

1589 host = replica_node[0] 

1590 port = replica_node[1] 

1591 host, port = self.remap_host_port(host, port) 

1592 

1593 target_replica_node = tmp_nodes_cache.get(get_node_name(host, port)) 

1594 if not target_replica_node: 

1595 target_replica_node = ClusterNode( 

1596 host, port, REPLICA, **self.connection_kwargs 

1597 ) 

1598 # add this node to the nodes cache 

1599 tmp_nodes_cache[target_replica_node.name] = target_replica_node 

1600 nodes_for_slot.append(target_replica_node) 

1601 

1602 for i in range(int(slot[0]), int(slot[1]) + 1): 

1603 if i not in tmp_slots: 

1604 tmp_slots[i] = nodes_for_slot 

1605 else: 

1606 # Validate that 2 nodes want to use the same slot cache 

1607 # setup 

1608 tmp_slot = tmp_slots[i][0] 

1609 if tmp_slot.name != target_node.name: 

1610 disagreements.append( 

1611 f"{tmp_slot.name} vs {target_node.name} on slot: {i}" 

1612 ) 

1613 

1614 if len(disagreements) > 5: 

1615 raise RedisClusterException( 

1616 f"startup_nodes could not agree on a valid " 

1617 f"slots cache: {', '.join(disagreements)}" 

1618 ) 

1619 

1620 # Validate if all slots are covered or if we should try next startup node 

1621 fully_covered = True 

1622 for i in range(REDIS_CLUSTER_HASH_SLOTS): 

1623 if i not in tmp_slots: 

1624 fully_covered = False 

1625 break 

1626 if fully_covered: 

1627 break 

1628 

1629 if not startup_nodes_reachable: 

1630 raise RedisClusterException( 

1631 f"Redis Cluster cannot be connected. Please provide at least " 

1632 f"one reachable node: {str(exception)}" 

1633 ) from exception 

1634 

1635 # Check if the slots are not fully covered 

1636 if not fully_covered and self.require_full_coverage: 

1637 # Despite the requirement that the slots be covered, there 

1638 # isn't a full coverage 

1639 raise RedisClusterException( 

1640 f"All slots are not covered after query all startup_nodes. " 

1641 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " 

1642 f"covered..." 

1643 ) 

1644 

1645 # Set the tmp variables to the real variables 

1646 self.slots_cache = tmp_slots 

1647 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) 

1648 

1649 if self._dynamic_startup_nodes: 

1650 # Populate the startup nodes with all discovered nodes 

1651 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) 

1652 

1653 # Set the default node 

1654 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] 

1655 # If initialize was called after a MovedError, clear it 

1656 self._moved_exception = None 

1657 

1658 async def aclose(self, attr: str = "nodes_cache") -> None: 

1659 self.default_node = None 

1660 await asyncio.gather( 

1661 *( 

1662 asyncio.create_task(node.disconnect()) 

1663 for node in getattr(self, attr).values() 

1664 ) 

1665 ) 

1666 

1667 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: 

1668 """ 

1669 Remap the host and port returned from the cluster to a different 

1670 internal value. Useful if the client is not connecting directly 

1671 to the cluster. 

1672 """ 

1673 if self.address_remap: 

1674 return self.address_remap((host, port)) 

1675 return host, port 

1676 

1677 

1678class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

1679 """ 

1680 Create a new ClusterPipeline object. 

1681 

1682 Usage:: 

1683 

1684 result = await ( 

1685 rc.pipeline() 

1686 .set("A", 1) 

1687 .get("A") 

1688 .hset("K", "F", "V") 

1689 .hgetall("K") 

1690 .mset_nonatomic({"A": 2, "B": 3}) 

1691 .get("A") 

1692 .get("B") 

1693 .delete("A", "B", "K") 

1694 .execute() 

1695 ) 

1696 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] 

1697 

1698 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which 

1699 are split across multiple nodes, you'll get multiple results for them in the array. 

1700 

1701 Retryable errors: 

1702 - :class:`~.ClusterDownError` 

1703 - :class:`~.ConnectionError` 

1704 - :class:`~.TimeoutError` 

1705 

1706 Redirection errors: 

1707 - :class:`~.TryAgainError` 

1708 - :class:`~.MovedError` 

1709 - :class:`~.AskError` 

1710 

1711 :param client: 

1712 | Existing :class:`~.RedisCluster` client 

1713 """ 

1714 

1715 __slots__ = ("cluster_client", "_transaction", "_execution_strategy") 

1716 

1717 def __init__( 

1718 self, client: RedisCluster, transaction: Optional[bool] = None 

1719 ) -> None: 

1720 self.cluster_client = client 

1721 self._transaction = transaction 

1722 self._execution_strategy: ExecutionStrategy = ( 

1723 PipelineStrategy(self) 

1724 if not self._transaction 

1725 else TransactionStrategy(self) 

1726 ) 

1727 

1728 async def initialize(self) -> "ClusterPipeline": 

1729 await self._execution_strategy.initialize() 

1730 return self 

1731 

1732 async def __aenter__(self) -> "ClusterPipeline": 

1733 return await self.initialize() 

1734 

1735 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: 

1736 await self.reset() 

1737 

1738 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: 

1739 return self.initialize().__await__() 

1740 

1741 def __bool__(self) -> bool: 

1742 "Pipeline instances should always evaluate to True on Python 3+" 

1743 return True 

1744 

1745 def __len__(self) -> int: 

1746 return len(self._execution_strategy) 

1747 

1748 def execute_command( 

1749 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

1750 ) -> "ClusterPipeline": 

1751 """ 

1752 Append a raw command to the pipeline. 

1753 

1754 :param args: 

1755 | Raw command args 

1756 :param kwargs: 

1757 

1758 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

1759 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

1760 - Rest of the kwargs are passed to the Redis connection 

1761 """ 

1762 return self._execution_strategy.execute_command(*args, **kwargs) 

1763 

1764 async def execute( 

1765 self, raise_on_error: bool = True, allow_redirections: bool = True 

1766 ) -> List[Any]: 

1767 """ 

1768 Execute the pipeline. 

1769 

1770 It will retry the commands as specified by retries specified in :attr:`retry` 

1771 & then raise an exception. 

1772 

1773 :param raise_on_error: 

1774 | Raise the first error if there are any errors 

1775 :param allow_redirections: 

1776 | Whether to retry each failed command individually in case of redirection 

1777 errors 

1778 

1779 :raises RedisClusterException: if target_nodes is not provided & the command 

1780 can't be mapped to a slot 

1781 """ 

1782 try: 

1783 return await self._execution_strategy.execute( 

1784 raise_on_error, allow_redirections 

1785 ) 

1786 finally: 

1787 await self.reset() 

1788 

1789 def _split_command_across_slots( 

1790 self, command: str, *keys: KeyT 

1791 ) -> "ClusterPipeline": 

1792 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values(): 

1793 self.execute_command(command, *slot_keys) 

1794 

1795 return self 

1796 

1797 async def reset(self): 

1798 """ 

1799 Reset back to empty pipeline. 

1800 """ 

1801 await self._execution_strategy.reset() 

1802 

1803 def multi(self): 

1804 """ 

1805 Start a transactional block of the pipeline after WATCH commands 

1806 are issued. End the transactional block with `execute`. 

1807 """ 

1808 self._execution_strategy.multi() 

1809 

1810 async def discard(self): 

1811 """ """ 

1812 await self._execution_strategy.discard() 

1813 

1814 async def watch(self, *names): 

1815 """Watches the values at keys ``names``""" 

1816 await self._execution_strategy.watch(*names) 

1817 

1818 async def unwatch(self): 

1819 """Unwatches all previously specified keys""" 

1820 await self._execution_strategy.unwatch() 

1821 

1822 async def unlink(self, *names): 

1823 await self._execution_strategy.unlink(*names) 

1824 

1825 def mset_nonatomic( 

1826 self, mapping: Mapping[AnyKeyT, EncodableT] 

1827 ) -> "ClusterPipeline": 

1828 return self._execution_strategy.mset_nonatomic(mapping) 

1829 

1830 

1831for command in PIPELINE_BLOCKED_COMMANDS: 

1832 command = command.replace(" ", "_").lower() 

1833 if command == "mset_nonatomic": 

1834 continue 

1835 

1836 setattr(ClusterPipeline, command, block_pipeline_command(command)) 

1837 

1838 

1839class PipelineCommand: 

1840 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: 

1841 self.args = args 

1842 self.kwargs = kwargs 

1843 self.position = position 

1844 self.result: Union[Any, Exception] = None 

1845 self.command_policies: Optional[CommandPolicies] = None 

1846 

1847 def __repr__(self) -> str: 

1848 return f"[{self.position}] {self.args} ({self.kwargs})" 

1849 

1850 

1851class ExecutionStrategy(ABC): 

1852 @abstractmethod 

1853 async def initialize(self) -> "ClusterPipeline": 

1854 """ 

1855 Initialize the execution strategy. 

1856 

1857 See ClusterPipeline.initialize() 

1858 """ 

1859 pass 

1860 

1861 @abstractmethod 

1862 def execute_command( 

1863 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

1864 ) -> "ClusterPipeline": 

1865 """ 

1866 Append a raw command to the pipeline. 

1867 

1868 See ClusterPipeline.execute_command() 

1869 """ 

1870 pass 

1871 

1872 @abstractmethod 

1873 async def execute( 

1874 self, raise_on_error: bool = True, allow_redirections: bool = True 

1875 ) -> List[Any]: 

1876 """ 

1877 Execute the pipeline. 

1878 

1879 It will retry the commands as specified by retries specified in :attr:`retry` 

1880 & then raise an exception. 

1881 

1882 See ClusterPipeline.execute() 

1883 """ 

1884 pass 

1885 

1886 @abstractmethod 

1887 def mset_nonatomic( 

1888 self, mapping: Mapping[AnyKeyT, EncodableT] 

1889 ) -> "ClusterPipeline": 

1890 """ 

1891 Executes multiple MSET commands according to the provided slot/pairs mapping. 

1892 

1893 See ClusterPipeline.mset_nonatomic() 

1894 """ 

1895 pass 

1896 

1897 @abstractmethod 

1898 async def reset(self): 

1899 """ 

1900 Resets current execution strategy. 

1901 

1902 See: ClusterPipeline.reset() 

1903 """ 

1904 pass 

1905 

1906 @abstractmethod 

1907 def multi(self): 

1908 """ 

1909 Starts transactional context. 

1910 

1911 See: ClusterPipeline.multi() 

1912 """ 

1913 pass 

1914 

1915 @abstractmethod 

1916 async def watch(self, *names): 

1917 """ 

1918 Watch given keys. 

1919 

1920 See: ClusterPipeline.watch() 

1921 """ 

1922 pass 

1923 

1924 @abstractmethod 

1925 async def unwatch(self): 

1926 """ 

1927 Unwatches all previously specified keys 

1928 

1929 See: ClusterPipeline.unwatch() 

1930 """ 

1931 pass 

1932 

1933 @abstractmethod 

1934 async def discard(self): 

1935 pass 

1936 

1937 @abstractmethod 

1938 async def unlink(self, *names): 

1939 """ 

1940 "Unlink a key specified by ``names``" 

1941 

1942 See: ClusterPipeline.unlink() 

1943 """ 

1944 pass 

1945 

1946 @abstractmethod 

1947 def __len__(self) -> int: 

1948 pass 

1949 

1950 

1951class AbstractStrategy(ExecutionStrategy): 

1952 def __init__(self, pipe: ClusterPipeline) -> None: 

1953 self._pipe: ClusterPipeline = pipe 

1954 self._command_queue: List["PipelineCommand"] = [] 

1955 

1956 async def initialize(self) -> "ClusterPipeline": 

1957 if self._pipe.cluster_client._initialize: 

1958 await self._pipe.cluster_client.initialize() 

1959 self._command_queue = [] 

1960 return self._pipe 

1961 

1962 def execute_command( 

1963 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

1964 ) -> "ClusterPipeline": 

1965 self._command_queue.append( 

1966 PipelineCommand(len(self._command_queue), *args, **kwargs) 

1967 ) 

1968 return self._pipe 

1969 

1970 def _annotate_exception(self, exception, number, command): 

1971 """ 

1972 Provides extra context to the exception prior to it being handled 

1973 """ 

1974 cmd = " ".join(map(safe_str, command)) 

1975 msg = ( 

1976 f"Command # {number} ({truncate_text(cmd)}) of pipeline " 

1977 f"caused error: {exception.args[0]}" 

1978 ) 

1979 exception.args = (msg,) + exception.args[1:] 

1980 

1981 @abstractmethod 

1982 def mset_nonatomic( 

1983 self, mapping: Mapping[AnyKeyT, EncodableT] 

1984 ) -> "ClusterPipeline": 

1985 pass 

1986 

1987 @abstractmethod 

1988 async def execute( 

1989 self, raise_on_error: bool = True, allow_redirections: bool = True 

1990 ) -> List[Any]: 

1991 pass 

1992 

1993 @abstractmethod 

1994 async def reset(self): 

1995 pass 

1996 

1997 @abstractmethod 

1998 def multi(self): 

1999 pass 

2000 

2001 @abstractmethod 

2002 async def watch(self, *names): 

2003 pass 

2004 

2005 @abstractmethod 

2006 async def unwatch(self): 

2007 pass 

2008 

2009 @abstractmethod 

2010 async def discard(self): 

2011 pass 

2012 

2013 @abstractmethod 

2014 async def unlink(self, *names): 

2015 pass 

2016 

2017 def __len__(self) -> int: 

2018 return len(self._command_queue) 

2019 

2020 

2021class PipelineStrategy(AbstractStrategy): 

2022 def __init__(self, pipe: ClusterPipeline) -> None: 

2023 super().__init__(pipe) 

2024 

2025 def mset_nonatomic( 

2026 self, mapping: Mapping[AnyKeyT, EncodableT] 

2027 ) -> "ClusterPipeline": 

2028 encoder = self._pipe.cluster_client.encoder 

2029 

2030 slots_pairs = {} 

2031 for pair in mapping.items(): 

2032 slot = key_slot(encoder.encode(pair[0])) 

2033 slots_pairs.setdefault(slot, []).extend(pair) 

2034 

2035 for pairs in slots_pairs.values(): 

2036 self.execute_command("MSET", *pairs) 

2037 

2038 return self._pipe 

2039 

2040 async def execute( 

2041 self, raise_on_error: bool = True, allow_redirections: bool = True 

2042 ) -> List[Any]: 

2043 if not self._command_queue: 

2044 return [] 

2045 

2046 try: 

2047 retry_attempts = self._pipe.cluster_client.retry.get_retries() 

2048 while True: 

2049 try: 

2050 if self._pipe.cluster_client._initialize: 

2051 await self._pipe.cluster_client.initialize() 

2052 return await self._execute( 

2053 self._pipe.cluster_client, 

2054 self._command_queue, 

2055 raise_on_error=raise_on_error, 

2056 allow_redirections=allow_redirections, 

2057 ) 

2058 

2059 except RedisCluster.ERRORS_ALLOW_RETRY as e: 

2060 if retry_attempts > 0: 

2061 # Try again with the new cluster setup. All other errors 

2062 # should be raised. 

2063 retry_attempts -= 1 

2064 await self._pipe.cluster_client.aclose() 

2065 await asyncio.sleep(0.25) 

2066 else: 

2067 # All other errors should be raised. 

2068 raise e 

2069 finally: 

2070 await self.reset() 

2071 

2072 async def _execute( 

2073 self, 

2074 client: "RedisCluster", 

2075 stack: List["PipelineCommand"], 

2076 raise_on_error: bool = True, 

2077 allow_redirections: bool = True, 

2078 ) -> List[Any]: 

2079 todo = [ 

2080 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception) 

2081 ] 

2082 

2083 nodes = {} 

2084 for cmd in todo: 

2085 passed_targets = cmd.kwargs.pop("target_nodes", None) 

2086 command_policies = await client._policy_resolver.resolve( 

2087 cmd.args[0].lower() 

2088 ) 

2089 

2090 if passed_targets and not client._is_node_flag(passed_targets): 

2091 target_nodes = client._parse_target_nodes(passed_targets) 

2092 

2093 if not command_policies: 

2094 command_policies = CommandPolicies() 

2095 else: 

2096 if not command_policies: 

2097 command_flag = client.command_flags.get(cmd.args[0]) 

2098 if not command_flag: 

2099 # Fallback to default policy 

2100 if not client.get_default_node(): 

2101 slot = None 

2102 else: 

2103 slot = await client._determine_slot(*cmd.args) 

2104 if not slot: 

2105 command_policies = CommandPolicies() 

2106 else: 

2107 command_policies = CommandPolicies( 

2108 request_policy=RequestPolicy.DEFAULT_KEYED, 

2109 response_policy=ResponsePolicy.DEFAULT_KEYED, 

2110 ) 

2111 else: 

2112 if command_flag in client._command_flags_mapping: 

2113 command_policies = CommandPolicies( 

2114 request_policy=client._command_flags_mapping[ 

2115 command_flag 

2116 ] 

2117 ) 

2118 else: 

2119 command_policies = CommandPolicies() 

2120 

2121 target_nodes = await client._determine_nodes( 

2122 *cmd.args, 

2123 request_policy=command_policies.request_policy, 

2124 node_flag=passed_targets, 

2125 ) 

2126 if not target_nodes: 

2127 raise RedisClusterException( 

2128 f"No targets were found to execute {cmd.args} command on" 

2129 ) 

2130 cmd.command_policies = command_policies 

2131 if len(target_nodes) > 1: 

2132 raise RedisClusterException(f"Too many targets for command {cmd.args}") 

2133 node = target_nodes[0] 

2134 if node.name not in nodes: 

2135 nodes[node.name] = (node, []) 

2136 nodes[node.name][1].append(cmd) 

2137 

2138 errors = await asyncio.gather( 

2139 *( 

2140 asyncio.create_task(node[0].execute_pipeline(node[1])) 

2141 for node in nodes.values() 

2142 ) 

2143 ) 

2144 

2145 if any(errors): 

2146 if allow_redirections: 

2147 # send each errored command individually 

2148 for cmd in todo: 

2149 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): 

2150 try: 

2151 cmd.result = client._policies_callback_mapping[ 

2152 cmd.command_policies.response_policy 

2153 ](await client.execute_command(*cmd.args, **cmd.kwargs)) 

2154 except Exception as e: 

2155 cmd.result = e 

2156 

2157 if raise_on_error: 

2158 for cmd in todo: 

2159 result = cmd.result 

2160 if isinstance(result, Exception): 

2161 command = " ".join(map(safe_str, cmd.args)) 

2162 msg = ( 

2163 f"Command # {cmd.position + 1} " 

2164 f"({truncate_text(command)}) " 

2165 f"of pipeline caused error: {result.args}" 

2166 ) 

2167 result.args = (msg,) + result.args[1:] 

2168 raise result 

2169 

2170 default_cluster_node = client.get_default_node() 

2171 

2172 # Check whether the default node was used. In some cases, 

2173 # 'client.get_default_node()' may return None. The check below 

2174 # prevents a potential AttributeError. 

2175 if default_cluster_node is not None: 

2176 default_node = nodes.get(default_cluster_node.name) 

2177 if default_node is not None: 

2178 # This pipeline execution used the default node, check if we need 

2179 # to replace it. 

2180 # Note: when the error is raised we'll reset the default node in the 

2181 # caller function. 

2182 for cmd in default_node[1]: 

2183 # Check if it has a command that failed with a relevant 

2184 # exception 

2185 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY: 

2186 client.replace_default_node() 

2187 break 

2188 

2189 return [cmd.result for cmd in stack] 

2190 

2191 async def reset(self): 

2192 """ 

2193 Reset back to empty pipeline. 

2194 """ 

2195 self._command_queue = [] 

2196 

2197 def multi(self): 

2198 raise RedisClusterException( 

2199 "method multi() is not supported outside of transactional context" 

2200 ) 

2201 

2202 async def watch(self, *names): 

2203 raise RedisClusterException( 

2204 "method watch() is not supported outside of transactional context" 

2205 ) 

2206 

2207 async def unwatch(self): 

2208 raise RedisClusterException( 

2209 "method unwatch() is not supported outside of transactional context" 

2210 ) 

2211 

2212 async def discard(self): 

2213 raise RedisClusterException( 

2214 "method discard() is not supported outside of transactional context" 

2215 ) 

2216 

2217 async def unlink(self, *names): 

2218 if len(names) != 1: 

2219 raise RedisClusterException( 

2220 "unlinking multiple keys is not implemented in pipeline command" 

2221 ) 

2222 

2223 return self.execute_command("UNLINK", names[0]) 

2224 

2225 

2226class TransactionStrategy(AbstractStrategy): 

2227 NO_SLOTS_COMMANDS = {"UNWATCH"} 

2228 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} 

2229 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} 

2230 SLOT_REDIRECT_ERRORS = (AskError, MovedError) 

2231 CONNECTION_ERRORS = ( 

2232 ConnectionError, 

2233 OSError, 

2234 ClusterDownError, 

2235 SlotNotCoveredError, 

2236 ) 

2237 

2238 def __init__(self, pipe: ClusterPipeline) -> None: 

2239 super().__init__(pipe) 

2240 self._explicit_transaction = False 

2241 self._watching = False 

2242 self._pipeline_slots: Set[int] = set() 

2243 self._transaction_node: Optional[ClusterNode] = None 

2244 self._transaction_connection: Optional[Connection] = None 

2245 self._executing = False 

2246 self._retry = copy(self._pipe.cluster_client.retry) 

2247 self._retry.update_supported_errors( 

2248 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS 

2249 ) 

2250 

2251 def _get_client_and_connection_for_transaction( 

2252 self, 

2253 ) -> Tuple[ClusterNode, Connection]: 

2254 """ 

2255 Find a connection for a pipeline transaction. 

2256 

2257 For running an atomic transaction, watch keys ensure that contents have not been 

2258 altered as long as the watch commands for those keys were sent over the same 

2259 connection. So once we start watching a key, we fetch a connection to the 

2260 node that owns that slot and reuse it. 

2261 """ 

2262 if not self._pipeline_slots: 

2263 raise RedisClusterException( 

2264 "At least a command with a key is needed to identify a node" 

2265 ) 

2266 

2267 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot( 

2268 list(self._pipeline_slots)[0], False 

2269 ) 

2270 self._transaction_node = node 

2271 

2272 if not self._transaction_connection: 

2273 connection: Connection = self._transaction_node.acquire_connection() 

2274 self._transaction_connection = connection 

2275 

2276 return self._transaction_node, self._transaction_connection 

2277 

2278 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any": 

2279 # Given the limitation of ClusterPipeline sync API, we have to run it in thread. 

2280 response = None 

2281 error = None 

2282 

2283 def runner(): 

2284 nonlocal response 

2285 nonlocal error 

2286 try: 

2287 response = asyncio.run(self._execute_command(*args, **kwargs)) 

2288 except Exception as e: 

2289 error = e 

2290 

2291 thread = threading.Thread(target=runner) 

2292 thread.start() 

2293 thread.join() 

2294 

2295 if error: 

2296 raise error 

2297 

2298 return response 

2299 

2300 async def _execute_command( 

2301 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2302 ) -> Any: 

2303 if self._pipe.cluster_client._initialize: 

2304 await self._pipe.cluster_client.initialize() 

2305 

2306 slot_number: Optional[int] = None 

2307 if args[0] not in self.NO_SLOTS_COMMANDS: 

2308 slot_number = await self._pipe.cluster_client._determine_slot(*args) 

2309 

2310 if ( 

2311 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS 

2312 ) and not self._explicit_transaction: 

2313 if args[0] == "WATCH": 

2314 self._validate_watch() 

2315 

2316 if slot_number is not None: 

2317 if self._pipeline_slots and slot_number not in self._pipeline_slots: 

2318 raise CrossSlotTransactionError( 

2319 "Cannot watch or send commands on different slots" 

2320 ) 

2321 

2322 self._pipeline_slots.add(slot_number) 

2323 elif args[0] not in self.NO_SLOTS_COMMANDS: 

2324 raise RedisClusterException( 

2325 f"Cannot identify slot number for command: {args[0]}," 

2326 "it cannot be triggered in a transaction" 

2327 ) 

2328 

2329 return self._immediate_execute_command(*args, **kwargs) 

2330 else: 

2331 if slot_number is not None: 

2332 self._pipeline_slots.add(slot_number) 

2333 

2334 return super().execute_command(*args, **kwargs) 

2335 

2336 def _validate_watch(self): 

2337 if self._explicit_transaction: 

2338 raise RedisError("Cannot issue a WATCH after a MULTI") 

2339 

2340 self._watching = True 

2341 

2342 async def _immediate_execute_command(self, *args, **options): 

2343 return await self._retry.call_with_retry( 

2344 lambda: self._get_connection_and_send_command(*args, **options), 

2345 self._reinitialize_on_error, 

2346 ) 

2347 

2348 async def _get_connection_and_send_command(self, *args, **options): 

2349 redis_node, connection = self._get_client_and_connection_for_transaction() 

2350 return await self._send_command_parse_response( 

2351 connection, redis_node, args[0], *args, **options 

2352 ) 

2353 

2354 async def _send_command_parse_response( 

2355 self, 

2356 connection: Connection, 

2357 redis_node: ClusterNode, 

2358 command_name, 

2359 *args, 

2360 **options, 

2361 ): 

2362 """ 

2363 Send a command and parse the response 

2364 """ 

2365 

2366 await connection.send_command(*args) 

2367 output = await redis_node.parse_response(connection, command_name, **options) 

2368 

2369 if command_name in self.UNWATCH_COMMANDS: 

2370 self._watching = False 

2371 return output 

2372 

2373 async def _reinitialize_on_error(self, error): 

2374 if self._watching: 

2375 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: 

2376 raise WatchError("Slot rebalancing occurred while watching keys") 

2377 

2378 if ( 

2379 type(error) in self.SLOT_REDIRECT_ERRORS 

2380 or type(error) in self.CONNECTION_ERRORS 

2381 ): 

2382 if self._transaction_connection: 

2383 self._transaction_connection = None 

2384 

2385 self._pipe.cluster_client.reinitialize_counter += 1 

2386 if ( 

2387 self._pipe.cluster_client.reinitialize_steps 

2388 and self._pipe.cluster_client.reinitialize_counter 

2389 % self._pipe.cluster_client.reinitialize_steps 

2390 == 0 

2391 ): 

2392 await self._pipe.cluster_client.nodes_manager.initialize() 

2393 self.reinitialize_counter = 0 

2394 else: 

2395 if isinstance(error, AskError): 

2396 self._pipe.cluster_client.nodes_manager.update_moved_exception( 

2397 error 

2398 ) 

2399 

2400 self._executing = False 

2401 

2402 def _raise_first_error(self, responses, stack): 

2403 """ 

2404 Raise the first exception on the stack 

2405 """ 

2406 for r, cmd in zip(responses, stack): 

2407 if isinstance(r, Exception): 

2408 self._annotate_exception(r, cmd.position + 1, cmd.args) 

2409 raise r 

2410 

2411 def mset_nonatomic( 

2412 self, mapping: Mapping[AnyKeyT, EncodableT] 

2413 ) -> "ClusterPipeline": 

2414 raise NotImplementedError("Method is not supported in transactional context.") 

2415 

2416 async def execute( 

2417 self, raise_on_error: bool = True, allow_redirections: bool = True 

2418 ) -> List[Any]: 

2419 stack = self._command_queue 

2420 if not stack and (not self._watching or not self._pipeline_slots): 

2421 return [] 

2422 

2423 return await self._execute_transaction_with_retries(stack, raise_on_error) 

2424 

2425 async def _execute_transaction_with_retries( 

2426 self, stack: List["PipelineCommand"], raise_on_error: bool 

2427 ): 

2428 return await self._retry.call_with_retry( 

2429 lambda: self._execute_transaction(stack, raise_on_error), 

2430 self._reinitialize_on_error, 

2431 ) 

2432 

2433 async def _execute_transaction( 

2434 self, stack: List["PipelineCommand"], raise_on_error: bool 

2435 ): 

2436 if len(self._pipeline_slots) > 1: 

2437 raise CrossSlotTransactionError( 

2438 "All keys involved in a cluster transaction must map to the same slot" 

2439 ) 

2440 

2441 self._executing = True 

2442 

2443 redis_node, connection = self._get_client_and_connection_for_transaction() 

2444 

2445 stack = chain( 

2446 [PipelineCommand(0, "MULTI")], 

2447 stack, 

2448 [PipelineCommand(0, "EXEC")], 

2449 ) 

2450 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs] 

2451 packed_commands = connection.pack_commands(commands) 

2452 await connection.send_packed_command(packed_commands) 

2453 errors = [] 

2454 

2455 # parse off the response for MULTI 

2456 # NOTE: we need to handle ResponseErrors here and continue 

2457 # so that we read all the additional command messages from 

2458 # the socket 

2459 try: 

2460 await redis_node.parse_response(connection, "MULTI") 

2461 except ResponseError as e: 

2462 self._annotate_exception(e, 0, "MULTI") 

2463 errors.append(e) 

2464 except self.CONNECTION_ERRORS as cluster_error: 

2465 self._annotate_exception(cluster_error, 0, "MULTI") 

2466 raise 

2467 

2468 # and all the other commands 

2469 for i, command in enumerate(self._command_queue): 

2470 if EMPTY_RESPONSE in command.kwargs: 

2471 errors.append((i, command.kwargs[EMPTY_RESPONSE])) 

2472 else: 

2473 try: 

2474 _ = await redis_node.parse_response(connection, "_") 

2475 except self.SLOT_REDIRECT_ERRORS as slot_error: 

2476 self._annotate_exception(slot_error, i + 1, command.args) 

2477 errors.append(slot_error) 

2478 except self.CONNECTION_ERRORS as cluster_error: 

2479 self._annotate_exception(cluster_error, i + 1, command.args) 

2480 raise 

2481 except ResponseError as e: 

2482 self._annotate_exception(e, i + 1, command.args) 

2483 errors.append(e) 

2484 

2485 response = None 

2486 # parse the EXEC. 

2487 try: 

2488 response = await redis_node.parse_response(connection, "EXEC") 

2489 except ExecAbortError: 

2490 if errors: 

2491 raise errors[0] 

2492 raise 

2493 

2494 self._executing = False 

2495 

2496 # EXEC clears any watched keys 

2497 self._watching = False 

2498 

2499 if response is None: 

2500 raise WatchError("Watched variable changed.") 

2501 

2502 # put any parse errors into the response 

2503 for i, e in errors: 

2504 response.insert(i, e) 

2505 

2506 if len(response) != len(self._command_queue): 

2507 raise InvalidPipelineStack( 

2508 "Unexpected response length for cluster pipeline EXEC." 

2509 " Command stack was {} but response had length {}".format( 

2510 [c.args[0] for c in self._command_queue], len(response) 

2511 ) 

2512 ) 

2513 

2514 # find any errors in the response and raise if necessary 

2515 if raise_on_error or len(errors) > 0: 

2516 self._raise_first_error( 

2517 response, 

2518 self._command_queue, 

2519 ) 

2520 

2521 # We have to run response callbacks manually 

2522 data = [] 

2523 for r, cmd in zip(response, self._command_queue): 

2524 if not isinstance(r, Exception): 

2525 command_name = cmd.args[0] 

2526 if command_name in self._pipe.cluster_client.response_callbacks: 

2527 r = self._pipe.cluster_client.response_callbacks[command_name]( 

2528 r, **cmd.kwargs 

2529 ) 

2530 data.append(r) 

2531 return data 

2532 

2533 async def reset(self): 

2534 self._command_queue = [] 

2535 

2536 # make sure to reset the connection state in the event that we were 

2537 # watching something 

2538 if self._transaction_connection: 

2539 try: 

2540 if self._watching: 

2541 # call this manually since our unwatch or 

2542 # immediate_execute_command methods can call reset() 

2543 await self._transaction_connection.send_command("UNWATCH") 

2544 await self._transaction_connection.read_response() 

2545 # we can safely return the connection to the pool here since we're 

2546 # sure we're no longer WATCHing anything 

2547 self._transaction_node.release(self._transaction_connection) 

2548 self._transaction_connection = None 

2549 except self.CONNECTION_ERRORS: 

2550 # disconnect will also remove any previous WATCHes 

2551 if self._transaction_connection: 

2552 await self._transaction_connection.disconnect() 

2553 

2554 # clean up the other instance attributes 

2555 self._transaction_node = None 

2556 self._watching = False 

2557 self._explicit_transaction = False 

2558 self._pipeline_slots = set() 

2559 self._executing = False 

2560 

2561 def multi(self): 

2562 if self._explicit_transaction: 

2563 raise RedisError("Cannot issue nested calls to MULTI") 

2564 if self._command_queue: 

2565 raise RedisError( 

2566 "Commands without an initial WATCH have already been issued" 

2567 ) 

2568 self._explicit_transaction = True 

2569 

2570 async def watch(self, *names): 

2571 if self._explicit_transaction: 

2572 raise RedisError("Cannot issue a WATCH after a MULTI") 

2573 

2574 return await self.execute_command("WATCH", *names) 

2575 

2576 async def unwatch(self): 

2577 if self._watching: 

2578 return await self.execute_command("UNWATCH") 

2579 

2580 return True 

2581 

2582 async def discard(self): 

2583 await self.reset() 

2584 

2585 async def unlink(self, *names): 

2586 return self.execute_command("UNLINK", *names)