Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/cluster.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1048 statements  

1import asyncio 

2import collections 

3import random 

4import socket 

5import threading 

6import time 

7import warnings 

8from abc import ABC, abstractmethod 

9from copy import copy 

10from itertools import chain 

11from typing import ( 

12 Any, 

13 Callable, 

14 Coroutine, 

15 Deque, 

16 Dict, 

17 Generator, 

18 List, 

19 Mapping, 

20 Optional, 

21 Set, 

22 Tuple, 

23 Type, 

24 TypeVar, 

25 Union, 

26) 

27 

28from redis._parsers import AsyncCommandsParser, Encoder 

29from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy 

30from redis._parsers.helpers import ( 

31 _RedisCallbacks, 

32 _RedisCallbacksRESP2, 

33 _RedisCallbacksRESP3, 

34) 

35from redis.asyncio.client import ResponseCallbackT 

36from redis.asyncio.connection import Connection, SSLConnection, parse_url 

37from redis.asyncio.lock import Lock 

38from redis.asyncio.retry import Retry 

39from redis.auth.token import TokenInterface 

40from redis.backoff import ExponentialWithJitterBackoff, NoBackoff 

41from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis 

42from redis.cluster import ( 

43 PIPELINE_BLOCKED_COMMANDS, 

44 PRIMARY, 

45 REPLICA, 

46 SLOT_ID, 

47 AbstractRedisCluster, 

48 LoadBalancer, 

49 LoadBalancingStrategy, 

50 block_pipeline_command, 

51 get_node_name, 

52 parse_cluster_slots, 

53) 

54from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands 

55from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver 

56from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot 

57from redis.credentials import CredentialProvider 

58from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher 

59from redis.exceptions import ( 

60 AskError, 

61 BusyLoadingError, 

62 ClusterDownError, 

63 ClusterError, 

64 ConnectionError, 

65 CrossSlotTransactionError, 

66 DataError, 

67 ExecAbortError, 

68 InvalidPipelineStack, 

69 MaxConnectionsError, 

70 MovedError, 

71 RedisClusterException, 

72 RedisError, 

73 ResponseError, 

74 SlotNotCoveredError, 

75 TimeoutError, 

76 TryAgainError, 

77 WatchError, 

78) 

79from redis.typing import AnyKeyT, EncodableT, KeyT 

80from redis.utils import ( 

81 SSL_AVAILABLE, 

82 deprecated_args, 

83 deprecated_function, 

84 get_lib_version, 

85 safe_str, 

86 str_if_bytes, 

87 truncate_text, 

88) 

89 

90if SSL_AVAILABLE: 

91 from ssl import TLSVersion, VerifyFlags, VerifyMode 

92else: 

93 TLSVersion = None 

94 VerifyMode = None 

95 VerifyFlags = None 

96 

97TargetNodesT = TypeVar( 

98 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] 

99) 

100 

101 

102class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

103 """ 

104 Create a new RedisCluster client. 

105 

106 Pass one of parameters: 

107 

108 - `host` & `port` 

109 - `startup_nodes` 

110 

111 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. 

112 | Use ``await`` :meth:`close` to disconnect connections & close client. 

113 

114 Many commands support the target_nodes kwarg. It can be one of the 

115 :attr:`NODE_FLAGS`: 

116 

117 - :attr:`PRIMARIES` 

118 - :attr:`REPLICAS` 

119 - :attr:`ALL_NODES` 

120 - :attr:`RANDOM` 

121 - :attr:`DEFAULT_NODE` 

122 

123 Note: This client is not thread/process/fork safe. 

124 

125 :param host: 

126 | Can be used to point to a startup node 

127 :param port: 

128 | Port used if **host** is provided 

129 :param startup_nodes: 

130 | :class:`~.ClusterNode` to used as a startup node 

131 :param require_full_coverage: 

132 | When set to ``False``: the client will not require a full coverage of 

133 the slots. However, if not all slots are covered, and at least one node 

134 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw 

135 a :class:`~.ClusterDownError` for some key-based commands. 

136 | When set to ``True``: all slots must be covered to construct the cluster 

137 client. If not all slots are covered, :class:`~.RedisClusterException` will be 

138 thrown. 

139 | See: 

140 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters 

141 :param read_from_replicas: 

142 | @deprecated - please use load_balancing_strategy instead 

143 | Enable read from replicas in READONLY mode. 

144 When set to true, read commands will be assigned between the primary and 

145 its replications in a Round-Robin manner. 

146 The data read from replicas is eventually consistent with the data in primary nodes. 

147 :param load_balancing_strategy: 

148 | Enable read from replicas in READONLY mode and defines the load balancing 

149 strategy that will be used for cluster node selection. 

150 The data read from replicas is eventually consistent with the data in primary nodes. 

151 :param dynamic_startup_nodes: 

152 | Set the RedisCluster's startup nodes to all the discovered nodes. 

153 If true (default value), the cluster's discovered nodes will be used to 

154 determine the cluster nodes-slots mapping in the next topology refresh. 

155 It will remove the initial passed startup nodes if their endpoints aren't 

156 listed in the CLUSTER SLOTS output. 

157 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists 

158 specific IP addresses, it is best to set it to false. 

159 :param reinitialize_steps: 

160 | Specifies the number of MOVED errors that need to occur before reinitializing 

161 the whole cluster topology. If a MOVED error occurs and the cluster does not 

162 need to be reinitialized on this current error handling, only the MOVED slot 

163 will be patched with the redirected node. 

164 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. 

165 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 

166 0. 

167 :param cluster_error_retry_attempts: 

168 | @deprecated - Please configure the 'retry' object instead 

169 In case 'retry' object is set - this argument is ignored! 

170 

171 Number of times to retry before raising an error when :class:`~.TimeoutError`, 

172 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` 

173 or :class:`~.ClusterDownError` are encountered 

174 :param retry: 

175 | A retry object that defines the retry strategy and the number of 

176 retries for the cluster client. 

177 In current implementation for the cluster client (starting form redis-py version 6.0.0) 

178 the retry object is not yet fully utilized, instead it is used just to determine 

179 the number of retries for the cluster client. 

180 In the future releases the retry object will be used to handle the cluster client retries! 

181 :param max_connections: 

182 | Maximum number of connections per node. If there are no free connections & the 

183 maximum number of connections are already created, a 

184 :class:`~.MaxConnectionsError` is raised. 

185 :param address_remap: 

186 | An optional callable which, when provided with an internal network 

187 address of a node, e.g. a `(host, port)` tuple, will return the address 

188 where the node is reachable. This can be used to map the addresses at 

189 which the nodes _think_ they are, to addresses at which a client may 

190 reach them, such as when they sit behind a proxy. 

191 

192 | Rest of the arguments will be passed to the 

193 :class:`~redis.asyncio.connection.Connection` instances when created 

194 

195 :raises RedisClusterException: 

196 if any arguments are invalid or unknown. Eg: 

197 

198 - `db` != 0 or None 

199 - `path` argument for unix socket connection 

200 - none of the `host`/`port` & `startup_nodes` were provided 

201 

202 """ 

203 

204 @classmethod 

205 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": 

206 """ 

207 Return a Redis client object configured from the given URL. 

208 

209 For example:: 

210 

211 redis://[[username]:[password]]@localhost:6379/0 

212 rediss://[[username]:[password]]@localhost:6379/0 

213 

214 Three URL schemes are supported: 

215 

216 - `redis://` creates a TCP socket connection. See more at: 

217 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

218 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

219 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

220 

221 The username, password, hostname, path and all querystring values are passed 

222 through ``urllib.parse.unquote`` in order to replace any percent-encoded values 

223 with their corresponding characters. 

224 

225 All querystring options are cast to their appropriate Python types. Boolean 

226 arguments can be specified with string values "True"/"False" or "Yes"/"No". 

227 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once 

228 parsed, the querystring arguments and keyword arguments are passed to 

229 :class:`~redis.asyncio.connection.Connection` when created. 

230 In the case of conflicting arguments, querystring arguments are used. 

231 """ 

232 kwargs.update(parse_url(url)) 

233 if kwargs.pop("connection_class", None) is SSLConnection: 

234 kwargs["ssl"] = True 

235 return cls(**kwargs) 

236 

237 __slots__ = ( 

238 "_initialize", 

239 "_lock", 

240 "retry", 

241 "command_flags", 

242 "commands_parser", 

243 "connection_kwargs", 

244 "encoder", 

245 "node_flags", 

246 "nodes_manager", 

247 "read_from_replicas", 

248 "reinitialize_counter", 

249 "reinitialize_steps", 

250 "response_callbacks", 

251 "result_callbacks", 

252 ) 

253 

254 @deprecated_args( 

255 args_to_warn=["read_from_replicas"], 

256 reason="Please configure the 'load_balancing_strategy' instead", 

257 version="5.3.0", 

258 ) 

259 @deprecated_args( 

260 args_to_warn=[ 

261 "cluster_error_retry_attempts", 

262 ], 

263 reason="Please configure the 'retry' object instead", 

264 version="6.0.0", 

265 ) 

266 def __init__( 

267 self, 

268 host: Optional[str] = None, 

269 port: Union[str, int] = 6379, 

270 # Cluster related kwargs 

271 startup_nodes: Optional[List["ClusterNode"]] = None, 

272 require_full_coverage: bool = True, 

273 read_from_replicas: bool = False, 

274 load_balancing_strategy: Optional[LoadBalancingStrategy] = None, 

275 dynamic_startup_nodes: bool = True, 

276 reinitialize_steps: int = 5, 

277 cluster_error_retry_attempts: int = 3, 

278 max_connections: int = 2**31, 

279 retry: Optional["Retry"] = None, 

280 retry_on_error: Optional[List[Type[Exception]]] = None, 

281 # Client related kwargs 

282 db: Union[str, int] = 0, 

283 path: Optional[str] = None, 

284 credential_provider: Optional[CredentialProvider] = None, 

285 username: Optional[str] = None, 

286 password: Optional[str] = None, 

287 client_name: Optional[str] = None, 

288 lib_name: Optional[str] = "redis-py", 

289 lib_version: Optional[str] = get_lib_version(), 

290 # Encoding related kwargs 

291 encoding: str = "utf-8", 

292 encoding_errors: str = "strict", 

293 decode_responses: bool = False, 

294 # Connection related kwargs 

295 health_check_interval: float = 0, 

296 socket_connect_timeout: Optional[float] = None, 

297 socket_keepalive: bool = False, 

298 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

299 socket_timeout: Optional[float] = None, 

300 # SSL related kwargs 

301 ssl: bool = False, 

302 ssl_ca_certs: Optional[str] = None, 

303 ssl_ca_data: Optional[str] = None, 

304 ssl_cert_reqs: Union[str, VerifyMode] = "required", 

305 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, 

306 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, 

307 ssl_certfile: Optional[str] = None, 

308 ssl_check_hostname: bool = True, 

309 ssl_keyfile: Optional[str] = None, 

310 ssl_min_version: Optional[TLSVersion] = None, 

311 ssl_ciphers: Optional[str] = None, 

312 protocol: Optional[int] = 2, 

313 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

314 event_dispatcher: Optional[EventDispatcher] = None, 

315 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(), 

316 ) -> None: 

317 if db: 

318 raise RedisClusterException( 

319 "Argument 'db' must be 0 or None in cluster mode" 

320 ) 

321 

322 if path: 

323 raise RedisClusterException( 

324 "Unix domain socket is not supported in cluster mode" 

325 ) 

