1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 Any,
13 Callable,
14 Coroutine,
15 Deque,
16 Dict,
17 Generator,
18 List,
19 Mapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26)
27
28from redis._parsers import AsyncCommandsParser, Encoder
29from redis._parsers.helpers import (
30 _RedisCallbacks,
31 _RedisCallbacksRESP2,
32 _RedisCallbacksRESP3,
33)
34from redis.asyncio.client import ResponseCallbackT
35from redis.asyncio.connection import Connection, SSLConnection, parse_url
36from redis.asyncio.lock import Lock
37from redis.asyncio.retry import Retry
38from redis.auth.token import TokenInterface
39from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
40from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
41from redis.cluster import (
42 PIPELINE_BLOCKED_COMMANDS,
43 PRIMARY,
44 REPLICA,
45 SLOT_ID,
46 AbstractRedisCluster,
47 LoadBalancer,
48 LoadBalancingStrategy,
49 block_pipeline_command,
50 get_node_name,
51 parse_cluster_slots,
52)
53from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
54from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
55from redis.credentials import CredentialProvider
56from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
57from redis.exceptions import (
58 AskError,
59 BusyLoadingError,
60 ClusterDownError,
61 ClusterError,
62 ConnectionError,
63 CrossSlotTransactionError,
64 DataError,
65 ExecAbortError,
66 InvalidPipelineStack,
67 MaxConnectionsError,
68 MovedError,
69 RedisClusterException,
70 RedisError,
71 ResponseError,
72 SlotNotCoveredError,
73 TimeoutError,
74 TryAgainError,
75 WatchError,
76)
77from redis.typing import AnyKeyT, EncodableT, KeyT
78from redis.utils import (
79 SSL_AVAILABLE,
80 deprecated_args,
81 deprecated_function,
82 get_lib_version,
83 safe_str,
84 str_if_bytes,
85 truncate_text,
86)
87
88if SSL_AVAILABLE:
89 from ssl import TLSVersion, VerifyMode
90else:
91 TLSVersion = None
92 VerifyMode = None
93
94TargetNodesT = TypeVar(
95 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
96)
97
98
99class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
100 """
101 Create a new RedisCluster client.
102
103 Pass one of parameters:
104
105 - `host` & `port`
106 - `startup_nodes`
107
108 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
109 | Use ``await`` :meth:`close` to disconnect connections & close client.
110
111 Many commands support the target_nodes kwarg. It can be one of the
112 :attr:`NODE_FLAGS`:
113
114 - :attr:`PRIMARIES`
115 - :attr:`REPLICAS`
116 - :attr:`ALL_NODES`
117 - :attr:`RANDOM`
118 - :attr:`DEFAULT_NODE`
119
120 Note: This client is not thread/process/fork safe.
121
122 :param host:
123 | Can be used to point to a startup node
124 :param port:
125 | Port used if **host** is provided
126 :param startup_nodes:
127 | :class:`~.ClusterNode` to used as a startup node
128 :param require_full_coverage:
129 | When set to ``False``: the client will not require a full coverage of
130 the slots. However, if not all slots are covered, and at least one node
131 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
132 a :class:`~.ClusterDownError` for some key-based commands.
133 | When set to ``True``: all slots must be covered to construct the cluster
134 client. If not all slots are covered, :class:`~.RedisClusterException` will be
135 thrown.
136 | See:
137 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
138 :param read_from_replicas:
139 | @deprecated - please use load_balancing_strategy instead
140 | Enable read from replicas in READONLY mode.
141 When set to true, read commands will be assigned between the primary and
142 its replications in a Round-Robin manner.
143 The data read from replicas is eventually consistent with the data in primary nodes.
144 :param load_balancing_strategy:
145 | Enable read from replicas in READONLY mode and defines the load balancing
146 strategy that will be used for cluster node selection.
147 The data read from replicas is eventually consistent with the data in primary nodes.
148 :param dynamic_startup_nodes:
149 | Set the RedisCluster's startup nodes to all the discovered nodes.
150 If true (default value), the cluster's discovered nodes will be used to
151 determine the cluster nodes-slots mapping in the next topology refresh.
152 It will remove the initial passed startup nodes if their endpoints aren't
153 listed in the CLUSTER SLOTS output.
154 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
155 specific IP addresses, it is best to set it to false.
156 :param reinitialize_steps:
157 | Specifies the number of MOVED errors that need to occur before reinitializing
158 the whole cluster topology. If a MOVED error occurs and the cluster does not
159 need to be reinitialized on this current error handling, only the MOVED slot
160 will be patched with the redirected node.
161 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
162 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
163 0.
164 :param cluster_error_retry_attempts:
165 | @deprecated - Please configure the 'retry' object instead
166 In case 'retry' object is set - this argument is ignored!
167
168 Number of times to retry before raising an error when :class:`~.TimeoutError`,
169 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
170 or :class:`~.ClusterDownError` are encountered
171 :param retry:
172 | A retry object that defines the retry strategy and the number of
173 retries for the cluster client.
174 In current implementation for the cluster client (starting form redis-py version 6.0.0)
175 the retry object is not yet fully utilized, instead it is used just to determine
176 the number of retries for the cluster client.
177 In the future releases the retry object will be used to handle the cluster client retries!
178 :param max_connections:
179 | Maximum number of connections per node. If there are no free connections & the
180 maximum number of connections are already created, a
181 :class:`~.MaxConnectionsError` is raised.
182 :param address_remap:
183 | An optional callable which, when provided with an internal network
184 address of a node, e.g. a `(host, port)` tuple, will return the address
185 where the node is reachable. This can be used to map the addresses at
186 which the nodes _think_ they are, to addresses at which a client may
187 reach them, such as when they sit behind a proxy.
188
189 | Rest of the arguments will be passed to the
190 :class:`~redis.asyncio.connection.Connection` instances when created
191
192 :raises RedisClusterException:
193 if any arguments are invalid or unknown. Eg:
194
195 - `db` != 0 or None
196 - `path` argument for unix socket connection
197 - none of the `host`/`port` & `startup_nodes` were provided
198
199 """
200
201 @classmethod
202 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
203 """
204 Return a Redis client object configured from the given URL.
205
206 For example::
207
208 redis://[[username]:[password]]@localhost:6379/0
209 rediss://[[username]:[password]]@localhost:6379/0
210
211 Three URL schemes are supported:
212
213 - `redis://` creates a TCP socket connection. See more at:
214 <https://www.iana.org/assignments/uri-schemes/prov/redis>
215 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
216 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
217
218 The username, password, hostname, path and all querystring values are passed
219 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
220 with their corresponding characters.
221
222 All querystring options are cast to their appropriate Python types. Boolean
223 arguments can be specified with string values "True"/"False" or "Yes"/"No".
224 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
225 parsed, the querystring arguments and keyword arguments are passed to
226 :class:`~redis.asyncio.connection.Connection` when created.
227 In the case of conflicting arguments, querystring arguments are used.
228 """
229 kwargs.update(parse_url(url))
230 if kwargs.pop("connection_class", None) is SSLConnection:
231 kwargs["ssl"] = True
232 return cls(**kwargs)
233
234 __slots__ = (
235 "_initialize",
236 "_lock",
237 "retry",
238 "command_flags",
239 "commands_parser",
240 "connection_kwargs",
241 "encoder",
242 "node_flags",
243 "nodes_manager",
244 "read_from_replicas",
245 "reinitialize_counter",
246 "reinitialize_steps",
247 "response_callbacks",
248 "result_callbacks",
249 )
250
251 @deprecated_args(
252 args_to_warn=["read_from_replicas"],
253 reason="Please configure the 'load_balancing_strategy' instead",
254 version="5.3.0",
255 )
256 @deprecated_args(
257 args_to_warn=[
258 "cluster_error_retry_attempts",
259 ],
260 reason="Please configure the 'retry' object instead",
261 version="6.0.0",
262 )
263 def __init__(
264 self,
265 host: Optional[str] = None,
266 port: Union[str, int] = 6379,
267 # Cluster related kwargs
268 startup_nodes: Optional[List["ClusterNode"]] = None,
269 require_full_coverage: bool = True,
270 read_from_replicas: bool = False,
271 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
272 dynamic_startup_nodes: bool = True,
273 reinitialize_steps: int = 5,
274 cluster_error_retry_attempts: int = 3,
275 max_connections: int = 2**31,
276 retry: Optional["Retry"] = None,
277 retry_on_error: Optional[List[Type[Exception]]] = None,
278 # Client related kwargs
279 db: Union[str, int] = 0,
280 path: Optional[str] = None,
281 credential_provider: Optional[CredentialProvider] = None,
282 username: Optional[str] = None,
283 password: Optional[str] = None,
284 client_name: Optional[str] = None,
285 lib_name: Optional[str] = "redis-py",
286 lib_version: Optional[str] = get_lib_version(),
287 # Encoding related kwargs
288 encoding: str = "utf-8",
289 encoding_errors: str = "strict",
290 decode_responses: bool = False,
291 # Connection related kwargs
292 health_check_interval: float = 0,
293 socket_connect_timeout: Optional[float] = None,
294 socket_keepalive: bool = False,
295 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
296 socket_timeout: Optional[float] = None,
297 # SSL related kwargs
298 ssl: bool = False,
299 ssl_ca_certs: Optional[str] = None,
300 ssl_ca_data: Optional[str] = None,
301 ssl_cert_reqs: Union[str, VerifyMode] = "required",
302 ssl_certfile: Optional[str] = None,
303 ssl_check_hostname: bool = True,
304 ssl_keyfile: Optional[str] = None,
305 ssl_min_version: Optional[TLSVersion] = None,
306 ssl_ciphers: Optional[str] = None,
307 protocol: Optional[int] = 2,
308 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
309 event_dispatcher: Optional[EventDispatcher] = None,
310 ) -> None:
311 if db:
312 raise RedisClusterException(
313 "Argument 'db' must be 0 or None in cluster mode"
314 )
315
316 if path:
317 raise RedisClusterException(
318 "Unix domain socket is not supported in cluster mode"
319 )
320
321 if (not host or not port) and not startup_nodes:
322 raise RedisClusterException(
323 "RedisCluster requires at least one node to discover the cluster.\n"
324 "Please provide one of the following or use RedisCluster.from_url:\n"
325 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
326 " - startup_nodes: RedisCluster(startup_nodes=["
327 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
328 )
329
330 kwargs: Dict[str, Any] = {
331 "max_connections": max_connections,
332 "connection_class": Connection,
333 # Client related kwargs
334 "credential_provider": credential_provider,
335 "username": username,
336 "password": password,
337 "client_name": client_name,
338 "lib_name": lib_name,
339 "lib_version": lib_version,
340 # Encoding related kwargs
341 "encoding": encoding,
342 "encoding_errors": encoding_errors,
343 "decode_responses": decode_responses,
344 # Connection related kwargs
345 "health_check_interval": health_check_interval,
346 "socket_connect_timeout": socket_connect_timeout,
347 "socket_keepalive": socket_keepalive,
348 "socket_keepalive_options": socket_keepalive_options,
349 "socket_timeout": socket_timeout,
350 "protocol": protocol,
351 }
352
353 if ssl:
354 # SSL related kwargs
355 kwargs.update(
356 {
357 "connection_class": SSLConnection,
358 "ssl_ca_certs": ssl_ca_certs,
359 "ssl_ca_data": ssl_ca_data,
360 "ssl_cert_reqs": ssl_cert_reqs,
361 "ssl_certfile": ssl_certfile,
362 "ssl_check_hostname": ssl_check_hostname,
363 "ssl_keyfile": ssl_keyfile,
364 "ssl_min_version": ssl_min_version,
365 "ssl_ciphers": ssl_ciphers,
366 }
367 )
368
369 if read_from_replicas or load_balancing_strategy:
370 # Call our on_connect function to configure READONLY mode
371 kwargs["redis_connect_func"] = self.on_connect
372
373 if retry:
374 self.retry = retry
375 else:
376 self.retry = Retry(
377 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
378 retries=cluster_error_retry_attempts,
379 )
380 if retry_on_error:
381 self.retry.update_supported_errors(retry_on_error)
382
383 kwargs["response_callbacks"] = _RedisCallbacks.copy()
384 if kwargs.get("protocol") in ["3", 3]:
385 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
386 else:
387 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
388 self.connection_kwargs = kwargs
389
390 if startup_nodes:
391 passed_nodes = []
392 for node in startup_nodes:
393 passed_nodes.append(
394 ClusterNode(node.host, node.port, **self.connection_kwargs)
395 )
396 startup_nodes = passed_nodes
397 else:
398 startup_nodes = []
399 if host and port:
400 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
401
402 if event_dispatcher is None:
403 self._event_dispatcher = EventDispatcher()
404 else:
405 self._event_dispatcher = event_dispatcher
406
407 self.nodes_manager = NodesManager(
408 startup_nodes,
409 require_full_coverage,
410 kwargs,
411 dynamic_startup_nodes=dynamic_startup_nodes,
412 address_remap=address_remap,
413 event_dispatcher=self._event_dispatcher,
414 )
415 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
416 self.read_from_replicas = read_from_replicas
417 self.load_balancing_strategy = load_balancing_strategy
418 self.reinitialize_steps = reinitialize_steps
419 self.reinitialize_counter = 0
420 self.commands_parser = AsyncCommandsParser()
421 self.node_flags = self.__class__.NODE_FLAGS.copy()
422 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
423 self.response_callbacks = kwargs["response_callbacks"]
424 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
425 self.result_callbacks["CLUSTER SLOTS"] = (
426 lambda cmd, res, **kwargs: parse_cluster_slots(
427 list(res.values())[0], **kwargs
428 )
429 )
430
431 self._initialize = True
432 self._lock: Optional[asyncio.Lock] = None
433
434 # When used as an async context manager, we need to increment and decrement
435 # a usage counter so that we can close the connection pool when no one is
436 # using the client.
437 self._usage_counter = 0
438 self._usage_lock = asyncio.Lock()
439
440 async def initialize(self) -> "RedisCluster":
441 """Get all nodes from startup nodes & creates connections if not initialized."""
442 if self._initialize:
443 if not self._lock:
444 self._lock = asyncio.Lock()
445 async with self._lock:
446 if self._initialize:
447 try:
448 await self.nodes_manager.initialize()
449 await self.commands_parser.initialize(
450 self.nodes_manager.default_node
451 )
452 self._initialize = False
453 except BaseException:
454 await self.nodes_manager.aclose()
455 await self.nodes_manager.aclose("startup_nodes")
456 raise
457 return self
458
459 async def aclose(self) -> None:
460 """Close all connections & client if initialized."""
461 if not self._initialize:
462 if not self._lock:
463 self._lock = asyncio.Lock()
464 async with self._lock:
465 if not self._initialize:
466 self._initialize = True
467 await self.nodes_manager.aclose()
468 await self.nodes_manager.aclose("startup_nodes")
469
470 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
471 async def close(self) -> None:
472 """alias for aclose() for backwards compatibility"""
473 await self.aclose()
474
475 async def __aenter__(self) -> "RedisCluster":
476 """
477 Async context manager entry. Increments a usage counter so that the
478 connection pool is only closed (via aclose()) when no context is using
479 the client.
480 """
481 await self._increment_usage()
482 try:
483 # Initialize the client (i.e. establish connection, etc.)
484 return await self.initialize()
485 except Exception:
486 # If initialization fails, decrement the counter to keep it in sync
487 await self._decrement_usage()
488 raise
489
490 async def _increment_usage(self) -> int:
491 """
492 Helper coroutine to increment the usage counter while holding the lock.
493 Returns the new value of the usage counter.
494 """
495 async with self._usage_lock:
496 self._usage_counter += 1
497 return self._usage_counter
498
499 async def _decrement_usage(self) -> int:
500 """
501 Helper coroutine to decrement the usage counter while holding the lock.
502 Returns the new value of the usage counter.
503 """
504 async with self._usage_lock:
505 self._usage_counter -= 1
506 return self._usage_counter
507
508 async def __aexit__(self, exc_type, exc_value, traceback):
509 """
510 Async context manager exit. Decrements a usage counter. If this is the
511 last exit (counter becomes zero), the client closes its connection pool.
512 """
513 current_usage = await asyncio.shield(self._decrement_usage())
514 if current_usage == 0:
515 # This was the last active context, so disconnect the pool.
516 await asyncio.shield(self.aclose())
517
518 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
519 return self.initialize().__await__()
520
521 _DEL_MESSAGE = "Unclosed RedisCluster client"
522
523 def __del__(
524 self,
525 _warn: Any = warnings.warn,
526 _grl: Any = asyncio.get_running_loop,
527 ) -> None:
528 if hasattr(self, "_initialize") and not self._initialize:
529 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
530 try:
531 context = {"client": self, "message": self._DEL_MESSAGE}
532 _grl().call_exception_handler(context)
533 except RuntimeError:
534 pass
535
536 async def on_connect(self, connection: Connection) -> None:
537 await connection.on_connect()
538
539 # Sending READONLY command to server to configure connection as
540 # readonly. Since each cluster node may change its server type due
541 # to a failover, we should establish a READONLY connection
542 # regardless of the server type. If this is a primary connection,
543 # READONLY would not affect executing write commands.
544 await connection.send_command("READONLY")
545 if str_if_bytes(await connection.read_response()) != "OK":
546 raise ConnectionError("READONLY command failed")
547
548 def get_nodes(self) -> List["ClusterNode"]:
549 """Get all nodes of the cluster."""
550 return list(self.nodes_manager.nodes_cache.values())
551
552 def get_primaries(self) -> List["ClusterNode"]:
553 """Get the primary nodes of the cluster."""
554 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
555
556 def get_replicas(self) -> List["ClusterNode"]:
557 """Get the replica nodes of the cluster."""
558 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
559
560 def get_random_node(self) -> "ClusterNode":
561 """Get a random node of the cluster."""
562 return random.choice(list(self.nodes_manager.nodes_cache.values()))
563
564 def get_default_node(self) -> "ClusterNode":
565 """Get the default node of the client."""
566 return self.nodes_manager.default_node
567
568 def set_default_node(self, node: "ClusterNode") -> None:
569 """
570 Set the default node of the client.
571
572 :raises DataError: if None is passed or node does not exist in cluster.
573 """
574 if not node or not self.get_node(node_name=node.name):
575 raise DataError("The requested node does not exist in the cluster.")
576
577 self.nodes_manager.default_node = node
578
579 def get_node(
580 self,
581 host: Optional[str] = None,
582 port: Optional[int] = None,
583 node_name: Optional[str] = None,
584 ) -> Optional["ClusterNode"]:
585 """Get node by (host, port) or node_name."""
586 return self.nodes_manager.get_node(host, port, node_name)
587
588 def get_node_from_key(
589 self, key: str, replica: bool = False
590 ) -> Optional["ClusterNode"]:
591 """
592 Get the cluster node corresponding to the provided key.
593
594 :param key:
595 :param replica:
596 | Indicates if a replica should be returned
597 |
598 None will returned if no replica holds this key
599
600 :raises SlotNotCoveredError: if the key is not covered by any slot.
601 """
602 slot = self.keyslot(key)
603 slot_cache = self.nodes_manager.slots_cache.get(slot)
604 if not slot_cache:
605 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
606
607 if replica:
608 if len(self.nodes_manager.slots_cache[slot]) < 2:
609 return None
610 node_idx = 1
611 else:
612 node_idx = 0
613
614 return slot_cache[node_idx]
615
616 def keyslot(self, key: EncodableT) -> int:
617 """
618 Find the keyslot for a given key.
619
620 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
621 """
622 return key_slot(self.encoder.encode(key))
623
624 def get_encoder(self) -> Encoder:
625 """Get the encoder object of the client."""
626 return self.encoder
627
628 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
629 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
630 return self.connection_kwargs
631
632 def set_retry(self, retry: Retry) -> None:
633 self.retry = retry
634
635 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
636 """Set a custom response callback."""
637 self.response_callbacks[command] = callback
638
639 async def _determine_nodes(
640 self, command: str, *args: Any, node_flag: Optional[str] = None
641 ) -> List["ClusterNode"]:
642 # Determine which nodes should be executed the command on.
643 # Returns a list of target nodes.
644 if not node_flag:
645 # get the nodes group for this command if it was predefined
646 node_flag = self.command_flags.get(command)
647
648 if node_flag in self.node_flags:
649 if node_flag == self.__class__.DEFAULT_NODE:
650 # return the cluster's default node
651 return [self.nodes_manager.default_node]
652 if node_flag == self.__class__.PRIMARIES:
653 # return all primaries
654 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
655 if node_flag == self.__class__.REPLICAS:
656 # return all replicas
657 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
658 if node_flag == self.__class__.ALL_NODES:
659 # return all nodes
660 return list(self.nodes_manager.nodes_cache.values())
661 if node_flag == self.__class__.RANDOM:
662 # return a random node
663 return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
664
665 # get the node that holds the key's slot
666 return [
667 self.nodes_manager.get_node_from_slot(
668 await self._determine_slot(command, *args),
669 self.read_from_replicas and command in READ_COMMANDS,
670 self.load_balancing_strategy if command in READ_COMMANDS else None,
671 )
672 ]
673
674 async def _determine_slot(self, command: str, *args: Any) -> int:
675 if self.command_flags.get(command) == SLOT_ID:
676 # The command contains the slot ID
677 return int(args[0])
678
679 # Get the keys in the command
680
681 # EVAL and EVALSHA are common enough that it's wasteful to go to the
682 # redis server to parse the keys. Besides, there is a bug in redis<7.0
683 # where `self._get_command_keys()` fails anyway. So, we special case
684 # EVAL/EVALSHA.
685 # - issue: https://github.com/redis/redis/issues/9493
686 # - fix: https://github.com/redis/redis/pull/9733
687 if command.upper() in ("EVAL", "EVALSHA"):
688 # command syntax: EVAL "script body" num_keys ...
689 if len(args) < 2:
690 raise RedisClusterException(
691 f"Invalid args in command: {command, *args}"
692 )
693 keys = args[2 : 2 + int(args[1])]
694 # if there are 0 keys, that means the script can be run on any node
695 # so we can just return a random slot
696 if not keys:
697 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
698 else:
699 keys = await self.commands_parser.get_keys(command, *args)
700 if not keys:
701 # FCALL can call a function with 0 keys, that means the function
702 # can be run on any node so we can just return a random slot
703 if command.upper() in ("FCALL", "FCALL_RO"):
704 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
705 raise RedisClusterException(
706 "No way to dispatch this command to Redis Cluster. "
707 "Missing key.\nYou can execute the command by specifying "
708 f"target nodes.\nCommand: {args}"
709 )
710
711 # single key command
712 if len(keys) == 1:
713 return self.keyslot(keys[0])
714
715 # multi-key command; we need to make sure all keys are mapped to
716 # the same slot
717 slots = {self.keyslot(key) for key in keys}
718 if len(slots) != 1:
719 raise RedisClusterException(
720 f"{command} - all keys must map to the same key slot"
721 )
722
723 return slots.pop()
724
725 def _is_node_flag(self, target_nodes: Any) -> bool:
726 return isinstance(target_nodes, str) and target_nodes in self.node_flags
727
728 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
729 if isinstance(target_nodes, list):
730 nodes = target_nodes
731 elif isinstance(target_nodes, ClusterNode):
732 # Supports passing a single ClusterNode as a variable
733 nodes = [target_nodes]
734 elif isinstance(target_nodes, dict):
735 # Supports dictionaries of the format {node_name: node}.
736 # It enables to execute commands with multi nodes as follows:
737 # rc.cluster_save_config(rc.get_primaries())
738 nodes = list(target_nodes.values())
739 else:
740 raise TypeError(
741 "target_nodes type can be one of the following: "
742 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
743 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
744 f"The passed type is {type(target_nodes)}"
745 )
746 return nodes
747
748 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
749 """
750 Execute a raw command on the appropriate cluster node or target_nodes.
751
752 It will retry the command as specified by the retries property of
753 the :attr:`retry` & then raise an exception.
754
755 :param args:
756 | Raw command args
757 :param kwargs:
758
759 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
760 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
761 - Rest of the kwargs are passed to the Redis connection
762
763 :raises RedisClusterException: if target_nodes is not provided & the command
764 can't be mapped to a slot
765 """
766 command = args[0]
767 target_nodes = []
768 target_nodes_specified = False
769 retry_attempts = self.retry.get_retries()
770
771 passed_targets = kwargs.pop("target_nodes", None)
772 if passed_targets and not self._is_node_flag(passed_targets):
773 target_nodes = self._parse_target_nodes(passed_targets)
774 target_nodes_specified = True
775 retry_attempts = 0
776
777 # Add one for the first execution
778 execute_attempts = 1 + retry_attempts
779 for _ in range(execute_attempts):
780 if self._initialize:
781 await self.initialize()
782 if (
783 len(target_nodes) == 1
784 and target_nodes[0] == self.get_default_node()
785 ):
786 # Replace the default cluster node
787 self.replace_default_node()
788 try:
789 if not target_nodes_specified:
790 # Determine the nodes to execute the command on
791 target_nodes = await self._determine_nodes(
792 *args, node_flag=passed_targets
793 )
794 if not target_nodes:
795 raise RedisClusterException(
796 f"No targets were found to execute {args} command on"
797 )
798
799 if len(target_nodes) == 1:
800 # Return the processed result
801 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
802 if command in self.result_callbacks:
803 return self.result_callbacks[command](
804 command, {target_nodes[0].name: ret}, **kwargs
805 )
806 return ret
807 else:
808 keys = [node.name for node in target_nodes]
809 values = await asyncio.gather(
810 *(
811 asyncio.create_task(
812 self._execute_command(node, *args, **kwargs)
813 )
814 for node in target_nodes
815 )
816 )
817 if command in self.result_callbacks:
818 return self.result_callbacks[command](
819 command, dict(zip(keys, values)), **kwargs
820 )
821 return dict(zip(keys, values))
822 except Exception as e:
823 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
824 # The nodes and slots cache were should be reinitialized.
825 # Try again with the new cluster setup.
826 retry_attempts -= 1
827 continue
828 else:
829 # raise the exception
830 raise e
831
832 async def _execute_command(
833 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
834 ) -> Any:
835 asking = moved = False
836 redirect_addr = None
837 ttl = self.RedisClusterRequestTTL
838
839 while ttl > 0:
840 ttl -= 1
841 try:
842 if asking:
843 target_node = self.get_node(node_name=redirect_addr)
844 await target_node.execute_command("ASKING")
845 asking = False
846 elif moved:
847 # MOVED occurred and the slots cache was updated,
848 # refresh the target node
849 slot = await self._determine_slot(*args)
850 target_node = self.nodes_manager.get_node_from_slot(
851 slot,
852 self.read_from_replicas and args[0] in READ_COMMANDS,
853 self.load_balancing_strategy
854 if args[0] in READ_COMMANDS
855 else None,
856 )
857 moved = False
858
859 return await target_node.execute_command(*args, **kwargs)
860 except BusyLoadingError:
861 raise
862 except MaxConnectionsError:
863 # MaxConnectionsError indicates client-side resource exhaustion
864 # (too many connections in the pool), not a node failure.
865 # Don't treat this as a node failure - just re-raise the error
866 # without reinitializing the cluster.
867 raise
868 except (ConnectionError, TimeoutError):
869 # Connection retries are being handled in the node's
870 # Retry object.
871 # Remove the failed node from the startup nodes before we try
872 # to reinitialize the cluster
873 self.nodes_manager.startup_nodes.pop(target_node.name, None)
874 # Hard force of reinitialize of the node/slots setup
875 # and try again with the new setup
876 await self.aclose()
877 raise
878 except (ClusterDownError, SlotNotCoveredError):
879 # ClusterDownError can occur during a failover and to get
880 # self-healed, we will try to reinitialize the cluster layout
881 # and retry executing the command
882
883 # SlotNotCoveredError can occur when the cluster is not fully
884 # initialized or can be temporary issue.
885 # We will try to reinitialize the cluster topology
886 # and retry executing the command
887
888 await self.aclose()
889 await asyncio.sleep(0.25)
890 raise
891 except MovedError as e:
892 # First, we will try to patch the slots/nodes cache with the
893 # redirected node output and try again. If MovedError exceeds
894 # 'reinitialize_steps' number of times, we will force
895 # reinitializing the tables, and then try again.
896 # 'reinitialize_steps' counter will increase faster when
897 # the same client object is shared between multiple threads. To
898 # reduce the frequency you can set this variable in the
899 # RedisCluster constructor.
900 self.reinitialize_counter += 1
901 if (
902 self.reinitialize_steps
903 and self.reinitialize_counter % self.reinitialize_steps == 0
904 ):
905 await self.aclose()
906 # Reset the counter
907 self.reinitialize_counter = 0
908 else:
909 self.nodes_manager._moved_exception = e
910 moved = True
911 except AskError as e:
912 redirect_addr = get_node_name(host=e.host, port=e.port)
913 asking = True
914 except TryAgainError:
915 if ttl < self.RedisClusterRequestTTL / 2:
916 await asyncio.sleep(0.05)
917
918 raise ClusterError("TTL exhausted.")
919
920 def pipeline(
921 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
922 ) -> "ClusterPipeline":
923 """
924 Create & return a new :class:`~.ClusterPipeline` object.
925
926 Cluster implementation of pipeline does not support transaction or shard_hint.
927
928 :raises RedisClusterException: if transaction or shard_hint are truthy values
929 """
930 if shard_hint:
931 raise RedisClusterException("shard_hint is deprecated in cluster mode")
932
933 return ClusterPipeline(self, transaction)
934
935 def lock(
936 self,
937 name: KeyT,
938 timeout: Optional[float] = None,
939 sleep: float = 0.1,
940 blocking: bool = True,
941 blocking_timeout: Optional[float] = None,
942 lock_class: Optional[Type[Lock]] = None,
943 thread_local: bool = True,
944 raise_on_release_error: bool = True,
945 ) -> Lock:
946 """
947 Return a new Lock object using key ``name`` that mimics
948 the behavior of threading.Lock.
949
950 If specified, ``timeout`` indicates a maximum life for the lock.
951 By default, it will remain locked until release() is called.
952
953 ``sleep`` indicates the amount of time to sleep per loop iteration
954 when the lock is in blocking mode and another client is currently
955 holding the lock.
956
957 ``blocking`` indicates whether calling ``acquire`` should block until
958 the lock has been acquired or to fail immediately, causing ``acquire``
959 to return False and the lock not being acquired. Defaults to True.
960 Note this value can be overridden by passing a ``blocking``
961 argument to ``acquire``.
962
963 ``blocking_timeout`` indicates the maximum amount of time in seconds to
964 spend trying to acquire the lock. A value of ``None`` indicates
965 continue trying forever. ``blocking_timeout`` can be specified as a
966 float or integer, both representing the number of seconds to wait.
967
968 ``lock_class`` forces the specified lock implementation. Note that as
969 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
970 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
971 you have created your own custom lock class.
972
973 ``thread_local`` indicates whether the lock token is placed in
974 thread-local storage. By default, the token is placed in thread local
975 storage so that a thread only sees its token, not a token set by
976 another thread. Consider the following timeline:
977
978 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
979 thread-1 sets the token to "abc"
980 time: 1, thread-2 blocks trying to acquire `my-lock` using the
981 Lock instance.
982 time: 5, thread-1 has not yet completed. redis expires the lock
983 key.
984 time: 5, thread-2 acquired `my-lock` now that it's available.
985 thread-2 sets the token to "xyz"
986 time: 6, thread-1 finishes its work and calls release(). if the
987 token is *not* stored in thread local storage, then
988 thread-1 would see the token value as "xyz" and would be
989 able to successfully release the thread-2's lock.
990
991 ``raise_on_release_error`` indicates whether to raise an exception when
992 the lock is no longer owned when exiting the context manager. By default,
993 this is True, meaning an exception will be raised. If False, the warning
994 will be logged and the exception will be suppressed.
995
996 In some use cases it's necessary to disable thread local storage. For
997 example, if you have code where one thread acquires a lock and passes
998 that lock instance to a worker thread to release later. If thread
999 local storage isn't disabled in this case, the worker thread won't see
1000 the token set by the thread that acquired the lock. Our assumption
1001 is that these cases aren't common and as such default to using
1002 thread local storage."""
1003 if lock_class is None:
1004 lock_class = Lock
1005 return lock_class(
1006 self,
1007 name,
1008 timeout=timeout,
1009 sleep=sleep,
1010 blocking=blocking,
1011 blocking_timeout=blocking_timeout,
1012 thread_local=thread_local,
1013 raise_on_release_error=raise_on_release_error,
1014 )
1015
1016 async def transaction(
1017 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1018 ):
1019 """
1020 Convenience method for executing the callable `func` as a transaction
1021 while watching all keys specified in `watches`. The 'func' callable
1022 should expect a single argument which is a Pipeline object.
1023 """
1024 shard_hint = kwargs.pop("shard_hint", None)
1025 value_from_callable = kwargs.pop("value_from_callable", False)
1026 watch_delay = kwargs.pop("watch_delay", None)
1027 async with self.pipeline(True, shard_hint) as pipe:
1028 while True:
1029 try:
1030 if watches:
1031 await pipe.watch(*watches)
1032 func_value = await func(pipe)
1033 exec_value = await pipe.execute()
1034 return func_value if value_from_callable else exec_value
1035 except WatchError:
1036 if watch_delay is not None and watch_delay > 0:
1037 time.sleep(watch_delay)
1038 continue
1039
1040
1041class ClusterNode:
1042 """
1043 Create a new ClusterNode.
1044
1045 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1046 objects for the (host, port).
1047 """
1048
1049 __slots__ = (
1050 "_connections",
1051 "_free",
1052 "_lock",
1053 "_event_dispatcher",
1054 "connection_class",
1055 "connection_kwargs",
1056 "host",
1057 "max_connections",
1058 "name",
1059 "port",
1060 "response_callbacks",
1061 "server_type",
1062 )
1063
1064 def __init__(
1065 self,
1066 host: str,
1067 port: Union[str, int],
1068 server_type: Optional[str] = None,
1069 *,
1070 max_connections: int = 2**31,
1071 connection_class: Type[Connection] = Connection,
1072 **connection_kwargs: Any,
1073 ) -> None:
1074 if host == "localhost":
1075 host = socket.gethostbyname(host)
1076
1077 connection_kwargs["host"] = host
1078 connection_kwargs["port"] = port
1079 self.host = host
1080 self.port = port
1081 self.name = get_node_name(host, port)
1082 self.server_type = server_type
1083
1084 self.max_connections = max_connections
1085 self.connection_class = connection_class
1086 self.connection_kwargs = connection_kwargs
1087 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1088
1089 self._connections: List[Connection] = []
1090 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1091 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1092 if self._event_dispatcher is None:
1093 self._event_dispatcher = EventDispatcher()
1094
1095 def __repr__(self) -> str:
1096 return (
1097 f"[host={self.host}, port={self.port}, "
1098 f"name={self.name}, server_type={self.server_type}]"
1099 )
1100
1101 def __eq__(self, obj: Any) -> bool:
1102 return isinstance(obj, ClusterNode) and obj.name == self.name
1103
1104 _DEL_MESSAGE = "Unclosed ClusterNode object"
1105
1106 def __del__(
1107 self,
1108 _warn: Any = warnings.warn,
1109 _grl: Any = asyncio.get_running_loop,
1110 ) -> None:
1111 for connection in self._connections:
1112 if connection.is_connected:
1113 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1114
1115 try:
1116 context = {"client": self, "message": self._DEL_MESSAGE}
1117 _grl().call_exception_handler(context)
1118 except RuntimeError:
1119 pass
1120 break
1121
1122 async def disconnect(self) -> None:
1123 ret = await asyncio.gather(
1124 *(
1125 asyncio.create_task(connection.disconnect())
1126 for connection in self._connections
1127 ),
1128 return_exceptions=True,
1129 )
1130 exc = next((res for res in ret if isinstance(res, Exception)), None)
1131 if exc:
1132 raise exc
1133
1134 def acquire_connection(self) -> Connection:
1135 try:
1136 return self._free.popleft()
1137 except IndexError:
1138 if len(self._connections) < self.max_connections:
1139 # We are configuring the connection pool not to retry
1140 # connections on lower level clients to avoid retrying
1141 # connections to nodes that are not reachable
1142 # and to avoid blocking the connection pool.
1143 # The only error that will have some handling in the lower
1144 # level clients is ConnectionError which will trigger disconnection
1145 # of the socket.
1146 # The retries will be handled on cluster client level
1147 # where we will have proper handling of the cluster topology
1148 retry = Retry(
1149 backoff=NoBackoff(),
1150 retries=0,
1151 supported_errors=(ConnectionError,),
1152 )
1153 connection_kwargs = self.connection_kwargs.copy()
1154 connection_kwargs["retry"] = retry
1155 connection = self.connection_class(**connection_kwargs)
1156 self._connections.append(connection)
1157 return connection
1158
1159 raise MaxConnectionsError()
1160
1161 def release(self, connection: Connection) -> None:
1162 """
1163 Release connection back to free queue.
1164 """
1165 self._free.append(connection)
1166
1167 async def parse_response(
1168 self, connection: Connection, command: str, **kwargs: Any
1169 ) -> Any:
1170 try:
1171 if NEVER_DECODE in kwargs:
1172 response = await connection.read_response(disable_decoding=True)
1173 kwargs.pop(NEVER_DECODE)
1174 else:
1175 response = await connection.read_response()
1176 except ResponseError:
1177 if EMPTY_RESPONSE in kwargs:
1178 return kwargs[EMPTY_RESPONSE]
1179 raise
1180
1181 if EMPTY_RESPONSE in kwargs:
1182 kwargs.pop(EMPTY_RESPONSE)
1183
1184 # Remove keys entry, it needs only for cache.
1185 kwargs.pop("keys", None)
1186
1187 # Return response
1188 if command in self.response_callbacks:
1189 return self.response_callbacks[command](response, **kwargs)
1190
1191 return response
1192
1193 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1194 # Acquire connection
1195 connection = self.acquire_connection()
1196
1197 # Execute command
1198 await connection.send_packed_command(connection.pack_command(*args), False)
1199
1200 # Read response
1201 try:
1202 return await self.parse_response(connection, args[0], **kwargs)
1203 finally:
1204 # Release connection
1205 self._free.append(connection)
1206
1207 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1208 # Acquire connection
1209 connection = self.acquire_connection()
1210
1211 # Execute command
1212 await connection.send_packed_command(
1213 connection.pack_commands(cmd.args for cmd in commands), False
1214 )
1215
1216 # Read responses
1217 ret = False
1218 for cmd in commands:
1219 try:
1220 cmd.result = await self.parse_response(
1221 connection, cmd.args[0], **cmd.kwargs
1222 )
1223 except Exception as e:
1224 cmd.result = e
1225 ret = True
1226
1227 # Release connection
1228 self._free.append(connection)
1229
1230 return ret
1231
1232 async def re_auth_callback(self, token: TokenInterface):
1233 tmp_queue = collections.deque()
1234 while self._free:
1235 conn = self._free.popleft()
1236 await conn.retry.call_with_retry(
1237 lambda: conn.send_command(
1238 "AUTH", token.try_get("oid"), token.get_value()
1239 ),
1240 lambda error: self._mock(error),
1241 )
1242 await conn.retry.call_with_retry(
1243 lambda: conn.read_response(), lambda error: self._mock(error)
1244 )
1245 tmp_queue.append(conn)
1246
1247 while tmp_queue:
1248 conn = tmp_queue.popleft()
1249 self._free.append(conn)
1250
1251 async def _mock(self, error: RedisError):
1252 """
1253 Dummy functions, needs to be passed as error callback to retry object.
1254 :param error:
1255 :return:
1256 """
1257 pass
1258
1259
1260class NodesManager:
1261 __slots__ = (
1262 "_dynamic_startup_nodes",
1263 "_moved_exception",
1264 "_event_dispatcher",
1265 "connection_kwargs",
1266 "default_node",
1267 "nodes_cache",
1268 "read_load_balancer",
1269 "require_full_coverage",
1270 "slots_cache",
1271 "startup_nodes",
1272 "address_remap",
1273 )
1274
1275 def __init__(
1276 self,
1277 startup_nodes: List["ClusterNode"],
1278 require_full_coverage: bool,
1279 connection_kwargs: Dict[str, Any],
1280 dynamic_startup_nodes: bool = True,
1281 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1282 event_dispatcher: Optional[EventDispatcher] = None,
1283 ) -> None:
1284 self.startup_nodes = {node.name: node for node in startup_nodes}
1285 self.require_full_coverage = require_full_coverage
1286 self.connection_kwargs = connection_kwargs
1287 self.address_remap = address_remap
1288
1289 self.default_node: "ClusterNode" = None
1290 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1291 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1292 self.read_load_balancer = LoadBalancer()
1293
1294 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1295 self._moved_exception: MovedError = None
1296 if event_dispatcher is None:
1297 self._event_dispatcher = EventDispatcher()
1298 else:
1299 self._event_dispatcher = event_dispatcher
1300
1301 def get_node(
1302 self,
1303 host: Optional[str] = None,
1304 port: Optional[int] = None,
1305 node_name: Optional[str] = None,
1306 ) -> Optional["ClusterNode"]:
1307 if host and port:
1308 # the user passed host and port
1309 if host == "localhost":
1310 host = socket.gethostbyname(host)
1311 return self.nodes_cache.get(get_node_name(host=host, port=port))
1312 elif node_name:
1313 return self.nodes_cache.get(node_name)
1314 else:
1315 raise DataError(
1316 "get_node requires one of the following: 1. node name 2. host and port"
1317 )
1318
1319 def set_nodes(
1320 self,
1321 old: Dict[str, "ClusterNode"],
1322 new: Dict[str, "ClusterNode"],
1323 remove_old: bool = False,
1324 ) -> None:
1325 if remove_old:
1326 for name in list(old.keys()):
1327 if name not in new:
1328 task = asyncio.create_task(old.pop(name).disconnect()) # noqa
1329
1330 for name, node in new.items():
1331 if name in old:
1332 if old[name] is node:
1333 continue
1334 task = asyncio.create_task(old[name].disconnect()) # noqa
1335 old[name] = node
1336
1337 def update_moved_exception(self, exception):
1338 self._moved_exception = exception
1339
1340 def _update_moved_slots(self) -> None:
1341 e = self._moved_exception
1342 redirected_node = self.get_node(host=e.host, port=e.port)
1343 if redirected_node:
1344 # The node already exists
1345 if redirected_node.server_type != PRIMARY:
1346 # Update the node's server type
1347 redirected_node.server_type = PRIMARY
1348 else:
1349 # This is a new node, we will add it to the nodes cache
1350 redirected_node = ClusterNode(
1351 e.host, e.port, PRIMARY, **self.connection_kwargs
1352 )
1353 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1354 if redirected_node in self.slots_cache[e.slot_id]:
1355 # The MOVED error resulted from a failover, and the new slot owner
1356 # had previously been a replica.
1357 old_primary = self.slots_cache[e.slot_id][0]
1358 # Update the old primary to be a replica and add it to the end of
1359 # the slot's node list
1360 old_primary.server_type = REPLICA
1361 self.slots_cache[e.slot_id].append(old_primary)
1362 # Remove the old replica, which is now a primary, from the slot's
1363 # node list
1364 self.slots_cache[e.slot_id].remove(redirected_node)
1365 # Override the old primary with the new one
1366 self.slots_cache[e.slot_id][0] = redirected_node
1367 if self.default_node == old_primary:
1368 # Update the default node with the new primary
1369 self.default_node = redirected_node
1370 else:
1371 # The new slot owner is a new server, or a server from a different
1372 # shard. We need to remove all current nodes from the slot's list
1373 # (including replications) and add just the new node.
1374 self.slots_cache[e.slot_id] = [redirected_node]
1375 # Reset moved_exception
1376 self._moved_exception = None
1377
1378 def get_node_from_slot(
1379 self,
1380 slot: int,
1381 read_from_replicas: bool = False,
1382 load_balancing_strategy=None,
1383 ) -> "ClusterNode":
1384 if self._moved_exception:
1385 self._update_moved_slots()
1386
1387 if read_from_replicas is True and load_balancing_strategy is None:
1388 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1389
1390 try:
1391 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1392 # get the server index using the strategy defined in load_balancing_strategy
1393 primary_name = self.slots_cache[slot][0].name
1394 node_idx = self.read_load_balancer.get_server_index(
1395 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1396 )
1397 return self.slots_cache[slot][node_idx]
1398 return self.slots_cache[slot][0]
1399 except (IndexError, TypeError):
1400 raise SlotNotCoveredError(
1401 f'Slot "{slot}" not covered by the cluster. '
1402 f'"require_full_coverage={self.require_full_coverage}"'
1403 )
1404
1405 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1406 return [
1407 node
1408 for node in self.nodes_cache.values()
1409 if node.server_type == server_type
1410 ]
1411
1412 async def initialize(self) -> None:
1413 self.read_load_balancer.reset()
1414 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1415 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1416 disagreements = []
1417 startup_nodes_reachable = False
1418 fully_covered = False
1419 exception = None
1420 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1421 # is modified during iteration
1422 for startup_node in tuple(self.startup_nodes.values()):
1423 try:
1424 # Make sure cluster mode is enabled on this node
1425 try:
1426 self._event_dispatcher.dispatch(
1427 AfterAsyncClusterInstantiationEvent(
1428 self.nodes_cache,
1429 self.connection_kwargs.get("credential_provider", None),
1430 )
1431 )
1432 cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
1433 except ResponseError:
1434 raise RedisClusterException(
1435 "Cluster mode is not enabled on this node"
1436 )
1437 startup_nodes_reachable = True
1438 except Exception as e:
1439 # Try the next startup node.
1440 # The exception is saved and raised only if we have no more nodes.
1441 exception = e
1442 continue
1443
1444 # CLUSTER SLOTS command results in the following output:
1445 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1446 # where each node contains the following list: [IP, port, node_id]
1447 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1448 # primary node of the first slot section.
1449 # If there's only one server in the cluster, its ``host`` is ''
1450 # Fix it to the host in startup_nodes
1451 if (
1452 len(cluster_slots) == 1
1453 and not cluster_slots[0][2][0]
1454 and len(self.startup_nodes) == 1
1455 ):
1456 cluster_slots[0][2][0] = startup_node.host
1457
1458 for slot in cluster_slots:
1459 for i in range(2, len(slot)):
1460 slot[i] = [str_if_bytes(val) for val in slot[i]]
1461 primary_node = slot[2]
1462 host = primary_node[0]
1463 if host == "":
1464 host = startup_node.host
1465 port = int(primary_node[1])
1466 host, port = self.remap_host_port(host, port)
1467
1468 nodes_for_slot = []
1469
1470 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1471 if not target_node:
1472 target_node = ClusterNode(
1473 host, port, PRIMARY, **self.connection_kwargs
1474 )
1475 # add this node to the nodes cache
1476 tmp_nodes_cache[target_node.name] = target_node
1477 nodes_for_slot.append(target_node)
1478
1479 replica_nodes = slot[3:]
1480 for replica_node in replica_nodes:
1481 host = replica_node[0]
1482 port = replica_node[1]
1483 host, port = self.remap_host_port(host, port)
1484
1485 target_replica_node = tmp_nodes_cache.get(get_node_name(host, port))
1486 if not target_replica_node:
1487 target_replica_node = ClusterNode(
1488 host, port, REPLICA, **self.connection_kwargs
1489 )
1490 # add this node to the nodes cache
1491 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1492 nodes_for_slot.append(target_replica_node)
1493
1494 for i in range(int(slot[0]), int(slot[1]) + 1):
1495 if i not in tmp_slots:
1496 tmp_slots[i] = nodes_for_slot
1497 else:
1498 # Validate that 2 nodes want to use the same slot cache
1499 # setup
1500 tmp_slot = tmp_slots[i][0]
1501 if tmp_slot.name != target_node.name:
1502 disagreements.append(
1503 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1504 )
1505
1506 if len(disagreements) > 5:
1507 raise RedisClusterException(
1508 f"startup_nodes could not agree on a valid "
1509 f"slots cache: {', '.join(disagreements)}"
1510 )
1511
1512 # Validate if all slots are covered or if we should try next startup node
1513 fully_covered = True
1514 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1515 if i not in tmp_slots:
1516 fully_covered = False
1517 break
1518 if fully_covered:
1519 break
1520
1521 if not startup_nodes_reachable:
1522 raise RedisClusterException(
1523 f"Redis Cluster cannot be connected. Please provide at least "
1524 f"one reachable node: {str(exception)}"
1525 ) from exception
1526
1527 # Check if the slots are not fully covered
1528 if not fully_covered and self.require_full_coverage:
1529 # Despite the requirement that the slots be covered, there
1530 # isn't a full coverage
1531 raise RedisClusterException(
1532 f"All slots are not covered after query all startup_nodes. "
1533 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1534 f"covered..."
1535 )
1536
1537 # Set the tmp variables to the real variables
1538 self.slots_cache = tmp_slots
1539 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1540
1541 if self._dynamic_startup_nodes:
1542 # Populate the startup nodes with all discovered nodes
1543 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1544
1545 # Set the default node
1546 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
1547 # If initialize was called after a MovedError, clear it
1548 self._moved_exception = None
1549
1550 async def aclose(self, attr: str = "nodes_cache") -> None:
1551 self.default_node = None
1552 await asyncio.gather(
1553 *(
1554 asyncio.create_task(node.disconnect())
1555 for node in getattr(self, attr).values()
1556 )
1557 )
1558
1559 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1560 """
1561 Remap the host and port returned from the cluster to a different
1562 internal value. Useful if the client is not connecting directly
1563 to the cluster.
1564 """
1565 if self.address_remap:
1566 return self.address_remap((host, port))
1567 return host, port
1568
1569
1570class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1571 """
1572 Create a new ClusterPipeline object.
1573
1574 Usage::
1575
1576 result = await (
1577 rc.pipeline()
1578 .set("A", 1)
1579 .get("A")
1580 .hset("K", "F", "V")
1581 .hgetall("K")
1582 .mset_nonatomic({"A": 2, "B": 3})
1583 .get("A")
1584 .get("B")
1585 .delete("A", "B", "K")
1586 .execute()
1587 )
1588 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
1589
1590 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
1591 are split across multiple nodes, you'll get multiple results for them in the array.
1592
1593 Retryable errors:
1594 - :class:`~.ClusterDownError`
1595 - :class:`~.ConnectionError`
1596 - :class:`~.TimeoutError`
1597
1598 Redirection errors:
1599 - :class:`~.TryAgainError`
1600 - :class:`~.MovedError`
1601 - :class:`~.AskError`
1602
1603 :param client:
1604 | Existing :class:`~.RedisCluster` client
1605 """
1606
1607 __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1608
1609 def __init__(
1610 self, client: RedisCluster, transaction: Optional[bool] = None
1611 ) -> None:
1612 self.cluster_client = client
1613 self._transaction = transaction
1614 self._execution_strategy: ExecutionStrategy = (
1615 PipelineStrategy(self)
1616 if not self._transaction
1617 else TransactionStrategy(self)
1618 )
1619
1620 async def initialize(self) -> "ClusterPipeline":
1621 await self._execution_strategy.initialize()
1622 return self
1623
1624 async def __aenter__(self) -> "ClusterPipeline":
1625 return await self.initialize()
1626
1627 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1628 await self.reset()
1629
1630 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1631 return self.initialize().__await__()
1632
1633 def __bool__(self) -> bool:
1634 "Pipeline instances should always evaluate to True on Python 3+"
1635 return True
1636
1637 def __len__(self) -> int:
1638 return len(self._execution_strategy)
1639
1640 def execute_command(
1641 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1642 ) -> "ClusterPipeline":
1643 """
1644 Append a raw command to the pipeline.
1645
1646 :param args:
1647 | Raw command args
1648 :param kwargs:
1649
1650 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
1651 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1652 - Rest of the kwargs are passed to the Redis connection
1653 """
1654 return self._execution_strategy.execute_command(*args, **kwargs)
1655
1656 async def execute(
1657 self, raise_on_error: bool = True, allow_redirections: bool = True
1658 ) -> List[Any]:
1659 """
1660 Execute the pipeline.
1661
1662 It will retry the commands as specified by retries specified in :attr:`retry`
1663 & then raise an exception.
1664
1665 :param raise_on_error:
1666 | Raise the first error if there are any errors
1667 :param allow_redirections:
1668 | Whether to retry each failed command individually in case of redirection
1669 errors
1670
1671 :raises RedisClusterException: if target_nodes is not provided & the command
1672 can't be mapped to a slot
1673 """
1674 try:
1675 return await self._execution_strategy.execute(
1676 raise_on_error, allow_redirections
1677 )
1678 finally:
1679 await self.reset()
1680
1681 def _split_command_across_slots(
1682 self, command: str, *keys: KeyT
1683 ) -> "ClusterPipeline":
1684 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1685 self.execute_command(command, *slot_keys)
1686
1687 return self
1688
1689 async def reset(self):
1690 """
1691 Reset back to empty pipeline.
1692 """
1693 await self._execution_strategy.reset()
1694
1695 def multi(self):
1696 """
1697 Start a transactional block of the pipeline after WATCH commands
1698 are issued. End the transactional block with `execute`.
1699 """
1700 self._execution_strategy.multi()
1701
1702 async def discard(self):
1703 """ """
1704 await self._execution_strategy.discard()
1705
1706 async def watch(self, *names):
1707 """Watches the values at keys ``names``"""
1708 await self._execution_strategy.watch(*names)
1709
1710 async def unwatch(self):
1711 """Unwatches all previously specified keys"""
1712 await self._execution_strategy.unwatch()
1713
1714 async def unlink(self, *names):
1715 await self._execution_strategy.unlink(*names)
1716
1717 def mset_nonatomic(
1718 self, mapping: Mapping[AnyKeyT, EncodableT]
1719 ) -> "ClusterPipeline":
1720 return self._execution_strategy.mset_nonatomic(mapping)
1721
1722
1723for command in PIPELINE_BLOCKED_COMMANDS:
1724 command = command.replace(" ", "_").lower()
1725 if command == "mset_nonatomic":
1726 continue
1727
1728 setattr(ClusterPipeline, command, block_pipeline_command(command))
1729
1730
1731class PipelineCommand:
1732 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1733 self.args = args
1734 self.kwargs = kwargs
1735 self.position = position
1736 self.result: Union[Any, Exception] = None
1737
1738 def __repr__(self) -> str:
1739 return f"[{self.position}] {self.args} ({self.kwargs})"
1740
1741
1742class ExecutionStrategy(ABC):
1743 @abstractmethod
1744 async def initialize(self) -> "ClusterPipeline":
1745 """
1746 Initialize the execution strategy.
1747
1748 See ClusterPipeline.initialize()
1749 """
1750 pass
1751
1752 @abstractmethod
1753 def execute_command(
1754 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1755 ) -> "ClusterPipeline":
1756 """
1757 Append a raw command to the pipeline.
1758
1759 See ClusterPipeline.execute_command()
1760 """
1761 pass
1762
1763 @abstractmethod
1764 async def execute(
1765 self, raise_on_error: bool = True, allow_redirections: bool = True
1766 ) -> List[Any]:
1767 """
1768 Execute the pipeline.
1769
1770 It will retry the commands as specified by retries specified in :attr:`retry`
1771 & then raise an exception.
1772
1773 See ClusterPipeline.execute()
1774 """
1775 pass
1776
1777 @abstractmethod
1778 def mset_nonatomic(
1779 self, mapping: Mapping[AnyKeyT, EncodableT]
1780 ) -> "ClusterPipeline":
1781 """
1782 Executes multiple MSET commands according to the provided slot/pairs mapping.
1783
1784 See ClusterPipeline.mset_nonatomic()
1785 """
1786 pass
1787
1788 @abstractmethod
1789 async def reset(self):
1790 """
1791 Resets current execution strategy.
1792
1793 See: ClusterPipeline.reset()
1794 """
1795 pass
1796
1797 @abstractmethod
1798 def multi(self):
1799 """
1800 Starts transactional context.
1801
1802 See: ClusterPipeline.multi()
1803 """
1804 pass
1805
1806 @abstractmethod
1807 async def watch(self, *names):
1808 """
1809 Watch given keys.
1810
1811 See: ClusterPipeline.watch()
1812 """
1813 pass
1814
1815 @abstractmethod
1816 async def unwatch(self):
1817 """
1818 Unwatches all previously specified keys
1819
1820 See: ClusterPipeline.unwatch()
1821 """
1822 pass
1823
1824 @abstractmethod
1825 async def discard(self):
1826 pass
1827
1828 @abstractmethod
1829 async def unlink(self, *names):
1830 """
1831 "Unlink a key specified by ``names``"
1832
1833 See: ClusterPipeline.unlink()
1834 """
1835 pass
1836
1837 @abstractmethod
1838 def __len__(self) -> int:
1839 pass
1840
1841
1842class AbstractStrategy(ExecutionStrategy):
1843 def __init__(self, pipe: ClusterPipeline) -> None:
1844 self._pipe: ClusterPipeline = pipe
1845 self._command_queue: List["PipelineCommand"] = []
1846
1847 async def initialize(self) -> "ClusterPipeline":
1848 if self._pipe.cluster_client._initialize:
1849 await self._pipe.cluster_client.initialize()
1850 self._command_queue = []
1851 return self._pipe
1852
1853 def execute_command(
1854 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1855 ) -> "ClusterPipeline":
1856 self._command_queue.append(
1857 PipelineCommand(len(self._command_queue), *args, **kwargs)
1858 )
1859 return self._pipe
1860
1861 def _annotate_exception(self, exception, number, command):
1862 """
1863 Provides extra context to the exception prior to it being handled
1864 """
1865 cmd = " ".join(map(safe_str, command))
1866 msg = (
1867 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
1868 f"caused error: {exception.args[0]}"
1869 )
1870 exception.args = (msg,) + exception.args[1:]
1871
1872 @abstractmethod
1873 def mset_nonatomic(
1874 self, mapping: Mapping[AnyKeyT, EncodableT]
1875 ) -> "ClusterPipeline":
1876 pass
1877
1878 @abstractmethod
1879 async def execute(
1880 self, raise_on_error: bool = True, allow_redirections: bool = True
1881 ) -> List[Any]:
1882 pass
1883
1884 @abstractmethod
1885 async def reset(self):
1886 pass
1887
1888 @abstractmethod
1889 def multi(self):
1890 pass
1891
1892 @abstractmethod
1893 async def watch(self, *names):
1894 pass
1895
1896 @abstractmethod
1897 async def unwatch(self):
1898 pass
1899
1900 @abstractmethod
1901 async def discard(self):
1902 pass
1903
1904 @abstractmethod
1905 async def unlink(self, *names):
1906 pass
1907
1908 def __len__(self) -> int:
1909 return len(self._command_queue)
1910
1911
1912class PipelineStrategy(AbstractStrategy):
1913 def __init__(self, pipe: ClusterPipeline) -> None:
1914 super().__init__(pipe)
1915
1916 def mset_nonatomic(
1917 self, mapping: Mapping[AnyKeyT, EncodableT]
1918 ) -> "ClusterPipeline":
1919 encoder = self._pipe.cluster_client.encoder
1920
1921 slots_pairs = {}
1922 for pair in mapping.items():
1923 slot = key_slot(encoder.encode(pair[0]))
1924 slots_pairs.setdefault(slot, []).extend(pair)
1925
1926 for pairs in slots_pairs.values():
1927 self.execute_command("MSET", *pairs)
1928
1929 return self._pipe
1930
1931 async def execute(
1932 self, raise_on_error: bool = True, allow_redirections: bool = True
1933 ) -> List[Any]:
1934 if not self._command_queue:
1935 return []
1936
1937 try:
1938 retry_attempts = self._pipe.cluster_client.retry.get_retries()
1939 while True:
1940 try:
1941 if self._pipe.cluster_client._initialize:
1942 await self._pipe.cluster_client.initialize()
1943 return await self._execute(
1944 self._pipe.cluster_client,
1945 self._command_queue,
1946 raise_on_error=raise_on_error,
1947 allow_redirections=allow_redirections,
1948 )
1949
1950 except RedisCluster.ERRORS_ALLOW_RETRY as e:
1951 if retry_attempts > 0:
1952 # Try again with the new cluster setup. All other errors
1953 # should be raised.
1954 retry_attempts -= 1
1955 await self._pipe.cluster_client.aclose()
1956 await asyncio.sleep(0.25)
1957 else:
1958 # All other errors should be raised.
1959 raise e
1960 finally:
1961 await self.reset()
1962
1963 async def _execute(
1964 self,
1965 client: "RedisCluster",
1966 stack: List["PipelineCommand"],
1967 raise_on_error: bool = True,
1968 allow_redirections: bool = True,
1969 ) -> List[Any]:
1970 todo = [
1971 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
1972 ]
1973
1974 nodes = {}
1975 for cmd in todo:
1976 passed_targets = cmd.kwargs.pop("target_nodes", None)
1977 if passed_targets and not client._is_node_flag(passed_targets):
1978 target_nodes = client._parse_target_nodes(passed_targets)
1979 else:
1980 target_nodes = await client._determine_nodes(
1981 *cmd.args, node_flag=passed_targets
1982 )
1983 if not target_nodes:
1984 raise RedisClusterException(
1985 f"No targets were found to execute {cmd.args} command on"
1986 )
1987 if len(target_nodes) > 1:
1988 raise RedisClusterException(f"Too many targets for command {cmd.args}")
1989 node = target_nodes[0]
1990 if node.name not in nodes:
1991 nodes[node.name] = (node, [])
1992 nodes[node.name][1].append(cmd)
1993
1994 errors = await asyncio.gather(
1995 *(
1996 asyncio.create_task(node[0].execute_pipeline(node[1]))
1997 for node in nodes.values()
1998 )
1999 )
2000
2001 if any(errors):
2002 if allow_redirections:
2003 # send each errored command individually
2004 for cmd in todo:
2005 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2006 try:
2007 cmd.result = await client.execute_command(
2008 *cmd.args, **cmd.kwargs
2009 )
2010 except Exception as e:
2011 cmd.result = e
2012
2013 if raise_on_error:
2014 for cmd in todo:
2015 result = cmd.result
2016 if isinstance(result, Exception):
2017 command = " ".join(map(safe_str, cmd.args))
2018 msg = (
2019 f"Command # {cmd.position + 1} "
2020 f"({truncate_text(command)}) "
2021 f"of pipeline caused error: {result.args}"
2022 )
2023 result.args = (msg,) + result.args[1:]
2024 raise result
2025
2026 default_cluster_node = client.get_default_node()
2027
2028 # Check whether the default node was used. In some cases,
2029 # 'client.get_default_node()' may return None. The check below
2030 # prevents a potential AttributeError.
2031 if default_cluster_node is not None:
2032 default_node = nodes.get(default_cluster_node.name)
2033 if default_node is not None:
2034 # This pipeline execution used the default node, check if we need
2035 # to replace it.
2036 # Note: when the error is raised we'll reset the default node in the
2037 # caller function.
2038 for cmd in default_node[1]:
2039 # Check if it has a command that failed with a relevant
2040 # exception
2041 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2042 client.replace_default_node()
2043 break
2044
2045 return [cmd.result for cmd in stack]
2046
2047 async def reset(self):
2048 """
2049 Reset back to empty pipeline.
2050 """
2051 self._command_queue = []
2052
2053 def multi(self):
2054 raise RedisClusterException(
2055 "method multi() is not supported outside of transactional context"
2056 )
2057
2058 async def watch(self, *names):
2059 raise RedisClusterException(
2060 "method watch() is not supported outside of transactional context"
2061 )
2062
2063 async def unwatch(self):
2064 raise RedisClusterException(
2065 "method unwatch() is not supported outside of transactional context"
2066 )
2067
2068 async def discard(self):
2069 raise RedisClusterException(
2070 "method discard() is not supported outside of transactional context"
2071 )
2072
2073 async def unlink(self, *names):
2074 if len(names) != 1:
2075 raise RedisClusterException(
2076 "unlinking multiple keys is not implemented in pipeline command"
2077 )
2078
2079 return self.execute_command("UNLINK", names[0])
2080
2081
2082class TransactionStrategy(AbstractStrategy):
2083 NO_SLOTS_COMMANDS = {"UNWATCH"}
2084 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2085 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2086 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2087 CONNECTION_ERRORS = (
2088 ConnectionError,
2089 OSError,
2090 ClusterDownError,
2091 SlotNotCoveredError,
2092 )
2093
2094 def __init__(self, pipe: ClusterPipeline) -> None:
2095 super().__init__(pipe)
2096 self._explicit_transaction = False
2097 self._watching = False
2098 self._pipeline_slots: Set[int] = set()
2099 self._transaction_node: Optional[ClusterNode] = None
2100 self._transaction_connection: Optional[Connection] = None
2101 self._executing = False
2102 self._retry = copy(self._pipe.cluster_client.retry)
2103 self._retry.update_supported_errors(
2104 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2105 )
2106
2107 def _get_client_and_connection_for_transaction(
2108 self,
2109 ) -> Tuple[ClusterNode, Connection]:
2110 """
2111 Find a connection for a pipeline transaction.
2112
2113 For running an atomic transaction, watch keys ensure that contents have not been
2114 altered as long as the watch commands for those keys were sent over the same
2115 connection. So once we start watching a key, we fetch a connection to the
2116 node that owns that slot and reuse it.
2117 """
2118 if not self._pipeline_slots:
2119 raise RedisClusterException(
2120 "At least a command with a key is needed to identify a node"
2121 )
2122
2123 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2124 list(self._pipeline_slots)[0], False
2125 )
2126 self._transaction_node = node
2127
2128 if not self._transaction_connection:
2129 connection: Connection = self._transaction_node.acquire_connection()
2130 self._transaction_connection = connection
2131
2132 return self._transaction_node, self._transaction_connection
2133
2134 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2135 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2136 response = None
2137 error = None
2138
2139 def runner():
2140 nonlocal response
2141 nonlocal error
2142 try:
2143 response = asyncio.run(self._execute_command(*args, **kwargs))
2144 except Exception as e:
2145 error = e
2146
2147 thread = threading.Thread(target=runner)
2148 thread.start()
2149 thread.join()
2150
2151 if error:
2152 raise error
2153
2154 return response
2155
2156 async def _execute_command(
2157 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2158 ) -> Any:
2159 if self._pipe.cluster_client._initialize:
2160 await self._pipe.cluster_client.initialize()
2161
2162 slot_number: Optional[int] = None
2163 if args[0] not in self.NO_SLOTS_COMMANDS:
2164 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2165
2166 if (
2167 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2168 ) and not self._explicit_transaction:
2169 if args[0] == "WATCH":
2170 self._validate_watch()
2171
2172 if slot_number is not None:
2173 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2174 raise CrossSlotTransactionError(
2175 "Cannot watch or send commands on different slots"
2176 )
2177
2178 self._pipeline_slots.add(slot_number)
2179 elif args[0] not in self.NO_SLOTS_COMMANDS:
2180 raise RedisClusterException(
2181 f"Cannot identify slot number for command: {args[0]},"
2182 "it cannot be triggered in a transaction"
2183 )
2184
2185 return self._immediate_execute_command(*args, **kwargs)
2186 else:
2187 if slot_number is not None:
2188 self._pipeline_slots.add(slot_number)
2189
2190 return super().execute_command(*args, **kwargs)
2191
2192 def _validate_watch(self):
2193 if self._explicit_transaction:
2194 raise RedisError("Cannot issue a WATCH after a MULTI")
2195
2196 self._watching = True
2197
2198 async def _immediate_execute_command(self, *args, **options):
2199 return await self._retry.call_with_retry(
2200 lambda: self._get_connection_and_send_command(*args, **options),
2201 self._reinitialize_on_error,
2202 )
2203
2204 async def _get_connection_and_send_command(self, *args, **options):
2205 redis_node, connection = self._get_client_and_connection_for_transaction()
2206 return await self._send_command_parse_response(
2207 connection, redis_node, args[0], *args, **options
2208 )
2209
2210 async def _send_command_parse_response(
2211 self,
2212 connection: Connection,
2213 redis_node: ClusterNode,
2214 command_name,
2215 *args,
2216 **options,
2217 ):
2218 """
2219 Send a command and parse the response
2220 """
2221
2222 await connection.send_command(*args)
2223 output = await redis_node.parse_response(connection, command_name, **options)
2224
2225 if command_name in self.UNWATCH_COMMANDS:
2226 self._watching = False
2227 return output
2228
2229 async def _reinitialize_on_error(self, error):
2230 if self._watching:
2231 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2232 raise WatchError("Slot rebalancing occurred while watching keys")
2233
2234 if (
2235 type(error) in self.SLOT_REDIRECT_ERRORS
2236 or type(error) in self.CONNECTION_ERRORS
2237 ):
2238 if self._transaction_connection:
2239 self._transaction_connection = None
2240
2241 self._pipe.cluster_client.reinitialize_counter += 1
2242 if (
2243 self._pipe.cluster_client.reinitialize_steps
2244 and self._pipe.cluster_client.reinitialize_counter
2245 % self._pipe.cluster_client.reinitialize_steps
2246 == 0
2247 ):
2248 await self._pipe.cluster_client.nodes_manager.initialize()
2249 self.reinitialize_counter = 0
2250 else:
2251 self._pipe.cluster_client.nodes_manager.update_moved_exception(error)
2252
2253 self._executing = False
2254
2255 def _raise_first_error(self, responses, stack):
2256 """
2257 Raise the first exception on the stack
2258 """
2259 for r, cmd in zip(responses, stack):
2260 if isinstance(r, Exception):
2261 self._annotate_exception(r, cmd.position + 1, cmd.args)
2262 raise r
2263
2264 def mset_nonatomic(
2265 self, mapping: Mapping[AnyKeyT, EncodableT]
2266 ) -> "ClusterPipeline":
2267 raise NotImplementedError("Method is not supported in transactional context.")
2268
2269 async def execute(
2270 self, raise_on_error: bool = True, allow_redirections: bool = True
2271 ) -> List[Any]:
2272 stack = self._command_queue
2273 if not stack and (not self._watching or not self._pipeline_slots):
2274 return []
2275
2276 return await self._execute_transaction_with_retries(stack, raise_on_error)
2277
2278 async def _execute_transaction_with_retries(
2279 self, stack: List["PipelineCommand"], raise_on_error: bool
2280 ):
2281 return await self._retry.call_with_retry(
2282 lambda: self._execute_transaction(stack, raise_on_error),
2283 self._reinitialize_on_error,
2284 )
2285
2286 async def _execute_transaction(
2287 self, stack: List["PipelineCommand"], raise_on_error: bool
2288 ):
2289 if len(self._pipeline_slots) > 1:
2290 raise CrossSlotTransactionError(
2291 "All keys involved in a cluster transaction must map to the same slot"
2292 )
2293
2294 self._executing = True
2295
2296 redis_node, connection = self._get_client_and_connection_for_transaction()
2297
2298 stack = chain(
2299 [PipelineCommand(0, "MULTI")],
2300 stack,
2301 [PipelineCommand(0, "EXEC")],
2302 )
2303 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2304 packed_commands = connection.pack_commands(commands)
2305 await connection.send_packed_command(packed_commands)
2306 errors = []
2307
2308 # parse off the response for MULTI
2309 # NOTE: we need to handle ResponseErrors here and continue
2310 # so that we read all the additional command messages from
2311 # the socket
2312 try:
2313 await redis_node.parse_response(connection, "MULTI")
2314 except ResponseError as e:
2315 self._annotate_exception(e, 0, "MULTI")
2316 errors.append(e)
2317 except self.CONNECTION_ERRORS as cluster_error:
2318 self._annotate_exception(cluster_error, 0, "MULTI")
2319 raise
2320
2321 # and all the other commands
2322 for i, command in enumerate(self._command_queue):
2323 if EMPTY_RESPONSE in command.kwargs:
2324 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2325 else:
2326 try:
2327 _ = await redis_node.parse_response(connection, "_")
2328 except self.SLOT_REDIRECT_ERRORS as slot_error:
2329 self._annotate_exception(slot_error, i + 1, command.args)
2330 errors.append(slot_error)
2331 except self.CONNECTION_ERRORS as cluster_error:
2332 self._annotate_exception(cluster_error, i + 1, command.args)
2333 raise
2334 except ResponseError as e:
2335 self._annotate_exception(e, i + 1, command.args)
2336 errors.append(e)
2337
2338 response = None
2339 # parse the EXEC.
2340 try:
2341 response = await redis_node.parse_response(connection, "EXEC")
2342 except ExecAbortError:
2343 if errors:
2344 raise errors[0]
2345 raise
2346
2347 self._executing = False
2348
2349 # EXEC clears any watched keys
2350 self._watching = False
2351
2352 if response is None:
2353 raise WatchError("Watched variable changed.")
2354
2355 # put any parse errors into the response
2356 for i, e in errors:
2357 response.insert(i, e)
2358
2359 if len(response) != len(self._command_queue):
2360 raise InvalidPipelineStack(
2361 "Unexpected response length for cluster pipeline EXEC."
2362 " Command stack was {} but response had length {}".format(
2363 [c.args[0] for c in self._command_queue], len(response)
2364 )
2365 )
2366
2367 # find any errors in the response and raise if necessary
2368 if raise_on_error or len(errors) > 0:
2369 self._raise_first_error(
2370 response,
2371 self._command_queue,
2372 )
2373
2374 # We have to run response callbacks manually
2375 data = []
2376 for r, cmd in zip(response, self._command_queue):
2377 if not isinstance(r, Exception):
2378 command_name = cmd.args[0]
2379 if command_name in self._pipe.cluster_client.response_callbacks:
2380 r = self._pipe.cluster_client.response_callbacks[command_name](
2381 r, **cmd.kwargs
2382 )
2383 data.append(r)
2384 return data
2385
2386 async def reset(self):
2387 self._command_queue = []
2388
2389 # make sure to reset the connection state in the event that we were
2390 # watching something
2391 if self._transaction_connection:
2392 try:
2393 if self._watching:
2394 # call this manually since our unwatch or
2395 # immediate_execute_command methods can call reset()
2396 await self._transaction_connection.send_command("UNWATCH")
2397 await self._transaction_connection.read_response()
2398 # we can safely return the connection to the pool here since we're
2399 # sure we're no longer WATCHing anything
2400 self._transaction_node.release(self._transaction_connection)
2401 self._transaction_connection = None
2402 except self.CONNECTION_ERRORS:
2403 # disconnect will also remove any previous WATCHes
2404 if self._transaction_connection:
2405 await self._transaction_connection.disconnect()
2406
2407 # clean up the other instance attributes
2408 self._transaction_node = None
2409 self._watching = False
2410 self._explicit_transaction = False
2411 self._pipeline_slots = set()
2412 self._executing = False
2413
2414 def multi(self):
2415 if self._explicit_transaction:
2416 raise RedisError("Cannot issue nested calls to MULTI")
2417 if self._command_queue:
2418 raise RedisError(
2419 "Commands without an initial WATCH have already been issued"
2420 )
2421 self._explicit_transaction = True
2422
2423 async def watch(self, *names):
2424 if self._explicit_transaction:
2425 raise RedisError("Cannot issue a WATCH after a MULTI")
2426
2427 return await self.execute_command("WATCH", *names)
2428
2429 async def unwatch(self):
2430 if self._watching:
2431 return await self.execute_command("UNWATCH")
2432
2433 return True
2434
2435 async def discard(self):
2436 await self.reset()
2437
2438 async def unlink(self, *names):
2439 return self.execute_command("UNLINK", *names)