1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 Any,
13 Callable,
14 Coroutine,
15 Deque,
16 Dict,
17 Generator,
18 List,
19 Mapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26)
27
28from redis._parsers import AsyncCommandsParser, Encoder
29from redis._parsers.helpers import (
30 _RedisCallbacks,
31 _RedisCallbacksRESP2,
32 _RedisCallbacksRESP3,
33)
34from redis.asyncio.client import ResponseCallbackT
35from redis.asyncio.connection import Connection, SSLConnection, parse_url
36from redis.asyncio.lock import Lock
37from redis.asyncio.retry import Retry
38from redis.auth.token import TokenInterface
39from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
40from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
41from redis.cluster import (
42 PIPELINE_BLOCKED_COMMANDS,
43 PRIMARY,
44 REPLICA,
45 SLOT_ID,
46 AbstractRedisCluster,
47 LoadBalancer,
48 LoadBalancingStrategy,
49 block_pipeline_command,
50 get_node_name,
51 parse_cluster_slots,
52)
53from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
54from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
55from redis.credentials import CredentialProvider
56from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
57from redis.exceptions import (
58 AskError,
59 BusyLoadingError,
60 ClusterDownError,
61 ClusterError,
62 ConnectionError,
63 CrossSlotTransactionError,
64 DataError,
65 ExecAbortError,
66 InvalidPipelineStack,
67 MaxConnectionsError,
68 MovedError,
69 RedisClusterException,
70 RedisError,
71 ResponseError,
72 SlotNotCoveredError,
73 TimeoutError,
74 TryAgainError,
75 WatchError,
76)
77from redis.typing import AnyKeyT, EncodableT, KeyT
78from redis.utils import (
79 SSL_AVAILABLE,
80 deprecated_args,
81 deprecated_function,
82 get_lib_version,
83 safe_str,
84 str_if_bytes,
85 truncate_text,
86)
87
88if SSL_AVAILABLE:
89 from ssl import TLSVersion, VerifyFlags, VerifyMode
90else:
91 TLSVersion = None
92 VerifyMode = None
93 VerifyFlags = None
94
95TargetNodesT = TypeVar(
96 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
97)
98
99
100class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
101 """
102 Create a new RedisCluster client.
103
104 Pass one of parameters:
105
106 - `host` & `port`
107 - `startup_nodes`
108
109 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
110 | Use ``await`` :meth:`close` to disconnect connections & close client.
111
112 Many commands support the target_nodes kwarg. It can be one of the
113 :attr:`NODE_FLAGS`:
114
115 - :attr:`PRIMARIES`
116 - :attr:`REPLICAS`
117 - :attr:`ALL_NODES`
118 - :attr:`RANDOM`
119 - :attr:`DEFAULT_NODE`
120
121 Note: This client is not thread/process/fork safe.
122
123 :param host:
124 | Can be used to point to a startup node
125 :param port:
126 | Port used if **host** is provided
127 :param startup_nodes:
128 | :class:`~.ClusterNode` to used as a startup node
129 :param require_full_coverage:
130 | When set to ``False``: the client will not require a full coverage of
131 the slots. However, if not all slots are covered, and at least one node
132 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
133 a :class:`~.ClusterDownError` for some key-based commands.
134 | When set to ``True``: all slots must be covered to construct the cluster
135 client. If not all slots are covered, :class:`~.RedisClusterException` will be
136 thrown.
137 | See:
138 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
139 :param read_from_replicas:
140 | @deprecated - please use load_balancing_strategy instead
141 | Enable read from replicas in READONLY mode.
142 When set to true, read commands will be assigned between the primary and
143 its replications in a Round-Robin manner.
144 The data read from replicas is eventually consistent with the data in primary nodes.
145 :param load_balancing_strategy:
146 | Enable read from replicas in READONLY mode and defines the load balancing
147 strategy that will be used for cluster node selection.
148 The data read from replicas is eventually consistent with the data in primary nodes.
149 :param dynamic_startup_nodes:
150 | Set the RedisCluster's startup nodes to all the discovered nodes.
151 If true (default value), the cluster's discovered nodes will be used to
152 determine the cluster nodes-slots mapping in the next topology refresh.
153 It will remove the initial passed startup nodes if their endpoints aren't
154 listed in the CLUSTER SLOTS output.
155 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
156 specific IP addresses, it is best to set it to false.
157 :param reinitialize_steps:
158 | Specifies the number of MOVED errors that need to occur before reinitializing
159 the whole cluster topology. If a MOVED error occurs and the cluster does not
160 need to be reinitialized on this current error handling, only the MOVED slot
161 will be patched with the redirected node.
162 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
163 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
164 0.
165 :param cluster_error_retry_attempts:
166 | @deprecated - Please configure the 'retry' object instead
167 In case 'retry' object is set - this argument is ignored!
168
169 Number of times to retry before raising an error when :class:`~.TimeoutError`,
170 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
171 or :class:`~.ClusterDownError` are encountered
172 :param retry:
173 | A retry object that defines the retry strategy and the number of
174 retries for the cluster client.
175 In current implementation for the cluster client (starting form redis-py version 6.0.0)
176 the retry object is not yet fully utilized, instead it is used just to determine
177 the number of retries for the cluster client.
178 In the future releases the retry object will be used to handle the cluster client retries!
179 :param max_connections:
180 | Maximum number of connections per node. If there are no free connections & the
181 maximum number of connections are already created, a
182 :class:`~.MaxConnectionsError` is raised.
183 :param address_remap:
184 | An optional callable which, when provided with an internal network
185 address of a node, e.g. a `(host, port)` tuple, will return the address
186 where the node is reachable. This can be used to map the addresses at
187 which the nodes _think_ they are, to addresses at which a client may
188 reach them, such as when they sit behind a proxy.
189
190 | Rest of the arguments will be passed to the
191 :class:`~redis.asyncio.connection.Connection` instances when created
192
193 :raises RedisClusterException:
194 if any arguments are invalid or unknown. Eg:
195
196 - `db` != 0 or None
197 - `path` argument for unix socket connection
198 - none of the `host`/`port` & `startup_nodes` were provided
199
200 """
201
202 @classmethod
203 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
204 """
205 Return a Redis client object configured from the given URL.
206
207 For example::
208
209 redis://[[username]:[password]]@localhost:6379/0
210 rediss://[[username]:[password]]@localhost:6379/0
211
212 Three URL schemes are supported:
213
214 - `redis://` creates a TCP socket connection. See more at:
215 <https://www.iana.org/assignments/uri-schemes/prov/redis>
216 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
217 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
218
219 The username, password, hostname, path and all querystring values are passed
220 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
221 with their corresponding characters.
222
223 All querystring options are cast to their appropriate Python types. Boolean
224 arguments can be specified with string values "True"/"False" or "Yes"/"No".
225 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
226 parsed, the querystring arguments and keyword arguments are passed to
227 :class:`~redis.asyncio.connection.Connection` when created.
228 In the case of conflicting arguments, querystring arguments are used.
229 """
230 kwargs.update(parse_url(url))
231 if kwargs.pop("connection_class", None) is SSLConnection:
232 kwargs["ssl"] = True
233 return cls(**kwargs)
234
235 __slots__ = (
236 "_initialize",
237 "_lock",
238 "retry",
239 "command_flags",
240 "commands_parser",
241 "connection_kwargs",
242 "encoder",
243 "node_flags",
244 "nodes_manager",
245 "read_from_replicas",
246 "reinitialize_counter",
247 "reinitialize_steps",
248 "response_callbacks",
249 "result_callbacks",
250 )
251
252 @deprecated_args(
253 args_to_warn=["read_from_replicas"],
254 reason="Please configure the 'load_balancing_strategy' instead",
255 version="5.3.0",
256 )
257 @deprecated_args(
258 args_to_warn=[
259 "cluster_error_retry_attempts",
260 ],
261 reason="Please configure the 'retry' object instead",
262 version="6.0.0",
263 )
264 def __init__(
265 self,
266 host: Optional[str] = None,
267 port: Union[str, int] = 6379,
268 # Cluster related kwargs
269 startup_nodes: Optional[List["ClusterNode"]] = None,
270 require_full_coverage: bool = True,
271 read_from_replicas: bool = False,
272 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
273 dynamic_startup_nodes: bool = True,
274 reinitialize_steps: int = 5,
275 cluster_error_retry_attempts: int = 3,
276 max_connections: int = 2**31,
277 retry: Optional["Retry"] = None,
278 retry_on_error: Optional[List[Type[Exception]]] = None,
279 # Client related kwargs
280 db: Union[str, int] = 0,
281 path: Optional[str] = None,
282 credential_provider: Optional[CredentialProvider] = None,
283 username: Optional[str] = None,
284 password: Optional[str] = None,
285 client_name: Optional[str] = None,
286 lib_name: Optional[str] = "redis-py",
287 lib_version: Optional[str] = get_lib_version(),
288 # Encoding related kwargs
289 encoding: str = "utf-8",
290 encoding_errors: str = "strict",
291 decode_responses: bool = False,
292 # Connection related kwargs
293 health_check_interval: float = 0,
294 socket_connect_timeout: Optional[float] = None,
295 socket_keepalive: bool = False,
296 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
297 socket_timeout: Optional[float] = None,
298 # SSL related kwargs
299 ssl: bool = False,
300 ssl_ca_certs: Optional[str] = None,
301 ssl_ca_data: Optional[str] = None,
302 ssl_cert_reqs: Union[str, VerifyMode] = "required",
303 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
304 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
305 ssl_certfile: Optional[str] = None,
306 ssl_check_hostname: bool = True,
307 ssl_keyfile: Optional[str] = None,
308 ssl_min_version: Optional[TLSVersion] = None,
309 ssl_ciphers: Optional[str] = None,
310 protocol: Optional[int] = 2,
311 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
312 event_dispatcher: Optional[EventDispatcher] = None,
313 ) -> None:
314 if db:
315 raise RedisClusterException(
316 "Argument 'db' must be 0 or None in cluster mode"
317 )
318
319 if path:
320 raise RedisClusterException(
321 "Unix domain socket is not supported in cluster mode"
322 )
323
324 if (not host or not port) and not startup_nodes:
325 raise RedisClusterException(
326 "RedisCluster requires at least one node to discover the cluster.\n"
327 "Please provide one of the following or use RedisCluster.from_url:\n"
328 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
329 " - startup_nodes: RedisCluster(startup_nodes=["
330 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
331 )
332
333 kwargs: Dict[str, Any] = {
334 "max_connections": max_connections,
335 "connection_class": Connection,
336 # Client related kwargs
337 "credential_provider": credential_provider,
338 "username": username,
339 "password": password,
340 "client_name": client_name,
341 "lib_name": lib_name,
342 "lib_version": lib_version,
343 # Encoding related kwargs
344 "encoding": encoding,
345 "encoding_errors": encoding_errors,
346 "decode_responses": decode_responses,
347 # Connection related kwargs
348 "health_check_interval": health_check_interval,
349 "socket_connect_timeout": socket_connect_timeout,
350 "socket_keepalive": socket_keepalive,
351 "socket_keepalive_options": socket_keepalive_options,
352 "socket_timeout": socket_timeout,
353 "protocol": protocol,
354 }
355
356 if ssl:
357 # SSL related kwargs
358 kwargs.update(
359 {
360 "connection_class": SSLConnection,
361 "ssl_ca_certs": ssl_ca_certs,
362 "ssl_ca_data": ssl_ca_data,
363 "ssl_cert_reqs": ssl_cert_reqs,
364 "ssl_include_verify_flags": ssl_include_verify_flags,
365 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
366 "ssl_certfile": ssl_certfile,
367 "ssl_check_hostname": ssl_check_hostname,
368 "ssl_keyfile": ssl_keyfile,
369 "ssl_min_version": ssl_min_version,
370 "ssl_ciphers": ssl_ciphers,
371 }
372 )
373
374 if read_from_replicas or load_balancing_strategy:
375 # Call our on_connect function to configure READONLY mode
376 kwargs["redis_connect_func"] = self.on_connect
377
378 if retry:
379 self.retry = retry
380 else:
381 self.retry = Retry(
382 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
383 retries=cluster_error_retry_attempts,
384 )
385 if retry_on_error:
386 self.retry.update_supported_errors(retry_on_error)
387
388 kwargs["response_callbacks"] = _RedisCallbacks.copy()
389 if kwargs.get("protocol") in ["3", 3]:
390 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
391 else:
392 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
393 self.connection_kwargs = kwargs
394
395 if startup_nodes:
396 passed_nodes = []
397 for node in startup_nodes:
398 passed_nodes.append(
399 ClusterNode(node.host, node.port, **self.connection_kwargs)
400 )
401 startup_nodes = passed_nodes
402 else:
403 startup_nodes = []
404 if host and port:
405 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
406
407 if event_dispatcher is None:
408 self._event_dispatcher = EventDispatcher()
409 else:
410 self._event_dispatcher = event_dispatcher
411
412 self.startup_nodes = startup_nodes
413 self.nodes_manager = NodesManager(
414 startup_nodes,
415 require_full_coverage,
416 kwargs,
417 dynamic_startup_nodes=dynamic_startup_nodes,
418 address_remap=address_remap,
419 event_dispatcher=self._event_dispatcher,
420 )
421 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
422 self.read_from_replicas = read_from_replicas
423 self.load_balancing_strategy = load_balancing_strategy
424 self.reinitialize_steps = reinitialize_steps
425 self.reinitialize_counter = 0
426 self.commands_parser = AsyncCommandsParser()
427 self.node_flags = self.__class__.NODE_FLAGS.copy()
428 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
429 self.response_callbacks = kwargs["response_callbacks"]
430 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
431 self.result_callbacks["CLUSTER SLOTS"] = (
432 lambda cmd, res, **kwargs: parse_cluster_slots(
433 list(res.values())[0], **kwargs
434 )
435 )
436
437 self._initialize = True
438 self._lock: Optional[asyncio.Lock] = None
439
440 # When used as an async context manager, we need to increment and decrement
441 # a usage counter so that we can close the connection pool when no one is
442 # using the client.
443 self._usage_counter = 0
444 self._usage_lock = asyncio.Lock()
445
446 async def initialize(self) -> "RedisCluster":
447 """Get all nodes from startup nodes & creates connections if not initialized."""
448 if self._initialize:
449 if not self._lock:
450 self._lock = asyncio.Lock()
451 async with self._lock:
452 if self._initialize:
453 try:
454 await self.nodes_manager.initialize()
455 await self.commands_parser.initialize(
456 self.nodes_manager.default_node
457 )
458 self._initialize = False
459 except BaseException:
460 await self.nodes_manager.aclose()
461 await self.nodes_manager.aclose("startup_nodes")
462 raise
463 return self
464
465 async def aclose(self) -> None:
466 """Close all connections & client if initialized."""
467 if not self._initialize:
468 if not self._lock:
469 self._lock = asyncio.Lock()
470 async with self._lock:
471 if not self._initialize:
472 self._initialize = True
473 await self.nodes_manager.aclose()
474 await self.nodes_manager.aclose("startup_nodes")
475
476 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
477 async def close(self) -> None:
478 """alias for aclose() for backwards compatibility"""
479 await self.aclose()
480
481 async def __aenter__(self) -> "RedisCluster":
482 """
483 Async context manager entry. Increments a usage counter so that the
484 connection pool is only closed (via aclose()) when no context is using
485 the client.
486 """
487 await self._increment_usage()
488 try:
489 # Initialize the client (i.e. establish connection, etc.)
490 return await self.initialize()
491 except Exception:
492 # If initialization fails, decrement the counter to keep it in sync
493 await self._decrement_usage()
494 raise
495
496 async def _increment_usage(self) -> int:
497 """
498 Helper coroutine to increment the usage counter while holding the lock.
499 Returns the new value of the usage counter.
500 """
501 async with self._usage_lock:
502 self._usage_counter += 1
503 return self._usage_counter
504
505 async def _decrement_usage(self) -> int:
506 """
507 Helper coroutine to decrement the usage counter while holding the lock.
508 Returns the new value of the usage counter.
509 """
510 async with self._usage_lock:
511 self._usage_counter -= 1
512 return self._usage_counter
513
514 async def __aexit__(self, exc_type, exc_value, traceback):
515 """
516 Async context manager exit. Decrements a usage counter. If this is the
517 last exit (counter becomes zero), the client closes its connection pool.
518 """
519 current_usage = await asyncio.shield(self._decrement_usage())
520 if current_usage == 0:
521 # This was the last active context, so disconnect the pool.
522 await asyncio.shield(self.aclose())
523
524 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
525 return self.initialize().__await__()
526
527 _DEL_MESSAGE = "Unclosed RedisCluster client"
528
529 def __del__(
530 self,
531 _warn: Any = warnings.warn,
532 _grl: Any = asyncio.get_running_loop,
533 ) -> None:
534 if hasattr(self, "_initialize") and not self._initialize:
535 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
536 try:
537 context = {"client": self, "message": self._DEL_MESSAGE}
538 _grl().call_exception_handler(context)
539 except RuntimeError:
540 pass
541
542 async def on_connect(self, connection: Connection) -> None:
543 await connection.on_connect()
544
545 # Sending READONLY command to server to configure connection as
546 # readonly. Since each cluster node may change its server type due
547 # to a failover, we should establish a READONLY connection
548 # regardless of the server type. If this is a primary connection,
549 # READONLY would not affect executing write commands.
550 await connection.send_command("READONLY")
551 if str_if_bytes(await connection.read_response()) != "OK":
552 raise ConnectionError("READONLY command failed")
553
554 def get_nodes(self) -> List["ClusterNode"]:
555 """Get all nodes of the cluster."""
556 return list(self.nodes_manager.nodes_cache.values())
557
558 def get_primaries(self) -> List["ClusterNode"]:
559 """Get the primary nodes of the cluster."""
560 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
561
562 def get_replicas(self) -> List["ClusterNode"]:
563 """Get the replica nodes of the cluster."""
564 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
565
566 def get_random_node(self) -> "ClusterNode":
567 """Get a random node of the cluster."""
568 return random.choice(list(self.nodes_manager.nodes_cache.values()))
569
570 def get_default_node(self) -> "ClusterNode":
571 """Get the default node of the client."""
572 return self.nodes_manager.default_node
573
574 def set_default_node(self, node: "ClusterNode") -> None:
575 """
576 Set the default node of the client.
577
578 :raises DataError: if None is passed or node does not exist in cluster.
579 """
580 if not node or not self.get_node(node_name=node.name):
581 raise DataError("The requested node does not exist in the cluster.")
582
583 self.nodes_manager.default_node = node
584
585 def get_node(
586 self,
587 host: Optional[str] = None,
588 port: Optional[int] = None,
589 node_name: Optional[str] = None,
590 ) -> Optional["ClusterNode"]:
591 """Get node by (host, port) or node_name."""
592 return self.nodes_manager.get_node(host, port, node_name)
593
594 def get_node_from_key(
595 self, key: str, replica: bool = False
596 ) -> Optional["ClusterNode"]:
597 """
598 Get the cluster node corresponding to the provided key.
599
600 :param key:
601 :param replica:
602 | Indicates if a replica should be returned
603 |
604 None will returned if no replica holds this key
605
606 :raises SlotNotCoveredError: if the key is not covered by any slot.
607 """
608 slot = self.keyslot(key)
609 slot_cache = self.nodes_manager.slots_cache.get(slot)
610 if not slot_cache:
611 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
612
613 if replica:
614 if len(self.nodes_manager.slots_cache[slot]) < 2:
615 return None
616 node_idx = 1
617 else:
618 node_idx = 0
619
620 return slot_cache[node_idx]
621
622 def keyslot(self, key: EncodableT) -> int:
623 """
624 Find the keyslot for a given key.
625
626 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
627 """
628 return key_slot(self.encoder.encode(key))
629
630 def get_encoder(self) -> Encoder:
631 """Get the encoder object of the client."""
632 return self.encoder
633
634 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
635 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
636 return self.connection_kwargs
637
638 def set_retry(self, retry: Retry) -> None:
639 self.retry = retry
640
641 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
642 """Set a custom response callback."""
643 self.response_callbacks[command] = callback
644
645 async def _determine_nodes(
646 self, command: str, *args: Any, node_flag: Optional[str] = None
647 ) -> List["ClusterNode"]:
648 # Determine which nodes should be executed the command on.
649 # Returns a list of target nodes.
650 if not node_flag:
651 # get the nodes group for this command if it was predefined
652 node_flag = self.command_flags.get(command)
653
654 if node_flag in self.node_flags:
655 if node_flag == self.__class__.DEFAULT_NODE:
656 # return the cluster's default node
657 return [self.nodes_manager.default_node]
658 if node_flag == self.__class__.PRIMARIES:
659 # return all primaries
660 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
661 if node_flag == self.__class__.REPLICAS:
662 # return all replicas
663 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
664 if node_flag == self.__class__.ALL_NODES:
665 # return all nodes
666 return list(self.nodes_manager.nodes_cache.values())
667 if node_flag == self.__class__.RANDOM:
668 # return a random node
669 return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
670
671 # get the node that holds the key's slot
672 return [
673 self.nodes_manager.get_node_from_slot(
674 await self._determine_slot(command, *args),
675 self.read_from_replicas and command in READ_COMMANDS,
676 self.load_balancing_strategy if command in READ_COMMANDS else None,
677 )
678 ]
679
680 async def _determine_slot(self, command: str, *args: Any) -> int:
681 if self.command_flags.get(command) == SLOT_ID:
682 # The command contains the slot ID
683 return int(args[0])
684
685 # Get the keys in the command
686
687 # EVAL and EVALSHA are common enough that it's wasteful to go to the
688 # redis server to parse the keys. Besides, there is a bug in redis<7.0
689 # where `self._get_command_keys()` fails anyway. So, we special case
690 # EVAL/EVALSHA.
691 # - issue: https://github.com/redis/redis/issues/9493
692 # - fix: https://github.com/redis/redis/pull/9733
693 if command.upper() in ("EVAL", "EVALSHA"):
694 # command syntax: EVAL "script body" num_keys ...
695 if len(args) < 2:
696 raise RedisClusterException(
697 f"Invalid args in command: {command, *args}"
698 )
699 keys = args[2 : 2 + int(args[1])]
700 # if there are 0 keys, that means the script can be run on any node
701 # so we can just return a random slot
702 if not keys:
703 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
704 else:
705 keys = await self.commands_parser.get_keys(command, *args)
706 if not keys:
707 # FCALL can call a function with 0 keys, that means the function
708 # can be run on any node so we can just return a random slot
709 if command.upper() in ("FCALL", "FCALL_RO"):
710 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
711 raise RedisClusterException(
712 "No way to dispatch this command to Redis Cluster. "
713 "Missing key.\nYou can execute the command by specifying "
714 f"target nodes.\nCommand: {args}"
715 )
716
717 # single key command
718 if len(keys) == 1:
719 return self.keyslot(keys[0])
720
721 # multi-key command; we need to make sure all keys are mapped to
722 # the same slot
723 slots = {self.keyslot(key) for key in keys}
724 if len(slots) != 1:
725 raise RedisClusterException(
726 f"{command} - all keys must map to the same key slot"
727 )
728
729 return slots.pop()
730
731 def _is_node_flag(self, target_nodes: Any) -> bool:
732 return isinstance(target_nodes, str) and target_nodes in self.node_flags
733
734 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
735 if isinstance(target_nodes, list):
736 nodes = target_nodes
737 elif isinstance(target_nodes, ClusterNode):
738 # Supports passing a single ClusterNode as a variable
739 nodes = [target_nodes]
740 elif isinstance(target_nodes, dict):
741 # Supports dictionaries of the format {node_name: node}.
742 # It enables to execute commands with multi nodes as follows:
743 # rc.cluster_save_config(rc.get_primaries())
744 nodes = list(target_nodes.values())
745 else:
746 raise TypeError(
747 "target_nodes type can be one of the following: "
748 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
749 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
750 f"The passed type is {type(target_nodes)}"
751 )
752 return nodes
753
754 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
755 """
756 Execute a raw command on the appropriate cluster node or target_nodes.
757
758 It will retry the command as specified by the retries property of
759 the :attr:`retry` & then raise an exception.
760
761 :param args:
762 | Raw command args
763 :param kwargs:
764
765 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
766 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
767 - Rest of the kwargs are passed to the Redis connection
768
769 :raises RedisClusterException: if target_nodes is not provided & the command
770 can't be mapped to a slot
771 """
772 command = args[0]
773 target_nodes = []
774 target_nodes_specified = False
775 retry_attempts = self.retry.get_retries()
776
777 passed_targets = kwargs.pop("target_nodes", None)
778 if passed_targets and not self._is_node_flag(passed_targets):
779 target_nodes = self._parse_target_nodes(passed_targets)
780 target_nodes_specified = True
781 retry_attempts = 0
782
783 # Add one for the first execution
784 execute_attempts = 1 + retry_attempts
785 for _ in range(execute_attempts):
786 if self._initialize:
787 await self.initialize()
788 if (
789 len(target_nodes) == 1
790 and target_nodes[0] == self.get_default_node()
791 ):
792 # Replace the default cluster node
793 self.replace_default_node()
794 try:
795 if not target_nodes_specified:
796 # Determine the nodes to execute the command on
797 target_nodes = await self._determine_nodes(
798 *args, node_flag=passed_targets
799 )
800 if not target_nodes:
801 raise RedisClusterException(
802 f"No targets were found to execute {args} command on"
803 )
804
805 if len(target_nodes) == 1:
806 # Return the processed result
807 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
808 if command in self.result_callbacks:
809 return self.result_callbacks[command](
810 command, {target_nodes[0].name: ret}, **kwargs
811 )
812 return ret
813 else:
814 keys = [node.name for node in target_nodes]
815 values = await asyncio.gather(
816 *(
817 asyncio.create_task(
818 self._execute_command(node, *args, **kwargs)
819 )
820 for node in target_nodes
821 )
822 )
823 if command in self.result_callbacks:
824 return self.result_callbacks[command](
825 command, dict(zip(keys, values)), **kwargs
826 )
827 return dict(zip(keys, values))
828 except Exception as e:
829 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
830 # The nodes and slots cache were should be reinitialized.
831 # Try again with the new cluster setup.
832 retry_attempts -= 1
833 continue
834 else:
835 # raise the exception
836 raise e
837
838 async def _execute_command(
839 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
840 ) -> Any:
841 asking = moved = False
842 redirect_addr = None
843 ttl = self.RedisClusterRequestTTL
844
845 while ttl > 0:
846 ttl -= 1
847 try:
848 if asking:
849 target_node = self.get_node(node_name=redirect_addr)
850 await target_node.execute_command("ASKING")
851 asking = False
852 elif moved:
853 # MOVED occurred and the slots cache was updated,
854 # refresh the target node
855 slot = await self._determine_slot(*args)
856 target_node = self.nodes_manager.get_node_from_slot(
857 slot,
858 self.read_from_replicas and args[0] in READ_COMMANDS,
859 self.load_balancing_strategy
860 if args[0] in READ_COMMANDS
861 else None,
862 )
863 moved = False
864
865 return await target_node.execute_command(*args, **kwargs)
866 except BusyLoadingError:
867 raise
868 except MaxConnectionsError:
869 # MaxConnectionsError indicates client-side resource exhaustion
870 # (too many connections in the pool), not a node failure.
871 # Don't treat this as a node failure - just re-raise the error
872 # without reinitializing the cluster.
873 raise
874 except (ConnectionError, TimeoutError):
875 # Connection retries are being handled in the node's
876 # Retry object.
877 # Remove the failed node from the startup nodes before we try
878 # to reinitialize the cluster
879 self.nodes_manager.startup_nodes.pop(target_node.name, None)
880 # Hard force of reinitialize of the node/slots setup
881 # and try again with the new setup
882 await self.aclose()
883 raise
884 except (ClusterDownError, SlotNotCoveredError):
885 # ClusterDownError can occur during a failover and to get
886 # self-healed, we will try to reinitialize the cluster layout
887 # and retry executing the command
888
889 # SlotNotCoveredError can occur when the cluster is not fully
890 # initialized or can be temporary issue.
891 # We will try to reinitialize the cluster topology
892 # and retry executing the command
893
894 await self.aclose()
895 await asyncio.sleep(0.25)
896 raise
897 except MovedError as e:
898 # First, we will try to patch the slots/nodes cache with the
899 # redirected node output and try again. If MovedError exceeds
900 # 'reinitialize_steps' number of times, we will force
901 # reinitializing the tables, and then try again.
902 # 'reinitialize_steps' counter will increase faster when
903 # the same client object is shared between multiple threads. To
904 # reduce the frequency you can set this variable in the
905 # RedisCluster constructor.
906 self.reinitialize_counter += 1
907 if (
908 self.reinitialize_steps
909 and self.reinitialize_counter % self.reinitialize_steps == 0
910 ):
911 await self.aclose()
912 # Reset the counter
913 self.reinitialize_counter = 0
914 else:
915 self.nodes_manager._moved_exception = e
916 moved = True
917 except AskError as e:
918 redirect_addr = get_node_name(host=e.host, port=e.port)
919 asking = True
920 except TryAgainError:
921 if ttl < self.RedisClusterRequestTTL / 2:
922 await asyncio.sleep(0.05)
923
924 raise ClusterError("TTL exhausted.")
925
926 def pipeline(
927 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
928 ) -> "ClusterPipeline":
929 """
930 Create & return a new :class:`~.ClusterPipeline` object.
931
932 Cluster implementation of pipeline does not support transaction or shard_hint.
933
934 :raises RedisClusterException: if transaction or shard_hint are truthy values
935 """
936 if shard_hint:
937 raise RedisClusterException("shard_hint is deprecated in cluster mode")
938
939 return ClusterPipeline(self, transaction)
940
941 def lock(
942 self,
943 name: KeyT,
944 timeout: Optional[float] = None,
945 sleep: float = 0.1,
946 blocking: bool = True,
947 blocking_timeout: Optional[float] = None,
948 lock_class: Optional[Type[Lock]] = None,
949 thread_local: bool = True,
950 raise_on_release_error: bool = True,
951 ) -> Lock:
952 """
953 Return a new Lock object using key ``name`` that mimics
954 the behavior of threading.Lock.
955
956 If specified, ``timeout`` indicates a maximum life for the lock.
957 By default, it will remain locked until release() is called.
958
959 ``sleep`` indicates the amount of time to sleep per loop iteration
960 when the lock is in blocking mode and another client is currently
961 holding the lock.
962
963 ``blocking`` indicates whether calling ``acquire`` should block until
964 the lock has been acquired or to fail immediately, causing ``acquire``
965 to return False and the lock not being acquired. Defaults to True.
966 Note this value can be overridden by passing a ``blocking``
967 argument to ``acquire``.
968
969 ``blocking_timeout`` indicates the maximum amount of time in seconds to
970 spend trying to acquire the lock. A value of ``None`` indicates
971 continue trying forever. ``blocking_timeout`` can be specified as a
972 float or integer, both representing the number of seconds to wait.
973
974 ``lock_class`` forces the specified lock implementation. Note that as
975 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
976 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
977 you have created your own custom lock class.
978
979 ``thread_local`` indicates whether the lock token is placed in
980 thread-local storage. By default, the token is placed in thread local
981 storage so that a thread only sees its token, not a token set by
982 another thread. Consider the following timeline:
983
984 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
985 thread-1 sets the token to "abc"
986 time: 1, thread-2 blocks trying to acquire `my-lock` using the
987 Lock instance.
988 time: 5, thread-1 has not yet completed. redis expires the lock
989 key.
990 time: 5, thread-2 acquired `my-lock` now that it's available.
991 thread-2 sets the token to "xyz"
992 time: 6, thread-1 finishes its work and calls release(). if the
993 token is *not* stored in thread local storage, then
994 thread-1 would see the token value as "xyz" and would be
995 able to successfully release the thread-2's lock.
996
997 ``raise_on_release_error`` indicates whether to raise an exception when
998 the lock is no longer owned when exiting the context manager. By default,
999 this is True, meaning an exception will be raised. If False, the warning
1000 will be logged and the exception will be suppressed.
1001
1002 In some use cases it's necessary to disable thread local storage. For
1003 example, if you have code where one thread acquires a lock and passes
1004 that lock instance to a worker thread to release later. If thread
1005 local storage isn't disabled in this case, the worker thread won't see
1006 the token set by the thread that acquired the lock. Our assumption
1007 is that these cases aren't common and as such default to using
1008 thread local storage."""
1009 if lock_class is None:
1010 lock_class = Lock
1011 return lock_class(
1012 self,
1013 name,
1014 timeout=timeout,
1015 sleep=sleep,
1016 blocking=blocking,
1017 blocking_timeout=blocking_timeout,
1018 thread_local=thread_local,
1019 raise_on_release_error=raise_on_release_error,
1020 )
1021
1022 async def transaction(
1023 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1024 ):
1025 """
1026 Convenience method for executing the callable `func` as a transaction
1027 while watching all keys specified in `watches`. The 'func' callable
1028 should expect a single argument which is a Pipeline object.
1029 """
1030 shard_hint = kwargs.pop("shard_hint", None)
1031 value_from_callable = kwargs.pop("value_from_callable", False)
1032 watch_delay = kwargs.pop("watch_delay", None)
1033 async with self.pipeline(True, shard_hint) as pipe:
1034 while True:
1035 try:
1036 if watches:
1037 await pipe.watch(*watches)
1038 func_value = await func(pipe)
1039 exec_value = await pipe.execute()
1040 return func_value if value_from_callable else exec_value
1041 except WatchError:
1042 if watch_delay is not None and watch_delay > 0:
1043 time.sleep(watch_delay)
1044 continue
1045
1046
1047class ClusterNode:
1048 """
1049 Create a new ClusterNode.
1050
1051 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1052 objects for the (host, port).
1053 """
1054
1055 __slots__ = (
1056 "_connections",
1057 "_free",
1058 "_lock",
1059 "_event_dispatcher",
1060 "connection_class",
1061 "connection_kwargs",
1062 "host",
1063 "max_connections",
1064 "name",
1065 "port",
1066 "response_callbacks",
1067 "server_type",
1068 )
1069
1070 def __init__(
1071 self,
1072 host: str,
1073 port: Union[str, int],
1074 server_type: Optional[str] = None,
1075 *,
1076 max_connections: int = 2**31,
1077 connection_class: Type[Connection] = Connection,
1078 **connection_kwargs: Any,
1079 ) -> None:
1080 if host == "localhost":
1081 host = socket.gethostbyname(host)
1082
1083 connection_kwargs["host"] = host
1084 connection_kwargs["port"] = port
1085 self.host = host
1086 self.port = port
1087 self.name = get_node_name(host, port)
1088 self.server_type = server_type
1089
1090 self.max_connections = max_connections
1091 self.connection_class = connection_class
1092 self.connection_kwargs = connection_kwargs
1093 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1094
1095 self._connections: List[Connection] = []
1096 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1097 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1098 if self._event_dispatcher is None:
1099 self._event_dispatcher = EventDispatcher()
1100
1101 def __repr__(self) -> str:
1102 return (
1103 f"[host={self.host}, port={self.port}, "
1104 f"name={self.name}, server_type={self.server_type}]"
1105 )
1106
1107 def __eq__(self, obj: Any) -> bool:
1108 return isinstance(obj, ClusterNode) and obj.name == self.name
1109
1110 _DEL_MESSAGE = "Unclosed ClusterNode object"
1111
1112 def __del__(
1113 self,
1114 _warn: Any = warnings.warn,
1115 _grl: Any = asyncio.get_running_loop,
1116 ) -> None:
1117 for connection in self._connections:
1118 if connection.is_connected:
1119 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1120
1121 try:
1122 context = {"client": self, "message": self._DEL_MESSAGE}
1123 _grl().call_exception_handler(context)
1124 except RuntimeError:
1125 pass
1126 break
1127
1128 async def disconnect(self) -> None:
1129 ret = await asyncio.gather(
1130 *(
1131 asyncio.create_task(connection.disconnect())
1132 for connection in self._connections
1133 ),
1134 return_exceptions=True,
1135 )
1136 exc = next((res for res in ret if isinstance(res, Exception)), None)
1137 if exc:
1138 raise exc
1139
1140 def acquire_connection(self) -> Connection:
1141 try:
1142 return self._free.popleft()
1143 except IndexError:
1144 if len(self._connections) < self.max_connections:
1145 # We are configuring the connection pool not to retry
1146 # connections on lower level clients to avoid retrying
1147 # connections to nodes that are not reachable
1148 # and to avoid blocking the connection pool.
1149 # The only error that will have some handling in the lower
1150 # level clients is ConnectionError which will trigger disconnection
1151 # of the socket.
1152 # The retries will be handled on cluster client level
1153 # where we will have proper handling of the cluster topology
1154 retry = Retry(
1155 backoff=NoBackoff(),
1156 retries=0,
1157 supported_errors=(ConnectionError,),
1158 )
1159 connection_kwargs = self.connection_kwargs.copy()
1160 connection_kwargs["retry"] = retry
1161 connection = self.connection_class(**connection_kwargs)
1162 self._connections.append(connection)
1163 return connection
1164
1165 raise MaxConnectionsError()
1166
1167 def release(self, connection: Connection) -> None:
1168 """
1169 Release connection back to free queue.
1170 """
1171 self._free.append(connection)
1172
1173 async def parse_response(
1174 self, connection: Connection, command: str, **kwargs: Any
1175 ) -> Any:
1176 try:
1177 if NEVER_DECODE in kwargs:
1178 response = await connection.read_response(disable_decoding=True)
1179 kwargs.pop(NEVER_DECODE)
1180 else:
1181 response = await connection.read_response()
1182 except ResponseError:
1183 if EMPTY_RESPONSE in kwargs:
1184 return kwargs[EMPTY_RESPONSE]
1185 raise
1186
1187 if EMPTY_RESPONSE in kwargs:
1188 kwargs.pop(EMPTY_RESPONSE)
1189
1190 # Remove keys entry, it needs only for cache.
1191 kwargs.pop("keys", None)
1192
1193 # Return response
1194 if command in self.response_callbacks:
1195 return self.response_callbacks[command](response, **kwargs)
1196
1197 return response
1198
1199 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1200 # Acquire connection
1201 connection = self.acquire_connection()
1202
1203 # Execute command
1204 await connection.send_packed_command(connection.pack_command(*args), False)
1205
1206 # Read response
1207 try:
1208 return await self.parse_response(connection, args[0], **kwargs)
1209 finally:
1210 # Release connection
1211 self._free.append(connection)
1212
1213 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1214 # Acquire connection
1215 connection = self.acquire_connection()
1216
1217 # Execute command
1218 await connection.send_packed_command(
1219 connection.pack_commands(cmd.args for cmd in commands), False
1220 )
1221
1222 # Read responses
1223 ret = False
1224 for cmd in commands:
1225 try:
1226 cmd.result = await self.parse_response(
1227 connection, cmd.args[0], **cmd.kwargs
1228 )
1229 except Exception as e:
1230 cmd.result = e
1231 ret = True
1232
1233 # Release connection
1234 self._free.append(connection)
1235
1236 return ret
1237
1238 async def re_auth_callback(self, token: TokenInterface):
1239 tmp_queue = collections.deque()
1240 while self._free:
1241 conn = self._free.popleft()
1242 await conn.retry.call_with_retry(
1243 lambda: conn.send_command(
1244 "AUTH", token.try_get("oid"), token.get_value()
1245 ),
1246 lambda error: self._mock(error),
1247 )
1248 await conn.retry.call_with_retry(
1249 lambda: conn.read_response(), lambda error: self._mock(error)
1250 )
1251 tmp_queue.append(conn)
1252
1253 while tmp_queue:
1254 conn = tmp_queue.popleft()
1255 self._free.append(conn)
1256
1257 async def _mock(self, error: RedisError):
1258 """
1259 Dummy functions, needs to be passed as error callback to retry object.
1260 :param error:
1261 :return:
1262 """
1263 pass
1264
1265
1266class NodesManager:
1267 __slots__ = (
1268 "_dynamic_startup_nodes",
1269 "_moved_exception",
1270 "_event_dispatcher",
1271 "connection_kwargs",
1272 "default_node",
1273 "nodes_cache",
1274 "read_load_balancer",
1275 "require_full_coverage",
1276 "slots_cache",
1277 "startup_nodes",
1278 "address_remap",
1279 )
1280
1281 def __init__(
1282 self,
1283 startup_nodes: List["ClusterNode"],
1284 require_full_coverage: bool,
1285 connection_kwargs: Dict[str, Any],
1286 dynamic_startup_nodes: bool = True,
1287 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1288 event_dispatcher: Optional[EventDispatcher] = None,
1289 ) -> None:
1290 self.startup_nodes = {node.name: node for node in startup_nodes}
1291 self.require_full_coverage = require_full_coverage
1292 self.connection_kwargs = connection_kwargs
1293 self.address_remap = address_remap
1294
1295 self.default_node: "ClusterNode" = None
1296 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1297 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1298 self.read_load_balancer = LoadBalancer()
1299
1300 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1301 self._moved_exception: MovedError = None
1302 if event_dispatcher is None:
1303 self._event_dispatcher = EventDispatcher()
1304 else:
1305 self._event_dispatcher = event_dispatcher
1306
1307 def get_node(
1308 self,
1309 host: Optional[str] = None,
1310 port: Optional[int] = None,
1311 node_name: Optional[str] = None,
1312 ) -> Optional["ClusterNode"]:
1313 if host and port:
1314 # the user passed host and port
1315 if host == "localhost":
1316 host = socket.gethostbyname(host)
1317 return self.nodes_cache.get(get_node_name(host=host, port=port))
1318 elif node_name:
1319 return self.nodes_cache.get(node_name)
1320 else:
1321 raise DataError(
1322 "get_node requires one of the following: 1. node name 2. host and port"
1323 )
1324
1325 def set_nodes(
1326 self,
1327 old: Dict[str, "ClusterNode"],
1328 new: Dict[str, "ClusterNode"],
1329 remove_old: bool = False,
1330 ) -> None:
1331 if remove_old:
1332 for name in list(old.keys()):
1333 if name not in new:
1334 task = asyncio.create_task(old.pop(name).disconnect()) # noqa
1335
1336 for name, node in new.items():
1337 if name in old:
1338 if old[name] is node:
1339 continue
1340 task = asyncio.create_task(old[name].disconnect()) # noqa
1341 old[name] = node
1342
1343 def update_moved_exception(self, exception):
1344 self._moved_exception = exception
1345
1346 def _update_moved_slots(self) -> None:
1347 e = self._moved_exception
1348 redirected_node = self.get_node(host=e.host, port=e.port)
1349 if redirected_node:
1350 # The node already exists
1351 if redirected_node.server_type != PRIMARY:
1352 # Update the node's server type
1353 redirected_node.server_type = PRIMARY
1354 else:
1355 # This is a new node, we will add it to the nodes cache
1356 redirected_node = ClusterNode(
1357 e.host, e.port, PRIMARY, **self.connection_kwargs
1358 )
1359 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1360 if redirected_node in self.slots_cache[e.slot_id]:
1361 # The MOVED error resulted from a failover, and the new slot owner
1362 # had previously been a replica.
1363 old_primary = self.slots_cache[e.slot_id][0]
1364 # Update the old primary to be a replica and add it to the end of
1365 # the slot's node list
1366 old_primary.server_type = REPLICA
1367 self.slots_cache[e.slot_id].append(old_primary)
1368 # Remove the old replica, which is now a primary, from the slot's
1369 # node list
1370 self.slots_cache[e.slot_id].remove(redirected_node)
1371 # Override the old primary with the new one
1372 self.slots_cache[e.slot_id][0] = redirected_node
1373 if self.default_node == old_primary:
1374 # Update the default node with the new primary
1375 self.default_node = redirected_node
1376 else:
1377 # The new slot owner is a new server, or a server from a different
1378 # shard. We need to remove all current nodes from the slot's list
1379 # (including replications) and add just the new node.
1380 self.slots_cache[e.slot_id] = [redirected_node]
1381 # Reset moved_exception
1382 self._moved_exception = None
1383
1384 def get_node_from_slot(
1385 self,
1386 slot: int,
1387 read_from_replicas: bool = False,
1388 load_balancing_strategy=None,
1389 ) -> "ClusterNode":
1390 if self._moved_exception:
1391 self._update_moved_slots()
1392
1393 if read_from_replicas is True and load_balancing_strategy is None:
1394 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1395
1396 try:
1397 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1398 # get the server index using the strategy defined in load_balancing_strategy
1399 primary_name = self.slots_cache[slot][0].name
1400 node_idx = self.read_load_balancer.get_server_index(
1401 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1402 )
1403 return self.slots_cache[slot][node_idx]
1404 return self.slots_cache[slot][0]
1405 except (IndexError, TypeError):
1406 raise SlotNotCoveredError(
1407 f'Slot "{slot}" not covered by the cluster. '
1408 f'"require_full_coverage={self.require_full_coverage}"'
1409 )
1410
1411 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1412 return [
1413 node
1414 for node in self.nodes_cache.values()
1415 if node.server_type == server_type
1416 ]
1417
1418 async def initialize(self) -> None:
1419 self.read_load_balancer.reset()
1420 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1421 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1422 disagreements = []
1423 startup_nodes_reachable = False
1424 fully_covered = False
1425 exception = None
1426 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1427 # is modified during iteration
1428 for startup_node in tuple(self.startup_nodes.values()):
1429 try:
1430 # Make sure cluster mode is enabled on this node
1431 try:
1432 self._event_dispatcher.dispatch(
1433 AfterAsyncClusterInstantiationEvent(
1434 self.nodes_cache,
1435 self.connection_kwargs.get("credential_provider", None),
1436 )
1437 )
1438 cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
1439 except ResponseError:
1440 raise RedisClusterException(
1441 "Cluster mode is not enabled on this node"
1442 )
1443 startup_nodes_reachable = True
1444 except Exception as e:
1445 # Try the next startup node.
1446 # The exception is saved and raised only if we have no more nodes.
1447 exception = e
1448 continue
1449
1450 # CLUSTER SLOTS command results in the following output:
1451 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1452 # where each node contains the following list: [IP, port, node_id]
1453 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1454 # primary node of the first slot section.
1455 # If there's only one server in the cluster, its ``host`` is ''
1456 # Fix it to the host in startup_nodes
1457 if (
1458 len(cluster_slots) == 1
1459 and not cluster_slots[0][2][0]
1460 and len(self.startup_nodes) == 1
1461 ):
1462 cluster_slots[0][2][0] = startup_node.host
1463
1464 for slot in cluster_slots:
1465 for i in range(2, len(slot)):
1466 slot[i] = [str_if_bytes(val) for val in slot[i]]
1467 primary_node = slot[2]
1468 host = primary_node[0]
1469 if host == "":
1470 host = startup_node.host
1471 port = int(primary_node[1])
1472 host, port = self.remap_host_port(host, port)
1473
1474 nodes_for_slot = []
1475
1476 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1477 if not target_node:
1478 target_node = ClusterNode(
1479 host, port, PRIMARY, **self.connection_kwargs
1480 )
1481 # add this node to the nodes cache
1482 tmp_nodes_cache[target_node.name] = target_node
1483 nodes_for_slot.append(target_node)
1484
1485 replica_nodes = slot[3:]
1486 for replica_node in replica_nodes:
1487 host = replica_node[0]
1488 port = replica_node[1]
1489 host, port = self.remap_host_port(host, port)
1490
1491 target_replica_node = tmp_nodes_cache.get(get_node_name(host, port))
1492 if not target_replica_node:
1493 target_replica_node = ClusterNode(
1494 host, port, REPLICA, **self.connection_kwargs
1495 )
1496 # add this node to the nodes cache
1497 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1498 nodes_for_slot.append(target_replica_node)
1499
1500 for i in range(int(slot[0]), int(slot[1]) + 1):
1501 if i not in tmp_slots:
1502 tmp_slots[i] = nodes_for_slot
1503 else:
1504 # Validate that 2 nodes want to use the same slot cache
1505 # setup
1506 tmp_slot = tmp_slots[i][0]
1507 if tmp_slot.name != target_node.name:
1508 disagreements.append(
1509 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1510 )
1511
1512 if len(disagreements) > 5:
1513 raise RedisClusterException(
1514 f"startup_nodes could not agree on a valid "
1515 f"slots cache: {', '.join(disagreements)}"
1516 )
1517
1518 # Validate if all slots are covered or if we should try next startup node
1519 fully_covered = True
1520 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1521 if i not in tmp_slots:
1522 fully_covered = False
1523 break
1524 if fully_covered:
1525 break
1526
1527 if not startup_nodes_reachable:
1528 raise RedisClusterException(
1529 f"Redis Cluster cannot be connected. Please provide at least "
1530 f"one reachable node: {str(exception)}"
1531 ) from exception
1532
1533 # Check if the slots are not fully covered
1534 if not fully_covered and self.require_full_coverage:
1535 # Despite the requirement that the slots be covered, there
1536 # isn't a full coverage
1537 raise RedisClusterException(
1538 f"All slots are not covered after query all startup_nodes. "
1539 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1540 f"covered..."
1541 )
1542
1543 # Set the tmp variables to the real variables
1544 self.slots_cache = tmp_slots
1545 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1546
1547 if self._dynamic_startup_nodes:
1548 # Populate the startup nodes with all discovered nodes
1549 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1550
1551 # Set the default node
1552 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
1553 # If initialize was called after a MovedError, clear it
1554 self._moved_exception = None
1555
1556 async def aclose(self, attr: str = "nodes_cache") -> None:
1557 self.default_node = None
1558 await asyncio.gather(
1559 *(
1560 asyncio.create_task(node.disconnect())
1561 for node in getattr(self, attr).values()
1562 )
1563 )
1564
1565 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1566 """
1567 Remap the host and port returned from the cluster to a different
1568 internal value. Useful if the client is not connecting directly
1569 to the cluster.
1570 """
1571 if self.address_remap:
1572 return self.address_remap((host, port))
1573 return host, port
1574
1575
1576class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1577 """
1578 Create a new ClusterPipeline object.
1579
1580 Usage::
1581
1582 result = await (
1583 rc.pipeline()
1584 .set("A", 1)
1585 .get("A")
1586 .hset("K", "F", "V")
1587 .hgetall("K")
1588 .mset_nonatomic({"A": 2, "B": 3})
1589 .get("A")
1590 .get("B")
1591 .delete("A", "B", "K")
1592 .execute()
1593 )
1594 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
1595
1596 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
1597 are split across multiple nodes, you'll get multiple results for them in the array.
1598
1599 Retryable errors:
1600 - :class:`~.ClusterDownError`
1601 - :class:`~.ConnectionError`
1602 - :class:`~.TimeoutError`
1603
1604 Redirection errors:
1605 - :class:`~.TryAgainError`
1606 - :class:`~.MovedError`
1607 - :class:`~.AskError`
1608
1609 :param client:
1610 | Existing :class:`~.RedisCluster` client
1611 """
1612
1613 __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1614
1615 def __init__(
1616 self, client: RedisCluster, transaction: Optional[bool] = None
1617 ) -> None:
1618 self.cluster_client = client
1619 self._transaction = transaction
1620 self._execution_strategy: ExecutionStrategy = (
1621 PipelineStrategy(self)
1622 if not self._transaction
1623 else TransactionStrategy(self)
1624 )
1625
1626 async def initialize(self) -> "ClusterPipeline":
1627 await self._execution_strategy.initialize()
1628 return self
1629
1630 async def __aenter__(self) -> "ClusterPipeline":
1631 return await self.initialize()
1632
1633 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1634 await self.reset()
1635
1636 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1637 return self.initialize().__await__()
1638
1639 def __bool__(self) -> bool:
1640 "Pipeline instances should always evaluate to True on Python 3+"
1641 return True
1642
1643 def __len__(self) -> int:
1644 return len(self._execution_strategy)
1645
1646 def execute_command(
1647 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1648 ) -> "ClusterPipeline":
1649 """
1650 Append a raw command to the pipeline.
1651
1652 :param args:
1653 | Raw command args
1654 :param kwargs:
1655
1656 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
1657 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1658 - Rest of the kwargs are passed to the Redis connection
1659 """
1660 return self._execution_strategy.execute_command(*args, **kwargs)
1661
1662 async def execute(
1663 self, raise_on_error: bool = True, allow_redirections: bool = True
1664 ) -> List[Any]:
1665 """
1666 Execute the pipeline.
1667
1668 It will retry the commands as specified by retries specified in :attr:`retry`
1669 & then raise an exception.
1670
1671 :param raise_on_error:
1672 | Raise the first error if there are any errors
1673 :param allow_redirections:
1674 | Whether to retry each failed command individually in case of redirection
1675 errors
1676
1677 :raises RedisClusterException: if target_nodes is not provided & the command
1678 can't be mapped to a slot
1679 """
1680 try:
1681 return await self._execution_strategy.execute(
1682 raise_on_error, allow_redirections
1683 )
1684 finally:
1685 await self.reset()
1686
1687 def _split_command_across_slots(
1688 self, command: str, *keys: KeyT
1689 ) -> "ClusterPipeline":
1690 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1691 self.execute_command(command, *slot_keys)
1692
1693 return self
1694
1695 async def reset(self):
1696 """
1697 Reset back to empty pipeline.
1698 """
1699 await self._execution_strategy.reset()
1700
1701 def multi(self):
1702 """
1703 Start a transactional block of the pipeline after WATCH commands
1704 are issued. End the transactional block with `execute`.
1705 """
1706 self._execution_strategy.multi()
1707
1708 async def discard(self):
1709 """ """
1710 await self._execution_strategy.discard()
1711
1712 async def watch(self, *names):
1713 """Watches the values at keys ``names``"""
1714 await self._execution_strategy.watch(*names)
1715
1716 async def unwatch(self):
1717 """Unwatches all previously specified keys"""
1718 await self._execution_strategy.unwatch()
1719
1720 async def unlink(self, *names):
1721 await self._execution_strategy.unlink(*names)
1722
1723 def mset_nonatomic(
1724 self, mapping: Mapping[AnyKeyT, EncodableT]
1725 ) -> "ClusterPipeline":
1726 return self._execution_strategy.mset_nonatomic(mapping)
1727
1728
1729for command in PIPELINE_BLOCKED_COMMANDS:
1730 command = command.replace(" ", "_").lower()
1731 if command == "mset_nonatomic":
1732 continue
1733
1734 setattr(ClusterPipeline, command, block_pipeline_command(command))
1735
1736
1737class PipelineCommand:
1738 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1739 self.args = args
1740 self.kwargs = kwargs
1741 self.position = position
1742 self.result: Union[Any, Exception] = None
1743
1744 def __repr__(self) -> str:
1745 return f"[{self.position}] {self.args} ({self.kwargs})"
1746
1747
1748class ExecutionStrategy(ABC):
1749 @abstractmethod
1750 async def initialize(self) -> "ClusterPipeline":
1751 """
1752 Initialize the execution strategy.
1753
1754 See ClusterPipeline.initialize()
1755 """
1756 pass
1757
1758 @abstractmethod
1759 def execute_command(
1760 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1761 ) -> "ClusterPipeline":
1762 """
1763 Append a raw command to the pipeline.
1764
1765 See ClusterPipeline.execute_command()
1766 """
1767 pass
1768
1769 @abstractmethod
1770 async def execute(
1771 self, raise_on_error: bool = True, allow_redirections: bool = True
1772 ) -> List[Any]:
1773 """
1774 Execute the pipeline.
1775
1776 It will retry the commands as specified by retries specified in :attr:`retry`
1777 & then raise an exception.
1778
1779 See ClusterPipeline.execute()
1780 """
1781 pass
1782
1783 @abstractmethod
1784 def mset_nonatomic(
1785 self, mapping: Mapping[AnyKeyT, EncodableT]
1786 ) -> "ClusterPipeline":
1787 """
1788 Executes multiple MSET commands according to the provided slot/pairs mapping.
1789
1790 See ClusterPipeline.mset_nonatomic()
1791 """
1792 pass
1793
1794 @abstractmethod
1795 async def reset(self):
1796 """
1797 Resets current execution strategy.
1798
1799 See: ClusterPipeline.reset()
1800 """
1801 pass
1802
1803 @abstractmethod
1804 def multi(self):
1805 """
1806 Starts transactional context.
1807
1808 See: ClusterPipeline.multi()
1809 """
1810 pass
1811
1812 @abstractmethod
1813 async def watch(self, *names):
1814 """
1815 Watch given keys.
1816
1817 See: ClusterPipeline.watch()
1818 """
1819 pass
1820
1821 @abstractmethod
1822 async def unwatch(self):
1823 """
1824 Unwatches all previously specified keys
1825
1826 See: ClusterPipeline.unwatch()
1827 """
1828 pass
1829
1830 @abstractmethod
1831 async def discard(self):
1832 pass
1833
1834 @abstractmethod
1835 async def unlink(self, *names):
1836 """
1837 "Unlink a key specified by ``names``"
1838
1839 See: ClusterPipeline.unlink()
1840 """
1841 pass
1842
1843 @abstractmethod
1844 def __len__(self) -> int:
1845 pass
1846
1847
1848class AbstractStrategy(ExecutionStrategy):
1849 def __init__(self, pipe: ClusterPipeline) -> None:
1850 self._pipe: ClusterPipeline = pipe
1851 self._command_queue: List["PipelineCommand"] = []
1852
1853 async def initialize(self) -> "ClusterPipeline":
1854 if self._pipe.cluster_client._initialize:
1855 await self._pipe.cluster_client.initialize()
1856 self._command_queue = []
1857 return self._pipe
1858
1859 def execute_command(
1860 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1861 ) -> "ClusterPipeline":
1862 self._command_queue.append(
1863 PipelineCommand(len(self._command_queue), *args, **kwargs)
1864 )
1865 return self._pipe
1866
1867 def _annotate_exception(self, exception, number, command):
1868 """
1869 Provides extra context to the exception prior to it being handled
1870 """
1871 cmd = " ".join(map(safe_str, command))
1872 msg = (
1873 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
1874 f"caused error: {exception.args[0]}"
1875 )
1876 exception.args = (msg,) + exception.args[1:]
1877
1878 @abstractmethod
1879 def mset_nonatomic(
1880 self, mapping: Mapping[AnyKeyT, EncodableT]
1881 ) -> "ClusterPipeline":
1882 pass
1883
1884 @abstractmethod
1885 async def execute(
1886 self, raise_on_error: bool = True, allow_redirections: bool = True
1887 ) -> List[Any]:
1888 pass
1889
1890 @abstractmethod
1891 async def reset(self):
1892 pass
1893
1894 @abstractmethod
1895 def multi(self):
1896 pass
1897
1898 @abstractmethod
1899 async def watch(self, *names):
1900 pass
1901
1902 @abstractmethod
1903 async def unwatch(self):
1904 pass
1905
1906 @abstractmethod
1907 async def discard(self):
1908 pass
1909
1910 @abstractmethod
1911 async def unlink(self, *names):
1912 pass
1913
1914 def __len__(self) -> int:
1915 return len(self._command_queue)
1916
1917
1918class PipelineStrategy(AbstractStrategy):
1919 def __init__(self, pipe: ClusterPipeline) -> None:
1920 super().__init__(pipe)
1921
1922 def mset_nonatomic(
1923 self, mapping: Mapping[AnyKeyT, EncodableT]
1924 ) -> "ClusterPipeline":
1925 encoder = self._pipe.cluster_client.encoder
1926
1927 slots_pairs = {}
1928 for pair in mapping.items():
1929 slot = key_slot(encoder.encode(pair[0]))
1930 slots_pairs.setdefault(slot, []).extend(pair)
1931
1932 for pairs in slots_pairs.values():
1933 self.execute_command("MSET", *pairs)
1934
1935 return self._pipe
1936
1937 async def execute(
1938 self, raise_on_error: bool = True, allow_redirections: bool = True
1939 ) -> List[Any]:
1940 if not self._command_queue:
1941 return []
1942
1943 try:
1944 retry_attempts = self._pipe.cluster_client.retry.get_retries()
1945 while True:
1946 try:
1947 if self._pipe.cluster_client._initialize:
1948 await self._pipe.cluster_client.initialize()
1949 return await self._execute(
1950 self._pipe.cluster_client,
1951 self._command_queue,
1952 raise_on_error=raise_on_error,
1953 allow_redirections=allow_redirections,
1954 )
1955
1956 except RedisCluster.ERRORS_ALLOW_RETRY as e:
1957 if retry_attempts > 0:
1958 # Try again with the new cluster setup. All other errors
1959 # should be raised.
1960 retry_attempts -= 1
1961 await self._pipe.cluster_client.aclose()
1962 await asyncio.sleep(0.25)
1963 else:
1964 # All other errors should be raised.
1965 raise e
1966 finally:
1967 await self.reset()
1968
1969 async def _execute(
1970 self,
1971 client: "RedisCluster",
1972 stack: List["PipelineCommand"],
1973 raise_on_error: bool = True,
1974 allow_redirections: bool = True,
1975 ) -> List[Any]:
1976 todo = [
1977 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
1978 ]
1979
1980 nodes = {}
1981 for cmd in todo:
1982 passed_targets = cmd.kwargs.pop("target_nodes", None)
1983 if passed_targets and not client._is_node_flag(passed_targets):
1984 target_nodes = client._parse_target_nodes(passed_targets)
1985 else:
1986 target_nodes = await client._determine_nodes(
1987 *cmd.args, node_flag=passed_targets
1988 )
1989 if not target_nodes:
1990 raise RedisClusterException(
1991 f"No targets were found to execute {cmd.args} command on"
1992 )
1993 if len(target_nodes) > 1:
1994 raise RedisClusterException(f"Too many targets for command {cmd.args}")
1995 node = target_nodes[0]
1996 if node.name not in nodes:
1997 nodes[node.name] = (node, [])
1998 nodes[node.name][1].append(cmd)
1999
2000 errors = await asyncio.gather(
2001 *(
2002 asyncio.create_task(node[0].execute_pipeline(node[1]))
2003 for node in nodes.values()
2004 )
2005 )
2006
2007 if any(errors):
2008 if allow_redirections:
2009 # send each errored command individually
2010 for cmd in todo:
2011 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2012 try:
2013 cmd.result = await client.execute_command(
2014 *cmd.args, **cmd.kwargs
2015 )
2016 except Exception as e:
2017 cmd.result = e
2018
2019 if raise_on_error:
2020 for cmd in todo:
2021 result = cmd.result
2022 if isinstance(result, Exception):
2023 command = " ".join(map(safe_str, cmd.args))
2024 msg = (
2025 f"Command # {cmd.position + 1} "
2026 f"({truncate_text(command)}) "
2027 f"of pipeline caused error: {result.args}"
2028 )
2029 result.args = (msg,) + result.args[1:]
2030 raise result
2031
2032 default_cluster_node = client.get_default_node()
2033
2034 # Check whether the default node was used. In some cases,
2035 # 'client.get_default_node()' may return None. The check below
2036 # prevents a potential AttributeError.
2037 if default_cluster_node is not None:
2038 default_node = nodes.get(default_cluster_node.name)
2039 if default_node is not None:
2040 # This pipeline execution used the default node, check if we need
2041 # to replace it.
2042 # Note: when the error is raised we'll reset the default node in the
2043 # caller function.
2044 for cmd in default_node[1]:
2045 # Check if it has a command that failed with a relevant
2046 # exception
2047 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2048 client.replace_default_node()
2049 break
2050
2051 return [cmd.result for cmd in stack]
2052
2053 async def reset(self):
2054 """
2055 Reset back to empty pipeline.
2056 """
2057 self._command_queue = []
2058
2059 def multi(self):
2060 raise RedisClusterException(
2061 "method multi() is not supported outside of transactional context"
2062 )
2063
2064 async def watch(self, *names):
2065 raise RedisClusterException(
2066 "method watch() is not supported outside of transactional context"
2067 )
2068
2069 async def unwatch(self):
2070 raise RedisClusterException(
2071 "method unwatch() is not supported outside of transactional context"
2072 )
2073
2074 async def discard(self):
2075 raise RedisClusterException(
2076 "method discard() is not supported outside of transactional context"
2077 )
2078
2079 async def unlink(self, *names):
2080 if len(names) != 1:
2081 raise RedisClusterException(
2082 "unlinking multiple keys is not implemented in pipeline command"
2083 )
2084
2085 return self.execute_command("UNLINK", names[0])
2086
2087
2088class TransactionStrategy(AbstractStrategy):
2089 NO_SLOTS_COMMANDS = {"UNWATCH"}
2090 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2091 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2092 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2093 CONNECTION_ERRORS = (
2094 ConnectionError,
2095 OSError,
2096 ClusterDownError,
2097 SlotNotCoveredError,
2098 )
2099
2100 def __init__(self, pipe: ClusterPipeline) -> None:
2101 super().__init__(pipe)
2102 self._explicit_transaction = False
2103 self._watching = False
2104 self._pipeline_slots: Set[int] = set()
2105 self._transaction_node: Optional[ClusterNode] = None
2106 self._transaction_connection: Optional[Connection] = None
2107 self._executing = False
2108 self._retry = copy(self._pipe.cluster_client.retry)
2109 self._retry.update_supported_errors(
2110 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2111 )
2112
2113 def _get_client_and_connection_for_transaction(
2114 self,
2115 ) -> Tuple[ClusterNode, Connection]:
2116 """
2117 Find a connection for a pipeline transaction.
2118
2119 For running an atomic transaction, watch keys ensure that contents have not been
2120 altered as long as the watch commands for those keys were sent over the same
2121 connection. So once we start watching a key, we fetch a connection to the
2122 node that owns that slot and reuse it.
2123 """
2124 if not self._pipeline_slots:
2125 raise RedisClusterException(
2126 "At least a command with a key is needed to identify a node"
2127 )
2128
2129 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2130 list(self._pipeline_slots)[0], False
2131 )
2132 self._transaction_node = node
2133
2134 if not self._transaction_connection:
2135 connection: Connection = self._transaction_node.acquire_connection()
2136 self._transaction_connection = connection
2137
2138 return self._transaction_node, self._transaction_connection
2139
2140 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2141 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2142 response = None
2143 error = None
2144
2145 def runner():
2146 nonlocal response
2147 nonlocal error
2148 try:
2149 response = asyncio.run(self._execute_command(*args, **kwargs))
2150 except Exception as e:
2151 error = e
2152
2153 thread = threading.Thread(target=runner)
2154 thread.start()
2155 thread.join()
2156
2157 if error:
2158 raise error
2159
2160 return response
2161
2162 async def _execute_command(
2163 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2164 ) -> Any:
2165 if self._pipe.cluster_client._initialize:
2166 await self._pipe.cluster_client.initialize()
2167
2168 slot_number: Optional[int] = None
2169 if args[0] not in self.NO_SLOTS_COMMANDS:
2170 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2171
2172 if (
2173 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2174 ) and not self._explicit_transaction:
2175 if args[0] == "WATCH":
2176 self._validate_watch()
2177
2178 if slot_number is not None:
2179 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2180 raise CrossSlotTransactionError(
2181 "Cannot watch or send commands on different slots"
2182 )
2183
2184 self._pipeline_slots.add(slot_number)
2185 elif args[0] not in self.NO_SLOTS_COMMANDS:
2186 raise RedisClusterException(
2187 f"Cannot identify slot number for command: {args[0]},"
2188 "it cannot be triggered in a transaction"
2189 )
2190
2191 return self._immediate_execute_command(*args, **kwargs)
2192 else:
2193 if slot_number is not None:
2194 self._pipeline_slots.add(slot_number)
2195
2196 return super().execute_command(*args, **kwargs)
2197
2198 def _validate_watch(self):
2199 if self._explicit_transaction:
2200 raise RedisError("Cannot issue a WATCH after a MULTI")
2201
2202 self._watching = True
2203
2204 async def _immediate_execute_command(self, *args, **options):
2205 return await self._retry.call_with_retry(
2206 lambda: self._get_connection_and_send_command(*args, **options),
2207 self._reinitialize_on_error,
2208 )
2209
2210 async def _get_connection_and_send_command(self, *args, **options):
2211 redis_node, connection = self._get_client_and_connection_for_transaction()
2212 return await self._send_command_parse_response(
2213 connection, redis_node, args[0], *args, **options
2214 )
2215
2216 async def _send_command_parse_response(
2217 self,
2218 connection: Connection,
2219 redis_node: ClusterNode,
2220 command_name,
2221 *args,
2222 **options,
2223 ):
2224 """
2225 Send a command and parse the response
2226 """
2227
2228 await connection.send_command(*args)
2229 output = await redis_node.parse_response(connection, command_name, **options)
2230
2231 if command_name in self.UNWATCH_COMMANDS:
2232 self._watching = False
2233 return output
2234
2235 async def _reinitialize_on_error(self, error):
2236 if self._watching:
2237 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2238 raise WatchError("Slot rebalancing occurred while watching keys")
2239
2240 if (
2241 type(error) in self.SLOT_REDIRECT_ERRORS
2242 or type(error) in self.CONNECTION_ERRORS
2243 ):
2244 if self._transaction_connection:
2245 self._transaction_connection = None
2246
2247 self._pipe.cluster_client.reinitialize_counter += 1
2248 if (
2249 self._pipe.cluster_client.reinitialize_steps
2250 and self._pipe.cluster_client.reinitialize_counter
2251 % self._pipe.cluster_client.reinitialize_steps
2252 == 0
2253 ):
2254 await self._pipe.cluster_client.nodes_manager.initialize()
2255 self.reinitialize_counter = 0
2256 else:
2257 if isinstance(error, AskError):
2258 self._pipe.cluster_client.nodes_manager.update_moved_exception(
2259 error
2260 )
2261
2262 self._executing = False
2263
2264 def _raise_first_error(self, responses, stack):
2265 """
2266 Raise the first exception on the stack
2267 """
2268 for r, cmd in zip(responses, stack):
2269 if isinstance(r, Exception):
2270 self._annotate_exception(r, cmd.position + 1, cmd.args)
2271 raise r
2272
2273 def mset_nonatomic(
2274 self, mapping: Mapping[AnyKeyT, EncodableT]
2275 ) -> "ClusterPipeline":
2276 raise NotImplementedError("Method is not supported in transactional context.")
2277
2278 async def execute(
2279 self, raise_on_error: bool = True, allow_redirections: bool = True
2280 ) -> List[Any]:
2281 stack = self._command_queue
2282 if not stack and (not self._watching or not self._pipeline_slots):
2283 return []
2284
2285 return await self._execute_transaction_with_retries(stack, raise_on_error)
2286
2287 async def _execute_transaction_with_retries(
2288 self, stack: List["PipelineCommand"], raise_on_error: bool
2289 ):
2290 return await self._retry.call_with_retry(
2291 lambda: self._execute_transaction(stack, raise_on_error),
2292 self._reinitialize_on_error,
2293 )
2294
2295 async def _execute_transaction(
2296 self, stack: List["PipelineCommand"], raise_on_error: bool
2297 ):
2298 if len(self._pipeline_slots) > 1:
2299 raise CrossSlotTransactionError(
2300 "All keys involved in a cluster transaction must map to the same slot"
2301 )
2302
2303 self._executing = True
2304
2305 redis_node, connection = self._get_client_and_connection_for_transaction()
2306
2307 stack = chain(
2308 [PipelineCommand(0, "MULTI")],
2309 stack,
2310 [PipelineCommand(0, "EXEC")],
2311 )
2312 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2313 packed_commands = connection.pack_commands(commands)
2314 await connection.send_packed_command(packed_commands)
2315 errors = []
2316
2317 # parse off the response for MULTI
2318 # NOTE: we need to handle ResponseErrors here and continue
2319 # so that we read all the additional command messages from
2320 # the socket
2321 try:
2322 await redis_node.parse_response(connection, "MULTI")
2323 except ResponseError as e:
2324 self._annotate_exception(e, 0, "MULTI")
2325 errors.append(e)
2326 except self.CONNECTION_ERRORS as cluster_error:
2327 self._annotate_exception(cluster_error, 0, "MULTI")
2328 raise
2329
2330 # and all the other commands
2331 for i, command in enumerate(self._command_queue):
2332 if EMPTY_RESPONSE in command.kwargs:
2333 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2334 else:
2335 try:
2336 _ = await redis_node.parse_response(connection, "_")
2337 except self.SLOT_REDIRECT_ERRORS as slot_error:
2338 self._annotate_exception(slot_error, i + 1, command.args)
2339 errors.append(slot_error)
2340 except self.CONNECTION_ERRORS as cluster_error:
2341 self._annotate_exception(cluster_error, i + 1, command.args)
2342 raise
2343 except ResponseError as e:
2344 self._annotate_exception(e, i + 1, command.args)
2345 errors.append(e)
2346
2347 response = None
2348 # parse the EXEC.
2349 try:
2350 response = await redis_node.parse_response(connection, "EXEC")
2351 except ExecAbortError:
2352 if errors:
2353 raise errors[0]
2354 raise
2355
2356 self._executing = False
2357
2358 # EXEC clears any watched keys
2359 self._watching = False
2360
2361 if response is None:
2362 raise WatchError("Watched variable changed.")
2363
2364 # put any parse errors into the response
2365 for i, e in errors:
2366 response.insert(i, e)
2367
2368 if len(response) != len(self._command_queue):
2369 raise InvalidPipelineStack(
2370 "Unexpected response length for cluster pipeline EXEC."
2371 " Command stack was {} but response had length {}".format(
2372 [c.args[0] for c in self._command_queue], len(response)
2373 )
2374 )
2375
2376 # find any errors in the response and raise if necessary
2377 if raise_on_error or len(errors) > 0:
2378 self._raise_first_error(
2379 response,
2380 self._command_queue,
2381 )
2382
2383 # We have to run response callbacks manually
2384 data = []
2385 for r, cmd in zip(response, self._command_queue):
2386 if not isinstance(r, Exception):
2387 command_name = cmd.args[0]
2388 if command_name in self._pipe.cluster_client.response_callbacks:
2389 r = self._pipe.cluster_client.response_callbacks[command_name](
2390 r, **cmd.kwargs
2391 )
2392 data.append(r)
2393 return data
2394
2395 async def reset(self):
2396 self._command_queue = []
2397
2398 # make sure to reset the connection state in the event that we were
2399 # watching something
2400 if self._transaction_connection:
2401 try:
2402 if self._watching:
2403 # call this manually since our unwatch or
2404 # immediate_execute_command methods can call reset()
2405 await self._transaction_connection.send_command("UNWATCH")
2406 await self._transaction_connection.read_response()
2407 # we can safely return the connection to the pool here since we're
2408 # sure we're no longer WATCHing anything
2409 self._transaction_node.release(self._transaction_connection)
2410 self._transaction_connection = None
2411 except self.CONNECTION_ERRORS:
2412 # disconnect will also remove any previous WATCHes
2413 if self._transaction_connection:
2414 await self._transaction_connection.disconnect()
2415
2416 # clean up the other instance attributes
2417 self._transaction_node = None
2418 self._watching = False
2419 self._explicit_transaction = False
2420 self._pipeline_slots = set()
2421 self._executing = False
2422
2423 def multi(self):
2424 if self._explicit_transaction:
2425 raise RedisError("Cannot issue nested calls to MULTI")
2426 if self._command_queue:
2427 raise RedisError(
2428 "Commands without an initial WATCH have already been issued"
2429 )
2430 self._explicit_transaction = True
2431
2432 async def watch(self, *names):
2433 if self._explicit_transaction:
2434 raise RedisError("Cannot issue a WATCH after a MULTI")
2435
2436 return await self.execute_command("WATCH", *names)
2437
2438 async def unwatch(self):
2439 if self._watching:
2440 return await self.execute_command("UNWATCH")
2441
2442 return True
2443
2444 async def discard(self):
2445 await self.reset()
2446
2447 async def unlink(self, *names):
2448 return self.execute_command("UNLINK", *names)