326 

327 if (not host or not port) and not startup_nodes: 

328 raise RedisClusterException( 

329 "RedisCluster requires at least one node to discover the cluster.\n" 

330 "Please provide one of the following or use RedisCluster.from_url:\n" 

331 ' - host and port: RedisCluster(host="localhost", port=6379)\n' 

332 " - startup_nodes: RedisCluster(startup_nodes=[" 

333 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' 

334 ) 

335 

336 kwargs: Dict[str, Any] = { 

337 "max_connections": max_connections, 

338 "connection_class": Connection, 

339 # Client related kwargs 

340 "credential_provider": credential_provider, 

341 "username": username, 

342 "password": password, 

343 "client_name": client_name, 

344 "lib_name": lib_name, 

345 "lib_version": lib_version, 

346 # Encoding related kwargs 

347 "encoding": encoding, 

348 "encoding_errors": encoding_errors, 

349 "decode_responses": decode_responses, 

350 # Connection related kwargs 

351 "health_check_interval": health_check_interval, 

352 "socket_connect_timeout": socket_connect_timeout, 

353 "socket_keepalive": socket_keepalive, 

354 "socket_keepalive_options": socket_keepalive_options, 

355 "socket_timeout": socket_timeout, 

356 "protocol": protocol, 

357 } 

358 

359 if ssl: 

360 # SSL related kwargs 

361 kwargs.update( 

362 { 

363 "connection_class": SSLConnection, 

364 "ssl_ca_certs": ssl_ca_certs, 

365 "ssl_ca_data": ssl_ca_data, 

366 "ssl_cert_reqs": ssl_cert_reqs, 

367 "ssl_include_verify_flags": ssl_include_verify_flags, 

368 "ssl_exclude_verify_flags": ssl_exclude_verify_flags, 

369 "ssl_certfile": ssl_certfile, 

370 "ssl_check_hostname": ssl_check_hostname, 

371 "ssl_keyfile": ssl_keyfile, 

372 "ssl_min_version": ssl_min_version, 

373 "ssl_ciphers": ssl_ciphers, 

374 } 

375 ) 

376 

377 if read_from_replicas or load_balancing_strategy: 

378 # Call our on_connect function to configure READONLY mode 

379 kwargs["redis_connect_func"] = self.on_connect 

380 

381 if retry: 

382 self.retry = retry 

383 else: 

384 self.retry = Retry( 

385 backoff=ExponentialWithJitterBackoff(base=1, cap=10), 

386 retries=cluster_error_retry_attempts, 

387 ) 

388 if retry_on_error: 

389 self.retry.update_supported_errors(retry_on_error) 

390 

391 kwargs["response_callbacks"] = _RedisCallbacks.copy() 

392 if kwargs.get("protocol") in ["3", 3]: 

393 kwargs["response_callbacks"].update(_RedisCallbacksRESP3) 

394 else: 

395 kwargs["response_callbacks"].update(_RedisCallbacksRESP2) 

396 self.connection_kwargs = kwargs 

397 

398 if startup_nodes: 

399 passed_nodes = [] 

400 for node in startup_nodes: 

401 passed_nodes.append( 

402 ClusterNode(node.host, node.port, **self.connection_kwargs) 

403 ) 

404 startup_nodes = passed_nodes 

405 else: 

406 startup_nodes = [] 

407 if host and port: 

408 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) 

409 

410 if event_dispatcher is None: 

411 self._event_dispatcher = EventDispatcher() 

412 else: 

413 self._event_dispatcher = event_dispatcher 

414 

415 self.startup_nodes = startup_nodes 

416 self.nodes_manager = NodesManager( 

417 startup_nodes, 

418 require_full_coverage, 

419 kwargs, 

420 dynamic_startup_nodes=dynamic_startup_nodes, 

421 address_remap=address_remap, 

422 event_dispatcher=self._event_dispatcher, 

423 ) 

424 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

425 self.read_from_replicas = read_from_replicas 

426 self.load_balancing_strategy = load_balancing_strategy 

427 self.reinitialize_steps = reinitialize_steps 

428 self.reinitialize_counter = 0 

429 

430 # For backward compatibility, mapping from existing policies to new one 

431 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { 

432 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, 

433 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, 

434 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, 

435 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, 

436 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, 

437 SLOT_ID: RequestPolicy.DEFAULT_KEYED, 

438 } 

439 

440 self._policies_callback_mapping: dict[ 

441 Union[RequestPolicy, ResponsePolicy], Callable 

442 ] = { 

443 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ 

444 self.get_random_primary_or_all_nodes(command_name) 

445 ], 

446 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, 

447 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], 

448 RequestPolicy.ALL_SHARDS: self.get_primaries, 

449 RequestPolicy.ALL_NODES: self.get_nodes, 

450 RequestPolicy.ALL_REPLICAS: self.get_replicas, 

451 RequestPolicy.SPECIAL: self.get_special_nodes, 

452 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, 

453 ResponsePolicy.DEFAULT_KEYED: lambda res: res, 

454 } 

455 

456 self._policy_resolver = policy_resolver 

457 self.commands_parser = AsyncCommandsParser() 

458 self._aggregate_nodes = None 

459 self.node_flags = self.__class__.NODE_FLAGS.copy() 

460 self.command_flags = self.__class__.COMMAND_FLAGS.copy() 

461 self.response_callbacks = kwargs["response_callbacks"] 

462 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() 

463 self.result_callbacks["CLUSTER SLOTS"] = ( 

464 lambda cmd, res, **kwargs: parse_cluster_slots( 

465 list(res.values())[0], **kwargs 

466 ) 

467 ) 

468 

469 self._initialize = True 

470 self._lock: Optional[asyncio.Lock] = None 

471 

472 # When used as an async context manager, we need to increment and decrement 

473 # a usage counter so that we can close the connection pool when no one is 

474 # using the client. 

475 self._usage_counter = 0 

476 self._usage_lock = asyncio.Lock() 

477 

478 async def initialize(self) -> "RedisCluster": 

479 """Get all nodes from startup nodes & creates connections if not initialized.""" 

480 if self._initialize: 

481 if not self._lock: 

482 self._lock = asyncio.Lock() 

483 async with self._lock: 

484 if self._initialize: 

485 try: 

486 await self.nodes_manager.initialize() 

487 await self.commands_parser.initialize( 

488 self.nodes_manager.default_node 

489 ) 

490 self._initialize = False 

491 except BaseException: 

492 await self.nodes_manager.aclose() 

493 await self.nodes_manager.aclose("startup_nodes") 

494 raise 

495 return self 

496 

497 async def aclose(self) -> None: 

498 """Close all connections & client if initialized.""" 

499 if not self._initialize: 

500 if not self._lock: 

501 self._lock = asyncio.Lock() 

502 async with self._lock: 

503 if not self._initialize: 

504 self._initialize = True 

505 await self.nodes_manager.aclose() 

506 await self.nodes_manager.aclose("startup_nodes") 

507 

508 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") 

509 async def close(self) -> None: 

510 """alias for aclose() for backwards compatibility""" 

511 await self.aclose() 

512 

513 async def __aenter__(self) -> "RedisCluster": 

514 """ 

515 Async context manager entry. Increments a usage counter so that the 

516 connection pool is only closed (via aclose()) when no context is using 

517 the client. 

518 """ 

519 await self._increment_usage() 

520 try: 

521 # Initialize the client (i.e. establish connection, etc.) 

522 return await self.initialize() 

523 except Exception: 

524 # If initialization fails, decrement the counter to keep it in sync 

525 await self._decrement_usage() 

526 raise 

527 

528 async def _increment_usage(self) -> int: 

529 """ 

530 Helper coroutine to increment the usage counter while holding the lock. 

531 Returns the new value of the usage counter. 

532 """ 

533 async with self._usage_lock: 

534 self._usage_counter += 1 

535 return self._usage_counter 

536 

537 async def _decrement_usage(self) -> int: 

538 """ 

539 Helper coroutine to decrement the usage counter while holding the lock. 

540 Returns the new value of the usage counter. 

541 """ 

542 async with self._usage_lock: 

543 self._usage_counter -= 1 

544 return self._usage_counter 

545 

546 async def __aexit__(self, exc_type, exc_value, traceback): 

547 """ 

548 Async context manager exit. Decrements a usage counter. If this is the 

549 last exit (counter becomes zero), the client closes its connection pool. 

550 """ 

551 current_usage = await asyncio.shield(self._decrement_usage()) 

552 if current_usage == 0: 

553 # This was the last active context, so disconnect the pool. 

554 await asyncio.shield(self.aclose()) 

555 

556 def __await__(self) -> Generator[Any, None, "RedisCluster"]: 

557 return self.initialize().__await__() 

558 

559 _DEL_MESSAGE = "Unclosed RedisCluster client" 

560 

561 def __del__( 

562 self, 

563 _warn: Any = warnings.warn, 

564 _grl: Any = asyncio.get_running_loop, 

565 ) -> None: 

566 if hasattr(self, "_initialize") and not self._initialize: 

567 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

568 try: 

569 context = {"client": self, "message": self._DEL_MESSAGE} 

570 _grl().call_exception_handler(context) 

571 except RuntimeError: 

572 pass 

573 

574 async def on_connect(self, connection: Connection) -> None: 

575 await connection.on_connect() 

576 

577 # Sending READONLY command to server to configure connection as 

578 # readonly. Since each cluster node may change its server type due 

579 # to a failover, we should establish a READONLY connection 

580 # regardless of the server type. If this is a primary connection, 

581 # READONLY would not affect executing write commands. 

582 await connection.send_command("READONLY") 

583 if str_if_bytes(await connection.read_response()) != "OK": 

584 raise ConnectionError("READONLY command failed") 

585 

586 def get_nodes(self) -> List["ClusterNode"]: 

587 """Get all nodes of the cluster.""" 

588 return list(self.nodes_manager.nodes_cache.values()) 

589 

590 def get_primaries(self) -> List["ClusterNode"]: 

591 """Get the primary nodes of the cluster.""" 

592 return self.nodes_manager.get_nodes_by_server_type(PRIMARY) 

593 

594 def get_replicas(self) -> List["ClusterNode"]: 

595 """Get the replica nodes of the cluster.""" 

596 return self.nodes_manager.get_nodes_by_server_type(REPLICA) 

597 

598 def get_random_node(self) -> "ClusterNode": 

599 """Get a random node of the cluster.""" 

600 return random.choice(list(self.nodes_manager.nodes_cache.values())) 

601 

602 def get_default_node(self) -> "ClusterNode": 

603 """Get the default node of the client.""" 

604 return self.nodes_manager.default_node 

605 

606 def set_default_node(self, node: "ClusterNode") -> None: 

607 """ 

608 Set the default node of the client. 

609 

610 :raises DataError: if None is passed or node does not exist in cluster. 

611 """ 

612 if not node or not self.get_node(node_name=node.name): 

613 raise DataError("The requested node does not exist in the cluster.") 

614 

615 self.nodes_manager.default_node = node 

616 

617 def get_node( 

618 self, 

619 host: Optional[str] = None, 

620 port: Optional[int] = None, 

621 node_name: Optional[str] = None, 

622 ) -> Optional["ClusterNode"]: 

623 """Get node by (host, port) or node_name.""" 

624 return self.nodes_manager.get_node(host, port, node_name) 

625 

626 def get_node_from_key( 

627 self, key: str, replica: bool = False 

628 ) -> Optional["ClusterNode"]: 

