1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 Any,
13 Callable,
14 Coroutine,
15 Deque,
16 Dict,
17 Generator,
18 List,
19 Mapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26)
27
28from redis._parsers import AsyncCommandsParser, Encoder
29from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
30from redis._parsers.helpers import (
31 _RedisCallbacks,
32 _RedisCallbacksRESP2,
33 _RedisCallbacksRESP3,
34)
35from redis.asyncio.client import ResponseCallbackT
36from redis.asyncio.connection import Connection, SSLConnection, parse_url
37from redis.asyncio.lock import Lock
38from redis.asyncio.observability.recorder import (
39 record_error_count,
40 record_operation_duration,
41)
42from redis.asyncio.retry import Retry
43from redis.auth.token import TokenInterface
44from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
45from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
46from redis.cluster import (
47 PIPELINE_BLOCKED_COMMANDS,
48 PRIMARY,
49 REPLICA,
50 SLOT_ID,
51 AbstractRedisCluster,
52 LoadBalancer,
53 LoadBalancingStrategy,
54 block_pipeline_command,
55 get_node_name,
56 parse_cluster_slots,
57)
58from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
59from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
60from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
61from redis.credentials import CredentialProvider
62from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
63from redis.exceptions import (
64 AskError,
65 BusyLoadingError,
66 ClusterDownError,
67 ClusterError,
68 ConnectionError,
69 CrossSlotTransactionError,
70 DataError,
71 ExecAbortError,
72 InvalidPipelineStack,
73 MaxConnectionsError,
74 MovedError,
75 RedisClusterException,
76 RedisError,
77 ResponseError,
78 SlotNotCoveredError,
79 TimeoutError,
80 TryAgainError,
81 WatchError,
82)
83from redis.typing import AnyKeyT, EncodableT, KeyT
84from redis.utils import (
85 SSL_AVAILABLE,
86 deprecated_args,
87 deprecated_function,
88 get_lib_version,
89 safe_str,
90 str_if_bytes,
91 truncate_text,
92)
93
94if SSL_AVAILABLE:
95 from ssl import TLSVersion, VerifyFlags, VerifyMode
96else:
97 TLSVersion = None
98 VerifyMode = None
99 VerifyFlags = None
100
101TargetNodesT = TypeVar(
102 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
103)
104
105
106class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
107 """
108 Create a new RedisCluster client.
109
110 Pass one of parameters:
111
112 - `host` & `port`
113 - `startup_nodes`
114
115 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
116 | Use ``await`` :meth:`close` to disconnect connections & close client.
117
118 Many commands support the target_nodes kwarg. It can be one of the
119 :attr:`NODE_FLAGS`:
120
121 - :attr:`PRIMARIES`
122 - :attr:`REPLICAS`
123 - :attr:`ALL_NODES`
124 - :attr:`RANDOM`
125 - :attr:`DEFAULT_NODE`
126
127 Note: This client is not thread/process/fork safe.
128
129 :param host:
130 | Can be used to point to a startup node
131 :param port:
132 | Port used if **host** is provided
133 :param startup_nodes:
134 | :class:`~.ClusterNode` to used as a startup node
135 :param require_full_coverage:
136 | When set to ``False``: the client will not require a full coverage of
137 the slots. However, if not all slots are covered, and at least one node
138 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
139 a :class:`~.ClusterDownError` for some key-based commands.
140 | When set to ``True``: all slots must be covered to construct the cluster
141 client. If not all slots are covered, :class:`~.RedisClusterException` will be
142 thrown.
143 | See:
144 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
145 :param read_from_replicas:
146 | @deprecated - please use load_balancing_strategy instead
147 | Enable read from replicas in READONLY mode.
148 When set to true, read commands will be assigned between the primary and
149 its replications in a Round-Robin manner.
150 The data read from replicas is eventually consistent with the data in primary nodes.
151 :param load_balancing_strategy:
152 | Enable read from replicas in READONLY mode and defines the load balancing
153 strategy that will be used for cluster node selection.
154 The data read from replicas is eventually consistent with the data in primary nodes.
155 :param dynamic_startup_nodes:
156 | Set the RedisCluster's startup nodes to all the discovered nodes.
157 If true (default value), the cluster's discovered nodes will be used to
158 determine the cluster nodes-slots mapping in the next topology refresh.
159 It will remove the initial passed startup nodes if their endpoints aren't
160 listed in the CLUSTER SLOTS output.
161 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
162 specific IP addresses, it is best to set it to false.
163 :param reinitialize_steps:
164 | Specifies the number of MOVED errors that need to occur before reinitializing
165 the whole cluster topology. If a MOVED error occurs and the cluster does not
166 need to be reinitialized on this current error handling, only the MOVED slot
167 will be patched with the redirected node.
168 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
169 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
170 0.
171 :param cluster_error_retry_attempts:
172 | @deprecated - Please configure the 'retry' object instead
173 In case 'retry' object is set - this argument is ignored!
174
175 Number of times to retry before raising an error when :class:`~.TimeoutError`,
176 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
177 or :class:`~.ClusterDownError` are encountered
178 :param retry:
179 | A retry object that defines the retry strategy and the number of
180 retries for the cluster client.
181 In current implementation for the cluster client (starting form redis-py version 6.0.0)
182 the retry object is not yet fully utilized, instead it is used just to determine
183 the number of retries for the cluster client.
184 In the future releases the retry object will be used to handle the cluster client retries!
185 :param max_connections:
186 | Maximum number of connections per node. If there are no free connections & the
187 maximum number of connections are already created, a
188 :class:`~.MaxConnectionsError` is raised.
189 :param address_remap:
190 | An optional callable which, when provided with an internal network
191 address of a node, e.g. a `(host, port)` tuple, will return the address
192 where the node is reachable. This can be used to map the addresses at
193 which the nodes _think_ they are, to addresses at which a client may
194 reach them, such as when they sit behind a proxy.
195
196 | Rest of the arguments will be passed to the
197 :class:`~redis.asyncio.connection.Connection` instances when created
198
199 :raises RedisClusterException:
200 if any arguments are invalid or unknown. Eg:
201
202 - `db` != 0 or None
203 - `path` argument for unix socket connection
204 - none of the `host`/`port` & `startup_nodes` were provided
205
206 """
207
208 @classmethod
209 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
210 """
211 Return a Redis client object configured from the given URL.
212
213 For example::
214
215 redis://[[username]:[password]]@localhost:6379/0
216 rediss://[[username]:[password]]@localhost:6379/0
217
218 Three URL schemes are supported:
219
220 - `redis://` creates a TCP socket connection. See more at:
221 <https://www.iana.org/assignments/uri-schemes/prov/redis>
222 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
223 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
224
225 The username, password, hostname, path and all querystring values are passed
226 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
227 with their corresponding characters.
228
229 All querystring options are cast to their appropriate Python types. Boolean
230 arguments can be specified with string values "True"/"False" or "Yes"/"No".
231 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
232 parsed, the querystring arguments and keyword arguments are passed to
233 :class:`~redis.asyncio.connection.Connection` when created.
234 In the case of conflicting arguments, querystring arguments are used.
235 """
236 kwargs.update(parse_url(url))
237 if kwargs.pop("connection_class", None) is SSLConnection:
238 kwargs["ssl"] = True
239 return cls(**kwargs)
240
241 __slots__ = (
242 "_initialize",
243 "_lock",
244 "retry",
245 "command_flags",
246 "commands_parser",
247 "connection_kwargs",
248 "encoder",
249 "node_flags",
250 "nodes_manager",
251 "read_from_replicas",
252 "reinitialize_counter",
253 "reinitialize_steps",
254 "response_callbacks",
255 "result_callbacks",
256 )
257
258 @deprecated_args(
259 args_to_warn=["read_from_replicas"],
260 reason="Please configure the 'load_balancing_strategy' instead",
261 version="5.3.0",
262 )
263 @deprecated_args(
264 args_to_warn=[
265 "cluster_error_retry_attempts",
266 ],
267 reason="Please configure the 'retry' object instead",
268 version="6.0.0",
269 )
270 def __init__(
271 self,
272 host: Optional[str] = None,
273 port: Union[str, int] = 6379,
274 # Cluster related kwargs
275 startup_nodes: Optional[List["ClusterNode"]] = None,
276 require_full_coverage: bool = True,
277 read_from_replicas: bool = False,
278 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
279 dynamic_startup_nodes: bool = True,
280 reinitialize_steps: int = 5,
281 cluster_error_retry_attempts: int = 3,
282 max_connections: int = 2**31,
283 retry: Optional["Retry"] = None,
284 retry_on_error: Optional[List[Type[Exception]]] = None,
285 # Client related kwargs
286 db: Union[str, int] = 0,
287 path: Optional[str] = None,
288 credential_provider: Optional[CredentialProvider] = None,
289 username: Optional[str] = None,
290 password: Optional[str] = None,
291 client_name: Optional[str] = None,
292 lib_name: Optional[str] = "redis-py",
293 lib_version: Optional[str] = get_lib_version(),
294 # Encoding related kwargs
295 encoding: str = "utf-8",
296 encoding_errors: str = "strict",
297 decode_responses: bool = False,
298 # Connection related kwargs
299 health_check_interval: float = 0,
300 socket_connect_timeout: Optional[float] = None,
301 socket_keepalive: bool = False,
302 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
303 socket_timeout: Optional[float] = None,
304 # SSL related kwargs
305 ssl: bool = False,
306 ssl_ca_certs: Optional[str] = None,
307 ssl_ca_data: Optional[str] = None,
308 ssl_cert_reqs: Union[str, VerifyMode] = "required",
309 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
310 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
311 ssl_certfile: Optional[str] = None,
312 ssl_check_hostname: bool = True,
313 ssl_keyfile: Optional[str] = None,
314 ssl_min_version: Optional[TLSVersion] = None,
315 ssl_ciphers: Optional[str] = None,
316 protocol: Optional[int] = 2,
317 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
318 event_dispatcher: Optional[EventDispatcher] = None,
319 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
320 ) -> None:
321 if db:
322 raise RedisClusterException(
323 "Argument 'db' must be 0 or None in cluster mode"
324 )
325
326 if path:
327 raise RedisClusterException(
328 "Unix domain socket is not supported in cluster mode"
329 )
330
331 if (not host or not port) and not startup_nodes:
332 raise RedisClusterException(
333 "RedisCluster requires at least one node to discover the cluster.\n"
334 "Please provide one of the following or use RedisCluster.from_url:\n"
335 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
336 " - startup_nodes: RedisCluster(startup_nodes=["
337 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
338 )
339
340 kwargs: Dict[str, Any] = {
341 "max_connections": max_connections,
342 "connection_class": Connection,
343 # Client related kwargs
344 "credential_provider": credential_provider,
345 "username": username,
346 "password": password,
347 "client_name": client_name,
348 "lib_name": lib_name,
349 "lib_version": lib_version,
350 # Encoding related kwargs
351 "encoding": encoding,
352 "encoding_errors": encoding_errors,
353 "decode_responses": decode_responses,
354 # Connection related kwargs
355 "health_check_interval": health_check_interval,
356 "socket_connect_timeout": socket_connect_timeout,
357 "socket_keepalive": socket_keepalive,
358 "socket_keepalive_options": socket_keepalive_options,
359 "socket_timeout": socket_timeout,
360 "protocol": protocol,
361 }
362
363 if ssl:
364 # SSL related kwargs
365 kwargs.update(
366 {
367 "connection_class": SSLConnection,
368 "ssl_ca_certs": ssl_ca_certs,
369 "ssl_ca_data": ssl_ca_data,
370 "ssl_cert_reqs": ssl_cert_reqs,
371 "ssl_include_verify_flags": ssl_include_verify_flags,
372 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
373 "ssl_certfile": ssl_certfile,
374 "ssl_check_hostname": ssl_check_hostname,
375 "ssl_keyfile": ssl_keyfile,
376 "ssl_min_version": ssl_min_version,
377 "ssl_ciphers": ssl_ciphers,
378 }
379 )
380
381 if read_from_replicas or load_balancing_strategy:
382 # Call our on_connect function to configure READONLY mode
383 kwargs["redis_connect_func"] = self.on_connect
384
385 if retry:
386 self.retry = retry
387 else:
388 self.retry = Retry(
389 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
390 retries=cluster_error_retry_attempts,
391 )
392 if retry_on_error:
393 self.retry.update_supported_errors(retry_on_error)
394
395 kwargs["response_callbacks"] = _RedisCallbacks.copy()
396 if kwargs.get("protocol") in ["3", 3]:
397 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
398 else:
399 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
400 self.connection_kwargs = kwargs
401
402 if startup_nodes:
403 passed_nodes = []
404 for node in startup_nodes:
405 passed_nodes.append(
406 ClusterNode(node.host, node.port, **self.connection_kwargs)
407 )
408 startup_nodes = passed_nodes
409 else:
410 startup_nodes = []
411 if host and port:
412 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
413
414 if event_dispatcher is None:
415 self._event_dispatcher = EventDispatcher()
416 else:
417 self._event_dispatcher = event_dispatcher
418
419 self.startup_nodes = startup_nodes
420 self.nodes_manager = NodesManager(
421 startup_nodes,
422 require_full_coverage,
423 kwargs,
424 dynamic_startup_nodes=dynamic_startup_nodes,
425 address_remap=address_remap,
426 event_dispatcher=self._event_dispatcher,
427 )
428 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
429 self.read_from_replicas = read_from_replicas
430 self.load_balancing_strategy = load_balancing_strategy
431 self.reinitialize_steps = reinitialize_steps
432 self.reinitialize_counter = 0
433
434 # For backward compatibility, mapping from existing policies to new one
435 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
436 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
437 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
438 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
439 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
440 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
441 SLOT_ID: RequestPolicy.DEFAULT_KEYED,
442 }
443
444 self._policies_callback_mapping: dict[
445 Union[RequestPolicy, ResponsePolicy], Callable
446 ] = {
447 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
448 self.get_random_primary_or_all_nodes(command_name)
449 ],
450 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
451 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
452 RequestPolicy.ALL_SHARDS: self.get_primaries,
453 RequestPolicy.ALL_NODES: self.get_nodes,
454 RequestPolicy.ALL_REPLICAS: self.get_replicas,
455 RequestPolicy.SPECIAL: self.get_special_nodes,
456 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
457 ResponsePolicy.DEFAULT_KEYED: lambda res: res,
458 }
459
460 self._policy_resolver = policy_resolver
461 self.commands_parser = AsyncCommandsParser()
462 self._aggregate_nodes = None
463 self.node_flags = self.__class__.NODE_FLAGS.copy()
464 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
465 self.response_callbacks = kwargs["response_callbacks"]
466 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
467 self.result_callbacks["CLUSTER SLOTS"] = (
468 lambda cmd, res, **kwargs: parse_cluster_slots(
469 list(res.values())[0], **kwargs
470 )
471 )
472
473 self._initialize = True
474 self._lock: Optional[asyncio.Lock] = None
475
476 # When used as an async context manager, we need to increment and decrement
477 # a usage counter so that we can close the connection pool when no one is
478 # using the client.
479 self._usage_counter = 0
480 self._usage_lock = asyncio.Lock()
481
482 async def initialize(self) -> "RedisCluster":
483 """Get all nodes from startup nodes & creates connections if not initialized."""
484 if self._initialize:
485 if not self._lock:
486 self._lock = asyncio.Lock()
487 async with self._lock:
488 if self._initialize:
489 try:
490 await self.nodes_manager.initialize()
491 await self.commands_parser.initialize(
492 self.nodes_manager.default_node
493 )
494 self._initialize = False
495 except BaseException:
496 await self.nodes_manager.aclose()
497 await self.nodes_manager.aclose("startup_nodes")
498 raise
499 return self
500
501 async def aclose(self) -> None:
502 """Close all connections & client if initialized."""
503 if not self._initialize:
504 if not self._lock:
505 self._lock = asyncio.Lock()
506 async with self._lock:
507 if not self._initialize:
508 self._initialize = True
509 await self.nodes_manager.aclose()
510 await self.nodes_manager.aclose("startup_nodes")
511
512 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
513 async def close(self) -> None:
514 """alias for aclose() for backwards compatibility"""
515 await self.aclose()
516
517 async def __aenter__(self) -> "RedisCluster":
518 """
519 Async context manager entry. Increments a usage counter so that the
520 connection pool is only closed (via aclose()) when no context is using
521 the client.
522 """
523 await self._increment_usage()
524 try:
525 # Initialize the client (i.e. establish connection, etc.)
526 return await self.initialize()
527 except Exception:
528 # If initialization fails, decrement the counter to keep it in sync
529 await self._decrement_usage()
530 raise
531
532 async def _increment_usage(self) -> int:
533 """
534 Helper coroutine to increment the usage counter while holding the lock.
535 Returns the new value of the usage counter.
536 """
537 async with self._usage_lock:
538 self._usage_counter += 1
539 return self._usage_counter
540
541 async def _decrement_usage(self) -> int:
542 """
543 Helper coroutine to decrement the usage counter while holding the lock.
544 Returns the new value of the usage counter.
545 """
546 async with self._usage_lock:
547 self._usage_counter -= 1
548 return self._usage_counter
549
550 async def __aexit__(self, exc_type, exc_value, traceback):
551 """
552 Async context manager exit. Decrements a usage counter. If this is the
553 last exit (counter becomes zero), the client closes its connection pool.
554 """
555 current_usage = await asyncio.shield(self._decrement_usage())
556 if current_usage == 0:
557 # This was the last active context, so disconnect the pool.
558 await asyncio.shield(self.aclose())
559
560 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
561 return self.initialize().__await__()
562
563 _DEL_MESSAGE = "Unclosed RedisCluster client"
564
565 def __del__(
566 self,
567 _warn: Any = warnings.warn,
568 _grl: Any = asyncio.get_running_loop,
569 ) -> None:
570 if hasattr(self, "_initialize") and not self._initialize:
571 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
572 try:
573 context = {"client": self, "message": self._DEL_MESSAGE}
574 _grl().call_exception_handler(context)
575 except RuntimeError:
576 pass
577
578 async def on_connect(self, connection: Connection) -> None:
579 await connection.on_connect()
580
581 # Sending READONLY command to server to configure connection as
582 # readonly. Since each cluster node may change its server type due
583 # to a failover, we should establish a READONLY connection
584 # regardless of the server type. If this is a primary connection,
585 # READONLY would not affect executing write commands.
586 await connection.send_command("READONLY")
587 if str_if_bytes(await connection.read_response()) != "OK":
588 raise ConnectionError("READONLY command failed")
589
590 def get_nodes(self) -> List["ClusterNode"]:
591 """Get all nodes of the cluster."""
592 return list(self.nodes_manager.nodes_cache.values())
593
594 def get_primaries(self) -> List["ClusterNode"]:
595 """Get the primary nodes of the cluster."""
596 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
597
598 def get_replicas(self) -> List["ClusterNode"]:
599 """Get the replica nodes of the cluster."""
600 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
601
602 def get_random_node(self) -> "ClusterNode":
603 """Get a random node of the cluster."""
604 return random.choice(list(self.nodes_manager.nodes_cache.values()))
605
606 def get_default_node(self) -> "ClusterNode":
607 """Get the default node of the client."""
608 return self.nodes_manager.default_node
609
610 def set_default_node(self, node: "ClusterNode") -> None:
611 """
612 Set the default node of the client.
613
614 :raises DataError: if None is passed or node does not exist in cluster.
615 """
616 if not node or not self.get_node(node_name=node.name):
617 raise DataError("The requested node does not exist in the cluster.")
618
619 self.nodes_manager.default_node = node
620
621 def get_node(
622 self,
623 host: Optional[str] = None,
624 port: Optional[int] = None,
625 node_name: Optional[str] = None,
626 ) -> Optional["ClusterNode"]:
627 """Get node by (host, port) or node_name."""
628 return self.nodes_manager.get_node(host, port, node_name)
629
630 def get_node_from_key(
631 self, key: str, replica: bool = False
632 ) -> Optional["ClusterNode"]:
633 """
634 Get the cluster node corresponding to the provided key.
635
636 :param key:
637 :param replica:
638 | Indicates if a replica should be returned
639 |
640 None will returned if no replica holds this key
641
642 :raises SlotNotCoveredError: if the key is not covered by any slot.
643 """
644 slot = self.keyslot(key)
645 slot_cache = self.nodes_manager.slots_cache.get(slot)
646 if not slot_cache:
647 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
648
649 if replica:
650 if len(self.nodes_manager.slots_cache[slot]) < 2:
651 return None
652 node_idx = 1
653 else:
654 node_idx = 0
655
656 return slot_cache[node_idx]
657
658 def get_random_primary_or_all_nodes(self, command_name):
659 """
660 Returns random primary or all nodes depends on READONLY mode.
661 """
662 if self.read_from_replicas and command_name in READ_COMMANDS:
663 return self.get_random_node()
664
665 return self.get_random_primary_node()
666
667 def get_random_primary_node(self) -> "ClusterNode":
668 """
669 Returns a random primary node
670 """
671 return random.choice(self.get_primaries())
672
673 async def get_nodes_from_slot(self, command: str, *args):
674 """
675 Returns a list of nodes that hold the specified keys' slots.
676 """
677 # get the node that holds the key's slot
678 return [
679 self.nodes_manager.get_node_from_slot(
680 await self._determine_slot(command, *args),
681 self.read_from_replicas and command in READ_COMMANDS,
682 self.load_balancing_strategy if command in READ_COMMANDS else None,
683 )
684 ]
685
686 def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
687 """
688 Returns a list of nodes for commands with a special policy.
689 """
690 if not self._aggregate_nodes:
691 raise RedisClusterException(
692 "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
693 )
694
695 return self._aggregate_nodes
696
697 def keyslot(self, key: EncodableT) -> int:
698 """
699 Find the keyslot for a given key.
700
701 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
702 """
703 return key_slot(self.encoder.encode(key))
704
705 def get_encoder(self) -> Encoder:
706 """Get the encoder object of the client."""
707 return self.encoder
708
709 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
710 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
711 return self.connection_kwargs
712
713 def set_retry(self, retry: Retry) -> None:
714 self.retry = retry
715
716 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
717 """Set a custom response callback."""
718 self.response_callbacks[command] = callback
719
720 async def _determine_nodes(
721 self,
722 command: str,
723 *args: Any,
724 request_policy: RequestPolicy,
725 node_flag: Optional[str] = None,
726 ) -> List["ClusterNode"]:
727 # Determine which nodes should be executed the command on.
728 # Returns a list of target nodes.
729 if not node_flag:
730 # get the nodes group for this command if it was predefined
731 node_flag = self.command_flags.get(command)
732
733 if node_flag in self._command_flags_mapping:
734 request_policy = self._command_flags_mapping[node_flag]
735
736 policy_callback = self._policies_callback_mapping[request_policy]
737
738 if request_policy == RequestPolicy.DEFAULT_KEYED:
739 nodes = await policy_callback(command, *args)
740 elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
741 nodes = policy_callback(command)
742 else:
743 nodes = policy_callback()
744
745 if command.lower() == "ft.aggregate":
746 self._aggregate_nodes = nodes
747
748 return nodes
749
750 async def _determine_slot(self, command: str, *args: Any) -> int:
751 if self.command_flags.get(command) == SLOT_ID:
752 # The command contains the slot ID
753 return int(args[0])
754
755 # Get the keys in the command
756
757 # EVAL and EVALSHA are common enough that it's wasteful to go to the
758 # redis server to parse the keys. Besides, there is a bug in redis<7.0
759 # where `self._get_command_keys()` fails anyway. So, we special case
760 # EVAL/EVALSHA.
761 # - issue: https://github.com/redis/redis/issues/9493
762 # - fix: https://github.com/redis/redis/pull/9733
763 if command.upper() in ("EVAL", "EVALSHA"):
764 # command syntax: EVAL "script body" num_keys ...
765 if len(args) < 2:
766 raise RedisClusterException(
767 f"Invalid args in command: {command, *args}"
768 )
769 keys = args[2 : 2 + int(args[1])]
770 # if there are 0 keys, that means the script can be run on any node
771 # so we can just return a random slot
772 if not keys:
773 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
774 else:
775 keys = await self.commands_parser.get_keys(command, *args)
776 if not keys:
777 # FCALL can call a function with 0 keys, that means the function
778 # can be run on any node so we can just return a random slot
779 if command.upper() in ("FCALL", "FCALL_RO"):
780 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
781 raise RedisClusterException(
782 "No way to dispatch this command to Redis Cluster. "
783 "Missing key.\nYou can execute the command by specifying "
784 f"target nodes.\nCommand: {args}"
785 )
786
787 # single key command
788 if len(keys) == 1:
789 return self.keyslot(keys[0])
790
791 # multi-key command; we need to make sure all keys are mapped to
792 # the same slot
793 slots = {self.keyslot(key) for key in keys}
794 if len(slots) != 1:
795 raise RedisClusterException(
796 f"{command} - all keys must map to the same key slot"
797 )
798
799 return slots.pop()
800
801 def _is_node_flag(self, target_nodes: Any) -> bool:
802 return isinstance(target_nodes, str) and target_nodes in self.node_flags
803
804 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
805 if isinstance(target_nodes, list):
806 nodes = target_nodes
807 elif isinstance(target_nodes, ClusterNode):
808 # Supports passing a single ClusterNode as a variable
809 nodes = [target_nodes]
810 elif isinstance(target_nodes, dict):
811 # Supports dictionaries of the format {node_name: node}.
812 # It enables to execute commands with multi nodes as follows:
813 # rc.cluster_save_config(rc.get_primaries())
814 nodes = list(target_nodes.values())
815 else:
816 raise TypeError(
817 "target_nodes type can be one of the following: "
818 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
819 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
820 f"The passed type is {type(target_nodes)}"
821 )
822 return nodes
823
824 async def _record_error_metric(
825 self,
826 error: Exception,
827 connection: Union[Connection, "ClusterNode"],
828 is_internal: bool = True,
829 retry_attempts: Optional[int] = None,
830 ):
831 """
832 Records error count metric directly.
833 Accepts either a Connection or ClusterNode object.
834 """
835 await record_error_count(
836 server_address=connection.host,
837 server_port=connection.port,
838 network_peer_address=connection.host,
839 network_peer_port=connection.port,
840 error_type=error,
841 retry_attempts=retry_attempts if retry_attempts is not None else 0,
842 is_internal=is_internal,
843 )
844
845 async def _record_command_metric(
846 self,
847 command_name: str,
848 duration_seconds: float,
849 connection: Union[Connection, "ClusterNode"],
850 error: Optional[Exception] = None,
851 ):
852 """
853 Records operation duration metric directly.
854 Accepts either a Connection or ClusterNode object.
855 """
856 # Connection has db attribute, ClusterNode has connection_kwargs
857 if hasattr(connection, "db"):
858 db = connection.db
859 else:
860 db = connection.connection_kwargs.get("db", 0)
861 await record_operation_duration(
862 command_name=command_name,
863 duration_seconds=duration_seconds,
864 server_address=connection.host,
865 server_port=connection.port,
866 db_namespace=str(db) if db is not None else None,
867 error=error,
868 )
869
870 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
871 """
872 Execute a raw command on the appropriate cluster node or target_nodes.
873
874 It will retry the command as specified by the retries property of
875 the :attr:`retry` & then raise an exception.
876
877 :param args:
878 | Raw command args
879 :param kwargs:
880
881 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
882 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
883 - Rest of the kwargs are passed to the Redis connection
884
885 :raises RedisClusterException: if target_nodes is not provided & the command
886 can't be mapped to a slot
887 """
888 command = args[0]
889 target_nodes = []
890 target_nodes_specified = False
891 retry_attempts = self.retry.get_retries()
892
893 passed_targets = kwargs.pop("target_nodes", None)
894 if passed_targets and not self._is_node_flag(passed_targets):
895 target_nodes = self._parse_target_nodes(passed_targets)
896 target_nodes_specified = True
897 retry_attempts = 0
898
899 command_policies = await self._policy_resolver.resolve(args[0].lower())
900
901 if not command_policies and not target_nodes_specified:
902 command_flag = self.command_flags.get(command)
903 if not command_flag:
904 # Fallback to default policy
905 if not self.get_default_node():
906 slot = None
907 else:
908 slot = await self._determine_slot(*args)
909 if slot is None:
910 command_policies = CommandPolicies()
911 else:
912 command_policies = CommandPolicies(
913 request_policy=RequestPolicy.DEFAULT_KEYED,
914 response_policy=ResponsePolicy.DEFAULT_KEYED,
915 )
916 else:
917 if command_flag in self._command_flags_mapping:
918 command_policies = CommandPolicies(
919 request_policy=self._command_flags_mapping[command_flag]
920 )
921 else:
922 command_policies = CommandPolicies()
923 elif not command_policies and target_nodes_specified:
924 command_policies = CommandPolicies()
925
926 # Add one for the first execution
927 execute_attempts = 1 + retry_attempts
928 failure_count = 0
929
930 # Start timing for observability
931 start_time = time.monotonic()
932
933 for _ in range(execute_attempts):
934 if self._initialize:
935 await self.initialize()
936 if (
937 len(target_nodes) == 1
938 and target_nodes[0] == self.get_default_node()
939 ):
940 # Replace the default cluster node
941 self.replace_default_node()
942 try:
943 if not target_nodes_specified:
944 # Determine the nodes to execute the command on
945 target_nodes = await self._determine_nodes(
946 *args,
947 request_policy=command_policies.request_policy,
948 node_flag=passed_targets,
949 )
950 if not target_nodes:
951 raise RedisClusterException(
952 f"No targets were found to execute {args} command on"
953 )
954
955 if len(target_nodes) == 1:
956 # Return the processed result
957 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
958 if command in self.result_callbacks:
959 ret = self.result_callbacks[command](
960 command, {target_nodes[0].name: ret}, **kwargs
961 )
962 return self._policies_callback_mapping[
963 command_policies.response_policy
964 ](ret)
965 else:
966 keys = [node.name for node in target_nodes]
967 values = await asyncio.gather(
968 *(
969 asyncio.create_task(
970 self._execute_command(node, *args, **kwargs)
971 )
972 for node in target_nodes
973 )
974 )
975 if command in self.result_callbacks:
976 return self.result_callbacks[command](
977 command, dict(zip(keys, values)), **kwargs
978 )
979 return self._policies_callback_mapping[
980 command_policies.response_policy
981 ](dict(zip(keys, values)))
982 except Exception as e:
983 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
984 # The nodes and slots cache were should be reinitialized.
985 # Try again with the new cluster setup.
986 retry_attempts -= 1
987 failure_count += 1
988
989 if hasattr(e, "connection"):
990 await self._record_command_metric(
991 command_name=command,
992 duration_seconds=time.monotonic() - start_time,
993 connection=e.connection,
994 error=e,
995 )
996 await self._record_error_metric(
997 error=e,
998 connection=e.connection,
999 retry_attempts=failure_count,
1000 )
1001 continue
1002 else:
1003 # raise the exception
1004 if hasattr(e, "connection"):
1005 await self._record_error_metric(
1006 error=e,
1007 connection=e.connection,
1008 retry_attempts=failure_count,
1009 is_internal=False,
1010 )
1011 raise e
1012
1013 async def _execute_command(
1014 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
1015 ) -> Any:
1016 asking = moved = False
1017 redirect_addr = None
1018 ttl = self.RedisClusterRequestTTL
1019 command = args[0]
1020 start_time = time.monotonic()
1021
1022 while ttl > 0:
1023 ttl -= 1
1024 try:
1025 if asking:
1026 target_node = self.get_node(node_name=redirect_addr)
1027 await target_node.execute_command("ASKING")
1028 asking = False
1029 elif moved:
1030 # MOVED occurred and the slots cache was updated,
1031 # refresh the target node
1032 slot = await self._determine_slot(*args)
1033 target_node = self.nodes_manager.get_node_from_slot(
1034 slot,
1035 self.read_from_replicas and args[0] in READ_COMMANDS,
1036 self.load_balancing_strategy
1037 if args[0] in READ_COMMANDS
1038 else None,
1039 )
1040 moved = False
1041
1042 response = await target_node.execute_command(*args, **kwargs)
1043 await self._record_command_metric(
1044 command_name=command,
1045 duration_seconds=time.monotonic() - start_time,
1046 connection=target_node,
1047 )
1048 return response
1049 except BusyLoadingError as e:
1050 e.connection = target_node
1051 await self._record_command_metric(
1052 command_name=command,
1053 duration_seconds=time.monotonic() - start_time,
1054 connection=target_node,
1055 error=e,
1056 )
1057 raise
1058 except MaxConnectionsError as e:
1059 # MaxConnectionsError indicates client-side resource exhaustion
1060 # (too many connections in the pool), not a node failure.
1061 # Don't treat this as a node failure - just re-raise the error
1062 # without reinitializing the cluster.
1063 e.connection = target_node
1064 await self._record_command_metric(
1065 command_name=command,
1066 duration_seconds=time.monotonic() - start_time,
1067 connection=target_node,
1068 error=e,
1069 )
1070 raise
1071 except (ConnectionError, TimeoutError) as e:
1072 # Connection retries are being handled in the node's
1073 # Retry object.
1074 # Mark active connections for reconnect and disconnect free ones
1075 # This handles connection state (like READONLY) that may be stale
1076 target_node.update_active_connections_for_reconnect()
1077 await target_node.disconnect_free_connections()
1078
1079 # Move the failed node to the end of the cached nodes list
1080 # so it's tried last during reinitialization
1081 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)
1082
1083 # Signal that reinitialization is needed
1084 # The retry loop will handle initialize() AND replace_default_node()
1085 self._initialize = True
1086 e.connection = target_node
1087 await self._record_command_metric(
1088 command_name=command,
1089 duration_seconds=time.monotonic() - start_time,
1090 connection=target_node,
1091 error=e,
1092 )
1093 raise
1094 except (ClusterDownError, SlotNotCoveredError) as e:
1095 # ClusterDownError can occur during a failover and to get
1096 # self-healed, we will try to reinitialize the cluster layout
1097 # and retry executing the command
1098
1099 # SlotNotCoveredError can occur when the cluster is not fully
1100 # initialized or can be temporary issue.
1101 # We will try to reinitialize the cluster topology
1102 # and retry executing the command
1103
1104 await self.aclose()
1105 await asyncio.sleep(0.25)
1106 e.connection = target_node
1107 await self._record_command_metric(
1108 command_name=command,
1109 duration_seconds=time.monotonic() - start_time,
1110 connection=target_node,
1111 error=e,
1112 )
1113 raise
1114 except MovedError as e:
1115 # First, we will try to patch the slots/nodes cache with the
1116 # redirected node output and try again. If MovedError exceeds
1117 # 'reinitialize_steps' number of times, we will force
1118 # reinitializing the tables, and then try again.
1119 # 'reinitialize_steps' counter will increase faster when
1120 # the same client object is shared between multiple threads. To
1121 # reduce the frequency you can set this variable in the
1122 # RedisCluster constructor.
1123 self.reinitialize_counter += 1
1124 if (
1125 self.reinitialize_steps
1126 and self.reinitialize_counter % self.reinitialize_steps == 0
1127 ):
1128 await self.aclose()
1129 # Reset the counter
1130 self.reinitialize_counter = 0
1131 else:
1132 self.nodes_manager.move_slot(e)
1133 moved = True
1134 await self._record_command_metric(
1135 command_name=command,
1136 duration_seconds=time.monotonic() - start_time,
1137 connection=target_node,
1138 error=e,
1139 )
1140 await self._record_error_metric(
1141 error=e,
1142 connection=target_node,
1143 )
1144 except AskError as e:
1145 redirect_addr = get_node_name(host=e.host, port=e.port)
1146 asking = True
1147 await self._record_command_metric(
1148 command_name=command,
1149 duration_seconds=time.monotonic() - start_time,
1150 connection=target_node,
1151 error=e,
1152 )
1153 await self._record_error_metric(
1154 error=e,
1155 connection=target_node,
1156 )
1157 except TryAgainError as e:
1158 if ttl < self.RedisClusterRequestTTL / 2:
1159 await asyncio.sleep(0.05)
1160 await self._record_command_metric(
1161 command_name=command,
1162 duration_seconds=time.monotonic() - start_time,
1163 connection=target_node,
1164 error=e,
1165 )
1166 await self._record_error_metric(
1167 error=e,
1168 connection=target_node,
1169 )
1170 except ResponseError as e:
1171 e.connection = target_node
1172 await self._record_command_metric(
1173 command_name=command,
1174 duration_seconds=time.monotonic() - start_time,
1175 connection=target_node,
1176 error=e,
1177 )
1178 raise
1179 except Exception as e:
1180 e.connection = target_node
1181 await self._record_command_metric(
1182 command_name=command,
1183 duration_seconds=time.monotonic() - start_time,
1184 connection=target_node,
1185 error=e,
1186 )
1187 raise
1188
1189 e = ClusterError("TTL exhausted.")
1190 e.connection = target_node
1191 await self._record_command_metric(
1192 command_name=command,
1193 duration_seconds=time.monotonic() - start_time,
1194 connection=target_node,
1195 error=e,
1196 )
1197 raise e
1198
1199 def pipeline(
1200 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
1201 ) -> "ClusterPipeline":
1202 """
1203 Create & return a new :class:`~.ClusterPipeline` object.
1204
1205 Cluster implementation of pipeline does not support transaction or shard_hint.
1206
1207 :raises RedisClusterException: if transaction or shard_hint are truthy values
1208 """
1209 if shard_hint:
1210 raise RedisClusterException("shard_hint is deprecated in cluster mode")
1211
1212 return ClusterPipeline(self, transaction)
1213
1214 def lock(
1215 self,
1216 name: KeyT,
1217 timeout: Optional[float] = None,
1218 sleep: float = 0.1,
1219 blocking: bool = True,
1220 blocking_timeout: Optional[float] = None,
1221 lock_class: Optional[Type[Lock]] = None,
1222 thread_local: bool = True,
1223 raise_on_release_error: bool = True,
1224 ) -> Lock:
1225 """
1226 Return a new Lock object using key ``name`` that mimics
1227 the behavior of threading.Lock.
1228
1229 If specified, ``timeout`` indicates a maximum life for the lock.
1230 By default, it will remain locked until release() is called.
1231
1232 ``sleep`` indicates the amount of time to sleep per loop iteration
1233 when the lock is in blocking mode and another client is currently
1234 holding the lock.
1235
1236 ``blocking`` indicates whether calling ``acquire`` should block until
1237 the lock has been acquired or to fail immediately, causing ``acquire``
1238 to return False and the lock not being acquired. Defaults to True.
1239 Note this value can be overridden by passing a ``blocking``
1240 argument to ``acquire``.
1241
1242 ``blocking_timeout`` indicates the maximum amount of time in seconds to
1243 spend trying to acquire the lock. A value of ``None`` indicates
1244 continue trying forever. ``blocking_timeout`` can be specified as a
1245 float or integer, both representing the number of seconds to wait.
1246
1247 ``lock_class`` forces the specified lock implementation. Note that as
1248 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
1249 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
1250 you have created your own custom lock class.
1251
1252 ``thread_local`` indicates whether the lock token is placed in
1253 thread-local storage. By default, the token is placed in thread local
1254 storage so that a thread only sees its token, not a token set by
1255 another thread. Consider the following timeline:
1256
1257 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
1258 thread-1 sets the token to "abc"
1259 time: 1, thread-2 blocks trying to acquire `my-lock` using the
1260 Lock instance.
1261 time: 5, thread-1 has not yet completed. redis expires the lock
1262 key.
1263 time: 5, thread-2 acquired `my-lock` now that it's available.
1264 thread-2 sets the token to "xyz"
1265 time: 6, thread-1 finishes its work and calls release(). if the
1266 token is *not* stored in thread local storage, then
1267 thread-1 would see the token value as "xyz" and would be
1268 able to successfully release the thread-2's lock.
1269
1270 ``raise_on_release_error`` indicates whether to raise an exception when
1271 the lock is no longer owned when exiting the context manager. By default,
1272 this is True, meaning an exception will be raised. If False, the warning
1273 will be logged and the exception will be suppressed.
1274
1275 In some use cases it's necessary to disable thread local storage. For
1276 example, if you have code where one thread acquires a lock and passes
1277 that lock instance to a worker thread to release later. If thread
1278 local storage isn't disabled in this case, the worker thread won't see
1279 the token set by the thread that acquired the lock. Our assumption
1280 is that these cases aren't common and as such default to using
1281 thread local storage."""
1282 if lock_class is None:
1283 lock_class = Lock
1284 return lock_class(
1285 self,
1286 name,
1287 timeout=timeout,
1288 sleep=sleep,
1289 blocking=blocking,
1290 blocking_timeout=blocking_timeout,
1291 thread_local=thread_local,
1292 raise_on_release_error=raise_on_release_error,
1293 )
1294
1295 async def transaction(
1296 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1297 ):
1298 """
1299 Convenience method for executing the callable `func` as a transaction
1300 while watching all keys specified in `watches`. The 'func' callable
1301 should expect a single argument which is a Pipeline object.
1302 """
1303 shard_hint = kwargs.pop("shard_hint", None)
1304 value_from_callable = kwargs.pop("value_from_callable", False)
1305 watch_delay = kwargs.pop("watch_delay", None)
1306 async with self.pipeline(True, shard_hint) as pipe:
1307 while True:
1308 try:
1309 if watches:
1310 await pipe.watch(*watches)
1311 func_value = await func(pipe)
1312 exec_value = await pipe.execute()
1313 return func_value if value_from_callable else exec_value
1314 except WatchError:
1315 if watch_delay is not None and watch_delay > 0:
1316 time.sleep(watch_delay)
1317 continue
1318
1319
1320class ClusterNode:
1321 """
1322 Create a new ClusterNode.
1323
1324 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1325 objects for the (host, port).
1326 """
1327
1328 __slots__ = (
1329 "_connections",
1330 "_free",
1331 "_lock",
1332 "_event_dispatcher",
1333 "connection_class",
1334 "connection_kwargs",
1335 "host",
1336 "max_connections",
1337 "name",
1338 "port",
1339 "response_callbacks",
1340 "server_type",
1341 )
1342
1343 def __init__(
1344 self,
1345 host: str,
1346 port: Union[str, int],
1347 server_type: Optional[str] = None,
1348 *,
1349 max_connections: int = 2**31,
1350 connection_class: Type[Connection] = Connection,
1351 **connection_kwargs: Any,
1352 ) -> None:
1353 if host == "localhost":
1354 host = socket.gethostbyname(host)
1355
1356 connection_kwargs["host"] = host
1357 connection_kwargs["port"] = port
1358 self.host = host
1359 self.port = port
1360 self.name = get_node_name(host, port)
1361 self.server_type = server_type
1362
1363 self.max_connections = max_connections
1364 self.connection_class = connection_class
1365 self.connection_kwargs = connection_kwargs
1366 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1367
1368 self._connections: List[Connection] = []
1369 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1370 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1371 if self._event_dispatcher is None:
1372 self._event_dispatcher = EventDispatcher()
1373
1374 def __repr__(self) -> str:
1375 return (
1376 f"[host={self.host}, port={self.port}, "
1377 f"name={self.name}, server_type={self.server_type}]"
1378 )
1379
1380 def __eq__(self, obj: Any) -> bool:
1381 return isinstance(obj, ClusterNode) and obj.name == self.name
1382
1383 def __hash__(self) -> int:
1384 return hash(self.name)
1385
1386 _DEL_MESSAGE = "Unclosed ClusterNode object"
1387
1388 def __del__(
1389 self,
1390 _warn: Any = warnings.warn,
1391 _grl: Any = asyncio.get_running_loop,
1392 ) -> None:
1393 for connection in self._connections:
1394 if connection.is_connected:
1395 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1396
1397 try:
1398 context = {"client": self, "message": self._DEL_MESSAGE}
1399 _grl().call_exception_handler(context)
1400 except RuntimeError:
1401 pass
1402 break
1403
1404 async def disconnect(self) -> None:
1405 ret = await asyncio.gather(
1406 *(
1407 asyncio.create_task(connection.disconnect())
1408 for connection in self._connections
1409 ),
1410 return_exceptions=True,
1411 )
1412 exc = next((res for res in ret if isinstance(res, Exception)), None)
1413 if exc:
1414 raise exc
1415
1416 def acquire_connection(self) -> Connection:
1417 try:
1418 return self._free.popleft()
1419 except IndexError:
1420 if len(self._connections) < self.max_connections:
1421 # We are configuring the connection pool not to retry
1422 # connections on lower level clients to avoid retrying
1423 # connections to nodes that are not reachable
1424 # and to avoid blocking the connection pool.
1425 # The only error that will have some handling in the lower
1426 # level clients is ConnectionError which will trigger disconnection
1427 # of the socket.
1428 # The retries will be handled on cluster client level
1429 # where we will have proper handling of the cluster topology
1430 retry = Retry(
1431 backoff=NoBackoff(),
1432 retries=0,
1433 supported_errors=(ConnectionError,),
1434 )
1435 connection_kwargs = self.connection_kwargs.copy()
1436 connection_kwargs["retry"] = retry
1437 connection = self.connection_class(**connection_kwargs)
1438 self._connections.append(connection)
1439 return connection
1440
1441 raise MaxConnectionsError()
1442
1443 async def disconnect_if_needed(self, connection: Connection) -> None:
1444 """
1445 Disconnect a connection if it's marked for reconnect.
1446 This implements lazy disconnection to avoid race conditions.
1447 The connection will auto-reconnect on next use.
1448 """
1449 if connection.should_reconnect():
1450 await connection.disconnect()
1451
1452 def release(self, connection: Connection) -> None:
1453 """
1454 Release connection back to free queue.
1455 If the connection is marked for reconnect, it will be disconnected
1456 lazily when next acquired via disconnect_if_needed().
1457 """
1458 self._free.append(connection)
1459
1460 def update_active_connections_for_reconnect(self) -> None:
1461 """
1462 Mark all in-use (active) connections for reconnect.
1463 In-use connections are those in _connections but not currently in _free.
1464 They will be disconnected when released back to the pool.
1465 """
1466 free_set = set(self._free)
1467 for connection in self._connections:
1468 if connection not in free_set:
1469 connection.mark_for_reconnect()
1470
1471 async def disconnect_free_connections(self) -> None:
1472 """
1473 Disconnect all free/idle connections in the pool.
1474 This is useful after topology changes (e.g., failover) to clear
1475 stale connection state like READONLY mode.
1476 The connections remain in the pool and will reconnect on next use.
1477 """
1478 if self._free:
1479 # Take a snapshot to avoid issues if _free changes during await
1480 await asyncio.gather(
1481 *(connection.disconnect() for connection in tuple(self._free)),
1482 return_exceptions=True,
1483 )
1484
1485 async def parse_response(
1486 self, connection: Connection, command: str, **kwargs: Any
1487 ) -> Any:
1488 try:
1489 if NEVER_DECODE in kwargs:
1490 response = await connection.read_response(disable_decoding=True)
1491 kwargs.pop(NEVER_DECODE)
1492 else:
1493 response = await connection.read_response()
1494 except ResponseError:
1495 if EMPTY_RESPONSE in kwargs:
1496 return kwargs[EMPTY_RESPONSE]
1497 raise
1498
1499 if EMPTY_RESPONSE in kwargs:
1500 kwargs.pop(EMPTY_RESPONSE)
1501
1502 # Remove keys entry, it needs only for cache.
1503 kwargs.pop("keys", None)
1504
1505 # Return response
1506 if command in self.response_callbacks:
1507 return self.response_callbacks[command](response, **kwargs)
1508
1509 return response
1510
1511 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1512 # Acquire connection
1513 connection = self.acquire_connection()
1514 # Handle lazy disconnect for connections marked for reconnect
1515 await self.disconnect_if_needed(connection)
1516
1517 # Execute command
1518 await connection.send_packed_command(connection.pack_command(*args), False)
1519
1520 # Read response
1521 try:
1522 return await self.parse_response(connection, args[0], **kwargs)
1523 finally:
1524 await self.disconnect_if_needed(connection)
1525 # Release connection
1526 self._free.append(connection)
1527
1528 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1529 # Acquire connection
1530 connection = self.acquire_connection()
1531 # Handle lazy disconnect for connections marked for reconnect
1532 await self.disconnect_if_needed(connection)
1533
1534 # Execute command
1535 await connection.send_packed_command(
1536 connection.pack_commands(cmd.args for cmd in commands), False
1537 )
1538
1539 # Read responses
1540 ret = False
1541 for cmd in commands:
1542 try:
1543 cmd.result = await self.parse_response(
1544 connection, cmd.args[0], **cmd.kwargs
1545 )
1546 except Exception as e:
1547 cmd.result = e
1548 ret = True
1549
1550 # Release connection
1551 await self.disconnect_if_needed(connection)
1552 self._free.append(connection)
1553
1554 return ret
1555
1556 async def re_auth_callback(self, token: TokenInterface):
1557 tmp_queue = collections.deque()
1558 while self._free:
1559 conn = self._free.popleft()
1560 await conn.retry.call_with_retry(
1561 lambda: conn.send_command(
1562 "AUTH", token.try_get("oid"), token.get_value()
1563 ),
1564 lambda error: self._mock(error),
1565 )
1566 await conn.retry.call_with_retry(
1567 lambda: conn.read_response(), lambda error: self._mock(error)
1568 )
1569 tmp_queue.append(conn)
1570
1571 while tmp_queue:
1572 conn = tmp_queue.popleft()
1573 self._free.append(conn)
1574
1575 async def _mock(self, error: RedisError):
1576 """
1577 Dummy functions, needs to be passed as error callback to retry object.
1578 :param error:
1579 :return:
1580 """
1581 pass
1582
1583
1584class NodesManager:
1585 __slots__ = (
1586 "_dynamic_startup_nodes",
1587 "_event_dispatcher",
1588 "_background_tasks",
1589 "connection_kwargs",
1590 "default_node",
1591 "nodes_cache",
1592 "_epoch",
1593 "read_load_balancer",
1594 "_initialize_lock",
1595 "require_full_coverage",
1596 "slots_cache",
1597 "startup_nodes",
1598 "address_remap",
1599 )
1600
1601 def __init__(
1602 self,
1603 startup_nodes: List["ClusterNode"],
1604 require_full_coverage: bool,
1605 connection_kwargs: Dict[str, Any],
1606 dynamic_startup_nodes: bool = True,
1607 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1608 event_dispatcher: Optional[EventDispatcher] = None,
1609 ) -> None:
1610 self.startup_nodes = {node.name: node for node in startup_nodes}
1611 self.require_full_coverage = require_full_coverage
1612 self.connection_kwargs = connection_kwargs
1613 self.address_remap = address_remap
1614
1615 self.default_node: "ClusterNode" = None
1616 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1617 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1618 self._epoch: int = 0
1619 self.read_load_balancer = LoadBalancer()
1620 self._initialize_lock: asyncio.Lock = asyncio.Lock()
1621
1622 self._background_tasks: Set[asyncio.Task] = set()
1623 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1624 if event_dispatcher is None:
1625 self._event_dispatcher = EventDispatcher()
1626 else:
1627 self._event_dispatcher = event_dispatcher
1628
1629 def get_node(
1630 self,
1631 host: Optional[str] = None,
1632 port: Optional[int] = None,
1633 node_name: Optional[str] = None,
1634 ) -> Optional["ClusterNode"]:
1635 if host and port:
1636 # the user passed host and port
1637 if host == "localhost":
1638 host = socket.gethostbyname(host)
1639 return self.nodes_cache.get(get_node_name(host=host, port=port))
1640 elif node_name:
1641 return self.nodes_cache.get(node_name)
1642 else:
1643 raise DataError(
1644 "get_node requires one of the following: 1. node name 2. host and port"
1645 )
1646
1647 def set_nodes(
1648 self,
1649 old: Dict[str, "ClusterNode"],
1650 new: Dict[str, "ClusterNode"],
1651 remove_old: bool = False,
1652 ) -> None:
1653 if remove_old:
1654 for name in list(old.keys()):
1655 if name not in new:
1656 # Node is removed from cache before disconnect starts,
1657 # so it won't be found in lookups during disconnect
1658 # Mark active connections for reconnect so they get disconnected after current command completes
1659 # and disconnect free connections immediately
1660 # the node is removed from the cache before the connections changes so it won't be used and should be safe
1661 # not to wait for the disconnects
1662 removed_node = old.pop(name)
1663 removed_node.update_active_connections_for_reconnect()
1664 task = asyncio.create_task(
1665 removed_node.disconnect_free_connections()
1666 )
1667 self._background_tasks.add(task)
1668 task.add_done_callback(self._background_tasks.discard)
1669
1670 for name, node in new.items():
1671 if name in old:
1672 # Preserve the existing node but mark connections for reconnect.
1673 # This method is sync so we can't call disconnect_free_connections()
1674 # which is async. Instead, we mark free connections for reconnect
1675 # and they will be lazily disconnected when acquired via
1676 # disconnect_if_needed() to avoid race conditions.
1677 # TODO: Make this method async in the next major release to allow
1678 # immediate disconnection of free connections.
1679 existing_node = old[name]
1680 existing_node.update_active_connections_for_reconnect()
1681 for conn in existing_node._free:
1682 conn.mark_for_reconnect()
1683 continue
1684 # New node is detected and should be added to the pool
1685 old[name] = node
1686
1687 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
1688 """
1689 Move a failing node to the end of startup_nodes and nodes_cache so it's
1690 tried last during reinitialization and when selecting the default node.
1691 If the node is not in the respective list, nothing is done.
1692 """
1693 # Move in startup_nodes
1694 if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
1695 node = self.startup_nodes.pop(node_name)
1696 self.startup_nodes[node_name] = node # Re-insert at end
1697
1698 # Move in nodes_cache - this affects get_nodes_by_server_type ordering
1699 # which is used to select the default_node during initialize()
1700 if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
1701 node = self.nodes_cache.pop(node_name)
1702 self.nodes_cache[node_name] = node # Re-insert at end
1703
1704 def move_slot(self, e: AskError | MovedError):
1705 redirected_node = self.get_node(host=e.host, port=e.port)
1706 if redirected_node:
1707 # The node already exists
1708 if redirected_node.server_type != PRIMARY:
1709 # Update the node's server type
1710 redirected_node.server_type = PRIMARY
1711 else:
1712 # This is a new node, we will add it to the nodes cache
1713 redirected_node = ClusterNode(
1714 e.host, e.port, PRIMARY, **self.connection_kwargs
1715 )
1716 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1717 slot_nodes = self.slots_cache[e.slot_id]
1718 if redirected_node not in slot_nodes:
1719 # The new slot owner is a new server, or a server from a different
1720 # shard. We need to remove all current nodes from the slot's list
1721 # (including replications) and add just the new node.
1722 self.slots_cache[e.slot_id] = [redirected_node]
1723 elif redirected_node is not slot_nodes[0]:
1724 # The MOVED error resulted from a failover, and the new slot owner
1725 # had previously been a replica.
1726 old_primary = slot_nodes[0]
1727 # Update the old primary to be a replica and add it to the end of
1728 # the slot's node list
1729 old_primary.server_type = REPLICA
1730 slot_nodes.append(old_primary)
1731 # Remove the old replica, which is now a primary, from the slot's
1732 # node list
1733 slot_nodes.remove(redirected_node)
1734 # Override the old primary with the new one
1735 slot_nodes[0] = redirected_node
1736 if self.default_node == old_primary:
1737 # Update the default node with the new primary
1738 self.default_node = redirected_node
1739 # else: circular MOVED to current primary -> no-op
1740
1741 def get_node_from_slot(
1742 self,
1743 slot: int,
1744 read_from_replicas: bool = False,
1745 load_balancing_strategy=None,
1746 ) -> "ClusterNode":
1747 if read_from_replicas is True and load_balancing_strategy is None:
1748 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1749
1750 try:
1751 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1752 # get the server index using the strategy defined in load_balancing_strategy
1753 primary_name = self.slots_cache[slot][0].name
1754 node_idx = self.read_load_balancer.get_server_index(
1755 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1756 )
1757 return self.slots_cache[slot][node_idx]
1758 return self.slots_cache[slot][0]
1759 except (IndexError, TypeError):
1760 raise SlotNotCoveredError(
1761 f'Slot "{slot}" not covered by the cluster. '
1762 f'"require_full_coverage={self.require_full_coverage}"'
1763 )
1764
1765 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1766 return [
1767 node
1768 for node in self.nodes_cache.values()
1769 if node.server_type == server_type
1770 ]
1771
1772 async def initialize(self) -> None:
1773 self.read_load_balancer.reset()
1774 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1775 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1776 disagreements = []
1777 startup_nodes_reachable = False
1778 fully_covered = False
1779 exception = None
1780 epoch = self._epoch
1781
1782 async with self._initialize_lock:
1783 if self._epoch != epoch:
1784 # another initialize call has already reinitialized the
1785 # nodes since we started waiting for the lock;
1786 # we don't need to do it again.
1787 return
1788
1789 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1790 # is modified during iteration
1791 for startup_node in tuple(self.startup_nodes.values()):
1792 try:
1793 # Make sure cluster mode is enabled on this node
1794 try:
1795 self._event_dispatcher.dispatch(
1796 AfterAsyncClusterInstantiationEvent(
1797 self.nodes_cache,
1798 self.connection_kwargs.get("credential_provider", None),
1799 )
1800 )
1801 cluster_slots = await startup_node.execute_command(
1802 "CLUSTER SLOTS"
1803 )
1804 except ResponseError:
1805 raise RedisClusterException(
1806 "Cluster mode is not enabled on this node"
1807 )
1808 startup_nodes_reachable = True
1809 except Exception as e:
1810 # Try the next startup node.
1811 # The exception is saved and raised only if we have no more nodes.
1812 exception = e
1813 continue
1814
1815 # CLUSTER SLOTS command results in the following output:
1816 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1817 # where each node contains the following list: [IP, port, node_id]
1818 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1819 # primary node of the first slot section.
1820 # If there's only one server in the cluster, its ``host`` is ''
1821 # Fix it to the host in startup_nodes
1822 if (
1823 len(cluster_slots) == 1
1824 and not cluster_slots[0][2][0]
1825 and len(self.startup_nodes) == 1
1826 ):
1827 cluster_slots[0][2][0] = startup_node.host
1828
1829 for slot in cluster_slots:
1830 for i in range(2, len(slot)):
1831 slot[i] = [str_if_bytes(val) for val in slot[i]]
1832 primary_node = slot[2]
1833 host = primary_node[0]
1834 if host == "":
1835 host = startup_node.host
1836 port = int(primary_node[1])
1837 host, port = self.remap_host_port(host, port)
1838
1839 nodes_for_slot = []
1840
1841 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1842 if not target_node:
1843 target_node = ClusterNode(
1844 host, port, PRIMARY, **self.connection_kwargs
1845 )
1846 # add this node to the nodes cache
1847 tmp_nodes_cache[target_node.name] = target_node
1848 nodes_for_slot.append(target_node)
1849
1850 replica_nodes = slot[3:]
1851 for replica_node in replica_nodes:
1852 host = replica_node[0]
1853 port = replica_node[1]
1854 host, port = self.remap_host_port(host, port)
1855
1856 target_replica_node = tmp_nodes_cache.get(
1857 get_node_name(host, port)
1858 )
1859 if not target_replica_node:
1860 target_replica_node = ClusterNode(
1861 host, port, REPLICA, **self.connection_kwargs
1862 )
1863 # add this node to the nodes cache
1864 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1865 nodes_for_slot.append(target_replica_node)
1866
1867 for i in range(int(slot[0]), int(slot[1]) + 1):
1868 if i not in tmp_slots:
1869 tmp_slots[i] = nodes_for_slot
1870 else:
1871 # Validate that 2 nodes want to use the same slot cache
1872 # setup
1873 tmp_slot = tmp_slots[i][0]
1874 if tmp_slot.name != target_node.name:
1875 disagreements.append(
1876 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1877 )
1878
1879 if len(disagreements) > 5:
1880 raise RedisClusterException(
1881 f"startup_nodes could not agree on a valid "
1882 f"slots cache: {', '.join(disagreements)}"
1883 )
1884
1885 # Validate if all slots are covered or if we should try next startup node
1886 fully_covered = True
1887 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1888 if i not in tmp_slots:
1889 fully_covered = False
1890 break
1891 if fully_covered:
1892 break
1893
1894 if not startup_nodes_reachable:
1895 raise RedisClusterException(
1896 f"Redis Cluster cannot be connected. Please provide at least "
1897 f"one reachable node: {str(exception)}"
1898 ) from exception
1899
1900 # Check if the slots are not fully covered
1901 if not fully_covered and self.require_full_coverage:
1902 # Despite the requirement that the slots be covered, there
1903 # isn't a full coverage
1904 raise RedisClusterException(
1905 f"All slots are not covered after query all startup_nodes. "
1906 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1907 f"covered..."
1908 )
1909
1910 # Set the tmp variables to the real variables
1911 self.slots_cache = tmp_slots
1912 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1913
1914 if self._dynamic_startup_nodes:
1915 # Populate the startup nodes with all discovered nodes
1916 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1917
1918 # Set the default node
1919 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
1920 self._epoch += 1
1921
1922 async def aclose(self, attr: str = "nodes_cache") -> None:
1923 self.default_node = None
1924 await asyncio.gather(
1925 *(
1926 asyncio.create_task(node.disconnect())
1927 for node in getattr(self, attr).values()
1928 )
1929 )
1930
1931 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1932 """
1933 Remap the host and port returned from the cluster to a different
1934 internal value. Useful if the client is not connecting directly
1935 to the cluster.
1936 """
1937 if self.address_remap:
1938 return self.address_remap((host, port))
1939 return host, port
1940
1941
1942class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1943 """
1944 Create a new ClusterPipeline object.
1945
1946 Usage::
1947
1948 result = await (
1949 rc.pipeline()
1950 .set("A", 1)
1951 .get("A")
1952 .hset("K", "F", "V")
1953 .hgetall("K")
1954 .mset_nonatomic({"A": 2, "B": 3})
1955 .get("A")
1956 .get("B")
1957 .delete("A", "B", "K")
1958 .execute()
1959 )
1960 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
1961
1962 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
1963 are split across multiple nodes, you'll get multiple results for them in the array.
1964
1965 Retryable errors:
1966 - :class:`~.ClusterDownError`
1967 - :class:`~.ConnectionError`
1968 - :class:`~.TimeoutError`
1969
1970 Redirection errors:
1971 - :class:`~.TryAgainError`
1972 - :class:`~.MovedError`
1973 - :class:`~.AskError`
1974
1975 :param client:
1976 | Existing :class:`~.RedisCluster` client
1977 """
1978
1979 __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1980
1981 def __init__(
1982 self, client: RedisCluster, transaction: Optional[bool] = None
1983 ) -> None:
1984 self.cluster_client = client
1985 self._transaction = transaction
1986 self._execution_strategy: ExecutionStrategy = (
1987 PipelineStrategy(self)
1988 if not self._transaction
1989 else TransactionStrategy(self)
1990 )
1991
1992 @property
1993 def nodes_manager(self) -> "NodesManager":
1994 """Get the nodes manager from the cluster client."""
1995 return self.cluster_client.nodes_manager
1996
1997 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
1998 """Set a custom response callback on the cluster client."""
1999 self.cluster_client.set_response_callback(command, callback)
2000
2001 async def initialize(self) -> "ClusterPipeline":
2002 await self._execution_strategy.initialize()
2003 return self
2004
2005 async def __aenter__(self) -> "ClusterPipeline":
2006 return await self.initialize()
2007
2008 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
2009 await self.reset()
2010
2011 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
2012 return self.initialize().__await__()
2013
2014 def __bool__(self) -> bool:
2015 "Pipeline instances should always evaluate to True on Python 3+"
2016 return True
2017
2018 def __len__(self) -> int:
2019 return len(self._execution_strategy)
2020
2021 def execute_command(
2022 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2023 ) -> "ClusterPipeline":
2024 """
2025 Append a raw command to the pipeline.
2026
2027 :param args:
2028 | Raw command args
2029 :param kwargs:
2030
2031 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
2032 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
2033 - Rest of the kwargs are passed to the Redis connection
2034 """
2035 return self._execution_strategy.execute_command(*args, **kwargs)
2036
2037 async def execute(
2038 self, raise_on_error: bool = True, allow_redirections: bool = True
2039 ) -> List[Any]:
2040 """
2041 Execute the pipeline.
2042
2043 It will retry the commands as specified by retries specified in :attr:`retry`
2044 & then raise an exception.
2045
2046 :param raise_on_error:
2047 | Raise the first error if there are any errors
2048 :param allow_redirections:
2049 | Whether to retry each failed command individually in case of redirection
2050 errors
2051
2052 :raises RedisClusterException: if target_nodes is not provided & the command
2053 can't be mapped to a slot
2054 """
2055 try:
2056 return await self._execution_strategy.execute(
2057 raise_on_error, allow_redirections
2058 )
2059 finally:
2060 await self.reset()
2061
2062 def _split_command_across_slots(
2063 self, command: str, *keys: KeyT
2064 ) -> "ClusterPipeline":
2065 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
2066 self.execute_command(command, *slot_keys)
2067
2068 return self
2069
2070 async def reset(self):
2071 """
2072 Reset back to empty pipeline.
2073 """
2074 await self._execution_strategy.reset()
2075
2076 def multi(self):
2077 """
2078 Start a transactional block of the pipeline after WATCH commands
2079 are issued. End the transactional block with `execute`.
2080 """
2081 self._execution_strategy.multi()
2082
2083 async def discard(self):
2084 """ """
2085 await self._execution_strategy.discard()
2086
2087 async def watch(self, *names):
2088 """Watches the values at keys ``names``"""
2089 await self._execution_strategy.watch(*names)
2090
2091 async def unwatch(self):
2092 """Unwatches all previously specified keys"""
2093 await self._execution_strategy.unwatch()
2094
2095 async def unlink(self, *names):
2096 await self._execution_strategy.unlink(*names)
2097
2098 def mset_nonatomic(
2099 self, mapping: Mapping[AnyKeyT, EncodableT]
2100 ) -> "ClusterPipeline":
2101 return self._execution_strategy.mset_nonatomic(mapping)
2102
2103
2104for command in PIPELINE_BLOCKED_COMMANDS:
2105 command = command.replace(" ", "_").lower()
2106 if command == "mset_nonatomic":
2107 continue
2108
2109 setattr(ClusterPipeline, command, block_pipeline_command(command))
2110
2111
2112class PipelineCommand:
2113 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
2114 self.args = args
2115 self.kwargs = kwargs
2116 self.position = position
2117 self.result: Union[Any, Exception] = None
2118 self.command_policies: Optional[CommandPolicies] = None
2119
2120 def __repr__(self) -> str:
2121 return f"[{self.position}] {self.args} ({self.kwargs})"
2122
2123
2124class ExecutionStrategy(ABC):
2125 @abstractmethod
2126 async def initialize(self) -> "ClusterPipeline":
2127 """
2128 Initialize the execution strategy.
2129
2130 See ClusterPipeline.initialize()
2131 """
2132 pass
2133
2134 @abstractmethod
2135 def execute_command(
2136 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2137 ) -> "ClusterPipeline":
2138 """
2139 Append a raw command to the pipeline.
2140
2141 See ClusterPipeline.execute_command()
2142 """
2143 pass
2144
2145 @abstractmethod
2146 async def execute(
2147 self, raise_on_error: bool = True, allow_redirections: bool = True
2148 ) -> List[Any]:
2149 """
2150 Execute the pipeline.
2151
2152 It will retry the commands as specified by retries specified in :attr:`retry`
2153 & then raise an exception.
2154
2155 See ClusterPipeline.execute()
2156 """
2157 pass
2158
2159 @abstractmethod
2160 def mset_nonatomic(
2161 self, mapping: Mapping[AnyKeyT, EncodableT]
2162 ) -> "ClusterPipeline":
2163 """
2164 Executes multiple MSET commands according to the provided slot/pairs mapping.
2165
2166 See ClusterPipeline.mset_nonatomic()
2167 """
2168 pass
2169
2170 @abstractmethod
2171 async def reset(self):
2172 """
2173 Resets current execution strategy.
2174
2175 See: ClusterPipeline.reset()
2176 """
2177 pass
2178
2179 @abstractmethod
2180 def multi(self):
2181 """
2182 Starts transactional context.
2183
2184 See: ClusterPipeline.multi()
2185 """
2186 pass
2187
2188 @abstractmethod
2189 async def watch(self, *names):
2190 """
2191 Watch given keys.
2192
2193 See: ClusterPipeline.watch()
2194 """
2195 pass
2196
2197 @abstractmethod
2198 async def unwatch(self):
2199 """
2200 Unwatches all previously specified keys
2201
2202 See: ClusterPipeline.unwatch()
2203 """
2204 pass
2205
2206 @abstractmethod
2207 async def discard(self):
2208 pass
2209
2210 @abstractmethod
2211 async def unlink(self, *names):
2212 """
2213 "Unlink a key specified by ``names``"
2214
2215 See: ClusterPipeline.unlink()
2216 """
2217 pass
2218
2219 @abstractmethod
2220 def __len__(self) -> int:
2221 pass
2222
2223
2224class AbstractStrategy(ExecutionStrategy):
2225 def __init__(self, pipe: ClusterPipeline) -> None:
2226 self._pipe: ClusterPipeline = pipe
2227 self._command_queue: List["PipelineCommand"] = []
2228
2229 async def initialize(self) -> "ClusterPipeline":
2230 if self._pipe.cluster_client._initialize:
2231 await self._pipe.cluster_client.initialize()
2232 self._command_queue = []
2233 return self._pipe
2234
2235 def execute_command(
2236 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2237 ) -> "ClusterPipeline":
2238 self._command_queue.append(
2239 PipelineCommand(len(self._command_queue), *args, **kwargs)
2240 )
2241 return self._pipe
2242
2243 def _annotate_exception(self, exception, number, command):
2244 """
2245 Provides extra context to the exception prior to it being handled
2246 """
2247 cmd = " ".join(map(safe_str, command))
2248 msg = (
2249 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
2250 f"caused error: {exception.args[0]}"
2251 )
2252 exception.args = (msg,) + exception.args[1:]
2253
2254 @abstractmethod
2255 def mset_nonatomic(
2256 self, mapping: Mapping[AnyKeyT, EncodableT]
2257 ) -> "ClusterPipeline":
2258 pass
2259
2260 @abstractmethod
2261 async def execute(
2262 self, raise_on_error: bool = True, allow_redirections: bool = True
2263 ) -> List[Any]:
2264 pass
2265
2266 @abstractmethod
2267 async def reset(self):
2268 pass
2269
2270 @abstractmethod
2271 def multi(self):
2272 pass
2273
2274 @abstractmethod
2275 async def watch(self, *names):
2276 pass
2277
2278 @abstractmethod
2279 async def unwatch(self):
2280 pass
2281
2282 @abstractmethod
2283 async def discard(self):
2284 pass
2285
2286 @abstractmethod
2287 async def unlink(self, *names):
2288 pass
2289
2290 def __len__(self) -> int:
2291 return len(self._command_queue)
2292
2293
2294class PipelineStrategy(AbstractStrategy):
2295 def __init__(self, pipe: ClusterPipeline) -> None:
2296 super().__init__(pipe)
2297
2298 def mset_nonatomic(
2299 self, mapping: Mapping[AnyKeyT, EncodableT]
2300 ) -> "ClusterPipeline":
2301 encoder = self._pipe.cluster_client.encoder
2302
2303 slots_pairs = {}
2304 for pair in mapping.items():
2305 slot = key_slot(encoder.encode(pair[0]))
2306 slots_pairs.setdefault(slot, []).extend(pair)
2307
2308 for pairs in slots_pairs.values():
2309 self.execute_command("MSET", *pairs)
2310
2311 return self._pipe
2312
2313 async def execute(
2314 self, raise_on_error: bool = True, allow_redirections: bool = True
2315 ) -> List[Any]:
2316 if not self._command_queue:
2317 return []
2318
2319 try:
2320 retry_attempts = self._pipe.cluster_client.retry.get_retries()
2321 while True:
2322 try:
2323 if self._pipe.cluster_client._initialize:
2324 await self._pipe.cluster_client.initialize()
2325 return await self._execute(
2326 self._pipe.cluster_client,
2327 self._command_queue,
2328 raise_on_error=raise_on_error,
2329 allow_redirections=allow_redirections,
2330 )
2331
2332 except RedisCluster.ERRORS_ALLOW_RETRY as e:
2333 if retry_attempts > 0:
2334 # Try again with the new cluster setup. All other errors
2335 # should be raised.
2336 retry_attempts -= 1
2337 await self._pipe.cluster_client.aclose()
2338 await asyncio.sleep(0.25)
2339 else:
2340 # All other errors should be raised.
2341 raise e
2342 finally:
2343 await self.reset()
2344
2345 async def _execute(
2346 self,
2347 client: "RedisCluster",
2348 stack: List["PipelineCommand"],
2349 raise_on_error: bool = True,
2350 allow_redirections: bool = True,
2351 ) -> List[Any]:
2352 todo = [
2353 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
2354 ]
2355
2356 nodes = {}
2357 for cmd in todo:
2358 passed_targets = cmd.kwargs.pop("target_nodes", None)
2359 command_policies = await client._policy_resolver.resolve(
2360 cmd.args[0].lower()
2361 )
2362
2363 if passed_targets and not client._is_node_flag(passed_targets):
2364 target_nodes = client._parse_target_nodes(passed_targets)
2365
2366 if not command_policies:
2367 command_policies = CommandPolicies()
2368 else:
2369 if not command_policies:
2370 command_flag = client.command_flags.get(cmd.args[0])
2371 if not command_flag:
2372 # Fallback to default policy
2373 if not client.get_default_node():
2374 slot = None
2375 else:
2376 slot = await client._determine_slot(*cmd.args)
2377 if slot is None:
2378 command_policies = CommandPolicies()
2379 else:
2380 command_policies = CommandPolicies(
2381 request_policy=RequestPolicy.DEFAULT_KEYED,
2382 response_policy=ResponsePolicy.DEFAULT_KEYED,
2383 )
2384 else:
2385 if command_flag in client._command_flags_mapping:
2386 command_policies = CommandPolicies(
2387 request_policy=client._command_flags_mapping[
2388 command_flag
2389 ]
2390 )
2391 else:
2392 command_policies = CommandPolicies()
2393
2394 target_nodes = await client._determine_nodes(
2395 *cmd.args,
2396 request_policy=command_policies.request_policy,
2397 node_flag=passed_targets,
2398 )
2399 if not target_nodes:
2400 raise RedisClusterException(
2401 f"No targets were found to execute {cmd.args} command on"
2402 )
2403 cmd.command_policies = command_policies
2404 if len(target_nodes) > 1:
2405 raise RedisClusterException(f"Too many targets for command {cmd.args}")
2406 node = target_nodes[0]
2407 if node.name not in nodes:
2408 nodes[node.name] = (node, [])
2409 nodes[node.name][1].append(cmd)
2410
2411 # Start timing for observability
2412 start_time = time.monotonic()
2413
2414 errors = await asyncio.gather(
2415 *(
2416 asyncio.create_task(node[0].execute_pipeline(node[1]))
2417 for node in nodes.values()
2418 )
2419 )
2420
2421 # Record operation duration for each node
2422 for node_name, (node, commands) in nodes.items():
2423 # Find the first error in this node's commands, if any
2424 node_error = None
2425 for cmd in commands:
2426 if isinstance(cmd.result, Exception):
2427 node_error = cmd.result
2428 break
2429
2430 db = node.connection_kwargs.get("db", 0)
2431 await record_operation_duration(
2432 command_name="PIPELINE",
2433 duration_seconds=time.monotonic() - start_time,
2434 server_address=node.host,
2435 server_port=node.port,
2436 db_namespace=str(db) if db is not None else None,
2437 error=node_error,
2438 )
2439
2440 if any(errors):
2441 if allow_redirections:
2442 # send each errored command individually
2443 for cmd in todo:
2444 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2445 try:
2446 cmd.result = client._policies_callback_mapping[
2447 cmd.command_policies.response_policy
2448 ](await client.execute_command(*cmd.args, **cmd.kwargs))
2449 except Exception as e:
2450 cmd.result = e
2451
2452 if raise_on_error:
2453 for cmd in todo:
2454 result = cmd.result
2455 if isinstance(result, Exception):
2456 command = " ".join(map(safe_str, cmd.args))
2457 msg = (
2458 f"Command # {cmd.position + 1} "
2459 f"({truncate_text(command)}) "
2460 f"of pipeline caused error: {result.args}"
2461 )
2462 result.args = (msg,) + result.args[1:]
2463 raise result
2464
2465 default_cluster_node = client.get_default_node()
2466
2467 # Check whether the default node was used. In some cases,
2468 # 'client.get_default_node()' may return None. The check below
2469 # prevents a potential AttributeError.
2470 if default_cluster_node is not None:
2471 default_node = nodes.get(default_cluster_node.name)
2472 if default_node is not None:
2473 # This pipeline execution used the default node, check if we need
2474 # to replace it.
2475 # Note: when the error is raised we'll reset the default node in the
2476 # caller function.
2477 for cmd in default_node[1]:
2478 # Check if it has a command that failed with a relevant
2479 # exception
2480 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2481 client.replace_default_node()
2482 break
2483
2484 return [cmd.result for cmd in stack]
2485
2486 async def reset(self):
2487 """
2488 Reset back to empty pipeline.
2489 """
2490 self._command_queue = []
2491
2492 def multi(self):
2493 raise RedisClusterException(
2494 "method multi() is not supported outside of transactional context"
2495 )
2496
2497 async def watch(self, *names):
2498 raise RedisClusterException(
2499 "method watch() is not supported outside of transactional context"
2500 )
2501
2502 async def unwatch(self):
2503 raise RedisClusterException(
2504 "method unwatch() is not supported outside of transactional context"
2505 )
2506
2507 async def discard(self):
2508 raise RedisClusterException(
2509 "method discard() is not supported outside of transactional context"
2510 )
2511
2512 async def unlink(self, *names):
2513 if len(names) != 1:
2514 raise RedisClusterException(
2515 "unlinking multiple keys is not implemented in pipeline command"
2516 )
2517
2518 return self.execute_command("UNLINK", names[0])
2519
2520
2521class TransactionStrategy(AbstractStrategy):
2522 NO_SLOTS_COMMANDS = {"UNWATCH"}
2523 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2524 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2525 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2526 CONNECTION_ERRORS = (
2527 ConnectionError,
2528 OSError,
2529 ClusterDownError,
2530 SlotNotCoveredError,
2531 )
2532
2533 def __init__(self, pipe: ClusterPipeline) -> None:
2534 super().__init__(pipe)
2535 self._explicit_transaction = False
2536 self._watching = False
2537 self._pipeline_slots: Set[int] = set()
2538 self._transaction_node: Optional[ClusterNode] = None
2539 self._transaction_connection: Optional[Connection] = None
2540 self._executing = False
2541 self._retry = copy(self._pipe.cluster_client.retry)
2542 self._retry.update_supported_errors(
2543 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2544 )
2545
2546 def _get_client_and_connection_for_transaction(
2547 self,
2548 ) -> Tuple[ClusterNode, Connection]:
2549 """
2550 Find a connection for a pipeline transaction.
2551
2552 For running an atomic transaction, watch keys ensure that contents have not been
2553 altered as long as the watch commands for those keys were sent over the same
2554 connection. So once we start watching a key, we fetch a connection to the
2555 node that owns that slot and reuse it.
2556 """
2557 if not self._pipeline_slots:
2558 raise RedisClusterException(
2559 "At least a command with a key is needed to identify a node"
2560 )
2561
2562 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2563 list(self._pipeline_slots)[0], False
2564 )
2565 self._transaction_node = node
2566
2567 if not self._transaction_connection:
2568 connection: Connection = self._transaction_node.acquire_connection()
2569 self._transaction_connection = connection
2570
2571 return self._transaction_node, self._transaction_connection
2572
2573 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2574 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2575 response = None
2576 error = None
2577
2578 def runner():
2579 nonlocal response
2580 nonlocal error
2581 try:
2582 response = asyncio.run(self._execute_command(*args, **kwargs))
2583 except Exception as e:
2584 error = e
2585
2586 thread = threading.Thread(target=runner)
2587 thread.start()
2588 thread.join()
2589
2590 if error:
2591 raise error
2592
2593 return response
2594
2595 async def _execute_command(
2596 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2597 ) -> Any:
2598 if self._pipe.cluster_client._initialize:
2599 await self._pipe.cluster_client.initialize()
2600
2601 slot_number: Optional[int] = None
2602 if args[0] not in self.NO_SLOTS_COMMANDS:
2603 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2604
2605 if (
2606 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2607 ) and not self._explicit_transaction:
2608 if args[0] == "WATCH":
2609 self._validate_watch()
2610
2611 if slot_number is not None:
2612 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2613 raise CrossSlotTransactionError(
2614 "Cannot watch or send commands on different slots"
2615 )
2616
2617 self._pipeline_slots.add(slot_number)
2618 elif args[0] not in self.NO_SLOTS_COMMANDS:
2619 raise RedisClusterException(
2620 f"Cannot identify slot number for command: {args[0]},"
2621 "it cannot be triggered in a transaction"
2622 )
2623
2624 return self._immediate_execute_command(*args, **kwargs)
2625 else:
2626 if slot_number is not None:
2627 self._pipeline_slots.add(slot_number)
2628
2629 return super().execute_command(*args, **kwargs)
2630
2631 def _validate_watch(self):
2632 if self._explicit_transaction:
2633 raise RedisError("Cannot issue a WATCH after a MULTI")
2634
2635 self._watching = True
2636
2637 async def _immediate_execute_command(self, *args, **options):
2638 return await self._retry.call_with_retry(
2639 lambda: self._get_connection_and_send_command(*args, **options),
2640 self._reinitialize_on_error,
2641 with_failure_count=True,
2642 )
2643
2644 async def _get_connection_and_send_command(self, *args, **options):
2645 redis_node, connection = self._get_client_and_connection_for_transaction()
2646 # Only disconnect if not watching - disconnecting would lose WATCH state
2647 if not self._watching:
2648 await redis_node.disconnect_if_needed(connection)
2649
2650 # Start timing for observability
2651 start_time = time.monotonic()
2652
2653 try:
2654 response = await self._send_command_parse_response(
2655 connection, redis_node, args[0], *args, **options
2656 )
2657
2658 await record_operation_duration(
2659 command_name=args[0],
2660 duration_seconds=time.monotonic() - start_time,
2661 server_address=connection.host,
2662 server_port=connection.port,
2663 db_namespace=str(connection.db),
2664 )
2665
2666 return response
2667 except Exception as e:
2668 e.connection = connection
2669 await record_operation_duration(
2670 command_name=args[0],
2671 duration_seconds=time.monotonic() - start_time,
2672 server_address=connection.host,
2673 server_port=connection.port,
2674 db_namespace=str(connection.db),
2675 error=e,
2676 )
2677 raise
2678
2679 async def _send_command_parse_response(
2680 self,
2681 connection: Connection,
2682 redis_node: ClusterNode,
2683 command_name,
2684 *args,
2685 **options,
2686 ):
2687 """
2688 Send a command and parse the response
2689 """
2690
2691 await connection.send_command(*args)
2692 output = await redis_node.parse_response(connection, command_name, **options)
2693
2694 if command_name in self.UNWATCH_COMMANDS:
2695 self._watching = False
2696 return output
2697
2698 async def _reinitialize_on_error(self, error, failure_count):
2699 if hasattr(error, "connection"):
2700 await record_error_count(
2701 server_address=error.connection.host,
2702 server_port=error.connection.port,
2703 network_peer_address=error.connection.host,
2704 network_peer_port=error.connection.port,
2705 error_type=error,
2706 retry_attempts=failure_count,
2707 is_internal=True,
2708 )
2709
2710 if self._watching:
2711 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2712 raise WatchError("Slot rebalancing occurred while watching keys")
2713
2714 if (
2715 type(error) in self.SLOT_REDIRECT_ERRORS
2716 or type(error) in self.CONNECTION_ERRORS
2717 ):
2718 if self._transaction_connection and self._transaction_node:
2719 # Disconnect and release back to pool
2720 await self._transaction_connection.disconnect()
2721 self._transaction_node.release(self._transaction_connection)
2722 self._transaction_connection = None
2723
2724 self._pipe.cluster_client.reinitialize_counter += 1
2725 if (
2726 self._pipe.cluster_client.reinitialize_steps
2727 and self._pipe.cluster_client.reinitialize_counter
2728 % self._pipe.cluster_client.reinitialize_steps
2729 == 0
2730 ):
2731 await self._pipe.cluster_client.nodes_manager.initialize()
2732 self.reinitialize_counter = 0
2733 else:
2734 if isinstance(error, AskError):
2735 self._pipe.cluster_client.nodes_manager.move_slot(error)
2736
2737 self._executing = False
2738
2739 async def _raise_first_error(self, responses, stack, start_time):
2740 """
2741 Raise the first exception on the stack
2742 """
2743 for r, cmd in zip(responses, stack):
2744 if isinstance(r, Exception):
2745 self._annotate_exception(r, cmd.position + 1, cmd.args)
2746
2747 await record_operation_duration(
2748 command_name="TRANSACTION",
2749 duration_seconds=time.monotonic() - start_time,
2750 server_address=self._transaction_connection.host,
2751 server_port=self._transaction_connection.port,
2752 db_namespace=str(self._transaction_connection.db),
2753 error=r,
2754 )
2755
2756 raise r
2757
2758 def mset_nonatomic(
2759 self, mapping: Mapping[AnyKeyT, EncodableT]
2760 ) -> "ClusterPipeline":
2761 raise NotImplementedError("Method is not supported in transactional context.")
2762
2763 async def execute(
2764 self, raise_on_error: bool = True, allow_redirections: bool = True
2765 ) -> List[Any]:
2766 stack = self._command_queue
2767 if not stack and (not self._watching or not self._pipeline_slots):
2768 return []
2769
2770 return await self._execute_transaction_with_retries(stack, raise_on_error)
2771
2772 async def _execute_transaction_with_retries(
2773 self, stack: List["PipelineCommand"], raise_on_error: bool
2774 ):
2775 return await self._retry.call_with_retry(
2776 lambda: self._execute_transaction(stack, raise_on_error),
2777 lambda error, failure_count: self._reinitialize_on_error(
2778 error, failure_count
2779 ),
2780 with_failure_count=True,
2781 )
2782
2783 async def _execute_transaction(
2784 self, stack: List["PipelineCommand"], raise_on_error: bool
2785 ):
2786 if len(self._pipeline_slots) > 1:
2787 raise CrossSlotTransactionError(
2788 "All keys involved in a cluster transaction must map to the same slot"
2789 )
2790
2791 self._executing = True
2792
2793 redis_node, connection = self._get_client_and_connection_for_transaction()
2794 # Only disconnect if not watching - disconnecting would lose WATCH state
2795 if not self._watching:
2796 await redis_node.disconnect_if_needed(connection)
2797
2798 stack = chain(
2799 [PipelineCommand(0, "MULTI")],
2800 stack,
2801 [PipelineCommand(0, "EXEC")],
2802 )
2803 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2804 packed_commands = connection.pack_commands(commands)
2805
2806 # Start timing for observability
2807 start_time = time.monotonic()
2808
2809 await connection.send_packed_command(packed_commands)
2810 errors = []
2811
2812 # parse off the response for MULTI
2813 # NOTE: we need to handle ResponseErrors here and continue
2814 # so that we read all the additional command messages from
2815 # the socket
2816 try:
2817 await redis_node.parse_response(connection, "MULTI")
2818 except ResponseError as e:
2819 self._annotate_exception(e, 0, "MULTI")
2820 errors.append(e)
2821 except self.CONNECTION_ERRORS as cluster_error:
2822 self._annotate_exception(cluster_error, 0, "MULTI")
2823 cluster_error.connection = connection
2824 raise
2825
2826 # and all the other commands
2827 for i, command in enumerate(self._command_queue):
2828 if EMPTY_RESPONSE in command.kwargs:
2829 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2830 else:
2831 try:
2832 _ = await redis_node.parse_response(connection, "_")
2833 except self.SLOT_REDIRECT_ERRORS as slot_error:
2834 self._annotate_exception(slot_error, i + 1, command.args)
2835 errors.append(slot_error)
2836 except self.CONNECTION_ERRORS as cluster_error:
2837 self._annotate_exception(cluster_error, i + 1, command.args)
2838 cluster_error.connection = connection
2839 raise
2840 except ResponseError as e:
2841 self._annotate_exception(e, i + 1, command.args)
2842 errors.append(e)
2843
2844 response = None
2845 # parse the EXEC.
2846 try:
2847 response = await redis_node.parse_response(connection, "EXEC")
2848 except ExecAbortError:
2849 if errors:
2850 raise errors[0]
2851 raise
2852
2853 self._executing = False
2854
2855 # EXEC clears any watched keys
2856 self._watching = False
2857
2858 if response is None:
2859 raise WatchError("Watched variable changed.")
2860
2861 # put any parse errors into the response
2862 for i, e in errors:
2863 response.insert(i, e)
2864
2865 if len(response) != len(self._command_queue):
2866 raise InvalidPipelineStack(
2867 "Unexpected response length for cluster pipeline EXEC."
2868 " Command stack was {} but response had length {}".format(
2869 [c.args[0] for c in self._command_queue], len(response)
2870 )
2871 )
2872
2873 # find any errors in the response and raise if necessary
2874 if raise_on_error or len(errors) > 0:
2875 await self._raise_first_error(
2876 response,
2877 self._command_queue,
2878 start_time,
2879 )
2880
2881 # We have to run response callbacks manually
2882 data = []
2883 for r, cmd in zip(response, self._command_queue):
2884 if not isinstance(r, Exception):
2885 command_name = cmd.args[0]
2886 if command_name in self._pipe.cluster_client.response_callbacks:
2887 r = self._pipe.cluster_client.response_callbacks[command_name](
2888 r, **cmd.kwargs
2889 )
2890 data.append(r)
2891
2892 await record_operation_duration(
2893 command_name="TRANSACTION",
2894 duration_seconds=time.monotonic() - start_time,
2895 server_address=connection.host,
2896 server_port=connection.port,
2897 db_namespace=str(connection.db),
2898 )
2899
2900 return data
2901
2902 async def reset(self):
2903 self._command_queue = []
2904
2905 # make sure to reset the connection state in the event that we were
2906 # watching something
2907 if self._transaction_connection:
2908 try:
2909 if self._watching:
2910 # call this manually since our unwatch or
2911 # immediate_execute_command methods can call reset()
2912 await self._transaction_connection.send_command("UNWATCH")
2913 await self._transaction_connection.read_response()
2914 # we can safely return the connection to the pool here since we're
2915 # sure we're no longer WATCHing anything
2916 self._transaction_node.release(self._transaction_connection)
2917 self._transaction_connection = None
2918 except self.CONNECTION_ERRORS:
2919 # disconnect will also remove any previous WATCHes
2920 if self._transaction_connection and self._transaction_node:
2921 await self._transaction_connection.disconnect()
2922 self._transaction_node.release(self._transaction_connection)
2923 self._transaction_connection = None
2924
2925 # clean up the other instance attributes
2926 self._transaction_node = None
2927 self._watching = False
2928 self._explicit_transaction = False
2929 self._pipeline_slots = set()
2930 self._executing = False
2931
2932 def multi(self):
2933 if self._explicit_transaction:
2934 raise RedisError("Cannot issue nested calls to MULTI")
2935 if self._command_queue:
2936 raise RedisError(
2937 "Commands without an initial WATCH have already been issued"
2938 )
2939 self._explicit_transaction = True
2940
2941 async def watch(self, *names):
2942 if self._explicit_transaction:
2943 raise RedisError("Cannot issue a WATCH after a MULTI")
2944
2945 return await self.execute_command("WATCH", *names)
2946
2947 async def unwatch(self):
2948 if self._watching:
2949 return await self.execute_command("UNWATCH")
2950
2951 return True
2952
2953 async def discard(self):
2954 await self.reset()
2955
2956 async def unlink(self, *names):
2957 return self.execute_command("UNLINK", *names)