1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 Any,
13 Callable,
14 Coroutine,
15 Deque,
16 Dict,
17 Generator,
18 List,
19 Mapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26)
27
28from redis._parsers import AsyncCommandsParser, Encoder
29from redis._parsers.helpers import (
30 _RedisCallbacks,
31 _RedisCallbacksRESP2,
32 _RedisCallbacksRESP3,
33)
34from redis.asyncio.client import ResponseCallbackT
35from redis.asyncio.connection import Connection, SSLConnection, parse_url
36from redis.asyncio.lock import Lock
37from redis.asyncio.retry import Retry
38from redis.auth.token import TokenInterface
39from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
40from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
41from redis.cluster import (
42 PIPELINE_BLOCKED_COMMANDS,
43 PRIMARY,
44 REPLICA,
45 SLOT_ID,
46 AbstractRedisCluster,
47 LoadBalancer,
48 LoadBalancingStrategy,
49 block_pipeline_command,
50 get_node_name,
51 parse_cluster_slots,
52)
53from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
54from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
55from redis.credentials import CredentialProvider
56from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
57from redis.exceptions import (
58 AskError,
59 BusyLoadingError,
60 ClusterDownError,
61 ClusterError,
62 ConnectionError,
63 CrossSlotTransactionError,
64 DataError,
65 ExecAbortError,
66 InvalidPipelineStack,
67 MaxConnectionsError,
68 MovedError,
69 RedisClusterException,
70 RedisError,
71 ResponseError,
72 SlotNotCoveredError,
73 TimeoutError,
74 TryAgainError,
75 WatchError,
76)
77from redis.typing import AnyKeyT, EncodableT, KeyT
78from redis.utils import (
79 SSL_AVAILABLE,
80 deprecated_args,
81 deprecated_function,
82 get_lib_version,
83 safe_str,
84 str_if_bytes,
85 truncate_text,
86)
87
88if SSL_AVAILABLE:
89 from ssl import TLSVersion, VerifyMode
90else:
91 TLSVersion = None
92 VerifyMode = None
93
94TargetNodesT = TypeVar(
95 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
96)
97
98
99class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
100 """
101 Create a new RedisCluster client.
102
103 Pass one of parameters:
104
105 - `host` & `port`
106 - `startup_nodes`
107
108 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
109 | Use ``await`` :meth:`close` to disconnect connections & close client.
110
111 Many commands support the target_nodes kwarg. It can be one of the
112 :attr:`NODE_FLAGS`:
113
114 - :attr:`PRIMARIES`
115 - :attr:`REPLICAS`
116 - :attr:`ALL_NODES`
117 - :attr:`RANDOM`
118 - :attr:`DEFAULT_NODE`
119
120 Note: This client is not thread/process/fork safe.
121
122 :param host:
123 | Can be used to point to a startup node
124 :param port:
125 | Port used if **host** is provided
126 :param startup_nodes:
127 | :class:`~.ClusterNode` to used as a startup node
128 :param require_full_coverage:
129 | When set to ``False``: the client will not require a full coverage of
130 the slots. However, if not all slots are covered, and at least one node
131 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
132 a :class:`~.ClusterDownError` for some key-based commands.
133 | When set to ``True``: all slots must be covered to construct the cluster
134 client. If not all slots are covered, :class:`~.RedisClusterException` will be
135 thrown.
136 | See:
137 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
138 :param read_from_replicas:
139 | @deprecated - please use load_balancing_strategy instead
140 | Enable read from replicas in READONLY mode.
141 When set to true, read commands will be assigned between the primary and
142 its replications in a Round-Robin manner.
143 The data read from replicas is eventually consistent with the data in primary nodes.
144 :param load_balancing_strategy:
145 | Enable read from replicas in READONLY mode and defines the load balancing
146 strategy that will be used for cluster node selection.
147 The data read from replicas is eventually consistent with the data in primary nodes.
148 :param dynamic_startup_nodes:
149 | Set the RedisCluster's startup nodes to all the discovered nodes.
150 If true (default value), the cluster's discovered nodes will be used to
151 determine the cluster nodes-slots mapping in the next topology refresh.
152 It will remove the initial passed startup nodes if their endpoints aren't
153 listed in the CLUSTER SLOTS output.
154 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
155 specific IP addresses, it is best to set it to false.
156 :param reinitialize_steps:
157 | Specifies the number of MOVED errors that need to occur before reinitializing
158 the whole cluster topology. If a MOVED error occurs and the cluster does not
159 need to be reinitialized on this current error handling, only the MOVED slot
160 will be patched with the redirected node.
161 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
162 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
163 0.
164 :param cluster_error_retry_attempts:
165 | @deprecated - Please configure the 'retry' object instead
166 In case 'retry' object is set - this argument is ignored!
167
168 Number of times to retry before raising an error when :class:`~.TimeoutError`,
169 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
170 or :class:`~.ClusterDownError` are encountered
171 :param retry:
172 | A retry object that defines the retry strategy and the number of
173 retries for the cluster client.
174 In current implementation for the cluster client (starting form redis-py version 6.0.0)
175 the retry object is not yet fully utilized, instead it is used just to determine
176 the number of retries for the cluster client.
177 In the future releases the retry object will be used to handle the cluster client retries!
178 :param max_connections:
179 | Maximum number of connections per node. If there are no free connections & the
180 maximum number of connections are already created, a
181 :class:`~.MaxConnectionsError` is raised.
182 :param address_remap:
183 | An optional callable which, when provided with an internal network
184 address of a node, e.g. a `(host, port)` tuple, will return the address
185 where the node is reachable. This can be used to map the addresses at
186 which the nodes _think_ they are, to addresses at which a client may
187 reach them, such as when they sit behind a proxy.
188
189 | Rest of the arguments will be passed to the
190 :class:`~redis.asyncio.connection.Connection` instances when created
191
192 :raises RedisClusterException:
193 if any arguments are invalid or unknown. Eg:
194
195 - `db` != 0 or None
196 - `path` argument for unix socket connection
197 - none of the `host`/`port` & `startup_nodes` were provided
198
199 """
200
201 @classmethod
202 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
203 """
204 Return a Redis client object configured from the given URL.
205
206 For example::
207
208 redis://[[username]:[password]]@localhost:6379/0
209 rediss://[[username]:[password]]@localhost:6379/0
210
211 Three URL schemes are supported:
212
213 - `redis://` creates a TCP socket connection. See more at:
214 <https://www.iana.org/assignments/uri-schemes/prov/redis>
215 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
216 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
217
218 The username, password, hostname, path and all querystring values are passed
219 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
220 with their corresponding characters.
221
222 All querystring options are cast to their appropriate Python types. Boolean
223 arguments can be specified with string values "True"/"False" or "Yes"/"No".
224 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
225 parsed, the querystring arguments and keyword arguments are passed to
226 :class:`~redis.asyncio.connection.Connection` when created.
227 In the case of conflicting arguments, querystring arguments are used.
228 """
229 kwargs.update(parse_url(url))
230 if kwargs.pop("connection_class", None) is SSLConnection:
231 kwargs["ssl"] = True
232 return cls(**kwargs)
233
234 __slots__ = (
235 "_initialize",
236 "_lock",
237 "retry",
238 "command_flags",
239 "commands_parser",
240 "connection_kwargs",
241 "encoder",
242 "node_flags",
243 "nodes_manager",
244 "read_from_replicas",
245 "reinitialize_counter",
246 "reinitialize_steps",
247 "response_callbacks",
248 "result_callbacks",
249 )
250
251 @deprecated_args(
252 args_to_warn=["read_from_replicas"],
253 reason="Please configure the 'load_balancing_strategy' instead",
254 version="5.3.0",
255 )
256 @deprecated_args(
257 args_to_warn=[
258 "cluster_error_retry_attempts",
259 ],
260 reason="Please configure the 'retry' object instead",
261 version="6.0.0",
262 )
263 def __init__(
264 self,
265 host: Optional[str] = None,
266 port: Union[str, int] = 6379,
267 # Cluster related kwargs
268 startup_nodes: Optional[List["ClusterNode"]] = None,
269 require_full_coverage: bool = True,
270 read_from_replicas: bool = False,
271 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
272 dynamic_startup_nodes: bool = True,
273 reinitialize_steps: int = 5,
274 cluster_error_retry_attempts: int = 3,
275 max_connections: int = 2**31,
276 retry: Optional["Retry"] = None,
277 retry_on_error: Optional[List[Type[Exception]]] = None,
278 # Client related kwargs
279 db: Union[str, int] = 0,
280 path: Optional[str] = None,
281 credential_provider: Optional[CredentialProvider] = None,
282 username: Optional[str] = None,
283 password: Optional[str] = None,
284 client_name: Optional[str] = None,
285 lib_name: Optional[str] = "redis-py",
286 lib_version: Optional[str] = get_lib_version(),
287 # Encoding related kwargs
288 encoding: str = "utf-8",
289 encoding_errors: str = "strict",
290 decode_responses: bool = False,
291 # Connection related kwargs
292 health_check_interval: float = 0,
293 socket_connect_timeout: Optional[float] = None,
294 socket_keepalive: bool = False,
295 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
296 socket_timeout: Optional[float] = None,
297 # SSL related kwargs
298 ssl: bool = False,
299 ssl_ca_certs: Optional[str] = None,
300 ssl_ca_data: Optional[str] = None,
301 ssl_cert_reqs: Union[str, VerifyMode] = "required",
302 ssl_certfile: Optional[str] = None,
303 ssl_check_hostname: bool = True,
304 ssl_keyfile: Optional[str] = None,
305 ssl_min_version: Optional[TLSVersion] = None,
306 ssl_ciphers: Optional[str] = None,
307 protocol: Optional[int] = 2,
308 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
309 event_dispatcher: Optional[EventDispatcher] = None,
310 ) -> None:
311 if db:
312 raise RedisClusterException(
313 "Argument 'db' must be 0 or None in cluster mode"
314 )
315
316 if path:
317 raise RedisClusterException(
318 "Unix domain socket is not supported in cluster mode"
319 )
320
321 if (not host or not port) and not startup_nodes:
322 raise RedisClusterException(
323 "RedisCluster requires at least one node to discover the cluster.\n"
324 "Please provide one of the following or use RedisCluster.from_url:\n"
325 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
326 " - startup_nodes: RedisCluster(startup_nodes=["
327 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
328 )
329
330 kwargs: Dict[str, Any] = {
331 "max_connections": max_connections,
332 "connection_class": Connection,
333 # Client related kwargs
334 "credential_provider": credential_provider,
335 "username": username,
336 "password": password,
337 "client_name": client_name,
338 "lib_name": lib_name,
339 "lib_version": lib_version,
340 # Encoding related kwargs
341 "encoding": encoding,
342 "encoding_errors": encoding_errors,
343 "decode_responses": decode_responses,
344 # Connection related kwargs
345 "health_check_interval": health_check_interval,
346 "socket_connect_timeout": socket_connect_timeout,
347 "socket_keepalive": socket_keepalive,
348 "socket_keepalive_options": socket_keepalive_options,
349 "socket_timeout": socket_timeout,
350 "protocol": protocol,
351 }
352
353 if ssl:
354 # SSL related kwargs
355 kwargs.update(
356 {
357 "connection_class": SSLConnection,
358 "ssl_ca_certs": ssl_ca_certs,
359 "ssl_ca_data": ssl_ca_data,
360 "ssl_cert_reqs": ssl_cert_reqs,
361 "ssl_certfile": ssl_certfile,
362 "ssl_check_hostname": ssl_check_hostname,
363 "ssl_keyfile": ssl_keyfile,
364 "ssl_min_version": ssl_min_version,
365 "ssl_ciphers": ssl_ciphers,
366 }
367 )
368
369 if read_from_replicas or load_balancing_strategy:
370 # Call our on_connect function to configure READONLY mode
371 kwargs["redis_connect_func"] = self.on_connect
372
373 if retry:
374 self.retry = retry
375 else:
376 self.retry = Retry(
377 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
378 retries=cluster_error_retry_attempts,
379 )
380 if retry_on_error:
381 self.retry.update_supported_errors(retry_on_error)
382
383 kwargs["response_callbacks"] = _RedisCallbacks.copy()
384 if kwargs.get("protocol") in ["3", 3]:
385 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
386 else:
387 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
388 self.connection_kwargs = kwargs
389
390 if startup_nodes:
391 passed_nodes = []
392 for node in startup_nodes:
393 passed_nodes.append(
394 ClusterNode(node.host, node.port, **self.connection_kwargs)
395 )
396 startup_nodes = passed_nodes
397 else:
398 startup_nodes = []
399 if host and port:
400 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
401
402 if event_dispatcher is None:
403 self._event_dispatcher = EventDispatcher()
404 else:
405 self._event_dispatcher = event_dispatcher
406
407 self.nodes_manager = NodesManager(
408 startup_nodes,
409 require_full_coverage,
410 kwargs,
411 dynamic_startup_nodes=dynamic_startup_nodes,
412 address_remap=address_remap,
413 event_dispatcher=self._event_dispatcher,
414 )
415 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
416 self.read_from_replicas = read_from_replicas
417 self.load_balancing_strategy = load_balancing_strategy
418 self.reinitialize_steps = reinitialize_steps
419 self.reinitialize_counter = 0
420 self.commands_parser = AsyncCommandsParser()
421 self.node_flags = self.__class__.NODE_FLAGS.copy()
422 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
423 self.response_callbacks = kwargs["response_callbacks"]
424 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
425 self.result_callbacks["CLUSTER SLOTS"] = (
426 lambda cmd, res, **kwargs: parse_cluster_slots(
427 list(res.values())[0], **kwargs
428 )
429 )
430
431 self._initialize = True
432 self._lock: Optional[asyncio.Lock] = None
433
434 async def initialize(self) -> "RedisCluster":
435 """Get all nodes from startup nodes & creates connections if not initialized."""
436 if self._initialize:
437 if not self._lock:
438 self._lock = asyncio.Lock()
439 async with self._lock:
440 if self._initialize:
441 try:
442 await self.nodes_manager.initialize()
443 await self.commands_parser.initialize(
444 self.nodes_manager.default_node
445 )
446 self._initialize = False
447 except BaseException:
448 await self.nodes_manager.aclose()
449 await self.nodes_manager.aclose("startup_nodes")
450 raise
451 return self
452
453 async def aclose(self) -> None:
454 """Close all connections & client if initialized."""
455 if not self._initialize:
456 if not self._lock:
457 self._lock = asyncio.Lock()
458 async with self._lock:
459 if not self._initialize:
460 self._initialize = True
461 await self.nodes_manager.aclose()
462 await self.nodes_manager.aclose("startup_nodes")
463
464 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
465 async def close(self) -> None:
466 """alias for aclose() for backwards compatibility"""
467 await self.aclose()
468
469 async def __aenter__(self) -> "RedisCluster":
470 return await self.initialize()
471
472 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
473 await self.aclose()
474
475 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
476 return self.initialize().__await__()
477
478 _DEL_MESSAGE = "Unclosed RedisCluster client"
479
480 def __del__(
481 self,
482 _warn: Any = warnings.warn,
483 _grl: Any = asyncio.get_running_loop,
484 ) -> None:
485 if hasattr(self, "_initialize") and not self._initialize:
486 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
487 try:
488 context = {"client": self, "message": self._DEL_MESSAGE}
489 _grl().call_exception_handler(context)
490 except RuntimeError:
491 pass
492
493 async def on_connect(self, connection: Connection) -> None:
494 await connection.on_connect()
495
496 # Sending READONLY command to server to configure connection as
497 # readonly. Since each cluster node may change its server type due
498 # to a failover, we should establish a READONLY connection
499 # regardless of the server type. If this is a primary connection,
500 # READONLY would not affect executing write commands.
501 await connection.send_command("READONLY")
502 if str_if_bytes(await connection.read_response()) != "OK":
503 raise ConnectionError("READONLY command failed")
504
505 def get_nodes(self) -> List["ClusterNode"]:
506 """Get all nodes of the cluster."""
507 return list(self.nodes_manager.nodes_cache.values())
508
509 def get_primaries(self) -> List["ClusterNode"]:
510 """Get the primary nodes of the cluster."""
511 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
512
513 def get_replicas(self) -> List["ClusterNode"]:
514 """Get the replica nodes of the cluster."""
515 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
516
517 def get_random_node(self) -> "ClusterNode":
518 """Get a random node of the cluster."""
519 return random.choice(list(self.nodes_manager.nodes_cache.values()))
520
521 def get_default_node(self) -> "ClusterNode":
522 """Get the default node of the client."""
523 return self.nodes_manager.default_node
524
525 def set_default_node(self, node: "ClusterNode") -> None:
526 """
527 Set the default node of the client.
528
529 :raises DataError: if None is passed or node does not exist in cluster.
530 """
531 if not node or not self.get_node(node_name=node.name):
532 raise DataError("The requested node does not exist in the cluster.")
533
534 self.nodes_manager.default_node = node
535
536 def get_node(
537 self,
538 host: Optional[str] = None,
539 port: Optional[int] = None,
540 node_name: Optional[str] = None,
541 ) -> Optional["ClusterNode"]:
542 """Get node by (host, port) or node_name."""
543 return self.nodes_manager.get_node(host, port, node_name)
544
545 def get_node_from_key(
546 self, key: str, replica: bool = False
547 ) -> Optional["ClusterNode"]:
548 """
549 Get the cluster node corresponding to the provided key.
550
551 :param key:
552 :param replica:
553 | Indicates if a replica should be returned
554 |
555 None will returned if no replica holds this key
556
557 :raises SlotNotCoveredError: if the key is not covered by any slot.
558 """
559 slot = self.keyslot(key)
560 slot_cache = self.nodes_manager.slots_cache.get(slot)
561 if not slot_cache:
562 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
563
564 if replica:
565 if len(self.nodes_manager.slots_cache[slot]) < 2:
566 return None
567 node_idx = 1
568 else:
569 node_idx = 0
570
571 return slot_cache[node_idx]
572
573 def keyslot(self, key: EncodableT) -> int:
574 """
575 Find the keyslot for a given key.
576
577 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
578 """
579 return key_slot(self.encoder.encode(key))
580
581 def get_encoder(self) -> Encoder:
582 """Get the encoder object of the client."""
583 return self.encoder
584
585 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
586 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
587 return self.connection_kwargs
588
589 def set_retry(self, retry: Retry) -> None:
590 self.retry = retry
591
592 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
593 """Set a custom response callback."""
594 self.response_callbacks[command] = callback
595
596 async def _determine_nodes(
597 self, command: str, *args: Any, node_flag: Optional[str] = None
598 ) -> List["ClusterNode"]:
599 # Determine which nodes should be executed the command on.
600 # Returns a list of target nodes.
601 if not node_flag:
602 # get the nodes group for this command if it was predefined
603 node_flag = self.command_flags.get(command)
604
605 if node_flag in self.node_flags:
606 if node_flag == self.__class__.DEFAULT_NODE:
607 # return the cluster's default node
608 return [self.nodes_manager.default_node]
609 if node_flag == self.__class__.PRIMARIES:
610 # return all primaries
611 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
612 if node_flag == self.__class__.REPLICAS:
613 # return all replicas
614 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
615 if node_flag == self.__class__.ALL_NODES:
616 # return all nodes
617 return list(self.nodes_manager.nodes_cache.values())
618 if node_flag == self.__class__.RANDOM:
619 # return a random node
620 return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
621
622 # get the node that holds the key's slot
623 return [
624 self.nodes_manager.get_node_from_slot(
625 await self._determine_slot(command, *args),
626 self.read_from_replicas and command in READ_COMMANDS,
627 self.load_balancing_strategy if command in READ_COMMANDS else None,
628 )
629 ]
630
631 async def _determine_slot(self, command: str, *args: Any) -> int:
632 if self.command_flags.get(command) == SLOT_ID:
633 # The command contains the slot ID
634 return int(args[0])
635
636 # Get the keys in the command
637
638 # EVAL and EVALSHA are common enough that it's wasteful to go to the
639 # redis server to parse the keys. Besides, there is a bug in redis<7.0
640 # where `self._get_command_keys()` fails anyway. So, we special case
641 # EVAL/EVALSHA.
642 # - issue: https://github.com/redis/redis/issues/9493
643 # - fix: https://github.com/redis/redis/pull/9733
644 if command.upper() in ("EVAL", "EVALSHA"):
645 # command syntax: EVAL "script body" num_keys ...
646 if len(args) < 2:
647 raise RedisClusterException(
648 f"Invalid args in command: {command, *args}"
649 )
650 keys = args[2 : 2 + int(args[1])]
651 # if there are 0 keys, that means the script can be run on any node
652 # so we can just return a random slot
653 if not keys:
654 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
655 else:
656 keys = await self.commands_parser.get_keys(command, *args)
657 if not keys:
658 # FCALL can call a function with 0 keys, that means the function
659 # can be run on any node so we can just return a random slot
660 if command.upper() in ("FCALL", "FCALL_RO"):
661 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
662 raise RedisClusterException(
663 "No way to dispatch this command to Redis Cluster. "
664 "Missing key.\nYou can execute the command by specifying "
665 f"target nodes.\nCommand: {args}"
666 )
667
668 # single key command
669 if len(keys) == 1:
670 return self.keyslot(keys[0])
671
672 # multi-key command; we need to make sure all keys are mapped to
673 # the same slot
674 slots = {self.keyslot(key) for key in keys}
675 if len(slots) != 1:
676 raise RedisClusterException(
677 f"{command} - all keys must map to the same key slot"
678 )
679
680 return slots.pop()
681
682 def _is_node_flag(self, target_nodes: Any) -> bool:
683 return isinstance(target_nodes, str) and target_nodes in self.node_flags
684
685 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
686 if isinstance(target_nodes, list):
687 nodes = target_nodes
688 elif isinstance(target_nodes, ClusterNode):
689 # Supports passing a single ClusterNode as a variable
690 nodes = [target_nodes]
691 elif isinstance(target_nodes, dict):
692 # Supports dictionaries of the format {node_name: node}.
693 # It enables to execute commands with multi nodes as follows:
694 # rc.cluster_save_config(rc.get_primaries())
695 nodes = list(target_nodes.values())
696 else:
697 raise TypeError(
698 "target_nodes type can be one of the following: "
699 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
700 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
701 f"The passed type is {type(target_nodes)}"
702 )
703 return nodes
704
705 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
706 """
707 Execute a raw command on the appropriate cluster node or target_nodes.
708
709 It will retry the command as specified by the retries property of
710 the :attr:`retry` & then raise an exception.
711
712 :param args:
713 | Raw command args
714 :param kwargs:
715
716 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
717 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
718 - Rest of the kwargs are passed to the Redis connection
719
720 :raises RedisClusterException: if target_nodes is not provided & the command
721 can't be mapped to a slot
722 """
723 command = args[0]
724 target_nodes = []
725 target_nodes_specified = False
726 retry_attempts = self.retry.get_retries()
727
728 passed_targets = kwargs.pop("target_nodes", None)
729 if passed_targets and not self._is_node_flag(passed_targets):
730 target_nodes = self._parse_target_nodes(passed_targets)
731 target_nodes_specified = True
732 retry_attempts = 0
733
734 # Add one for the first execution
735 execute_attempts = 1 + retry_attempts
736 for _ in range(execute_attempts):
737 if self._initialize:
738 await self.initialize()
739 if (
740 len(target_nodes) == 1
741 and target_nodes[0] == self.get_default_node()
742 ):
743 # Replace the default cluster node
744 self.replace_default_node()
745 try:
746 if not target_nodes_specified:
747 # Determine the nodes to execute the command on
748 target_nodes = await self._determine_nodes(
749 *args, node_flag=passed_targets
750 )
751 if not target_nodes:
752 raise RedisClusterException(
753 f"No targets were found to execute {args} command on"
754 )
755
756 if len(target_nodes) == 1:
757 # Return the processed result
758 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
759 if command in self.result_callbacks:
760 return self.result_callbacks[command](
761 command, {target_nodes[0].name: ret}, **kwargs
762 )
763 return ret
764 else:
765 keys = [node.name for node in target_nodes]
766 values = await asyncio.gather(
767 *(
768 asyncio.create_task(
769 self._execute_command(node, *args, **kwargs)
770 )
771 for node in target_nodes
772 )
773 )
774 if command in self.result_callbacks:
775 return self.result_callbacks[command](
776 command, dict(zip(keys, values)), **kwargs
777 )
778 return dict(zip(keys, values))
779 except Exception as e:
780 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
781 # The nodes and slots cache were should be reinitialized.
782 # Try again with the new cluster setup.
783 retry_attempts -= 1
784 continue
785 else:
786 # raise the exception
787 raise e
788
789 async def _execute_command(
790 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
791 ) -> Any:
792 asking = moved = False
793 redirect_addr = None
794 ttl = self.RedisClusterRequestTTL
795
796 while ttl > 0:
797 ttl -= 1
798 try:
799 if asking:
800 target_node = self.get_node(node_name=redirect_addr)
801 await target_node.execute_command("ASKING")
802 asking = False
803 elif moved:
804 # MOVED occurred and the slots cache was updated,
805 # refresh the target node
806 slot = await self._determine_slot(*args)
807 target_node = self.nodes_manager.get_node_from_slot(
808 slot,
809 self.read_from_replicas and args[0] in READ_COMMANDS,
810 self.load_balancing_strategy
811 if args[0] in READ_COMMANDS
812 else None,
813 )
814 moved = False
815
816 return await target_node.execute_command(*args, **kwargs)
817 except (BusyLoadingError, MaxConnectionsError):
818 raise
819 except (ConnectionError, TimeoutError):
820 # Connection retries are being handled in the node's
821 # Retry object.
822 # Remove the failed node from the startup nodes before we try
823 # to reinitialize the cluster
824 self.nodes_manager.startup_nodes.pop(target_node.name, None)
825 # Hard force of reinitialize of the node/slots setup
826 # and try again with the new setup
827 await self.aclose()
828 raise
829 except (ClusterDownError, SlotNotCoveredError):
830 # ClusterDownError can occur during a failover and to get
831 # self-healed, we will try to reinitialize the cluster layout
832 # and retry executing the command
833
834 # SlotNotCoveredError can occur when the cluster is not fully
835 # initialized or can be temporary issue.
836 # We will try to reinitialize the cluster topology
837 # and retry executing the command
838
839 await self.aclose()
840 await asyncio.sleep(0.25)
841 raise
842 except MovedError as e:
843 # First, we will try to patch the slots/nodes cache with the
844 # redirected node output and try again. If MovedError exceeds
845 # 'reinitialize_steps' number of times, we will force
846 # reinitializing the tables, and then try again.
847 # 'reinitialize_steps' counter will increase faster when
848 # the same client object is shared between multiple threads. To
849 # reduce the frequency you can set this variable in the
850 # RedisCluster constructor.
851 self.reinitialize_counter += 1
852 if (
853 self.reinitialize_steps
854 and self.reinitialize_counter % self.reinitialize_steps == 0
855 ):
856 await self.aclose()
857 # Reset the counter
858 self.reinitialize_counter = 0
859 else:
860 self.nodes_manager._moved_exception = e
861 moved = True
862 except AskError as e:
863 redirect_addr = get_node_name(host=e.host, port=e.port)
864 asking = True
865 except TryAgainError:
866 if ttl < self.RedisClusterRequestTTL / 2:
867 await asyncio.sleep(0.05)
868
869 raise ClusterError("TTL exhausted.")
870
871 def pipeline(
872 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
873 ) -> "ClusterPipeline":
874 """
875 Create & return a new :class:`~.ClusterPipeline` object.
876
877 Cluster implementation of pipeline does not support transaction or shard_hint.
878
879 :raises RedisClusterException: if transaction or shard_hint are truthy values
880 """
881 if shard_hint:
882 raise RedisClusterException("shard_hint is deprecated in cluster mode")
883
884 return ClusterPipeline(self, transaction)
885
886 def lock(
887 self,
888 name: KeyT,
889 timeout: Optional[float] = None,
890 sleep: float = 0.1,
891 blocking: bool = True,
892 blocking_timeout: Optional[float] = None,
893 lock_class: Optional[Type[Lock]] = None,
894 thread_local: bool = True,
895 raise_on_release_error: bool = True,
896 ) -> Lock:
897 """
898 Return a new Lock object using key ``name`` that mimics
899 the behavior of threading.Lock.
900
901 If specified, ``timeout`` indicates a maximum life for the lock.
902 By default, it will remain locked until release() is called.
903
904 ``sleep`` indicates the amount of time to sleep per loop iteration
905 when the lock is in blocking mode and another client is currently
906 holding the lock.
907
908 ``blocking`` indicates whether calling ``acquire`` should block until
909 the lock has been acquired or to fail immediately, causing ``acquire``
910 to return False and the lock not being acquired. Defaults to True.
911 Note this value can be overridden by passing a ``blocking``
912 argument to ``acquire``.
913
914 ``blocking_timeout`` indicates the maximum amount of time in seconds to
915 spend trying to acquire the lock. A value of ``None`` indicates
916 continue trying forever. ``blocking_timeout`` can be specified as a
917 float or integer, both representing the number of seconds to wait.
918
919 ``lock_class`` forces the specified lock implementation. Note that as
920 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
921 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
922 you have created your own custom lock class.
923
924 ``thread_local`` indicates whether the lock token is placed in
925 thread-local storage. By default, the token is placed in thread local
926 storage so that a thread only sees its token, not a token set by
927 another thread. Consider the following timeline:
928
929 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
930 thread-1 sets the token to "abc"
931 time: 1, thread-2 blocks trying to acquire `my-lock` using the
932 Lock instance.
933 time: 5, thread-1 has not yet completed. redis expires the lock
934 key.
935 time: 5, thread-2 acquired `my-lock` now that it's available.
936 thread-2 sets the token to "xyz"
937 time: 6, thread-1 finishes its work and calls release(). if the
938 token is *not* stored in thread local storage, then
939 thread-1 would see the token value as "xyz" and would be
940 able to successfully release the thread-2's lock.
941
942 ``raise_on_release_error`` indicates whether to raise an exception when
943 the lock is no longer owned when exiting the context manager. By default,
944 this is True, meaning an exception will be raised. If False, the warning
945 will be logged and the exception will be suppressed.
946
947 In some use cases it's necessary to disable thread local storage. For
948 example, if you have code where one thread acquires a lock and passes
949 that lock instance to a worker thread to release later. If thread
950 local storage isn't disabled in this case, the worker thread won't see
951 the token set by the thread that acquired the lock. Our assumption
952 is that these cases aren't common and as such default to using
953 thread local storage."""
954 if lock_class is None:
955 lock_class = Lock
956 return lock_class(
957 self,
958 name,
959 timeout=timeout,
960 sleep=sleep,
961 blocking=blocking,
962 blocking_timeout=blocking_timeout,
963 thread_local=thread_local,
964 raise_on_release_error=raise_on_release_error,
965 )
966
967 async def transaction(
968 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
969 ):
970 """
971 Convenience method for executing the callable `func` as a transaction
972 while watching all keys specified in `watches`. The 'func' callable
973 should expect a single argument which is a Pipeline object.
974 """
975 shard_hint = kwargs.pop("shard_hint", None)
976 value_from_callable = kwargs.pop("value_from_callable", False)
977 watch_delay = kwargs.pop("watch_delay", None)
978 async with self.pipeline(True, shard_hint) as pipe:
979 while True:
980 try:
981 if watches:
982 await pipe.watch(*watches)
983 func_value = await func(pipe)
984 exec_value = await pipe.execute()
985 return func_value if value_from_callable else exec_value
986 except WatchError:
987 if watch_delay is not None and watch_delay > 0:
988 time.sleep(watch_delay)
989 continue
990
991
992class ClusterNode:
993 """
994 Create a new ClusterNode.
995
996 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
997 objects for the (host, port).
998 """
999
1000 __slots__ = (
1001 "_connections",
1002 "_free",
1003 "_lock",
1004 "_event_dispatcher",
1005 "connection_class",
1006 "connection_kwargs",
1007 "host",
1008 "max_connections",
1009 "name",
1010 "port",
1011 "response_callbacks",
1012 "server_type",
1013 )
1014
1015 def __init__(
1016 self,
1017 host: str,
1018 port: Union[str, int],
1019 server_type: Optional[str] = None,
1020 *,
1021 max_connections: int = 2**31,
1022 connection_class: Type[Connection] = Connection,
1023 **connection_kwargs: Any,
1024 ) -> None:
1025 if host == "localhost":
1026 host = socket.gethostbyname(host)
1027
1028 connection_kwargs["host"] = host
1029 connection_kwargs["port"] = port
1030 self.host = host
1031 self.port = port
1032 self.name = get_node_name(host, port)
1033 self.server_type = server_type
1034
1035 self.max_connections = max_connections
1036 self.connection_class = connection_class
1037 self.connection_kwargs = connection_kwargs
1038 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1039
1040 self._connections: List[Connection] = []
1041 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1042 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1043 if self._event_dispatcher is None:
1044 self._event_dispatcher = EventDispatcher()
1045
1046 def __repr__(self) -> str:
1047 return (
1048 f"[host={self.host}, port={self.port}, "
1049 f"name={self.name}, server_type={self.server_type}]"
1050 )
1051
1052 def __eq__(self, obj: Any) -> bool:
1053 return isinstance(obj, ClusterNode) and obj.name == self.name
1054
1055 _DEL_MESSAGE = "Unclosed ClusterNode object"
1056
1057 def __del__(
1058 self,
1059 _warn: Any = warnings.warn,
1060 _grl: Any = asyncio.get_running_loop,
1061 ) -> None:
1062 for connection in self._connections:
1063 if connection.is_connected:
1064 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1065
1066 try:
1067 context = {"client": self, "message": self._DEL_MESSAGE}
1068 _grl().call_exception_handler(context)
1069 except RuntimeError:
1070 pass
1071 break
1072
1073 async def disconnect(self) -> None:
1074 ret = await asyncio.gather(
1075 *(
1076 asyncio.create_task(connection.disconnect())
1077 for connection in self._connections
1078 ),
1079 return_exceptions=True,
1080 )
1081 exc = next((res for res in ret if isinstance(res, Exception)), None)
1082 if exc:
1083 raise exc
1084
1085 def acquire_connection(self) -> Connection:
1086 try:
1087 return self._free.popleft()
1088 except IndexError:
1089 if len(self._connections) < self.max_connections:
1090 # We are configuring the connection pool not to retry
1091 # connections on lower level clients to avoid retrying
1092 # connections to nodes that are not reachable
1093 # and to avoid blocking the connection pool.
1094 # The only error that will have some handling in the lower
1095 # level clients is ConnectionError which will trigger disconnection
1096 # of the socket.
1097 # The retries will be handled on cluster client level
1098 # where we will have proper handling of the cluster topology
1099 retry = Retry(
1100 backoff=NoBackoff(),
1101 retries=0,
1102 supported_errors=(ConnectionError,),
1103 )
1104 connection_kwargs = self.connection_kwargs.copy()
1105 connection_kwargs["retry"] = retry
1106 connection = self.connection_class(**connection_kwargs)
1107 self._connections.append(connection)
1108 return connection
1109
1110 raise MaxConnectionsError()
1111
1112 def release(self, connection: Connection) -> None:
1113 """
1114 Release connection back to free queue.
1115 """
1116 self._free.append(connection)
1117
1118 async def parse_response(
1119 self, connection: Connection, command: str, **kwargs: Any
1120 ) -> Any:
1121 try:
1122 if NEVER_DECODE in kwargs:
1123 response = await connection.read_response(disable_decoding=True)
1124 kwargs.pop(NEVER_DECODE)
1125 else:
1126 response = await connection.read_response()
1127 except ResponseError:
1128 if EMPTY_RESPONSE in kwargs:
1129 return kwargs[EMPTY_RESPONSE]
1130 raise
1131
1132 if EMPTY_RESPONSE in kwargs:
1133 kwargs.pop(EMPTY_RESPONSE)
1134
1135 # Remove keys entry, it needs only for cache.
1136 kwargs.pop("keys", None)
1137
1138 # Return response
1139 if command in self.response_callbacks:
1140 return self.response_callbacks[command](response, **kwargs)
1141
1142 return response
1143
1144 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1145 # Acquire connection
1146 connection = self.acquire_connection()
1147
1148 # Execute command
1149 await connection.send_packed_command(connection.pack_command(*args), False)
1150
1151 # Read response
1152 try:
1153 return await self.parse_response(connection, args[0], **kwargs)
1154 finally:
1155 # Release connection
1156 self._free.append(connection)
1157
1158 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1159 # Acquire connection
1160 connection = self.acquire_connection()
1161
1162 # Execute command
1163 await connection.send_packed_command(
1164 connection.pack_commands(cmd.args for cmd in commands), False
1165 )
1166
1167 # Read responses
1168 ret = False
1169 for cmd in commands:
1170 try:
1171 cmd.result = await self.parse_response(
1172 connection, cmd.args[0], **cmd.kwargs
1173 )
1174 except Exception as e:
1175 cmd.result = e
1176 ret = True
1177
1178 # Release connection
1179 self._free.append(connection)
1180
1181 return ret
1182
1183 async def re_auth_callback(self, token: TokenInterface):
1184 tmp_queue = collections.deque()
1185 while self._free:
1186 conn = self._free.popleft()
1187 await conn.retry.call_with_retry(
1188 lambda: conn.send_command(
1189 "AUTH", token.try_get("oid"), token.get_value()
1190 ),
1191 lambda error: self._mock(error),
1192 )
1193 await conn.retry.call_with_retry(
1194 lambda: conn.read_response(), lambda error: self._mock(error)
1195 )
1196 tmp_queue.append(conn)
1197
1198 while tmp_queue:
1199 conn = tmp_queue.popleft()
1200 self._free.append(conn)
1201
1202 async def _mock(self, error: RedisError):
1203 """
1204 Dummy functions, needs to be passed as error callback to retry object.
1205 :param error:
1206 :return:
1207 """
1208 pass
1209
1210
1211class NodesManager:
1212 __slots__ = (
1213 "_dynamic_startup_nodes",
1214 "_moved_exception",
1215 "_event_dispatcher",
1216 "connection_kwargs",
1217 "default_node",
1218 "nodes_cache",
1219 "read_load_balancer",
1220 "require_full_coverage",
1221 "slots_cache",
1222 "startup_nodes",
1223 "address_remap",
1224 )
1225
1226 def __init__(
1227 self,
1228 startup_nodes: List["ClusterNode"],
1229 require_full_coverage: bool,
1230 connection_kwargs: Dict[str, Any],
1231 dynamic_startup_nodes: bool = True,
1232 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1233 event_dispatcher: Optional[EventDispatcher] = None,
1234 ) -> None:
1235 self.startup_nodes = {node.name: node for node in startup_nodes}
1236 self.require_full_coverage = require_full_coverage
1237 self.connection_kwargs = connection_kwargs
1238 self.address_remap = address_remap
1239
1240 self.default_node: "ClusterNode" = None
1241 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1242 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1243 self.read_load_balancer = LoadBalancer()
1244
1245 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1246 self._moved_exception: MovedError = None
1247 if event_dispatcher is None:
1248 self._event_dispatcher = EventDispatcher()
1249 else:
1250 self._event_dispatcher = event_dispatcher
1251
1252 def get_node(
1253 self,
1254 host: Optional[str] = None,
1255 port: Optional[int] = None,
1256 node_name: Optional[str] = None,
1257 ) -> Optional["ClusterNode"]:
1258 if host and port:
1259 # the user passed host and port
1260 if host == "localhost":
1261 host = socket.gethostbyname(host)
1262 return self.nodes_cache.get(get_node_name(host=host, port=port))
1263 elif node_name:
1264 return self.nodes_cache.get(node_name)
1265 else:
1266 raise DataError(
1267 "get_node requires one of the following: 1. node name 2. host and port"
1268 )
1269
1270 def set_nodes(
1271 self,
1272 old: Dict[str, "ClusterNode"],
1273 new: Dict[str, "ClusterNode"],
1274 remove_old: bool = False,
1275 ) -> None:
1276 if remove_old:
1277 for name in list(old.keys()):
1278 if name not in new:
1279 task = asyncio.create_task(old.pop(name).disconnect()) # noqa
1280
1281 for name, node in new.items():
1282 if name in old:
1283 if old[name] is node:
1284 continue
1285 task = asyncio.create_task(old[name].disconnect()) # noqa
1286 old[name] = node
1287
1288 def update_moved_exception(self, exception):
1289 self._moved_exception = exception
1290
1291 def _update_moved_slots(self) -> None:
1292 e = self._moved_exception
1293 redirected_node = self.get_node(host=e.host, port=e.port)
1294 if redirected_node:
1295 # The node already exists
1296 if redirected_node.server_type != PRIMARY:
1297 # Update the node's server type
1298 redirected_node.server_type = PRIMARY
1299 else:
1300 # This is a new node, we will add it to the nodes cache
1301 redirected_node = ClusterNode(
1302 e.host, e.port, PRIMARY, **self.connection_kwargs
1303 )
1304 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1305 if redirected_node in self.slots_cache[e.slot_id]:
1306 # The MOVED error resulted from a failover, and the new slot owner
1307 # had previously been a replica.
1308 old_primary = self.slots_cache[e.slot_id][0]
1309 # Update the old primary to be a replica and add it to the end of
1310 # the slot's node list
1311 old_primary.server_type = REPLICA
1312 self.slots_cache[e.slot_id].append(old_primary)
1313 # Remove the old replica, which is now a primary, from the slot's
1314 # node list
1315 self.slots_cache[e.slot_id].remove(redirected_node)
1316 # Override the old primary with the new one
1317 self.slots_cache[e.slot_id][0] = redirected_node
1318 if self.default_node == old_primary:
1319 # Update the default node with the new primary
1320 self.default_node = redirected_node
1321 else:
1322 # The new slot owner is a new server, or a server from a different
1323 # shard. We need to remove all current nodes from the slot's list
1324 # (including replications) and add just the new node.
1325 self.slots_cache[e.slot_id] = [redirected_node]
1326 # Reset moved_exception
1327 self._moved_exception = None
1328
1329 def get_node_from_slot(
1330 self,
1331 slot: int,
1332 read_from_replicas: bool = False,
1333 load_balancing_strategy=None,
1334 ) -> "ClusterNode":
1335 if self._moved_exception:
1336 self._update_moved_slots()
1337
1338 if read_from_replicas is True and load_balancing_strategy is None:
1339 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1340
1341 try:
1342 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1343 # get the server index using the strategy defined in load_balancing_strategy
1344 primary_name = self.slots_cache[slot][0].name
1345 node_idx = self.read_load_balancer.get_server_index(
1346 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1347 )
1348 return self.slots_cache[slot][node_idx]
1349 return self.slots_cache[slot][0]
1350 except (IndexError, TypeError):
1351 raise SlotNotCoveredError(
1352 f'Slot "{slot}" not covered by the cluster. '
1353 f'"require_full_coverage={self.require_full_coverage}"'
1354 )
1355
1356 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1357 return [
1358 node
1359 for node in self.nodes_cache.values()
1360 if node.server_type == server_type
1361 ]
1362
1363 async def initialize(self) -> None:
1364 self.read_load_balancer.reset()
1365 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1366 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1367 disagreements = []
1368 startup_nodes_reachable = False
1369 fully_covered = False
1370 exception = None
1371 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1372 # is modified during iteration
1373 for startup_node in tuple(self.startup_nodes.values()):
1374 try:
1375 # Make sure cluster mode is enabled on this node
1376 try:
1377 self._event_dispatcher.dispatch(
1378 AfterAsyncClusterInstantiationEvent(
1379 self.nodes_cache,
1380 self.connection_kwargs.get("credential_provider", None),
1381 )
1382 )
1383 cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
1384 except ResponseError:
1385 raise RedisClusterException(
1386 "Cluster mode is not enabled on this node"
1387 )
1388 startup_nodes_reachable = True
1389 except Exception as e:
1390 # Try the next startup node.
1391 # The exception is saved and raised only if we have no more nodes.
1392 exception = e
1393 continue
1394
1395 # CLUSTER SLOTS command results in the following output:
1396 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1397 # where each node contains the following list: [IP, port, node_id]
1398 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1399 # primary node of the first slot section.
1400 # If there's only one server in the cluster, its ``host`` is ''
1401 # Fix it to the host in startup_nodes
1402 if (
1403 len(cluster_slots) == 1
1404 and not cluster_slots[0][2][0]
1405 and len(self.startup_nodes) == 1
1406 ):
1407 cluster_slots[0][2][0] = startup_node.host
1408
1409 for slot in cluster_slots:
1410 for i in range(2, len(slot)):
1411 slot[i] = [str_if_bytes(val) for val in slot[i]]
1412 primary_node = slot[2]
1413 host = primary_node[0]
1414 if host == "":
1415 host = startup_node.host
1416 port = int(primary_node[1])
1417 host, port = self.remap_host_port(host, port)
1418
1419 nodes_for_slot = []
1420
1421 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1422 if not target_node:
1423 target_node = ClusterNode(
1424 host, port, PRIMARY, **self.connection_kwargs
1425 )
1426 # add this node to the nodes cache
1427 tmp_nodes_cache[target_node.name] = target_node
1428 nodes_for_slot.append(target_node)
1429
1430 replica_nodes = slot[3:]
1431 for replica_node in replica_nodes:
1432 host = replica_node[0]
1433 port = replica_node[1]
1434 host, port = self.remap_host_port(host, port)
1435
1436 target_replica_node = tmp_nodes_cache.get(get_node_name(host, port))
1437 if not target_replica_node:
1438 target_replica_node = ClusterNode(
1439 host, port, REPLICA, **self.connection_kwargs
1440 )
1441 # add this node to the nodes cache
1442 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1443 nodes_for_slot.append(target_replica_node)
1444
1445 for i in range(int(slot[0]), int(slot[1]) + 1):
1446 if i not in tmp_slots:
1447 tmp_slots[i] = nodes_for_slot
1448 else:
1449 # Validate that 2 nodes want to use the same slot cache
1450 # setup
1451 tmp_slot = tmp_slots[i][0]
1452 if tmp_slot.name != target_node.name:
1453 disagreements.append(
1454 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1455 )
1456
1457 if len(disagreements) > 5:
1458 raise RedisClusterException(
1459 f"startup_nodes could not agree on a valid "
1460 f"slots cache: {', '.join(disagreements)}"
1461 )
1462
1463 # Validate if all slots are covered or if we should try next startup node
1464 fully_covered = True
1465 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1466 if i not in tmp_slots:
1467 fully_covered = False
1468 break
1469 if fully_covered:
1470 break
1471
1472 if not startup_nodes_reachable:
1473 raise RedisClusterException(
1474 f"Redis Cluster cannot be connected. Please provide at least "
1475 f"one reachable node: {str(exception)}"
1476 ) from exception
1477
1478 # Check if the slots are not fully covered
1479 if not fully_covered and self.require_full_coverage:
1480 # Despite the requirement that the slots be covered, there
1481 # isn't a full coverage
1482 raise RedisClusterException(
1483 f"All slots are not covered after query all startup_nodes. "
1484 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1485 f"covered..."
1486 )
1487
1488 # Set the tmp variables to the real variables
1489 self.slots_cache = tmp_slots
1490 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1491
1492 if self._dynamic_startup_nodes:
1493 # Populate the startup nodes with all discovered nodes
1494 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1495
1496 # Set the default node
1497 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
1498 # If initialize was called after a MovedError, clear it
1499 self._moved_exception = None
1500
1501 async def aclose(self, attr: str = "nodes_cache") -> None:
1502 self.default_node = None
1503 await asyncio.gather(
1504 *(
1505 asyncio.create_task(node.disconnect())
1506 for node in getattr(self, attr).values()
1507 )
1508 )
1509
1510 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1511 """
1512 Remap the host and port returned from the cluster to a different
1513 internal value. Useful if the client is not connecting directly
1514 to the cluster.
1515 """
1516 if self.address_remap:
1517 return self.address_remap((host, port))
1518 return host, port
1519
1520
1521class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1522 """
1523 Create a new ClusterPipeline object.
1524
1525 Usage::
1526
1527 result = await (
1528 rc.pipeline()
1529 .set("A", 1)
1530 .get("A")
1531 .hset("K", "F", "V")
1532 .hgetall("K")
1533 .mset_nonatomic({"A": 2, "B": 3})
1534 .get("A")
1535 .get("B")
1536 .delete("A", "B", "K")
1537 .execute()
1538 )
1539 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
1540
1541 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
1542 are split across multiple nodes, you'll get multiple results for them in the array.
1543
1544 Retryable errors:
1545 - :class:`~.ClusterDownError`
1546 - :class:`~.ConnectionError`
1547 - :class:`~.TimeoutError`
1548
1549 Redirection errors:
1550 - :class:`~.TryAgainError`
1551 - :class:`~.MovedError`
1552 - :class:`~.AskError`
1553
1554 :param client:
1555 | Existing :class:`~.RedisCluster` client
1556 """
1557
1558 __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1559
1560 def __init__(
1561 self, client: RedisCluster, transaction: Optional[bool] = None
1562 ) -> None:
1563 self.cluster_client = client
1564 self._transaction = transaction
1565 self._execution_strategy: ExecutionStrategy = (
1566 PipelineStrategy(self)
1567 if not self._transaction
1568 else TransactionStrategy(self)
1569 )
1570
1571 async def initialize(self) -> "ClusterPipeline":
1572 await self._execution_strategy.initialize()
1573 return self
1574
1575 async def __aenter__(self) -> "ClusterPipeline":
1576 return await self.initialize()
1577
1578 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1579 await self.reset()
1580
1581 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1582 return self.initialize().__await__()
1583
1584 def __bool__(self) -> bool:
1585 "Pipeline instances should always evaluate to True on Python 3+"
1586 return True
1587
1588 def __len__(self) -> int:
1589 return len(self._execution_strategy)
1590
1591 def execute_command(
1592 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1593 ) -> "ClusterPipeline":
1594 """
1595 Append a raw command to the pipeline.
1596
1597 :param args:
1598 | Raw command args
1599 :param kwargs:
1600
1601 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
1602 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1603 - Rest of the kwargs are passed to the Redis connection
1604 """
1605 return self._execution_strategy.execute_command(*args, **kwargs)
1606
1607 async def execute(
1608 self, raise_on_error: bool = True, allow_redirections: bool = True
1609 ) -> List[Any]:
1610 """
1611 Execute the pipeline.
1612
1613 It will retry the commands as specified by retries specified in :attr:`retry`
1614 & then raise an exception.
1615
1616 :param raise_on_error:
1617 | Raise the first error if there are any errors
1618 :param allow_redirections:
1619 | Whether to retry each failed command individually in case of redirection
1620 errors
1621
1622 :raises RedisClusterException: if target_nodes is not provided & the command
1623 can't be mapped to a slot
1624 """
1625 try:
1626 return await self._execution_strategy.execute(
1627 raise_on_error, allow_redirections
1628 )
1629 finally:
1630 await self.reset()
1631
1632 def _split_command_across_slots(
1633 self, command: str, *keys: KeyT
1634 ) -> "ClusterPipeline":
1635 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1636 self.execute_command(command, *slot_keys)
1637
1638 return self
1639
1640 async def reset(self):
1641 """
1642 Reset back to empty pipeline.
1643 """
1644 await self._execution_strategy.reset()
1645
1646 def multi(self):
1647 """
1648 Start a transactional block of the pipeline after WATCH commands
1649 are issued. End the transactional block with `execute`.
1650 """
1651 self._execution_strategy.multi()
1652
1653 async def discard(self):
1654 """ """
1655 await self._execution_strategy.discard()
1656
1657 async def watch(self, *names):
1658 """Watches the values at keys ``names``"""
1659 await self._execution_strategy.watch(*names)
1660
1661 async def unwatch(self):
1662 """Unwatches all previously specified keys"""
1663 await self._execution_strategy.unwatch()
1664
1665 async def unlink(self, *names):
1666 await self._execution_strategy.unlink(*names)
1667
1668 def mset_nonatomic(
1669 self, mapping: Mapping[AnyKeyT, EncodableT]
1670 ) -> "ClusterPipeline":
1671 return self._execution_strategy.mset_nonatomic(mapping)
1672
1673
1674for command in PIPELINE_BLOCKED_COMMANDS:
1675 command = command.replace(" ", "_").lower()
1676 if command == "mset_nonatomic":
1677 continue
1678
1679 setattr(ClusterPipeline, command, block_pipeline_command(command))
1680
1681
1682class PipelineCommand:
1683 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1684 self.args = args
1685 self.kwargs = kwargs
1686 self.position = position
1687 self.result: Union[Any, Exception] = None
1688
1689 def __repr__(self) -> str:
1690 return f"[{self.position}] {self.args} ({self.kwargs})"
1691
1692
1693class ExecutionStrategy(ABC):
1694 @abstractmethod
1695 async def initialize(self) -> "ClusterPipeline":
1696 """
1697 Initialize the execution strategy.
1698
1699 See ClusterPipeline.initialize()
1700 """
1701 pass
1702
1703 @abstractmethod
1704 def execute_command(
1705 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1706 ) -> "ClusterPipeline":
1707 """
1708 Append a raw command to the pipeline.
1709
1710 See ClusterPipeline.execute_command()
1711 """
1712 pass
1713
1714 @abstractmethod
1715 async def execute(
1716 self, raise_on_error: bool = True, allow_redirections: bool = True
1717 ) -> List[Any]:
1718 """
1719 Execute the pipeline.
1720
1721 It will retry the commands as specified by retries specified in :attr:`retry`
1722 & then raise an exception.
1723
1724 See ClusterPipeline.execute()
1725 """
1726 pass
1727
1728 @abstractmethod
1729 def mset_nonatomic(
1730 self, mapping: Mapping[AnyKeyT, EncodableT]
1731 ) -> "ClusterPipeline":
1732 """
1733 Executes multiple MSET commands according to the provided slot/pairs mapping.
1734
1735 See ClusterPipeline.mset_nonatomic()
1736 """
1737 pass
1738
1739 @abstractmethod
1740 async def reset(self):
1741 """
1742 Resets current execution strategy.
1743
1744 See: ClusterPipeline.reset()
1745 """
1746 pass
1747
1748 @abstractmethod
1749 def multi(self):
1750 """
1751 Starts transactional context.
1752
1753 See: ClusterPipeline.multi()
1754 """
1755 pass
1756
1757 @abstractmethod
1758 async def watch(self, *names):
1759 """
1760 Watch given keys.
1761
1762 See: ClusterPipeline.watch()
1763 """
1764 pass
1765
1766 @abstractmethod
1767 async def unwatch(self):
1768 """
1769 Unwatches all previously specified keys
1770
1771 See: ClusterPipeline.unwatch()
1772 """
1773 pass
1774
1775 @abstractmethod
1776 async def discard(self):
1777 pass
1778
1779 @abstractmethod
1780 async def unlink(self, *names):
1781 """
1782 "Unlink a key specified by ``names``"
1783
1784 See: ClusterPipeline.unlink()
1785 """
1786 pass
1787
1788 @abstractmethod
1789 def __len__(self) -> int:
1790 pass
1791
1792
1793class AbstractStrategy(ExecutionStrategy):
1794 def __init__(self, pipe: ClusterPipeline) -> None:
1795 self._pipe: ClusterPipeline = pipe
1796 self._command_queue: List["PipelineCommand"] = []
1797
1798 async def initialize(self) -> "ClusterPipeline":
1799 if self._pipe.cluster_client._initialize:
1800 await self._pipe.cluster_client.initialize()
1801 self._command_queue = []
1802 return self._pipe
1803
1804 def execute_command(
1805 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1806 ) -> "ClusterPipeline":
1807 self._command_queue.append(
1808 PipelineCommand(len(self._command_queue), *args, **kwargs)
1809 )
1810 return self._pipe
1811
1812 def _annotate_exception(self, exception, number, command):
1813 """
1814 Provides extra context to the exception prior to it being handled
1815 """
1816 cmd = " ".join(map(safe_str, command))
1817 msg = (
1818 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
1819 f"caused error: {exception.args[0]}"
1820 )
1821 exception.args = (msg,) + exception.args[1:]
1822
1823 @abstractmethod
1824 def mset_nonatomic(
1825 self, mapping: Mapping[AnyKeyT, EncodableT]
1826 ) -> "ClusterPipeline":
1827 pass
1828
1829 @abstractmethod
1830 async def execute(
1831 self, raise_on_error: bool = True, allow_redirections: bool = True
1832 ) -> List[Any]:
1833 pass
1834
1835 @abstractmethod
1836 async def reset(self):
1837 pass
1838
1839 @abstractmethod
1840 def multi(self):
1841 pass
1842
1843 @abstractmethod
1844 async def watch(self, *names):
1845 pass
1846
1847 @abstractmethod
1848 async def unwatch(self):
1849 pass
1850
1851 @abstractmethod
1852 async def discard(self):
1853 pass
1854
1855 @abstractmethod
1856 async def unlink(self, *names):
1857 pass
1858
1859 def __len__(self) -> int:
1860 return len(self._command_queue)
1861
1862
1863class PipelineStrategy(AbstractStrategy):
1864 def __init__(self, pipe: ClusterPipeline) -> None:
1865 super().__init__(pipe)
1866
1867 def mset_nonatomic(
1868 self, mapping: Mapping[AnyKeyT, EncodableT]
1869 ) -> "ClusterPipeline":
1870 encoder = self._pipe.cluster_client.encoder
1871
1872 slots_pairs = {}
1873 for pair in mapping.items():
1874 slot = key_slot(encoder.encode(pair[0]))
1875 slots_pairs.setdefault(slot, []).extend(pair)
1876
1877 for pairs in slots_pairs.values():
1878 self.execute_command("MSET", *pairs)
1879
1880 return self._pipe
1881
1882 async def execute(
1883 self, raise_on_error: bool = True, allow_redirections: bool = True
1884 ) -> List[Any]:
1885 if not self._command_queue:
1886 return []
1887
1888 try:
1889 retry_attempts = self._pipe.cluster_client.retry.get_retries()
1890 while True:
1891 try:
1892 if self._pipe.cluster_client._initialize:
1893 await self._pipe.cluster_client.initialize()
1894 return await self._execute(
1895 self._pipe.cluster_client,
1896 self._command_queue,
1897 raise_on_error=raise_on_error,
1898 allow_redirections=allow_redirections,
1899 )
1900
1901 except RedisCluster.ERRORS_ALLOW_RETRY as e:
1902 if retry_attempts > 0:
1903 # Try again with the new cluster setup. All other errors
1904 # should be raised.
1905 retry_attempts -= 1
1906 await self._pipe.cluster_client.aclose()
1907 await asyncio.sleep(0.25)
1908 else:
1909 # All other errors should be raised.
1910 raise e
1911 finally:
1912 await self.reset()
1913
1914 async def _execute(
1915 self,
1916 client: "RedisCluster",
1917 stack: List["PipelineCommand"],
1918 raise_on_error: bool = True,
1919 allow_redirections: bool = True,
1920 ) -> List[Any]:
1921 todo = [
1922 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
1923 ]
1924
1925 nodes = {}
1926 for cmd in todo:
1927 passed_targets = cmd.kwargs.pop("target_nodes", None)
1928 if passed_targets and not client._is_node_flag(passed_targets):
1929 target_nodes = client._parse_target_nodes(passed_targets)
1930 else:
1931 target_nodes = await client._determine_nodes(
1932 *cmd.args, node_flag=passed_targets
1933 )
1934 if not target_nodes:
1935 raise RedisClusterException(
1936 f"No targets were found to execute {cmd.args} command on"
1937 )
1938 if len(target_nodes) > 1:
1939 raise RedisClusterException(f"Too many targets for command {cmd.args}")
1940 node = target_nodes[0]
1941 if node.name not in nodes:
1942 nodes[node.name] = (node, [])
1943 nodes[node.name][1].append(cmd)
1944
1945 errors = await asyncio.gather(
1946 *(
1947 asyncio.create_task(node[0].execute_pipeline(node[1]))
1948 for node in nodes.values()
1949 )
1950 )
1951
1952 if any(errors):
1953 if allow_redirections:
1954 # send each errored command individually
1955 for cmd in todo:
1956 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
1957 try:
1958 cmd.result = await client.execute_command(
1959 *cmd.args, **cmd.kwargs
1960 )
1961 except Exception as e:
1962 cmd.result = e
1963
1964 if raise_on_error:
1965 for cmd in todo:
1966 result = cmd.result
1967 if isinstance(result, Exception):
1968 command = " ".join(map(safe_str, cmd.args))
1969 msg = (
1970 f"Command # {cmd.position + 1} "
1971 f"({truncate_text(command)}) "
1972 f"of pipeline caused error: {result.args}"
1973 )
1974 result.args = (msg,) + result.args[1:]
1975 raise result
1976
1977 default_cluster_node = client.get_default_node()
1978
1979 # Check whether the default node was used. In some cases,
1980 # 'client.get_default_node()' may return None. The check below
1981 # prevents a potential AttributeError.
1982 if default_cluster_node is not None:
1983 default_node = nodes.get(default_cluster_node.name)
1984 if default_node is not None:
1985 # This pipeline execution used the default node, check if we need
1986 # to replace it.
1987 # Note: when the error is raised we'll reset the default node in the
1988 # caller function.
1989 for cmd in default_node[1]:
1990 # Check if it has a command that failed with a relevant
1991 # exception
1992 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
1993 client.replace_default_node()
1994 break
1995
1996 return [cmd.result for cmd in stack]
1997
1998 async def reset(self):
1999 """
2000 Reset back to empty pipeline.
2001 """
2002 self._command_queue = []
2003
2004 def multi(self):
2005 raise RedisClusterException(
2006 "method multi() is not supported outside of transactional context"
2007 )
2008
2009 async def watch(self, *names):
2010 raise RedisClusterException(
2011 "method watch() is not supported outside of transactional context"
2012 )
2013
2014 async def unwatch(self):
2015 raise RedisClusterException(
2016 "method unwatch() is not supported outside of transactional context"
2017 )
2018
2019 async def discard(self):
2020 raise RedisClusterException(
2021 "method discard() is not supported outside of transactional context"
2022 )
2023
2024 async def unlink(self, *names):
2025 if len(names) != 1:
2026 raise RedisClusterException(
2027 "unlinking multiple keys is not implemented in pipeline command"
2028 )
2029
2030 return self.execute_command("UNLINK", names[0])
2031
2032
2033class TransactionStrategy(AbstractStrategy):
2034 NO_SLOTS_COMMANDS = {"UNWATCH"}
2035 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2036 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2037 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2038 CONNECTION_ERRORS = (
2039 ConnectionError,
2040 OSError,
2041 ClusterDownError,
2042 SlotNotCoveredError,
2043 )
2044
2045 def __init__(self, pipe: ClusterPipeline) -> None:
2046 super().__init__(pipe)
2047 self._explicit_transaction = False
2048 self._watching = False
2049 self._pipeline_slots: Set[int] = set()
2050 self._transaction_node: Optional[ClusterNode] = None
2051 self._transaction_connection: Optional[Connection] = None
2052 self._executing = False
2053 self._retry = copy(self._pipe.cluster_client.retry)
2054 self._retry.update_supported_errors(
2055 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2056 )
2057
2058 def _get_client_and_connection_for_transaction(
2059 self,
2060 ) -> Tuple[ClusterNode, Connection]:
2061 """
2062 Find a connection for a pipeline transaction.
2063
2064 For running an atomic transaction, watch keys ensure that contents have not been
2065 altered as long as the watch commands for those keys were sent over the same
2066 connection. So once we start watching a key, we fetch a connection to the
2067 node that owns that slot and reuse it.
2068 """
2069 if not self._pipeline_slots:
2070 raise RedisClusterException(
2071 "At least a command with a key is needed to identify a node"
2072 )
2073
2074 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2075 list(self._pipeline_slots)[0], False
2076 )
2077 self._transaction_node = node
2078
2079 if not self._transaction_connection:
2080 connection: Connection = self._transaction_node.acquire_connection()
2081 self._transaction_connection = connection
2082
2083 return self._transaction_node, self._transaction_connection
2084
2085 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2086 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2087 response = None
2088 error = None
2089
2090 def runner():
2091 nonlocal response
2092 nonlocal error
2093 try:
2094 response = asyncio.run(self._execute_command(*args, **kwargs))
2095 except Exception as e:
2096 error = e
2097
2098 thread = threading.Thread(target=runner)
2099 thread.start()
2100 thread.join()
2101
2102 if error:
2103 raise error
2104
2105 return response
2106
2107 async def _execute_command(
2108 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2109 ) -> Any:
2110 if self._pipe.cluster_client._initialize:
2111 await self._pipe.cluster_client.initialize()
2112
2113 slot_number: Optional[int] = None
2114 if args[0] not in self.NO_SLOTS_COMMANDS:
2115 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2116
2117 if (
2118 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2119 ) and not self._explicit_transaction:
2120 if args[0] == "WATCH":
2121 self._validate_watch()
2122
2123 if slot_number is not None:
2124 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2125 raise CrossSlotTransactionError(
2126 "Cannot watch or send commands on different slots"
2127 )
2128
2129 self._pipeline_slots.add(slot_number)
2130 elif args[0] not in self.NO_SLOTS_COMMANDS:
2131 raise RedisClusterException(
2132 f"Cannot identify slot number for command: {args[0]},"
2133 "it cannot be triggered in a transaction"
2134 )
2135
2136 return self._immediate_execute_command(*args, **kwargs)
2137 else:
2138 if slot_number is not None:
2139 self._pipeline_slots.add(slot_number)
2140
2141 return super().execute_command(*args, **kwargs)
2142
2143 def _validate_watch(self):
2144 if self._explicit_transaction:
2145 raise RedisError("Cannot issue a WATCH after a MULTI")
2146
2147 self._watching = True
2148
2149 async def _immediate_execute_command(self, *args, **options):
2150 return await self._retry.call_with_retry(
2151 lambda: self._get_connection_and_send_command(*args, **options),
2152 self._reinitialize_on_error,
2153 )
2154
2155 async def _get_connection_and_send_command(self, *args, **options):
2156 redis_node, connection = self._get_client_and_connection_for_transaction()
2157 return await self._send_command_parse_response(
2158 connection, redis_node, args[0], *args, **options
2159 )
2160
2161 async def _send_command_parse_response(
2162 self,
2163 connection: Connection,
2164 redis_node: ClusterNode,
2165 command_name,
2166 *args,
2167 **options,
2168 ):
2169 """
2170 Send a command and parse the response
2171 """
2172
2173 await connection.send_command(*args)
2174 output = await redis_node.parse_response(connection, command_name, **options)
2175
2176 if command_name in self.UNWATCH_COMMANDS:
2177 self._watching = False
2178 return output
2179
2180 async def _reinitialize_on_error(self, error):
2181 if self._watching:
2182 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2183 raise WatchError("Slot rebalancing occurred while watching keys")
2184
2185 if (
2186 type(error) in self.SLOT_REDIRECT_ERRORS
2187 or type(error) in self.CONNECTION_ERRORS
2188 ):
2189 if self._transaction_connection:
2190 self._transaction_connection = None
2191
2192 self._pipe.cluster_client.reinitialize_counter += 1
2193 if (
2194 self._pipe.cluster_client.reinitialize_steps
2195 and self._pipe.cluster_client.reinitialize_counter
2196 % self._pipe.cluster_client.reinitialize_steps
2197 == 0
2198 ):
2199 await self._pipe.cluster_client.nodes_manager.initialize()
2200 self.reinitialize_counter = 0
2201 else:
2202 self._pipe.cluster_client.nodes_manager.update_moved_exception(error)
2203
2204 self._executing = False
2205
2206 def _raise_first_error(self, responses, stack):
2207 """
2208 Raise the first exception on the stack
2209 """
2210 for r, cmd in zip(responses, stack):
2211 if isinstance(r, Exception):
2212 self._annotate_exception(r, cmd.position + 1, cmd.args)
2213 raise r
2214
2215 def mset_nonatomic(
2216 self, mapping: Mapping[AnyKeyT, EncodableT]
2217 ) -> "ClusterPipeline":
2218 raise NotImplementedError("Method is not supported in transactional context.")
2219
2220 async def execute(
2221 self, raise_on_error: bool = True, allow_redirections: bool = True
2222 ) -> List[Any]:
2223 stack = self._command_queue
2224 if not stack and (not self._watching or not self._pipeline_slots):
2225 return []
2226
2227 return await self._execute_transaction_with_retries(stack, raise_on_error)
2228
2229 async def _execute_transaction_with_retries(
2230 self, stack: List["PipelineCommand"], raise_on_error: bool
2231 ):
2232 return await self._retry.call_with_retry(
2233 lambda: self._execute_transaction(stack, raise_on_error),
2234 self._reinitialize_on_error,
2235 )
2236
2237 async def _execute_transaction(
2238 self, stack: List["PipelineCommand"], raise_on_error: bool
2239 ):
2240 if len(self._pipeline_slots) > 1:
2241 raise CrossSlotTransactionError(
2242 "All keys involved in a cluster transaction must map to the same slot"
2243 )
2244
2245 self._executing = True
2246
2247 redis_node, connection = self._get_client_and_connection_for_transaction()
2248
2249 stack = chain(
2250 [PipelineCommand(0, "MULTI")],
2251 stack,
2252 [PipelineCommand(0, "EXEC")],
2253 )
2254 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2255 packed_commands = connection.pack_commands(commands)
2256 await connection.send_packed_command(packed_commands)
2257 errors = []
2258
2259 # parse off the response for MULTI
2260 # NOTE: we need to handle ResponseErrors here and continue
2261 # so that we read all the additional command messages from
2262 # the socket
2263 try:
2264 await redis_node.parse_response(connection, "MULTI")
2265 except ResponseError as e:
2266 self._annotate_exception(e, 0, "MULTI")
2267 errors.append(e)
2268 except self.CONNECTION_ERRORS as cluster_error:
2269 self._annotate_exception(cluster_error, 0, "MULTI")
2270 raise
2271
2272 # and all the other commands
2273 for i, command in enumerate(self._command_queue):
2274 if EMPTY_RESPONSE in command.kwargs:
2275 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2276 else:
2277 try:
2278 _ = await redis_node.parse_response(connection, "_")
2279 except self.SLOT_REDIRECT_ERRORS as slot_error:
2280 self._annotate_exception(slot_error, i + 1, command.args)
2281 errors.append(slot_error)
2282 except self.CONNECTION_ERRORS as cluster_error:
2283 self._annotate_exception(cluster_error, i + 1, command.args)
2284 raise
2285 except ResponseError as e:
2286 self._annotate_exception(e, i + 1, command.args)
2287 errors.append(e)
2288
2289 response = None
2290 # parse the EXEC.
2291 try:
2292 response = await redis_node.parse_response(connection, "EXEC")
2293 except ExecAbortError:
2294 if errors:
2295 raise errors[0]
2296 raise
2297
2298 self._executing = False
2299
2300 # EXEC clears any watched keys
2301 self._watching = False
2302
2303 if response is None:
2304 raise WatchError("Watched variable changed.")
2305
2306 # put any parse errors into the response
2307 for i, e in errors:
2308 response.insert(i, e)
2309
2310 if len(response) != len(self._command_queue):
2311 raise InvalidPipelineStack(
2312 "Unexpected response length for cluster pipeline EXEC."
2313 " Command stack was {} but response had length {}".format(
2314 [c.args[0] for c in self._command_queue], len(response)
2315 )
2316 )
2317
2318 # find any errors in the response and raise if necessary
2319 if raise_on_error or len(errors) > 0:
2320 self._raise_first_error(
2321 response,
2322 self._command_queue,
2323 )
2324
2325 # We have to run response callbacks manually
2326 data = []
2327 for r, cmd in zip(response, self._command_queue):
2328 if not isinstance(r, Exception):
2329 command_name = cmd.args[0]
2330 if command_name in self._pipe.cluster_client.response_callbacks:
2331 r = self._pipe.cluster_client.response_callbacks[command_name](
2332 r, **cmd.kwargs
2333 )
2334 data.append(r)
2335 return data
2336
2337 async def reset(self):
2338 self._command_queue = []
2339
2340 # make sure to reset the connection state in the event that we were
2341 # watching something
2342 if self._transaction_connection:
2343 try:
2344 if self._watching:
2345 # call this manually since our unwatch or
2346 # immediate_execute_command methods can call reset()
2347 await self._transaction_connection.send_command("UNWATCH")
2348 await self._transaction_connection.read_response()
2349 # we can safely return the connection to the pool here since we're
2350 # sure we're no longer WATCHing anything
2351 self._transaction_node.release(self._transaction_connection)
2352 self._transaction_connection = None
2353 except self.CONNECTION_ERRORS:
2354 # disconnect will also remove any previous WATCHes
2355 if self._transaction_connection:
2356 await self._transaction_connection.disconnect()
2357
2358 # clean up the other instance attributes
2359 self._transaction_node = None
2360 self._watching = False
2361 self._explicit_transaction = False
2362 self._pipeline_slots = set()
2363 self._executing = False
2364
2365 def multi(self):
2366 if self._explicit_transaction:
2367 raise RedisError("Cannot issue nested calls to MULTI")
2368 if self._command_queue:
2369 raise RedisError(
2370 "Commands without an initial WATCH have already been issued"
2371 )
2372 self._explicit_transaction = True
2373
2374 async def watch(self, *names):
2375 if self._explicit_transaction:
2376 raise RedisError("Cannot issue a WATCH after a MULTI")
2377
2378 return await self.execute_command("WATCH", *names)
2379
2380 async def unwatch(self):
2381 if self._watching:
2382 return await self.execute_command("UNWATCH")
2383
2384 return True
2385
2386 async def discard(self):
2387 await self.reset()
2388
2389 async def unlink(self, *names):
2390 return self.execute_command("UNLINK", *names)