629 """ 

630 Get the cluster node corresponding to the provided key. 

631 

632 :param key: 

633 :param replica: 

634 | Indicates if a replica should be returned 

635 | 

636 None will returned if no replica holds this key 

637 

638 :raises SlotNotCoveredError: if the key is not covered by any slot. 

639 """ 

640 slot = self.keyslot(key) 

641 slot_cache = self.nodes_manager.slots_cache.get(slot) 

642 if not slot_cache: 

643 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') 

644 

645 if replica: 

646 if len(self.nodes_manager.slots_cache[slot]) < 2: 

647 return None 

648 node_idx = 1 

649 else: 

650 node_idx = 0 

651 

652 return slot_cache[node_idx] 

653 

654 def get_random_primary_or_all_nodes(self, command_name): 

655 """ 

656 Returns random primary or all nodes depends on READONLY mode. 

657 """ 

658 if self.read_from_replicas and command_name in READ_COMMANDS: 

659 return self.get_random_node() 

660 

661 return self.get_random_primary_node() 

662 

663 def get_random_primary_node(self) -> "ClusterNode": 

664 """ 

665 Returns a random primary node 

666 """ 

667 return random.choice(self.get_primaries()) 

668 

669 async def get_nodes_from_slot(self, command: str, *args): 

670 """ 

671 Returns a list of nodes that hold the specified keys' slots. 

672 """ 

673 # get the node that holds the key's slot 

674 return [ 

675 self.nodes_manager.get_node_from_slot( 

676 await self._determine_slot(command, *args), 

677 self.read_from_replicas and command in READ_COMMANDS, 

678 self.load_balancing_strategy if command in READ_COMMANDS else None, 

679 ) 

680 ] 

681 

682 def get_special_nodes(self) -> Optional[list["ClusterNode"]]: 

683 """ 

684 Returns a list of nodes for commands with a special policy. 

685 """ 

686 if not self._aggregate_nodes: 

687 raise RedisClusterException( 

688 "Cannot execute FT.CURSOR commands without FT.AGGREGATE" 

689 ) 

690 

691 return self._aggregate_nodes 

692 

693 def keyslot(self, key: EncodableT) -> int: 

694 """ 

695 Find the keyslot for a given key. 

696 

697 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding 

698 """ 

699 return key_slot(self.encoder.encode(key)) 

700 

701 def get_encoder(self) -> Encoder: 

702 """Get the encoder object of the client.""" 

703 return self.encoder 

704 

705 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: 

706 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" 

707 return self.connection_kwargs 

708 

709 def set_retry(self, retry: Retry) -> None: 

710 self.retry = retry 

711 

712 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: 

713 """Set a custom response callback.""" 

714 self.response_callbacks[command] = callback 

715 

716 async def _determine_nodes( 

717 self, 

718 command: str, 

719 *args: Any, 

720 request_policy: RequestPolicy, 

721 node_flag: Optional[str] = None, 

722 ) -> List["ClusterNode"]: 

723 # Determine which nodes should be executed the command on. 

724 # Returns a list of target nodes. 

725 if not node_flag: 

726 # get the nodes group for this command if it was predefined 

727 node_flag = self.command_flags.get(command) 

728 

729 if node_flag in self._command_flags_mapping: 

730 request_policy = self._command_flags_mapping[node_flag] 

731 

732 policy_callback = self._policies_callback_mapping[request_policy] 

733 

734 if request_policy == RequestPolicy.DEFAULT_KEYED: 

735 nodes = await policy_callback(command, *args) 

736 elif request_policy == RequestPolicy.DEFAULT_KEYLESS: 

737 nodes = policy_callback(command) 

738 else: 

739 nodes = policy_callback() 

740 

741 if command.lower() == "ft.aggregate": 

742 self._aggregate_nodes = nodes 

743 

744 return nodes 

745 

746 async def _determine_slot(self, command: str, *args: Any) -> int: 

747 if self.command_flags.get(command) == SLOT_ID: 

748 # The command contains the slot ID 

749 return int(args[0]) 

750 

751 # Get the keys in the command 

752 

753 # EVAL and EVALSHA are common enough that it's wasteful to go to the 

754 # redis server to parse the keys. Besides, there is a bug in redis<7.0 

755 # where `self._get_command_keys()` fails anyway. So, we special case 

756 # EVAL/EVALSHA. 

757 # - issue: https://github.com/redis/redis/issues/9493 

758 # - fix: https://github.com/redis/redis/pull/9733 

759 if command.upper() in ("EVAL", "EVALSHA"): 

760 # command syntax: EVAL "script body" num_keys ... 

761 if len(args) < 2: 

762 raise RedisClusterException( 

763 f"Invalid args in command: {command, *args}" 

764 ) 

765 keys = args[2 : 2 + int(args[1])] 

766 # if there are 0 keys, that means the script can be run on any node 

767 # so we can just return a random slot 

768 if not keys: 

769 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

770 else: 

771 keys = await self.commands_parser.get_keys(command, *args) 

772 if not keys: 

773 # FCALL can call a function with 0 keys, that means the function 

774 # can be run on any node so we can just return a random slot 

775 if command.upper() in ("FCALL", "FCALL_RO"): 

776 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) 

777 raise RedisClusterException( 

778 "No way to dispatch this command to Redis Cluster. " 

779 "Missing key.\nYou can execute the command by specifying " 

780 f"target nodes.\nCommand: {args}" 

781 ) 

782 

783 # single key command 

784 if len(keys) == 1: 

785 return self.keyslot(keys[0]) 

786 

787 # multi-key command; we need to make sure all keys are mapped to 

788 # the same slot 

789 slots = {self.keyslot(key) for key in keys} 

790 if len(slots) != 1: 

791 raise RedisClusterException( 

792 f"{command} - all keys must map to the same key slot" 

793 ) 

794 

795 return slots.pop() 

796 

797 def _is_node_flag(self, target_nodes: Any) -> bool: 

798 return isinstance(target_nodes, str) and target_nodes in self.node_flags 

799 

800 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: 

801 if isinstance(target_nodes, list): 

802 nodes = target_nodes 

803 elif isinstance(target_nodes, ClusterNode): 

804 # Supports passing a single ClusterNode as a variable 

805 nodes = [target_nodes] 

806 elif isinstance(target_nodes, dict): 

807 # Supports dictionaries of the format {node_name: node}. 

808 # It enables to execute commands with multi nodes as follows: 

809 # rc.cluster_save_config(rc.get_primaries()) 

810 nodes = list(target_nodes.values()) 

811 else: 

812 raise TypeError( 

813 "target_nodes type can be one of the following: " 

814 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," 

815 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. " 

816 f"The passed type is {type(target_nodes)}" 

817 ) 

818 return nodes 

819 

820 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: 

821 """ 

822 Execute a raw command on the appropriate cluster node or target_nodes. 

823 

824 It will retry the command as specified by the retries property of 

825 the :attr:`retry` & then raise an exception. 

826 

827 :param args: 

828 | Raw command args 

829 :param kwargs: 

830 

831 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

832 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

833 - Rest of the kwargs are passed to the Redis connection 

834 

835 :raises RedisClusterException: if target_nodes is not provided & the command 

836 can't be mapped to a slot 

837 """ 

838 command = args[0] 

839 target_nodes = [] 

840 target_nodes_specified = False 

841 retry_attempts = self.retry.get_retries() 

842 

843 passed_targets = kwargs.pop("target_nodes", None) 

844 if passed_targets and not self._is_node_flag(passed_targets): 

845 target_nodes = self._parse_target_nodes(passed_targets) 

846 target_nodes_specified = True 

847 retry_attempts = 0 

848 

849 command_policies = await self._policy_resolver.resolve(args[0].lower()) 

850 

851 if not command_policies and not target_nodes_specified: 

852 command_flag = self.command_flags.get(command) 

853 if not command_flag: 

854 # Fallback to default policy 

855 if not self.get_default_node(): 

856 slot = None 

857 else: 

858 slot = await self._determine_slot(*args) 

859 if not slot: 

860 command_policies = CommandPolicies() 

861 else: 

862 command_policies = CommandPolicies( 

863 request_policy=RequestPolicy.DEFAULT_KEYED, 

864 response_policy=ResponsePolicy.DEFAULT_KEYED, 

865 ) 

866 else: 

867 if command_flag in self._command_flags_mapping: 

868 command_policies = CommandPolicies( 

869 request_policy=self._command_flags_mapping[command_flag] 

870 ) 

871 else: 

872 command_policies = CommandPolicies() 

873 elif not command_policies and target_nodes_specified: 

874 command_policies = CommandPolicies() 

875 

876 # Add one for the first execution 

877 execute_attempts = 1 + retry_attempts 

878 for _ in range(execute_attempts): 

879 if self._initialize: 

880 await self.initialize() 

881 if ( 

882 len(target_nodes) == 1 

883 and target_nodes[0] == self.get_default_node() 

884 ): 

885 # Replace the default cluster node 

886 self.replace_default_node() 

887 try: 

888 if not target_nodes_specified: 

889 # Determine the nodes to execute the command on 

890 target_nodes = await self._determine_nodes( 

891 *args, 

892 request_policy=command_policies.request_policy, 

893 node_flag=passed_targets, 

894 ) 

895 if not target_nodes: 

896 raise RedisClusterException( 

897 f"No targets were found to execute {args} command on" 

898 ) 

899 

900 if len(target_nodes) == 1: 

901 # Return the processed result 

902 ret = await self._execute_command(target_nodes[0], *args, **kwargs) 

903 if command in self.result_callbacks: 

904 ret = self.result_callbacks[command]( 

905 command, {target_nodes[0].name: ret}, **kwargs 

906 ) 

907 return self._policies_callback_mapping[ 

908 command_policies.response_policy 

909 ](ret) 

910 else: 

911 keys = [node.name for node in target_nodes] 

912 values = await asyncio.gather( 

913 *( 

914 asyncio.create_task( 

915 self._execute_command(node, *args, **kwargs) 

916 ) 

917 for node in target_nodes 

918 ) 

919 ) 

920 if command in self.result_callbacks: 

921 return self.result_callbacks[command]( 

922 command, dict(zip(keys, values)), **kwargs 

923 ) 

924 return self._policies_callback_mapping[ 

925 command_policies.response_policy 

926 ](dict(zip(keys, values))) 

927 except Exception as e: 

928 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: 

929 # The nodes and slots cache were should be reinitialized. 

930 # Try again with the new cluster setup. 

931 retry_attempts -= 1 

932 continue 

933 else: 

934 # raise the exception 

935 raise e 

936 

937 async def _execute_command( 

938 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any 

939 ) -> Any: 

940 asking = moved = False 

941 redirect_addr = None 

942 ttl = self.RedisClusterRequestTTL 

943 

944 while ttl > 0: 

945 ttl -= 1 

946 try: 

947 if asking: 

948 target_node = self.get_node(node_name=redirect_addr) 

949 await target_node.execute_command("ASKING") 

950 asking = False 

951 elif moved: 

952 # MOVED occurred and the slots cache was updated, 

953 # refresh the target node 

954 slot = await self._determine_slot(*args) 

955 target_node = self.nodes_manager.get_node_from_slot( 

956 slot, 

957 self.read_from_replicas and args[0] in READ_COMMANDS, 

958 self.load_balancing_strategy 

959 if args[0] in READ_COMMANDS 

960 else None, 

961 ) 

962 moved = False 

963 

964 return await target_node.execute_command(*args, **kwargs) 

965 except BusyLoadingError: 

966 raise 

967 except MaxConnectionsError: 

968 # MaxConnectionsError indicates client-side resource exhaustion 

969 # (too many connections in the pool), not a node failure. 

970 # Don't treat this as a node failure - just re-raise the error 

971 # without reinitializing the cluster. 

972 raise 

973 except (ConnectionError, TimeoutError): 

974 # Connection retries are being handled in the node's 

975 # Retry object. 

976 # Remove the failed node from the startup nodes before we try 

977 # to reinitialize the cluster 

978 self.nodes_manager.startup_nodes.pop(target_node.name, None) 

979 # Hard force of reinitialize of the node/slots setup 

980 # and try again with the new setup 

981 await self.aclose() 

982 raise 

983 except (ClusterDownError, SlotNotCoveredError): 

984 # ClusterDownError can occur during a failover and to get 

985 # self-healed, we will try to reinitialize the cluster layout 

986 # and retry executing the command 

987 

988 # SlotNotCoveredError can occur when the cluster is not fully 

989 # initialized or can be temporary issue. 

990 # We will try to reinitialize the cluster topology 

991 # and retry executing the command 

992 

993 await self.aclose() 

994 await asyncio.sleep(0.25) 

995 raise 

996 except MovedError as e: 

997 # First, we will try to patch the slots/nodes cache with the 

998 # redirected node output and try again. If MovedError exceeds 

999 # 'reinitialize_steps' number of times, we will force 

1000 # reinitializing the tables, and then try again. 

1001 # 'reinitialize_steps' counter will increase faster when 

1002 # the same client object is shared between multiple threads. To 

1003 # reduce the frequency you can set this variable in the 

1004 # RedisCluster constructor. 

1005 self.reinitialize_counter += 1 

1006 if ( 

1007 self.reinitialize_steps 

1008 and self.reinitialize_counter % self.reinitialize_steps == 0 

1009 ): 

1010 await self.aclose() 

1011 # Reset the counter 

1012 self.reinitialize_counter = 0 

1013 else: 

1014 self.nodes_manager._moved_exception = e 

1015 moved = True 

1016 except AskError as e: 

1017 redirect_addr = get_node_name(host=e.host, port=e.port) 

1018 asking = True 

1019 except TryAgainError: 

1020 if ttl < self.RedisClusterRequestTTL / 2: 

1021 await asyncio.sleep(0.05) 

1022 

1023 raise ClusterError("TTL exhausted.") 

1024 

1025 def pipeline( 

1026 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None 

1027 ) -> "ClusterPipeline": 

1028 """ 

1029 Create & return a new :class:`~.ClusterPipeline` object. 

1030 

1031 Cluster implementation of pipeline does not support transaction or shard_hint. 

1032 

1033 :raises RedisClusterException: if transaction or shard_hint are truthy values 

1034 """ 

1035 if shard_hint: 

1036 raise RedisClusterException("shard_hint is deprecated in cluster mode") 

1037 

1038 return ClusterPipeline(self, transaction) 

1039 

1040 def lock( 

1041 self, 

1042 name: KeyT, 

1043 timeout: Optional[float] = None, 

1044 sleep: float = 0.1, 

1045 blocking: bool = True, 

1046 blocking_timeout: Optional[float] = None, 

1047 lock_class: Optional[Type[Lock]] = None, 

1048 thread_local: bool = True, 

1049 raise_on_release_error: bool = True, 

1050 ) -> Lock: 

1051 """ 

1052 Return a new Lock object using key ``name`` that mimics 

1053 the behavior of threading.Lock. 

1054 

1055 If specified, ``timeout`` indicates a maximum life for the lock. 

1056 By default, it will remain locked until release() is called. 

1057 

1058 ``sleep`` indicates the amount of time to sleep per loop iteration 

1059 when the lock is in blocking mode and another client is currently 

1060 holding the lock. 

1061 

1062 ``blocking`` indicates whether calling ``acquire`` should block until 

1063 the lock has been acquired or to fail immediately, causing ``acquire`` 

1064 to return False and the lock not being acquired. Defaults to True. 

1065 Note this value can be overridden by passing a ``blocking`` 

1066 argument to ``acquire``. 

1067 

1068 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

1069 spend trying to acquire the lock. A value of ``None`` indicates 

1070 continue trying forever. ``blocking_timeout`` can be specified as a 

1071 float or integer, both representing the number of seconds to wait. 

1072 

1073 ``lock_class`` forces the specified lock implementation. Note that as 

1074 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is 

1075 a Lua-based lock). So, it's unlikely you'll need this parameter, unless 

1076 you have created your own custom lock class. 

1077 

1078 ``thread_local`` indicates whether the lock token is placed in 

1079 thread-local storage. By default, the token is placed in thread local 

1080 storage so that a thread only sees its token, not a token set by 

1081 another thread. Consider the following timeline: 

1082 

1083 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

1084 thread-1 sets the token to "abc" 

1085 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

1086 Lock instance. 

1087 time: 5, thread-1 has not yet completed. redis expires the lock 

1088 key. 

1089 time: 5, thread-2 acquired `my-lock` now that it's available. 

1090 thread-2 sets the token to "xyz" 

1091 time: 6, thread-1 finishes its work and calls release(). if the 

1092 token is *not* stored in thread local storage, then 

1093 thread-1 would see the token value as "xyz" and would be 

1094 able to successfully release the thread-2's lock. 

1095 

1096 ``raise_on_release_error`` indicates whether to raise an exception when 

1097 the lock is no longer owned when exiting the context manager. By default, 

1098 this is True, meaning an exception will be raised. If False, the warning 

1099 will be logged and the exception will be suppressed. 

1100 

1101 In some use cases it's necessary to disable thread local storage. For 

1102 example, if you have code where one thread acquires a lock and passes 

1103 that lock instance to a worker thread to release later. If thread 

1104 local storage isn't disabled in this case, the worker thread won't see 

1105 the token set by the thread that acquired the lock. Our assumption 

1106 is that these cases aren't common and as such default to using 

1107 thread local storage.""" 

1108 if lock_class is None: 

1109 lock_class = Lock 

1110 return lock_class( 

1111 self, 

1112 name, 

1113 timeout=timeout, 

1114 sleep=sleep, 

1115 blocking=blocking, 

1116 blocking_timeout=blocking_timeout, 

1117 thread_local=thread_local, 

1118 raise_on_release_error=raise_on_release_error, 

1119 ) 

1120 

1121 async def transaction( 

1122 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs 

1123 ): 

1124 """ 

1125 Convenience method for executing the callable `func` as a transaction 

1126 while watching all keys specified in `watches`. The 'func' callable 

1127 should expect a single argument which is a Pipeline object. 

1128 """ 

1129 shard_hint = kwargs.pop("shard_hint", None) 

1130 value_from_callable = kwargs.pop("value_from_callable", False) 

1131 watch_delay = kwargs.pop("watch_delay", None) 

1132 async with self.pipeline(True, shard_hint) as pipe: 

1133 while True: 

1134 try: 

1135 if watches: 

1136 await pipe.watch(*watches) 

1137 func_value = await func(pipe) 

1138 exec_value = await pipe.execute() 

1139 return func_value if value_from_callable else exec_value 

1140 except WatchError: 

1141 if watch_delay is not None and watch_delay > 0: 

1142 time.sleep(watch_delay) 

1143 continue 

1144 

1145 

1146class ClusterNode: 

1147 """ 

1148 Create a new ClusterNode. 

1149 

1150 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` 

1151 objects for the (host, port). 

1152 """ 

1153 

1154 __slots__ = ( 

1155 "_connections", 

1156 "_free", 

1157 "_lock", 

1158 "_event_dispatcher", 

1159 "connection_class", 

1160 "connection_kwargs", 

1161 "host", 

1162 "max_connections", 

1163 "name", 

1164 "port", 

1165 "response_callbacks", 

1166 "server_type", 

1167 ) 

1168 

1169 def __init__( 

1170 self, 

1171 host: str, 

1172 port: Union[str, int], 

1173 server_type: Optional[str] = None, 

1174 *, 

1175 max_connections: int = 2**31, 

1176 connection_class: Type[Connection] = Connection, 

1177 **connection_kwargs: Any, 

1178 ) -> None: 

1179 if host == "localhost": 

1180 host = socket.gethostbyname(host) 

1181 

1182 connection_kwargs["host"] = host 

1183 connection_kwargs["port"] = port 

1184 self.host = host 

1185 self.port = port 

1186 self.name = get_node_name(host, port) 

1187 self.server_type = server_type 

1188 

1189 self.max_connections = max_connections 

1190 self.connection_class = connection_class 

1191 self.connection_kwargs = connection_kwargs 

1192 self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) 

1193 

1194 self._connections: List[Connection] = [] 

1195 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) 

1196 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1197 if self._event_dispatcher is None: 

1198 self._event_dispatcher = EventDispatcher() 

1199 

1200 def __repr__(self) -> str: 

1201 return ( 

1202 f"[host={self.host}, port={self.port}, " 

1203 f"name={self.name}, server_type={self.server_type}]" 

1204 ) 

1205 

1206 def __eq__(self, obj: Any) -> bool: 

1207 return isinstance(obj, ClusterNode) and obj.name == self.name 

1208 

1209 _DEL_MESSAGE = "Unclosed ClusterNode object" 

1210 

1211 def __del__( 

1212 self, 

1213 _warn: Any = warnings.warn, 

1214 _grl: Any = asyncio.get_running_loop, 

1215 ) -> None: 

1216 for connection in self._connections: 

1217 if connection.is_connected: 

1218 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) 

1219 

1220 try: 

1221 context = {"client": self, "message": self._DEL_MESSAGE} 

1222 _grl().call_exception_handler(context) 

1223 except RuntimeError: 

1224 pass 

1225 break 

1226 

1227 async def disconnect(self) -> None: 

1228 ret = await asyncio.gather( 

1229 *( 

1230 asyncio.create_task(connection.disconnect()) 

1231 for connection in self._connections 

1232 ), 

1233 return_exceptions=True, 

1234 ) 

1235 exc = next((res for res in ret if isinstance(res, Exception)), None) 

1236 if exc: 

1237 raise exc 

1238 

1239 def acquire_connection(self) -> Connection: 

1240 try: 

1241 return self._free.popleft() 

1242 except IndexError: 

1243 if len(self._connections) < self.max_connections: 

1244 # We are configuring the connection pool not to retry 

1245 # connections on lower level clients to avoid retrying 

1246 # connections to nodes that are not reachable 

1247 # and to avoid blocking the connection pool. 

1248 # The only error that will have some handling in the lower 

1249 # level clients is ConnectionError which will trigger disconnection 

1250 # of the socket. 

1251 # The retries will be handled on cluster client level 

1252 # where we will have proper handling of the cluster topology 

1253 retry = Retry( 

1254 backoff=NoBackoff(), 

1255 retries=0, 

1256 supported_errors=(ConnectionError,), 

1257 ) 

1258 connection_kwargs = self.connection_kwargs.copy() 

1259 connection_kwargs["retry"] = retry 

1260 connection = self.connection_class(**connection_kwargs) 

1261 self._connections.append(connection) 

1262 return connection 

1263 

1264 raise MaxConnectionsError() 

1265 

1266 def release(self, connection: Connection) -> None: 

1267 """ 

1268 Release connection back to free queue. 

1269 """ 

1270 self._free.append(connection) 

1271 

1272 async def parse_response( 

1273 self, connection: Connection, command: str, **kwargs: Any 

1274 ) -> Any: 

1275 try: 

1276 if NEVER_DECODE in kwargs: 

1277 response = await connection.read_response(disable_decoding=True) 

1278 kwargs.pop(NEVER_DECODE) 

1279 else: 

1280 response = await connection.read_response() 

1281 except ResponseError: 

1282 if EMPTY_RESPONSE in kwargs: 

1283 return kwargs[EMPTY_RESPONSE] 

1284 raise 

1285 

1286 if EMPTY_RESPONSE in kwargs: 

1287 kwargs.pop(EMPTY_RESPONSE) 

1288 

1289 # Remove keys entry, it needs only for cache. 

1290 kwargs.pop("keys", None) 

1291 

1292 # Return response 

1293 if command in self.response_callbacks: 

1294 return self.response_callbacks[command](response, **kwargs) 

1295 

1296 return response 

1297 

1298 async def execute_command(self, *args: Any, **kwargs: Any) -> Any: 

1299 # Acquire connection 

1300 connection = self.acquire_connection() 

1301 

1302 # Execute command 

1303 await connection.send_packed_command(connection.pack_command(*args), False) 

1304 

1305 # Read response 

1306 try: 

1307 return await self.parse_response(connection, args[0], **kwargs) 

1308 finally: 

1309 # Release connection 

1310 self._free.append(connection) 

1311 

1312 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: 

1313 # Acquire connection 

1314 connection = self.acquire_connection() 

1315 

1316 # Execute command 

1317 await connection.send_packed_command( 

1318 connection.pack_commands(cmd.args for cmd in commands), False 

1319 ) 

1320 

1321 # Read responses 

1322 ret = False 

1323 for cmd in commands: 

1324 try: 

1325 cmd.result = await self.parse_response( 

1326 connection, cmd.args[0], **cmd.kwargs 

1327 ) 

1328 except Exception as e: 

1329 cmd.result = e 

1330 ret = True 

1331 

1332 # Release connection 

1333 self._free.append(connection) 

1334 

1335 return ret 

1336 

1337 async def re_auth_callback(self, token: TokenInterface): 

1338 tmp_queue = collections.deque() 

1339 while self._free: 

1340 conn = self._free.popleft() 

1341 await conn.retry.call_with_retry( 

1342 lambda: conn.send_command( 

1343 "AUTH", token.try_get("oid"), token.get_value() 

1344 ), 

1345 lambda error: self._mock(error), 

1346 ) 

1347 await conn.retry.call_with_retry( 

1348 lambda: conn.read_response(), lambda error: self._mock(error) 

1349 ) 

1350 tmp_queue.append(conn) 

1351 

1352 while tmp_queue: 

1353 conn = tmp_queue.popleft() 

1354 self._free.append(conn) 

1355 

1356 async def _mock(self, error: RedisError): 

1357 """ 

1358 Dummy functions, needs to be passed as error callback to retry object. 

1359 :param error: 

1360 :return: 

1361 """ 

1362 pass 

1363 

1364 

1365class NodesManager: 

1366 __slots__ = ( 

1367 "_dynamic_startup_nodes", 

1368 "_moved_exception", 

1369 "_event_dispatcher", 

1370 "connection_kwargs", 

1371 "default_node", 

1372 "nodes_cache", 

1373 "read_load_balancer", 

1374 "require_full_coverage", 

1375 "slots_cache", 

1376 "startup_nodes", 

1377 "address_remap", 

1378 ) 

1379 

1380 def __init__( 

1381 self, 

1382 startup_nodes: List["ClusterNode"], 

1383 require_full_coverage: bool, 

1384 connection_kwargs: Dict[str, Any], 

1385 dynamic_startup_nodes: bool = True, 

1386 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, 

1387 event_dispatcher: Optional[EventDispatcher] = None, 

1388 ) -> None: 

1389 self.startup_nodes = {node.name: node for node in startup_nodes} 

1390 self.require_full_coverage = require_full_coverage 

1391 self.connection_kwargs = connection_kwargs 

1392 self.address_remap = address_remap 

1393 

1394 self.default_node: "ClusterNode" = None 

1395 self.nodes_cache: Dict[str, "ClusterNode"] = {} 

1396 self.slots_cache: Dict[int, List["ClusterNode"]] = {} 

1397 self.read_load_balancer = LoadBalancer() 

1398 

1399 self._dynamic_startup_nodes: bool = dynamic_startup_nodes 

1400 self._moved_exception: MovedError = None 

1401 if event_dispatcher is None: 

1402 self._event_dispatcher = EventDispatcher() 

1403 else: 

1404 self._event_dispatcher = event_dispatcher 

1405 

1406 def get_node( 

1407 self, 

1408 host: Optional[str] = None, 

1409 port: Optional[int] = None, 

1410 node_name: Optional[str] = None, 

1411 ) -> Optional["ClusterNode"]: 

1412 if host and port: 

1413 # the user passed host and port 

1414 if host == "localhost": 

1415 host = socket.gethostbyname(host) 

1416 return self.nodes_cache.get(get_node_name(host=host, port=port)) 

1417 elif node_name: 

1418 return self.nodes_cache.get(node_name) 

1419 else: 

1420 raise DataError( 

1421 "get_node requires one of the following: 1. node name 2. host and port" 

1422 ) 

1423 

1424 def set_nodes( 

1425 self, 

1426 old: Dict[str, "ClusterNode"], 

1427 new: Dict[str, "ClusterNode"], 

1428 remove_old: bool = False, 

1429 ) -> None: 

1430 if remove_old: 

1431 for name in list(old.keys()): 

1432 if name not in new: 

1433 task = asyncio.create_task(old.pop(name).disconnect()) # noqa 

1434 

1435 for name, node in new.items(): 

1436 if name in old: 

1437 if old[name] is node: 

1438 continue 

1439 task = asyncio.create_task(old[name].disconnect()) # noqa 

1440 old[name] = node 

1441 

1442 def update_moved_exception(self, exception): 

1443 self._moved_exception = exception 

1444 

1445 def _update_moved_slots(self) -> None: 

1446 e = self._moved_exception 

1447 redirected_node = self.get_node(host=e.host, port=e.port) 

1448 if redirected_node: 

1449 # The node already exists 

1450 if redirected_node.server_type != PRIMARY: 

1451 # Update the node's server type 

1452 redirected_node.server_type = PRIMARY 

1453 else: 

1454 # This is a new node, we will add it to the nodes cache 

1455 redirected_node = ClusterNode( 

1456 e.host, e.port, PRIMARY, **self.connection_kwargs 

1457 ) 

1458 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) 

1459 if redirected_node in self.slots_cache[e.slot_id]: 

1460 # The MOVED error resulted from a failover, and the new slot owner 

1461 # had previously been a replica. 

1462 old_primary = self.slots_cache[e.slot_id][0] 

1463 # Update the old primary to be a replica and add it to the end of 

1464 # the slot's node list 

1465 old_primary.server_type = REPLICA 

1466 self.slots_cache[e.slot_id].append(old_primary) 

1467 # Remove the old replica, which is now a primary, from the slot's 

1468 # node list 

1469 self.slots_cache[e.slot_id].remove(redirected_node) 

1470 # Override the old primary with the new one 

1471 self.slots_cache[e.slot_id][0] = redirected_node 

1472 if self.default_node == old_primary: 

1473 # Update the default node with the new primary 

1474 self.default_node = redirected_node 

1475 else: 

1476 # The new slot owner is a new server, or a server from a different 

1477 # shard. We need to remove all current nodes from the slot's list 

1478 # (including replications) and add just the new node. 

1479 self.slots_cache[e.slot_id] = [redirected_node] 

1480 # Reset moved_exception 

1481 self._moved_exception = None 

1482 

1483 def get_node_from_slot( 

1484 self, 

1485 slot: int, 

1486 read_from_replicas: bool = False, 

1487 load_balancing_strategy=None, 

1488 ) -> "ClusterNode": 

1489 if self._moved_exception: 

1490 self._update_moved_slots() 

1491 

1492 if read_from_replicas is True and load_balancing_strategy is None: 

1493 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN 

1494 

1495 try: 

1496 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: 

1497 # get the server index using the strategy defined in load_balancing_strategy 

1498 primary_name = self.slots_cache[slot][0].name 

1499 node_idx = self.read_load_balancer.get_server_index( 

1500 primary_name, len(self.slots_cache[slot]), load_balancing_strategy 

1501 ) 

1502 return self.slots_cache[slot][node_idx] 

1503 return self.slots_cache[slot][0] 

1504 except (IndexError, TypeError): 

1505 raise SlotNotCoveredError( 

1506 f'Slot "{slot}" not covered by the cluster. ' 

1507 f'"require_full_coverage={self.require_full_coverage}"' 

1508 ) 

1509 

1510 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: 

1511 return [ 

1512 node 

1513 for node in self.nodes_cache.values() 

1514 if node.server_type == server_type 

1515 ] 

1516 

1517 async def initialize(self) -> None: 

1518 self.read_load_balancer.reset() 

1519 tmp_nodes_cache: Dict[str, "ClusterNode"] = {} 

1520 tmp_slots: Dict[int, List["ClusterNode"]] = {} 

1521 disagreements = [] 

1522 startup_nodes_reachable = False 

1523 fully_covered = False 

1524 exception = None 

1525 # Convert to tuple to prevent RuntimeError if self.startup_nodes 

1526 # is modified during iteration 

1527 for startup_node in tuple(self.startup_nodes.values()): 

1528 try: 

1529 # Make sure cluster mode is enabled on this node 

1530 try: 

1531 self._event_dispatcher.dispatch( 

1532 AfterAsyncClusterInstantiationEvent( 

1533 self.nodes_cache, 

1534 self.connection_kwargs.get("credential_provider", None), 

1535 ) 

1536 ) 

1537 cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") 

1538 except ResponseError: 

1539 raise RedisClusterException( 

1540 "Cluster mode is not enabled on this node" 

1541 ) 

1542 startup_nodes_reachable = True 

1543 except Exception as e: 

1544 # Try the next startup node. 

1545 # The exception is saved and raised only if we have no more nodes. 

1546 exception = e 

1547 continue 

1548 

1549 # CLUSTER SLOTS command results in the following output: 

1550 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] 

1551 # where each node contains the following list: [IP, port, node_id] 

1552 # Therefore, cluster_slots[0][2][0] will be the IP address of the 

1553 # primary node of the first slot section. 

1554 # If there's only one server in the cluster, its ``host`` is '' 

1555 # Fix it to the host in startup_nodes 

1556 if ( 

1557 len(cluster_slots) == 1 

1558 and not cluster_slots[0][2][0] 

1559 and len(self.startup_nodes) == 1 

1560 ): 

1561 cluster_slots[0][2][0] = startup_node.host 

1562 

1563 for slot in cluster_slots: 

1564 for i in range(2, len(slot)): 

1565 slot[i] = [str_if_bytes(val) for val in slot[i]] 

1566 primary_node = slot[2] 

1567 host = primary_node[0] 

1568 if host == "": 

1569 host = startup_node.host 

1570 port = int(primary_node[1]) 

1571 host, port = self.remap_host_port(host, port) 

1572 

1573 nodes_for_slot = [] 

1574 

1575 target_node = tmp_nodes_cache.get(get_node_name(host, port)) 

1576 if not target_node: 

1577 target_node = ClusterNode( 

1578 host, port, PRIMARY, **self.connection_kwargs 

1579 ) 

1580 # add this node to the nodes cache 

1581 tmp_nodes_cache[target_node.name] = target_node 

1582 nodes_for_slot.append(target_node) 

1583 

1584 replica_nodes = slot[3:] 

1585 for replica_node in replica_nodes: 

1586 host = replica_node[0] 

1587 port = replica_node[1] 

1588 host, port = self.remap_host_port(host, port) 

1589 

1590 target_replica_node = tmp_nodes_cache.get(get_node_name(host, port)) 

1591 if not target_replica_node: 

1592 target_replica_node = ClusterNode( 

1593 host, port, REPLICA, **self.connection_kwargs 

1594 ) 

1595 # add this node to the nodes cache 

1596 tmp_nodes_cache[target_replica_node.name] = target_replica_node 

1597 nodes_for_slot.append(target_replica_node) 

1598 

1599 for i in range(int(slot[0]), int(slot[1]) + 1): 

1600 if i not in tmp_slots: 

1601 tmp_slots[i] = nodes_for_slot 

1602 else: 

1603 # Validate that 2 nodes want to use the same slot cache 

1604 # setup 

1605 tmp_slot = tmp_slots[i][0] 

1606 if tmp_slot.name != target_node.name: 

1607 disagreements.append( 

1608 f"{tmp_slot.name} vs {target_node.name} on slot: {i}" 

1609 ) 

1610 

1611 if len(disagreements) > 5: 

1612 raise RedisClusterException( 

1613 f"startup_nodes could not agree on a valid " 

1614 f"slots cache: {', '.join(disagreements)}" 

1615 ) 

1616 

1617 # Validate if all slots are covered or if we should try next startup node 

1618 fully_covered = True 

1619 for i in range(REDIS_CLUSTER_HASH_SLOTS): 

1620 if i not in tmp_slots: 

1621 fully_covered = False 

1622 break 

1623 if fully_covered: 

1624 break 

1625 

1626 if not startup_nodes_reachable: 

1627 raise RedisClusterException( 

1628 f"Redis Cluster cannot be connected. Please provide at least " 

1629 f"one reachable node: {str(exception)}" 

1630 ) from exception 

1631 

1632 # Check if the slots are not fully covered 

1633 if not fully_covered and self.require_full_coverage: 

1634 # Despite the requirement that the slots be covered, there 

1635 # isn't a full coverage 

1636 raise RedisClusterException( 

1637 f"All slots are not covered after query all startup_nodes. " 

1638 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " 

1639 f"covered..." 

1640 ) 

1641 

1642 # Set the tmp variables to the real variables 

1643 self.slots_cache = tmp_slots 

1644 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) 

1645 

1646 if self._dynamic_startup_nodes: 

1647 # Populate the startup nodes with all discovered nodes 

1648 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) 

1649 

1650 # Set the default node 

1651 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] 

1652 # If initialize was called after a MovedError, clear it 

1653 self._moved_exception = None 

1654 

1655 async def aclose(self, attr: str = "nodes_cache") -> None: 

1656 self.default_node = None 

1657 await asyncio.gather( 

1658 *( 

1659 asyncio.create_task(node.disconnect()) 

1660 for node in getattr(self, attr).values() 

1661 ) 

1662 ) 

1663 

1664 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: 

1665 """ 

1666 Remap the host and port returned from the cluster to a different 

1667 internal value. Useful if the client is not connecting directly 

1668 to the cluster. 

1669 """ 

1670 if self.address_remap: 

1671 return self.address_remap((host, port)) 

1672 return host, port 

1673 

1674 

1675class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): 

1676 """ 

1677 Create a new ClusterPipeline object. 

1678 

1679 Usage:: 

1680 

1681 result = await ( 

1682 rc.pipeline() 

1683 .set("A", 1) 

1684 .get("A") 

1685 .hset("K", "F", "V") 

1686 .hgetall("K") 

1687 .mset_nonatomic({"A": 2, "B": 3}) 

1688 .get("A") 

1689 .get("B") 

1690 .delete("A", "B", "K") 

1691 .execute() 

1692 ) 

1693 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] 

1694 

1695 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which 

1696 are split across multiple nodes, you'll get multiple results for them in the array. 

1697 

1698 Retryable errors: 

1699 - :class:`~.ClusterDownError` 

1700 - :class:`~.ConnectionError` 

1701 - :class:`~.TimeoutError` 

1702 

1703 Redirection errors: 

1704 - :class:`~.TryAgainError` 

1705 - :class:`~.MovedError` 

1706 - :class:`~.AskError` 

1707 

1708 :param client: 

1709 | Existing :class:`~.RedisCluster` client 

1710 """ 

1711 

1712 __slots__ = ("cluster_client", "_transaction", "_execution_strategy") 

1713 

1714 def __init__( 

1715 self, client: RedisCluster, transaction: Optional[bool] = None 

1716 ) -> None: 

1717 self.cluster_client = client 

1718 self._transaction = transaction 

1719 self._execution_strategy: ExecutionStrategy = ( 

1720 PipelineStrategy(self) 

1721 if not self._transaction 

1722 else TransactionStrategy(self) 

1723 ) 

1724 

1725 async def initialize(self) -> "ClusterPipeline": 

1726 await self._execution_strategy.initialize() 

1727 return self 

1728 

1729 async def __aenter__(self) -> "ClusterPipeline": 

1730 return await self.initialize() 

1731 

1732 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: 

1733 await self.reset() 

1734 

1735 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: 

1736 return self.initialize().__await__() 

1737 

1738 def __bool__(self) -> bool: 

1739 "Pipeline instances should always evaluate to True on Python 3+" 

1740 return True 

1741 

1742 def __len__(self) -> int: 

1743 return len(self._execution_strategy) 

1744 

1745 def execute_command( 

1746 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

1747 ) -> "ClusterPipeline": 

1748 """ 

1749 Append a raw command to the pipeline. 

1750 

1751 :param args: 

1752 | Raw command args 

1753 :param kwargs: 

1754 

1755 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` 

1756 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] 

1757 - Rest of the kwargs are passed to the Redis connection 

1758 """ 

1759 return self._execution_strategy.execute_command(*args, **kwargs) 

1760 

1761 async def execute( 

1762 self, raise_on_error: bool = True, allow_redirections: bool = True 

1763 ) -> List[Any]: 

1764 """ 

1765 Execute the pipeline. 

1766 

1767 It will retry the commands as specified by retries specified in :attr:`retry` 

1768 & then raise an exception. 

1769 

1770 :param raise_on_error: 

1771 | Raise the first error if there are any errors 

1772 :param allow_redirections: 

1773 | Whether to retry each failed command individually in case of redirection 

1774 errors 

1775 

1776 :raises RedisClusterException: if target_nodes is not provided & the command 

1777 can't be mapped to a slot 

1778 """ 

1779 try: 

1780 return await self._execution_strategy.execute( 

1781 raise_on_error, allow_redirections 

1782 ) 

1783 finally: 

1784 await self.reset() 

1785 

1786 def _split_command_across_slots( 

1787 self, command: str, *keys: KeyT 

1788 ) -> "ClusterPipeline": 

1789 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values(): 

1790 self.execute_command(command, *slot_keys) 

1791 

1792 return self 

1793 

1794 async def reset(self): 

1795 """ 

1796 Reset back to empty pipeline. 

1797 """ 

1798 await self._execution_strategy.reset() 

1799 

1800 def multi(self): 

1801 """ 

1802 Start a transactional block of the pipeline after WATCH commands 

1803 are issued. End the transactional block with `execute`. 

1804 """ 

1805 self._execution_strategy.multi() 

1806 

1807 async def discard(self): 

1808 """ """ 

1809 await self._execution_strategy.discard() 

1810 

1811 async def watch(self, *names): 

1812 """Watches the values at keys ``names``""" 

1813 await self._execution_strategy.watch(*names) 

1814 

1815 async def unwatch(self): 

1816 """Unwatches all previously specified keys""" 

1817 await self._execution_strategy.unwatch() 

1818 

1819 async def unlink(self, *names): 

1820 await self._execution_strategy.unlink(*names) 

1821 

1822 def mset_nonatomic( 

1823 self, mapping: Mapping[AnyKeyT, EncodableT] 

1824 ) -> "ClusterPipeline": 

1825 return self._execution_strategy.mset_nonatomic(mapping) 

1826 

1827 

1828for command in PIPELINE_BLOCKED_COMMANDS: 

1829 command = command.replace(" ", "_").lower() 

1830 if command == "mset_nonatomic": 

1831 continue 

1832 

1833 setattr(ClusterPipeline, command, block_pipeline_command(command)) 

1834 

1835 

1836class PipelineCommand: 

1837 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: 

1838 self.args = args 

1839 self.kwargs = kwargs 

1840 self.position = position 

1841 self.result: Union[Any, Exception] = None 

1842 self.command_policies: Optional[CommandPolicies] = None 

1843 

1844 def __repr__(self) -> str: 

1845 return f"[{self.position}] {self.args} ({self.kwargs})" 

1846 

1847 

1848class ExecutionStrategy(ABC): 

1849 @abstractmethod 

1850 async def initialize(self) -> "ClusterPipeline": 

1851 """ 

1852 Initialize the execution strategy. 

1853 

1854 See ClusterPipeline.initialize() 

1855 """ 

1856 pass 

1857 

1858 @abstractmethod 

1859 def execute_command( 

1860 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

1861 ) -> "ClusterPipeline": 

1862 """ 

1863 Append a raw command to the pipeline. 

1864 

1865 See ClusterPipeline.execute_command() 

1866 """ 

1867 pass 

1868 

1869 @abstractmethod 

1870 async def execute( 

1871 self, raise_on_error: bool = True, allow_redirections: bool = True 

1872 ) -> List[Any]: 

1873 """ 

1874 Execute the pipeline. 

1875 

1876 It will retry the commands as specified by retries specified in :attr:`retry` 

1877 & then raise an exception. 

1878 

1879 See ClusterPipeline.execute() 

1880 """ 

1881 pass 

1882 

1883 @abstractmethod 

1884 def mset_nonatomic( 

1885 self, mapping: Mapping[AnyKeyT, EncodableT] 

1886 ) -> "ClusterPipeline": 

1887 """ 

1888 Executes multiple MSET commands according to the provided slot/pairs mapping. 

1889 

1890 See ClusterPipeline.mset_nonatomic() 

1891 """ 

1892 pass 

1893 

1894 @abstractmethod 

1895 async def reset(self): 

1896 """ 

1897 Resets current execution strategy. 

1898 

1899 See: ClusterPipeline.reset() 

1900 """ 

1901 pass 

1902 

1903 @abstractmethod 

1904 def multi(self): 

1905 """ 

1906 Starts transactional context. 

1907 

1908 See: ClusterPipeline.multi() 

1909 """ 

1910 pass 

1911 

1912 @abstractmethod 

1913 async def watch(self, *names): 

1914 """ 

1915 Watch given keys. 

1916 

1917 See: ClusterPipeline.watch() 

1918 """ 

1919 pass 

1920 

1921 @abstractmethod 

1922 async def unwatch(self): 

1923 """ 

1924 Unwatches all previously specified keys 

1925 

1926 See: ClusterPipeline.unwatch() 

1927 """ 

1928 pass 

1929 

1930 @abstractmethod 

1931 async def discard(self): 

1932 pass 

1933 

1934 @abstractmethod 

1935 async def unlink(self, *names): 

1936 """ 

1937 "Unlink a key specified by ``names``" 

1938 

1939 See: ClusterPipeline.unlink() 

1940 """ 

1941 pass 

1942 

1943 @abstractmethod 

1944 def __len__(self) -> int: 

1945 pass 

1946 

1947 

1948class AbstractStrategy(ExecutionStrategy): 

1949 def __init__(self, pipe: ClusterPipeline) -> None: 

1950 self._pipe: ClusterPipeline = pipe 

1951 self._command_queue: List["PipelineCommand"] = [] 

1952 

1953 async def initialize(self) -> "ClusterPipeline": 

1954 if self._pipe.cluster_client._initialize: 

1955 await self._pipe.cluster_client.initialize() 

1956 self._command_queue = [] 

1957 return self._pipe 

1958 

1959 def execute_command( 

1960 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

1961 ) -> "ClusterPipeline": 

1962 self._command_queue.append( 

1963 PipelineCommand(len(self._command_queue), *args, **kwargs) 

1964 ) 

1965 return self._pipe 

1966 

1967 def _annotate_exception(self, exception, number, command): 

1968 """ 

1969 Provides extra context to the exception prior to it being handled 

1970 """ 

1971 cmd = " ".join(map(safe_str, command)) 

1972 msg = ( 

1973 f"Command # {number} ({truncate_text(cmd)}) of pipeline " 

1974 f"caused error: {exception.args[0]}" 

1975 ) 

1976 exception.args = (msg,) + exception.args[1:] 

1977 

1978 @abstractmethod 

1979 def mset_nonatomic( 

1980 self, mapping: Mapping[AnyKeyT, EncodableT] 

1981 ) -> "ClusterPipeline": 

1982 pass 

1983 

1984 @abstractmethod 

1985 async def execute( 

1986 self, raise_on_error: bool = True, allow_redirections: bool = True 

1987 ) -> List[Any]: 

1988 pass 

1989 

1990 @abstractmethod 

1991 async def reset(self): 

1992 pass 

1993 

1994 @abstractmethod 

1995 def multi(self): 

1996 pass 

1997 

1998 @abstractmethod 

1999 async def watch(self, *names): 

2000 pass 

2001 

2002 @abstractmethod 

2003 async def unwatch(self): 

2004 pass 

2005 

2006 @abstractmethod 

2007 async def discard(self): 

2008 pass 

2009 

2010 @abstractmethod 

2011 async def unlink(self, *names): 

2012 pass 

2013 

2014 def __len__(self) -> int: 

2015 return len(self._command_queue) 

2016 

2017 

2018class PipelineStrategy(AbstractStrategy): 

2019 def __init__(self, pipe: ClusterPipeline) -> None: 

2020 super().__init__(pipe) 

2021 

2022 def mset_nonatomic( 

2023 self, mapping: Mapping[AnyKeyT, EncodableT] 

2024 ) -> "ClusterPipeline": 

2025 encoder = self._pipe.cluster_client.encoder 

2026 

2027 slots_pairs = {} 

2028 for pair in mapping.items(): 

2029 slot = key_slot(encoder.encode(pair[0])) 

2030 slots_pairs.setdefault(slot, []).extend(pair) 

2031 

2032 for pairs in slots_pairs.values(): 

2033 self.execute_command("MSET", *pairs) 

2034 

2035 return self._pipe 

2036 

2037 async def execute( 

2038 self, raise_on_error: bool = True, allow_redirections: bool = True 

2039 ) -> List[Any]: 

2040 if not self._command_queue: 

2041 return [] 

2042 

2043 try: 

2044 retry_attempts = self._pipe.cluster_client.retry.get_retries() 

2045 while True: 

2046 try: 

2047 if self._pipe.cluster_client._initialize: 

2048 await self._pipe.cluster_client.initialize() 

2049 return await self._execute( 

2050 self._pipe.cluster_client, 

2051 self._command_queue, 

2052 raise_on_error=raise_on_error, 

2053 allow_redirections=allow_redirections, 

2054 ) 

2055 

2056 except RedisCluster.ERRORS_ALLOW_RETRY as e: 

2057 if retry_attempts > 0: 

2058 # Try again with the new cluster setup. All other errors 

2059 # should be raised. 

2060 retry_attempts -= 1 

2061 await self._pipe.cluster_client.aclose() 

2062 await asyncio.sleep(0.25) 

2063 else: 

2064 # All other errors should be raised. 

2065 raise e 

2066 finally: 

2067 await self.reset() 

2068 

2069 async def _execute( 

2070 self, 

2071 client: "RedisCluster", 

2072 stack: List["PipelineCommand"], 

2073 raise_on_error: bool = True, 

2074 allow_redirections: bool = True, 

2075 ) -> List[Any]: 

2076 todo = [ 

2077 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception) 

2078 ] 

2079 

2080 nodes = {} 

2081 for cmd in todo: 

2082 passed_targets = cmd.kwargs.pop("target_nodes", None) 

2083 command_policies = await client._policy_resolver.resolve( 

2084 cmd.args[0].lower() 

2085 ) 

2086 

2087 if passed_targets and not client._is_node_flag(passed_targets): 

2088 target_nodes = client._parse_target_nodes(passed_targets) 

2089 

2090 if not command_policies: 

2091 command_policies = CommandPolicies() 

2092 else: 

2093 if not command_policies: 

2094 command_flag = client.command_flags.get(cmd.args[0]) 

2095 if not command_flag: 

2096 # Fallback to default policy 

2097 if not client.get_default_node(): 

2098 slot = None 

2099 else: 

2100 slot = await client._determine_slot(*cmd.args) 

2101 if not slot: 

2102 command_policies = CommandPolicies() 

2103 else: 

2104 command_policies = CommandPolicies( 

2105 request_policy=RequestPolicy.DEFAULT_KEYED, 

2106 response_policy=ResponsePolicy.DEFAULT_KEYED, 

2107 ) 

2108 else: 

2109 if command_flag in client._command_flags_mapping: 

2110 command_policies = CommandPolicies( 

2111 request_policy=client._command_flags_mapping[ 

2112 command_flag 

2113 ] 

2114 ) 

2115 else: 

2116 command_policies = CommandPolicies() 

2117 

2118 target_nodes = await client._determine_nodes( 

2119 *cmd.args, 

2120 request_policy=command_policies.request_policy, 

2121 node_flag=passed_targets, 

2122 ) 

2123 if not target_nodes: 

2124 raise RedisClusterException( 

2125 f"No targets were found to execute {cmd.args} command on" 

2126 ) 

2127 cmd.command_policies = command_policies 

2128 if len(target_nodes) > 1: 

2129 raise RedisClusterException(f"Too many targets for command {cmd.args}") 

2130 node = target_nodes[0] 

2131 if node.name not in nodes: 

2132 nodes[node.name] = (node, []) 

2133 nodes[node.name][1].append(cmd) 

2134 

2135 errors = await asyncio.gather( 

2136 *( 

2137 asyncio.create_task(node[0].execute_pipeline(node[1])) 

2138 for node in nodes.values() 

2139 ) 

2140 ) 

2141 

2142 if any(errors): 

2143 if allow_redirections: 

2144 # send each errored command individually 

2145 for cmd in todo: 

2146 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): 

2147 try: 

2148 cmd.result = client._policies_callback_mapping[ 

2149 cmd.command_policies.response_policy 

2150 ](await client.execute_command(*cmd.args, **cmd.kwargs)) 

2151 except Exception as e: 

2152 cmd.result = e 

2153 

2154 if raise_on_error: 

2155 for cmd in todo: 

2156 result = cmd.result 

2157 if isinstance(result, Exception): 

2158 command = " ".join(map(safe_str, cmd.args)) 

2159 msg = ( 

2160 f"Command # {cmd.position + 1} " 

2161 f"({truncate_text(command)}) " 

2162 f"of pipeline caused error: {result.args}" 

2163 ) 

2164 result.args = (msg,) + result.args[1:] 

2165 raise result 

2166 

2167 default_cluster_node = client.get_default_node() 

2168 

2169 # Check whether the default node was used. In some cases, 

2170 # 'client.get_default_node()' may return None. The check below 

2171 # prevents a potential AttributeError. 

2172 if default_cluster_node is not None: 

2173 default_node = nodes.get(default_cluster_node.name) 

2174 if default_node is not None: 

2175 # This pipeline execution used the default node, check if we need 

2176 # to replace it. 

2177 # Note: when the error is raised we'll reset the default node in the 

2178 # caller function. 

2179 for cmd in default_node[1]: 

2180 # Check if it has a command that failed with a relevant 

2181 # exception 

2182 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY: 

2183 client.replace_default_node() 

2184 break 

2185 

2186 return [cmd.result for cmd in stack] 

2187 

2188 async def reset(self): 

2189 """ 

2190 Reset back to empty pipeline. 

2191 """ 

2192 self._command_queue = [] 

2193 

2194 def multi(self): 

2195 raise RedisClusterException( 

2196 "method multi() is not supported outside of transactional context" 

2197 ) 

2198 

2199 async def watch(self, *names): 

2200 raise RedisClusterException( 

2201 "method watch() is not supported outside of transactional context" 

2202 ) 

2203 

2204 async def unwatch(self): 

2205 raise RedisClusterException( 

2206 "method unwatch() is not supported outside of transactional context" 

2207 ) 

2208 

2209 async def discard(self): 

2210 raise RedisClusterException( 

2211 "method discard() is not supported outside of transactional context" 

2212 ) 

2213 

2214 async def unlink(self, *names): 

2215 if len(names) != 1: 

2216 raise RedisClusterException( 

2217 "unlinking multiple keys is not implemented in pipeline command" 

2218 ) 

2219 

2220 return self.execute_command("UNLINK", names[0]) 

2221 

2222 

2223class TransactionStrategy(AbstractStrategy): 

2224 NO_SLOTS_COMMANDS = {"UNWATCH"} 

2225 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} 

2226 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} 

2227 SLOT_REDIRECT_ERRORS = (AskError, MovedError) 

2228 CONNECTION_ERRORS = ( 

2229 ConnectionError, 

2230 OSError, 

2231 ClusterDownError, 

2232 SlotNotCoveredError, 

2233 ) 

2234 

2235 def __init__(self, pipe: ClusterPipeline) -> None: 

2236 super().__init__(pipe) 

2237 self._explicit_transaction = False 

2238 self._watching = False 

2239 self._pipeline_slots: Set[int] = set() 

2240 self._transaction_node: Optional[ClusterNode] = None 

2241 self._transaction_connection: Optional[Connection] = None 

2242 self._executing = False 

2243 self._retry = copy(self._pipe.cluster_client.retry) 

2244 self._retry.update_supported_errors( 

2245 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS 

2246 ) 

2247 

2248 def _get_client_and_connection_for_transaction( 

2249 self, 

2250 ) -> Tuple[ClusterNode, Connection]: 

2251 """ 

2252 Find a connection for a pipeline transaction. 

2253 

2254 For running an atomic transaction, watch keys ensure that contents have not been 

2255 altered as long as the watch commands for those keys were sent over the same 

2256 connection. So once we start watching a key, we fetch a connection to the 

2257 node that owns that slot and reuse it. 

2258 """ 

2259 if not self._pipeline_slots: 

2260 raise RedisClusterException( 

2261 "At least a command with a key is needed to identify a node" 

2262 ) 

2263 

2264 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot( 

2265 list(self._pipeline_slots)[0], False 

2266 ) 

2267 self._transaction_node = node 

2268 

2269 if not self._transaction_connection: 

2270 connection: Connection = self._transaction_node.acquire_connection() 

2271 self._transaction_connection = connection 

2272 

2273 return self._transaction_node, self._transaction_connection 

2274 

2275 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any": 

2276 # Given the limitation of ClusterPipeline sync API, we have to run it in thread. 

2277 response = None 

2278 error = None 

2279 

2280 def runner(): 

2281 nonlocal response 

2282 nonlocal error 

2283 try: 

2284 response = asyncio.run(self._execute_command(*args, **kwargs)) 

2285 except Exception as e: 

2286 error = e 

2287 

2288 thread = threading.Thread(target=runner) 

2289 thread.start() 

2290 thread.join() 

2291 

2292 if error: 

2293 raise error 

2294 

2295 return response 

2296 

2297 async def _execute_command( 

2298 self, *args: Union[KeyT, EncodableT], **kwargs: Any 

2299 ) -> Any: 

2300 if self._pipe.cluster_client._initialize: 

2301 await self._pipe.cluster_client.initialize() 

2302 

2303 slot_number: Optional[int] = None 

2304 if args[0] not in self.NO_SLOTS_COMMANDS: 

2305 slot_number = await self._pipe.cluster_client._determine_slot(*args) 

2306 

2307 if ( 

2308 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS 

2309 ) and not self._explicit_transaction: 

2310 if args[0] == "WATCH": 

2311 self._validate_watch() 

2312 

2313 if slot_number is not None: 

2314 if self._pipeline_slots and slot_number not in self._pipeline_slots: 

2315 raise CrossSlotTransactionError( 

2316 "Cannot watch or send commands on different slots" 

2317 ) 

2318 

2319 self._pipeline_slots.add(slot_number) 

2320 elif args[0] not in self.NO_SLOTS_COMMANDS: 

2321 raise RedisClusterException( 

2322 f"Cannot identify slot number for command: {args[0]}," 

2323 "it cannot be triggered in a transaction" 

2324 ) 

2325 

2326 return self._immediate_execute_command(*args, **kwargs) 

2327 else: 

2328 if slot_number is not None: 

2329 self._pipeline_slots.add(slot_number) 

2330 

2331 return super().execute_command(*args, **kwargs) 

2332 

2333 def _validate_watch(self): 

2334 if self._explicit_transaction: 

2335 raise RedisError("Cannot issue a WATCH after a MULTI") 

2336 

2337 self._watching = True 

2338 

2339 async def _immediate_execute_command(self, *args, **options): 

2340 return await self._retry.call_with_retry( 

2341 lambda: self._get_connection_and_send_command(*args, **options), 

2342 self._reinitialize_on_error, 

2343 ) 

2344 

2345 async def _get_connection_and_send_command(self, *args, **options): 

2346 redis_node, connection = self._get_client_and_connection_for_transaction() 

2347 return await self._send_command_parse_response( 

2348 connection, redis_node, args[0], *args, **options 

2349 ) 

2350 

2351 async def _send_command_parse_response( 

2352 self, 

2353 connection: Connection, 

2354 redis_node: ClusterNode, 

2355 command_name, 

2356 *args, 

2357 **options, 

2358 ): 

2359 """ 

2360 Send a command and parse the response 

2361 """ 

2362 

2363 await connection.send_command(*args) 

2364 output = await redis_node.parse_response(connection, command_name, **options) 

2365 

2366 if command_name in self.UNWATCH_COMMANDS: 

2367 self._watching = False 

2368 return output 

2369 

2370 async def _reinitialize_on_error(self, error): 

2371 if self._watching: 

2372 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: 

2373 raise WatchError("Slot rebalancing occurred while watching keys") 

2374 

2375 if ( 

2376 type(error) in self.SLOT_REDIRECT_ERRORS 

2377 or type(error) in self.CONNECTION_ERRORS 

2378 ): 

2379 if self._transaction_connection: 

2380 self._transaction_connection = None 

2381 

2382 self._pipe.cluster_client.reinitialize_counter += 1 

2383 if ( 

2384 self._pipe.cluster_client.reinitialize_steps 

2385 and self._pipe.cluster_client.reinitialize_counter 

2386 % self._pipe.cluster_client.reinitialize_steps 

2387 == 0 

2388 ): 

2389 await self._pipe.cluster_client.nodes_manager.initialize() 

2390 self.reinitialize_counter = 0 

2391 else: 

2392 if isinstance(error, AskError): 

2393 self._pipe.cluster_client.nodes_manager.update_moved_exception( 

2394 error 

2395 ) 

2396 

2397 self._executing = False 

2398 

2399 def _raise_first_error(self, responses, stack): 

2400 """ 

2401 Raise the first exception on the stack 

2402 """ 

2403 for r, cmd in zip(responses, stack): 

2404 if isinstance(r, Exception): 

2405 self._annotate_exception(r, cmd.position + 1, cmd.args) 

2406 raise r 

2407 

2408 def mset_nonatomic( 

2409 self, mapping: Mapping[AnyKeyT, EncodableT] 

2410 ) -> "ClusterPipeline": 

2411 raise NotImplementedError("Method is not supported in transactional context.") 

2412 

2413 async def execute( 

2414 self, raise_on_error: bool = True, allow_redirections: bool = True 

2415 ) -> List[Any]: 

2416 stack = self._command_queue 

2417 if not stack and (not self._watching or not self._pipeline_slots): 

2418 return [] 

2419 

2420 return await self._execute_transaction_with_retries(stack, raise_on_error) 

2421 

2422 async def _execute_transaction_with_retries( 

2423 self, stack: List["PipelineCommand"], raise_on_error: bool 

2424 ): 

2425 return await self._retry.call_with_retry( 

2426 lambda: self._execute_transaction(stack, raise_on_error), 

2427 self._reinitialize_on_error, 

2428 ) 

2429 

2430 async def _execute_transaction( 

2431 self, stack: List["PipelineCommand"], raise_on_error: bool 

2432 ): 

2433 if len(self._pipeline_slots) > 1: 

2434 raise CrossSlotTransactionError( 

2435 "All keys involved in a cluster transaction must map to the same slot" 

2436 ) 

2437 

2438 self._executing = True 

2439 

2440 redis_node, connection = self._get_client_and_connection_for_transaction() 

2441 

2442 stack = chain( 

2443 [PipelineCommand(0, "MULTI")], 

2444 stack, 

2445 [PipelineCommand(0, "EXEC")], 

2446 ) 

2447 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs] 

2448 packed_commands = connection.pack_commands(commands) 

2449 await connection.send_packed_command(packed_commands) 

2450 errors = [] 

2451 

2452 # parse off the response for MULTI 

2453 # NOTE: we need to handle ResponseErrors here and continue 

2454 # so that we read all the additional command messages from 

2455 # the socket 

2456 try: 

2457 await redis_node.parse_response(connection, "MULTI") 

2458 except ResponseError as e: 

2459 self._annotate_exception(e, 0, "MULTI") 

2460 errors.append(e) 

2461 except self.CONNECTION_ERRORS as cluster_error: 

2462 self._annotate_exception(cluster_error, 0, "MULTI") 

2463 raise 

2464 

2465 # and all the other commands 

2466 for i, command in enumerate(self._command_queue): 

2467 if EMPTY_RESPONSE in command.kwargs: 

2468 errors.append((i, command.kwargs[EMPTY_RESPONSE])) 

2469 else: 

2470 try: 

2471 _ = await redis_node.parse_response(connection, "_") 

2472 except self.SLOT_REDIRECT_ERRORS as slot_error: 

2473 self._annotate_exception(slot_error, i + 1, command.args) 

2474 errors.append(slot_error) 

2475 except self.CONNECTION_ERRORS as cluster_error: 

2476 self._annotate_exception(cluster_error, i + 1, command.args) 

2477 raise 

2478 except ResponseError as e: 

2479 self._annotate_exception(e, i + 1, command.args) 

2480 errors.append(e) 

2481 

2482 response = None 

2483 # parse the EXEC. 

2484 try: 

2485 response = await redis_node.parse_response(connection, "EXEC") 

2486 except ExecAbortError: 

2487 if errors: 

2488 raise errors[0] 

2489 raise 

2490 

2491 self._executing = False 

2492 

2493 # EXEC clears any watched keys 

2494 self._watching = False 

2495 

2496 if response is None: 

2497 raise WatchError("Watched variable changed.") 

2498 

2499 # put any parse errors into the response 

2500 for i, e in errors: 

2501 response.insert(i, e) 

2502 

2503 if len(response) != len(self._command_queue): 

2504 raise InvalidPipelineStack( 

2505 "Unexpected response length for cluster pipeline EXEC." 

2506 " Command stack was {} but response had length {}".format( 

2507 [c.args[0] for c in self._command_queue], len(response) 

2508 ) 

2509 ) 

2510 

2511 # find any errors in the response and raise if necessary 

2512 if raise_on_error or len(errors) > 0: 

2513 self._raise_first_error( 

2514 response, 

2515 self._command_queue, 

2516 ) 

2517 

2518 # We have to run response callbacks manually 

2519 data = [] 

2520 for r, cmd in zip(response, self._command_queue): 

2521 if not isinstance(r, Exception): 

2522 command_name = cmd.args[0] 

2523 if command_name in self._pipe.cluster_client.response_callbacks: 

2524 r = self._pipe.cluster_client.response_callbacks[command_name]( 

2525 r, **cmd.kwargs 

2526 ) 

2527 data.append(r) 

2528 return data 

2529 

2530 async def reset(self): 

2531 self._command_queue = [] 

2532 

2533 # make sure to reset the connection state in the event that we were 

2534 # watching something 

2535 if self._transaction_connection: 

2536 try: 

2537 if self._watching: 

2538 # call this manually since our unwatch or 

2539 # immediate_execute_command methods can call reset() 

2540 await self._transaction_connection.send_command("UNWATCH") 

2541 await self._transaction_connection.read_response() 

2542 # we can safely return the connection to the pool here since we're 

2543 # sure we're no longer WATCHing anything 

2544 self._transaction_node.release(self._transaction_connection) 

2545 self._transaction_connection = None 

2546 except self.CONNECTION_ERRORS: 

2547 # disconnect will also remove any previous WATCHes 

2548 if self._transaction_connection: 

2549 await self._transaction_connection.disconnect() 

2550 

2551 # clean up the other instance attributes 

2552 self._transaction_node = None 

2553 self._watching = False 

2554 self._explicit_transaction = False 

2555 self._pipeline_slots = set() 

2556 self._executing = False 

2557 

2558 def multi(self): 

2559 if self._explicit_transaction: 

2560 raise RedisError("Cannot issue nested calls to MULTI") 

2561 if self._command_queue: 

2562 raise RedisError( 

2563 "Commands without an initial WATCH have already been issued" 

2564 ) 

2565 self._explicit_transaction = True 

2566 

2567 async def watch(self, *names): 

2568 if self._explicit_transaction: 

2569 raise RedisError("Cannot issue a WATCH after a MULTI") 

2570 

2571 return await self.execute_command("WATCH", *names) 

2572 

2573 async def unwatch(self): 

2574 if self._watching: 

2575 return await self.execute_command("UNWATCH") 

2576 

2577 return True 

2578 

2579 async def discard(self): 

2580 await self.reset() 

2581 

2582 async def unlink(self, *names): 

2583 return self.execute_command("UNLINK", *names)