1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 Any,
13 Callable,
14 Coroutine,
15 Deque,
16 Dict,
17 Generator,
18 List,
19 Mapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26)
27
28from redis._parsers import AsyncCommandsParser, Encoder
29from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
30from redis._parsers.helpers import (
31 _RedisCallbacks,
32 _RedisCallbacksRESP2,
33 _RedisCallbacksRESP3,
34)
35from redis.asyncio.client import ResponseCallbackT
36from redis.asyncio.connection import Connection, SSLConnection, parse_url
37from redis.asyncio.lock import Lock
38from redis.asyncio.retry import Retry
39from redis.auth.token import TokenInterface
40from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
41from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
42from redis.cluster import (
43 PIPELINE_BLOCKED_COMMANDS,
44 PRIMARY,
45 REPLICA,
46 SLOT_ID,
47 AbstractRedisCluster,
48 LoadBalancer,
49 LoadBalancingStrategy,
50 block_pipeline_command,
51 get_node_name,
52 parse_cluster_slots,
53)
54from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
55from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
56from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
57from redis.credentials import CredentialProvider
58from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
59from redis.exceptions import (
60 AskError,
61 BusyLoadingError,
62 ClusterDownError,
63 ClusterError,
64 ConnectionError,
65 CrossSlotTransactionError,
66 DataError,
67 ExecAbortError,
68 InvalidPipelineStack,
69 MaxConnectionsError,
70 MovedError,
71 RedisClusterException,
72 RedisError,
73 ResponseError,
74 SlotNotCoveredError,
75 TimeoutError,
76 TryAgainError,
77 WatchError,
78)
79from redis.typing import AnyKeyT, EncodableT, KeyT
80from redis.utils import (
81 SSL_AVAILABLE,
82 deprecated_args,
83 deprecated_function,
84 get_lib_version,
85 safe_str,
86 str_if_bytes,
87 truncate_text,
88)
89
90if SSL_AVAILABLE:
91 from ssl import TLSVersion, VerifyFlags, VerifyMode
92else:
93 TLSVersion = None
94 VerifyMode = None
95 VerifyFlags = None
96
97TargetNodesT = TypeVar(
98 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
99)
100
101
102class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
103 """
104 Create a new RedisCluster client.
105
106 Pass one of parameters:
107
108 - `host` & `port`
109 - `startup_nodes`
110
111 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
112 | Use ``await`` :meth:`close` to disconnect connections & close client.
113
114 Many commands support the target_nodes kwarg. It can be one of the
115 :attr:`NODE_FLAGS`:
116
117 - :attr:`PRIMARIES`
118 - :attr:`REPLICAS`
119 - :attr:`ALL_NODES`
120 - :attr:`RANDOM`
121 - :attr:`DEFAULT_NODE`
122
123 Note: This client is not thread/process/fork safe.
124
125 :param host:
126 | Can be used to point to a startup node
127 :param port:
128 | Port used if **host** is provided
129 :param startup_nodes:
130 | :class:`~.ClusterNode` to used as a startup node
131 :param require_full_coverage:
132 | When set to ``False``: the client will not require a full coverage of
133 the slots. However, if not all slots are covered, and at least one node
134 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
135 a :class:`~.ClusterDownError` for some key-based commands.
136 | When set to ``True``: all slots must be covered to construct the cluster
137 client. If not all slots are covered, :class:`~.RedisClusterException` will be
138 thrown.
139 | See:
140 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
141 :param read_from_replicas:
142 | @deprecated - please use load_balancing_strategy instead
143 | Enable read from replicas in READONLY mode.
144 When set to true, read commands will be assigned between the primary and
145 its replications in a Round-Robin manner.
146 The data read from replicas is eventually consistent with the data in primary nodes.
147 :param load_balancing_strategy:
148 | Enable read from replicas in READONLY mode and defines the load balancing
149 strategy that will be used for cluster node selection.
150 The data read from replicas is eventually consistent with the data in primary nodes.
151 :param dynamic_startup_nodes:
152 | Set the RedisCluster's startup nodes to all the discovered nodes.
153 If true (default value), the cluster's discovered nodes will be used to
154 determine the cluster nodes-slots mapping in the next topology refresh.
155 It will remove the initial passed startup nodes if their endpoints aren't
156 listed in the CLUSTER SLOTS output.
157 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
158 specific IP addresses, it is best to set it to false.
159 :param reinitialize_steps:
160 | Specifies the number of MOVED errors that need to occur before reinitializing
161 the whole cluster topology. If a MOVED error occurs and the cluster does not
162 need to be reinitialized on this current error handling, only the MOVED slot
163 will be patched with the redirected node.
164 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
165 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
166 0.
167 :param cluster_error_retry_attempts:
168 | @deprecated - Please configure the 'retry' object instead
169 In case 'retry' object is set - this argument is ignored!
170
171 Number of times to retry before raising an error when :class:`~.TimeoutError`,
172 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
173 or :class:`~.ClusterDownError` are encountered
174 :param retry:
175 | A retry object that defines the retry strategy and the number of
176 retries for the cluster client.
177 In current implementation for the cluster client (starting form redis-py version 6.0.0)
178 the retry object is not yet fully utilized, instead it is used just to determine
179 the number of retries for the cluster client.
180 In the future releases the retry object will be used to handle the cluster client retries!
181 :param max_connections:
182 | Maximum number of connections per node. If there are no free connections & the
183 maximum number of connections are already created, a
184 :class:`~.MaxConnectionsError` is raised.
185 :param address_remap:
186 | An optional callable which, when provided with an internal network
187 address of a node, e.g. a `(host, port)` tuple, will return the address
188 where the node is reachable. This can be used to map the addresses at
189 which the nodes _think_ they are, to addresses at which a client may
190 reach them, such as when they sit behind a proxy.
191
192 | Rest of the arguments will be passed to the
193 :class:`~redis.asyncio.connection.Connection` instances when created
194
195 :raises RedisClusterException:
196 if any arguments are invalid or unknown. Eg:
197
198 - `db` != 0 or None
199 - `path` argument for unix socket connection
200 - none of the `host`/`port` & `startup_nodes` were provided
201
202 """
203
204 @classmethod
205 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
206 """
207 Return a Redis client object configured from the given URL.
208
209 For example::
210
211 redis://[[username]:[password]]@localhost:6379/0
212 rediss://[[username]:[password]]@localhost:6379/0
213
214 Three URL schemes are supported:
215
216 - `redis://` creates a TCP socket connection. See more at:
217 <https://www.iana.org/assignments/uri-schemes/prov/redis>
218 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
219 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
220
221 The username, password, hostname, path and all querystring values are passed
222 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
223 with their corresponding characters.
224
225 All querystring options are cast to their appropriate Python types. Boolean
226 arguments can be specified with string values "True"/"False" or "Yes"/"No".
227 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
228 parsed, the querystring arguments and keyword arguments are passed to
229 :class:`~redis.asyncio.connection.Connection` when created.
230 In the case of conflicting arguments, querystring arguments are used.
231 """
232 kwargs.update(parse_url(url))
233 if kwargs.pop("connection_class", None) is SSLConnection:
234 kwargs["ssl"] = True
235 return cls(**kwargs)
236
237 __slots__ = (
238 "_initialize",
239 "_lock",
240 "retry",
241 "command_flags",
242 "commands_parser",
243 "connection_kwargs",
244 "encoder",
245 "node_flags",
246 "nodes_manager",
247 "read_from_replicas",
248 "reinitialize_counter",
249 "reinitialize_steps",
250 "response_callbacks",
251 "result_callbacks",
252 )
253
254 @deprecated_args(
255 args_to_warn=["read_from_replicas"],
256 reason="Please configure the 'load_balancing_strategy' instead",
257 version="5.3.0",
258 )
259 @deprecated_args(
260 args_to_warn=[
261 "cluster_error_retry_attempts",
262 ],
263 reason="Please configure the 'retry' object instead",
264 version="6.0.0",
265 )
266 def __init__(
267 self,
268 host: Optional[str] = None,
269 port: Union[str, int] = 6379,
270 # Cluster related kwargs
271 startup_nodes: Optional[List["ClusterNode"]] = None,
272 require_full_coverage: bool = True,
273 read_from_replicas: bool = False,
274 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
275 dynamic_startup_nodes: bool = True,
276 reinitialize_steps: int = 5,
277 cluster_error_retry_attempts: int = 3,
278 max_connections: int = 2**31,
279 retry: Optional["Retry"] = None,
280 retry_on_error: Optional[List[Type[Exception]]] = None,
281 # Client related kwargs
282 db: Union[str, int] = 0,
283 path: Optional[str] = None,
284 credential_provider: Optional[CredentialProvider] = None,
285 username: Optional[str] = None,
286 password: Optional[str] = None,
287 client_name: Optional[str] = None,
288 lib_name: Optional[str] = "redis-py",
289 lib_version: Optional[str] = get_lib_version(),
290 # Encoding related kwargs
291 encoding: str = "utf-8",
292 encoding_errors: str = "strict",
293 decode_responses: bool = False,
294 # Connection related kwargs
295 health_check_interval: float = 0,
296 socket_connect_timeout: Optional[float] = None,
297 socket_keepalive: bool = False,
298 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
299 socket_timeout: Optional[float] = None,
300 # SSL related kwargs
301 ssl: bool = False,
302 ssl_ca_certs: Optional[str] = None,
303 ssl_ca_data: Optional[str] = None,
304 ssl_cert_reqs: Union[str, VerifyMode] = "required",
305 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
306 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
307 ssl_certfile: Optional[str] = None,
308 ssl_check_hostname: bool = True,
309 ssl_keyfile: Optional[str] = None,
310 ssl_min_version: Optional[TLSVersion] = None,
311 ssl_ciphers: Optional[str] = None,
312 protocol: Optional[int] = 2,
313 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
314 event_dispatcher: Optional[EventDispatcher] = None,
315 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
316 ) -> None:
317 if db:
318 raise RedisClusterException(
319 "Argument 'db' must be 0 or None in cluster mode"
320 )
321
322 if path:
323 raise RedisClusterException(
324 "Unix domain socket is not supported in cluster mode"
325 )
326
327 if (not host or not port) and not startup_nodes:
328 raise RedisClusterException(
329 "RedisCluster requires at least one node to discover the cluster.\n"
330 "Please provide one of the following or use RedisCluster.from_url:\n"
331 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
332 " - startup_nodes: RedisCluster(startup_nodes=["
333 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
334 )
335
336 kwargs: Dict[str, Any] = {
337 "max_connections": max_connections,
338 "connection_class": Connection,
339 # Client related kwargs
340 "credential_provider": credential_provider,
341 "username": username,
342 "password": password,
343 "client_name": client_name,
344 "lib_name": lib_name,
345 "lib_version": lib_version,
346 # Encoding related kwargs
347 "encoding": encoding,
348 "encoding_errors": encoding_errors,
349 "decode_responses": decode_responses,
350 # Connection related kwargs
351 "health_check_interval": health_check_interval,
352 "socket_connect_timeout": socket_connect_timeout,
353 "socket_keepalive": socket_keepalive,
354 "socket_keepalive_options": socket_keepalive_options,
355 "socket_timeout": socket_timeout,
356 "protocol": protocol,
357 }
358
359 if ssl:
360 # SSL related kwargs
361 kwargs.update(
362 {
363 "connection_class": SSLConnection,
364 "ssl_ca_certs": ssl_ca_certs,
365 "ssl_ca_data": ssl_ca_data,
366 "ssl_cert_reqs": ssl_cert_reqs,
367 "ssl_include_verify_flags": ssl_include_verify_flags,
368 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
369 "ssl_certfile": ssl_certfile,
370 "ssl_check_hostname": ssl_check_hostname,
371 "ssl_keyfile": ssl_keyfile,
372 "ssl_min_version": ssl_min_version,
373 "ssl_ciphers": ssl_ciphers,
374 }
375 )
376
377 if read_from_replicas or load_balancing_strategy:
378 # Call our on_connect function to configure READONLY mode
379 kwargs["redis_connect_func"] = self.on_connect
380
381 if retry:
382 self.retry = retry
383 else:
384 self.retry = Retry(
385 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
386 retries=cluster_error_retry_attempts,
387 )
388 if retry_on_error:
389 self.retry.update_supported_errors(retry_on_error)
390
391 kwargs["response_callbacks"] = _RedisCallbacks.copy()
392 if kwargs.get("protocol") in ["3", 3]:
393 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
394 else:
395 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
396 self.connection_kwargs = kwargs
397
398 if startup_nodes:
399 passed_nodes = []
400 for node in startup_nodes:
401 passed_nodes.append(
402 ClusterNode(node.host, node.port, **self.connection_kwargs)
403 )
404 startup_nodes = passed_nodes
405 else:
406 startup_nodes = []
407 if host and port:
408 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
409
410 if event_dispatcher is None:
411 self._event_dispatcher = EventDispatcher()
412 else:
413 self._event_dispatcher = event_dispatcher
414
415 self.startup_nodes = startup_nodes
416 self.nodes_manager = NodesManager(
417 startup_nodes,
418 require_full_coverage,
419 kwargs,
420 dynamic_startup_nodes=dynamic_startup_nodes,
421 address_remap=address_remap,
422 event_dispatcher=self._event_dispatcher,
423 )
424 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
425 self.read_from_replicas = read_from_replicas
426 self.load_balancing_strategy = load_balancing_strategy
427 self.reinitialize_steps = reinitialize_steps
428 self.reinitialize_counter = 0
429
430 # For backward compatibility, mapping from existing policies to new one
431 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
432 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
433 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
434 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
435 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
436 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
437 SLOT_ID: RequestPolicy.DEFAULT_KEYED,
438 }
439
440 self._policies_callback_mapping: dict[
441 Union[RequestPolicy, ResponsePolicy], Callable
442 ] = {
443 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
444 self.get_random_primary_or_all_nodes(command_name)
445 ],
446 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
447 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
448 RequestPolicy.ALL_SHARDS: self.get_primaries,
449 RequestPolicy.ALL_NODES: self.get_nodes,
450 RequestPolicy.ALL_REPLICAS: self.get_replicas,
451 RequestPolicy.SPECIAL: self.get_special_nodes,
452 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
453 ResponsePolicy.DEFAULT_KEYED: lambda res: res,
454 }
455
456 self._policy_resolver = policy_resolver
457 self.commands_parser = AsyncCommandsParser()
458 self._aggregate_nodes = None
459 self.node_flags = self.__class__.NODE_FLAGS.copy()
460 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
461 self.response_callbacks = kwargs["response_callbacks"]
462 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
463 self.result_callbacks["CLUSTER SLOTS"] = (
464 lambda cmd, res, **kwargs: parse_cluster_slots(
465 list(res.values())[0], **kwargs
466 )
467 )
468
469 self._initialize = True
470 self._lock: Optional[asyncio.Lock] = None
471
472 # When used as an async context manager, we need to increment and decrement
473 # a usage counter so that we can close the connection pool when no one is
474 # using the client.
475 self._usage_counter = 0
476 self._usage_lock = asyncio.Lock()
477
478 async def initialize(self) -> "RedisCluster":
479 """Get all nodes from startup nodes & creates connections if not initialized."""
480 if self._initialize:
481 if not self._lock:
482 self._lock = asyncio.Lock()
483 async with self._lock:
484 if self._initialize:
485 try:
486 await self.nodes_manager.initialize()
487 await self.commands_parser.initialize(
488 self.nodes_manager.default_node
489 )
490 self._initialize = False
491 except BaseException:
492 await self.nodes_manager.aclose()
493 await self.nodes_manager.aclose("startup_nodes")
494 raise
495 return self
496
497 async def aclose(self) -> None:
498 """Close all connections & client if initialized."""
499 if not self._initialize:
500 if not self._lock:
501 self._lock = asyncio.Lock()
502 async with self._lock:
503 if not self._initialize:
504 self._initialize = True
505 await self.nodes_manager.aclose()
506 await self.nodes_manager.aclose("startup_nodes")
507
508 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
509 async def close(self) -> None:
510 """alias for aclose() for backwards compatibility"""
511 await self.aclose()
512
513 async def __aenter__(self) -> "RedisCluster":
514 """
515 Async context manager entry. Increments a usage counter so that the
516 connection pool is only closed (via aclose()) when no context is using
517 the client.
518 """
519 await self._increment_usage()
520 try:
521 # Initialize the client (i.e. establish connection, etc.)
522 return await self.initialize()
523 except Exception:
524 # If initialization fails, decrement the counter to keep it in sync
525 await self._decrement_usage()
526 raise
527
528 async def _increment_usage(self) -> int:
529 """
530 Helper coroutine to increment the usage counter while holding the lock.
531 Returns the new value of the usage counter.
532 """
533 async with self._usage_lock:
534 self._usage_counter += 1
535 return self._usage_counter
536
537 async def _decrement_usage(self) -> int:
538 """
539 Helper coroutine to decrement the usage counter while holding the lock.
540 Returns the new value of the usage counter.
541 """
542 async with self._usage_lock:
543 self._usage_counter -= 1
544 return self._usage_counter
545
546 async def __aexit__(self, exc_type, exc_value, traceback):
547 """
548 Async context manager exit. Decrements a usage counter. If this is the
549 last exit (counter becomes zero), the client closes its connection pool.
550 """
551 current_usage = await asyncio.shield(self._decrement_usage())
552 if current_usage == 0:
553 # This was the last active context, so disconnect the pool.
554 await asyncio.shield(self.aclose())
555
556 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
557 return self.initialize().__await__()
558
559 _DEL_MESSAGE = "Unclosed RedisCluster client"
560
561 def __del__(
562 self,
563 _warn: Any = warnings.warn,
564 _grl: Any = asyncio.get_running_loop,
565 ) -> None:
566 if hasattr(self, "_initialize") and not self._initialize:
567 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
568 try:
569 context = {"client": self, "message": self._DEL_MESSAGE}
570 _grl().call_exception_handler(context)
571 except RuntimeError:
572 pass
573
574 async def on_connect(self, connection: Connection) -> None:
575 await connection.on_connect()
576
577 # Sending READONLY command to server to configure connection as
578 # readonly. Since each cluster node may change its server type due
579 # to a failover, we should establish a READONLY connection
580 # regardless of the server type. If this is a primary connection,
581 # READONLY would not affect executing write commands.
582 await connection.send_command("READONLY")
583 if str_if_bytes(await connection.read_response()) != "OK":
584 raise ConnectionError("READONLY command failed")
585
586 def get_nodes(self) -> List["ClusterNode"]:
587 """Get all nodes of the cluster."""
588 return list(self.nodes_manager.nodes_cache.values())
589
590 def get_primaries(self) -> List["ClusterNode"]:
591 """Get the primary nodes of the cluster."""
592 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
593
594 def get_replicas(self) -> List["ClusterNode"]:
595 """Get the replica nodes of the cluster."""
596 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
597
598 def get_random_node(self) -> "ClusterNode":
599 """Get a random node of the cluster."""
600 return random.choice(list(self.nodes_manager.nodes_cache.values()))
601
602 def get_default_node(self) -> "ClusterNode":
603 """Get the default node of the client."""
604 return self.nodes_manager.default_node
605
606 def set_default_node(self, node: "ClusterNode") -> None:
607 """
608 Set the default node of the client.
609
610 :raises DataError: if None is passed or node does not exist in cluster.
611 """
612 if not node or not self.get_node(node_name=node.name):
613 raise DataError("The requested node does not exist in the cluster.")
614
615 self.nodes_manager.default_node = node
616
617 def get_node(
618 self,
619 host: Optional[str] = None,
620 port: Optional[int] = None,
621 node_name: Optional[str] = None,
622 ) -> Optional["ClusterNode"]:
623 """Get node by (host, port) or node_name."""
624 return self.nodes_manager.get_node(host, port, node_name)
625
626 def get_node_from_key(
627 self, key: str, replica: bool = False
628 ) -> Optional["ClusterNode"]:
629 """
630 Get the cluster node corresponding to the provided key.
631
632 :param key:
633 :param replica:
634 | Indicates if a replica should be returned
635 |
636 None will returned if no replica holds this key
637
638 :raises SlotNotCoveredError: if the key is not covered by any slot.
639 """
640 slot = self.keyslot(key)
641 slot_cache = self.nodes_manager.slots_cache.get(slot)
642 if not slot_cache:
643 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
644
645 if replica:
646 if len(self.nodes_manager.slots_cache[slot]) < 2:
647 return None
648 node_idx = 1
649 else:
650 node_idx = 0
651
652 return slot_cache[node_idx]
653
654 def get_random_primary_or_all_nodes(self, command_name):
655 """
656 Returns random primary or all nodes depends on READONLY mode.
657 """
658 if self.read_from_replicas and command_name in READ_COMMANDS:
659 return self.get_random_node()
660
661 return self.get_random_primary_node()
662
663 def get_random_primary_node(self) -> "ClusterNode":
664 """
665 Returns a random primary node
666 """
667 return random.choice(self.get_primaries())
668
669 async def get_nodes_from_slot(self, command: str, *args):
670 """
671 Returns a list of nodes that hold the specified keys' slots.
672 """
673 # get the node that holds the key's slot
674 return [
675 self.nodes_manager.get_node_from_slot(
676 await self._determine_slot(command, *args),
677 self.read_from_replicas and command in READ_COMMANDS,
678 self.load_balancing_strategy if command in READ_COMMANDS else None,
679 )
680 ]
681
682 def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
683 """
684 Returns a list of nodes for commands with a special policy.
685 """
686 if not self._aggregate_nodes:
687 raise RedisClusterException(
688 "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
689 )
690
691 return self._aggregate_nodes
692
693 def keyslot(self, key: EncodableT) -> int:
694 """
695 Find the keyslot for a given key.
696
697 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
698 """
699 return key_slot(self.encoder.encode(key))
700
701 def get_encoder(self) -> Encoder:
702 """Get the encoder object of the client."""
703 return self.encoder
704
705 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
706 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
707 return self.connection_kwargs
708
709 def set_retry(self, retry: Retry) -> None:
710 self.retry = retry
711
712 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
713 """Set a custom response callback."""
714 self.response_callbacks[command] = callback
715
716 async def _determine_nodes(
717 self,
718 command: str,
719 *args: Any,
720 request_policy: RequestPolicy,
721 node_flag: Optional[str] = None,
722 ) -> List["ClusterNode"]:
723 # Determine which nodes should be executed the command on.
724 # Returns a list of target nodes.
725 if not node_flag:
726 # get the nodes group for this command if it was predefined
727 node_flag = self.command_flags.get(command)
728
729 if node_flag in self._command_flags_mapping:
730 request_policy = self._command_flags_mapping[node_flag]
731
732 policy_callback = self._policies_callback_mapping[request_policy]
733
734 if request_policy == RequestPolicy.DEFAULT_KEYED:
735 nodes = await policy_callback(command, *args)
736 elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
737 nodes = policy_callback(command)
738 else:
739 nodes = policy_callback()
740
741 if command.lower() == "ft.aggregate":
742 self._aggregate_nodes = nodes
743
744 return nodes
745
746 async def _determine_slot(self, command: str, *args: Any) -> int:
747 if self.command_flags.get(command) == SLOT_ID:
748 # The command contains the slot ID
749 return int(args[0])
750
751 # Get the keys in the command
752
753 # EVAL and EVALSHA are common enough that it's wasteful to go to the
754 # redis server to parse the keys. Besides, there is a bug in redis<7.0
755 # where `self._get_command_keys()` fails anyway. So, we special case
756 # EVAL/EVALSHA.
757 # - issue: https://github.com/redis/redis/issues/9493
758 # - fix: https://github.com/redis/redis/pull/9733
759 if command.upper() in ("EVAL", "EVALSHA"):
760 # command syntax: EVAL "script body" num_keys ...
761 if len(args) < 2:
762 raise RedisClusterException(
763 f"Invalid args in command: {command, *args}"
764 )
765 keys = args[2 : 2 + int(args[1])]
766 # if there are 0 keys, that means the script can be run on any node
767 # so we can just return a random slot
768 if not keys:
769 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
770 else:
771 keys = await self.commands_parser.get_keys(command, *args)
772 if not keys:
773 # FCALL can call a function with 0 keys, that means the function
774 # can be run on any node so we can just return a random slot
775 if command.upper() in ("FCALL", "FCALL_RO"):
776 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
777 raise RedisClusterException(
778 "No way to dispatch this command to Redis Cluster. "
779 "Missing key.\nYou can execute the command by specifying "
780 f"target nodes.\nCommand: {args}"
781 )
782
783 # single key command
784 if len(keys) == 1:
785 return self.keyslot(keys[0])
786
787 # multi-key command; we need to make sure all keys are mapped to
788 # the same slot
789 slots = {self.keyslot(key) for key in keys}
790 if len(slots) != 1:
791 raise RedisClusterException(
792 f"{command} - all keys must map to the same key slot"
793 )
794
795 return slots.pop()
796
797 def _is_node_flag(self, target_nodes: Any) -> bool:
798 return isinstance(target_nodes, str) and target_nodes in self.node_flags
799
800 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
801 if isinstance(target_nodes, list):
802 nodes = target_nodes
803 elif isinstance(target_nodes, ClusterNode):
804 # Supports passing a single ClusterNode as a variable
805 nodes = [target_nodes]
806 elif isinstance(target_nodes, dict):
807 # Supports dictionaries of the format {node_name: node}.
808 # It enables to execute commands with multi nodes as follows:
809 # rc.cluster_save_config(rc.get_primaries())
810 nodes = list(target_nodes.values())
811 else:
812 raise TypeError(
813 "target_nodes type can be one of the following: "
814 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
815 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
816 f"The passed type is {type(target_nodes)}"
817 )
818 return nodes
819
820 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
821 """
822 Execute a raw command on the appropriate cluster node or target_nodes.
823
824 It will retry the command as specified by the retries property of
825 the :attr:`retry` & then raise an exception.
826
827 :param args:
828 | Raw command args
829 :param kwargs:
830
831 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
832 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
833 - Rest of the kwargs are passed to the Redis connection
834
835 :raises RedisClusterException: if target_nodes is not provided & the command
836 can't be mapped to a slot
837 """
838 command = args[0]
839 target_nodes = []
840 target_nodes_specified = False
841 retry_attempts = self.retry.get_retries()
842
843 passed_targets = kwargs.pop("target_nodes", None)
844 if passed_targets and not self._is_node_flag(passed_targets):
845 target_nodes = self._parse_target_nodes(passed_targets)
846 target_nodes_specified = True
847 retry_attempts = 0
848
849 command_policies = await self._policy_resolver.resolve(args[0].lower())
850
851 if not command_policies and not target_nodes_specified:
852 command_flag = self.command_flags.get(command)
853 if not command_flag:
854 # Fallback to default policy
855 if not self.get_default_node():
856 slot = None
857 else:
858 slot = await self._determine_slot(*args)
859 if slot is None:
860 command_policies = CommandPolicies()
861 else:
862 command_policies = CommandPolicies(
863 request_policy=RequestPolicy.DEFAULT_KEYED,
864 response_policy=ResponsePolicy.DEFAULT_KEYED,
865 )
866 else:
867 if command_flag in self._command_flags_mapping:
868 command_policies = CommandPolicies(
869 request_policy=self._command_flags_mapping[command_flag]
870 )
871 else:
872 command_policies = CommandPolicies()
873 elif not command_policies and target_nodes_specified:
874 command_policies = CommandPolicies()
875
876 # Add one for the first execution
877 execute_attempts = 1 + retry_attempts
878 for _ in range(execute_attempts):
879 if self._initialize:
880 await self.initialize()
881 if (
882 len(target_nodes) == 1
883 and target_nodes[0] == self.get_default_node()
884 ):
885 # Replace the default cluster node
886 self.replace_default_node()
887 try:
888 if not target_nodes_specified:
889 # Determine the nodes to execute the command on
890 target_nodes = await self._determine_nodes(
891 *args,
892 request_policy=command_policies.request_policy,
893 node_flag=passed_targets,
894 )
895 if not target_nodes:
896 raise RedisClusterException(
897 f"No targets were found to execute {args} command on"
898 )
899
900 if len(target_nodes) == 1:
901 # Return the processed result
902 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
903 if command in self.result_callbacks:
904 ret = self.result_callbacks[command](
905 command, {target_nodes[0].name: ret}, **kwargs
906 )
907 return self._policies_callback_mapping[
908 command_policies.response_policy
909 ](ret)
910 else:
911 keys = [node.name for node in target_nodes]
912 values = await asyncio.gather(
913 *(
914 asyncio.create_task(
915 self._execute_command(node, *args, **kwargs)
916 )
917 for node in target_nodes
918 )
919 )
920 if command in self.result_callbacks:
921 return self.result_callbacks[command](
922 command, dict(zip(keys, values)), **kwargs
923 )
924 return self._policies_callback_mapping[
925 command_policies.response_policy
926 ](dict(zip(keys, values)))
927 except Exception as e:
928 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
929 # The nodes and slots cache were should be reinitialized.
930 # Try again with the new cluster setup.
931 retry_attempts -= 1
932 continue
933 else:
934 # raise the exception
935 raise e
936
937 async def _execute_command(
938 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
939 ) -> Any:
940 asking = moved = False
941 redirect_addr = None
942 ttl = self.RedisClusterRequestTTL
943
944 while ttl > 0:
945 ttl -= 1
946 try:
947 if asking:
948 target_node = self.get_node(node_name=redirect_addr)
949 await target_node.execute_command("ASKING")
950 asking = False
951 elif moved:
952 # MOVED occurred and the slots cache was updated,
953 # refresh the target node
954 slot = await self._determine_slot(*args)
955 target_node = self.nodes_manager.get_node_from_slot(
956 slot,
957 self.read_from_replicas and args[0] in READ_COMMANDS,
958 self.load_balancing_strategy
959 if args[0] in READ_COMMANDS
960 else None,
961 )
962 moved = False
963
964 return await target_node.execute_command(*args, **kwargs)
965 except BusyLoadingError:
966 raise
967 except MaxConnectionsError:
968 # MaxConnectionsError indicates client-side resource exhaustion
969 # (too many connections in the pool), not a node failure.
970 # Don't treat this as a node failure - just re-raise the error
971 # without reinitializing the cluster.
972 raise
973 except (ConnectionError, TimeoutError):
974 # Connection retries are being handled in the node's
975 # Retry object.
976 # Mark active connections for reconnect and disconnect free ones
977 # This handles connection state (like READONLY) that may be stale
978 target_node.update_active_connections_for_reconnect()
979 await target_node.disconnect_free_connections()
980
981 # Move the failed node to the end of the cached nodes list
982 # so it's tried last during reinitialization
983 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)
984
985 # Signal that reinitialization is needed
986 # The retry loop will handle initialize() AND replace_default_node()
987 self._initialize = True
988 raise
989 except (ClusterDownError, SlotNotCoveredError):
990 # ClusterDownError can occur during a failover and to get
991 # self-healed, we will try to reinitialize the cluster layout
992 # and retry executing the command
993
994 # SlotNotCoveredError can occur when the cluster is not fully
995 # initialized or can be temporary issue.
996 # We will try to reinitialize the cluster topology
997 # and retry executing the command
998
999 await self.aclose()
1000 await asyncio.sleep(0.25)
1001 raise
1002 except MovedError as e:
1003 # First, we will try to patch the slots/nodes cache with the
1004 # redirected node output and try again. If MovedError exceeds
1005 # 'reinitialize_steps' number of times, we will force
1006 # reinitializing the tables, and then try again.
1007 # 'reinitialize_steps' counter will increase faster when
1008 # the same client object is shared between multiple threads. To
1009 # reduce the frequency you can set this variable in the
1010 # RedisCluster constructor.
1011 self.reinitialize_counter += 1
1012 if (
1013 self.reinitialize_steps
1014 and self.reinitialize_counter % self.reinitialize_steps == 0
1015 ):
1016 await self.aclose()
1017 # Reset the counter
1018 self.reinitialize_counter = 0
1019 else:
1020 self.nodes_manager.move_slot(e)
1021 moved = True
1022 except AskError as e:
1023 redirect_addr = get_node_name(host=e.host, port=e.port)
1024 asking = True
1025 except TryAgainError:
1026 if ttl < self.RedisClusterRequestTTL / 2:
1027 await asyncio.sleep(0.05)
1028
1029 raise ClusterError("TTL exhausted.")
1030
1031 def pipeline(
1032 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
1033 ) -> "ClusterPipeline":
1034 """
1035 Create & return a new :class:`~.ClusterPipeline` object.
1036
1037 Cluster implementation of pipeline does not support transaction or shard_hint.
1038
1039 :raises RedisClusterException: if transaction or shard_hint are truthy values
1040 """
1041 if shard_hint:
1042 raise RedisClusterException("shard_hint is deprecated in cluster mode")
1043
1044 return ClusterPipeline(self, transaction)
1045
1046 def lock(
1047 self,
1048 name: KeyT,
1049 timeout: Optional[float] = None,
1050 sleep: float = 0.1,
1051 blocking: bool = True,
1052 blocking_timeout: Optional[float] = None,
1053 lock_class: Optional[Type[Lock]] = None,
1054 thread_local: bool = True,
1055 raise_on_release_error: bool = True,
1056 ) -> Lock:
1057 """
1058 Return a new Lock object using key ``name`` that mimics
1059 the behavior of threading.Lock.
1060
1061 If specified, ``timeout`` indicates a maximum life for the lock.
1062 By default, it will remain locked until release() is called.
1063
1064 ``sleep`` indicates the amount of time to sleep per loop iteration
1065 when the lock is in blocking mode and another client is currently
1066 holding the lock.
1067
1068 ``blocking`` indicates whether calling ``acquire`` should block until
1069 the lock has been acquired or to fail immediately, causing ``acquire``
1070 to return False and the lock not being acquired. Defaults to True.
1071 Note this value can be overridden by passing a ``blocking``
1072 argument to ``acquire``.
1073
1074 ``blocking_timeout`` indicates the maximum amount of time in seconds to
1075 spend trying to acquire the lock. A value of ``None`` indicates
1076 continue trying forever. ``blocking_timeout`` can be specified as a
1077 float or integer, both representing the number of seconds to wait.
1078
1079 ``lock_class`` forces the specified lock implementation. Note that as
1080 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
1081 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
1082 you have created your own custom lock class.
1083
1084 ``thread_local`` indicates whether the lock token is placed in
1085 thread-local storage. By default, the token is placed in thread local
1086 storage so that a thread only sees its token, not a token set by
1087 another thread. Consider the following timeline:
1088
1089 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
1090 thread-1 sets the token to "abc"
1091 time: 1, thread-2 blocks trying to acquire `my-lock` using the
1092 Lock instance.
1093 time: 5, thread-1 has not yet completed. redis expires the lock
1094 key.
1095 time: 5, thread-2 acquired `my-lock` now that it's available.
1096 thread-2 sets the token to "xyz"
1097 time: 6, thread-1 finishes its work and calls release(). if the
1098 token is *not* stored in thread local storage, then
1099 thread-1 would see the token value as "xyz" and would be
1100 able to successfully release the thread-2's lock.
1101
1102 ``raise_on_release_error`` indicates whether to raise an exception when
1103 the lock is no longer owned when exiting the context manager. By default,
1104 this is True, meaning an exception will be raised. If False, the warning
1105 will be logged and the exception will be suppressed.
1106
1107 In some use cases it's necessary to disable thread local storage. For
1108 example, if you have code where one thread acquires a lock and passes
1109 that lock instance to a worker thread to release later. If thread
1110 local storage isn't disabled in this case, the worker thread won't see
1111 the token set by the thread that acquired the lock. Our assumption
1112 is that these cases aren't common and as such default to using
1113 thread local storage."""
1114 if lock_class is None:
1115 lock_class = Lock
1116 return lock_class(
1117 self,
1118 name,
1119 timeout=timeout,
1120 sleep=sleep,
1121 blocking=blocking,
1122 blocking_timeout=blocking_timeout,
1123 thread_local=thread_local,
1124 raise_on_release_error=raise_on_release_error,
1125 )
1126
1127 async def transaction(
1128 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1129 ):
1130 """
1131 Convenience method for executing the callable `func` as a transaction
1132 while watching all keys specified in `watches`. The 'func' callable
1133 should expect a single argument which is a Pipeline object.
1134 """
1135 shard_hint = kwargs.pop("shard_hint", None)
1136 value_from_callable = kwargs.pop("value_from_callable", False)
1137 watch_delay = kwargs.pop("watch_delay", None)
1138 async with self.pipeline(True, shard_hint) as pipe:
1139 while True:
1140 try:
1141 if watches:
1142 await pipe.watch(*watches)
1143 func_value = await func(pipe)
1144 exec_value = await pipe.execute()
1145 return func_value if value_from_callable else exec_value
1146 except WatchError:
1147 if watch_delay is not None and watch_delay > 0:
1148 time.sleep(watch_delay)
1149 continue
1150
1151
1152class ClusterNode:
1153 """
1154 Create a new ClusterNode.
1155
1156 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1157 objects for the (host, port).
1158 """
1159
1160 __slots__ = (
1161 "_connections",
1162 "_free",
1163 "_lock",
1164 "_event_dispatcher",
1165 "connection_class",
1166 "connection_kwargs",
1167 "host",
1168 "max_connections",
1169 "name",
1170 "port",
1171 "response_callbacks",
1172 "server_type",
1173 )
1174
1175 def __init__(
1176 self,
1177 host: str,
1178 port: Union[str, int],
1179 server_type: Optional[str] = None,
1180 *,
1181 max_connections: int = 2**31,
1182 connection_class: Type[Connection] = Connection,
1183 **connection_kwargs: Any,
1184 ) -> None:
1185 if host == "localhost":
1186 host = socket.gethostbyname(host)
1187
1188 connection_kwargs["host"] = host
1189 connection_kwargs["port"] = port
1190 self.host = host
1191 self.port = port
1192 self.name = get_node_name(host, port)
1193 self.server_type = server_type
1194
1195 self.max_connections = max_connections
1196 self.connection_class = connection_class
1197 self.connection_kwargs = connection_kwargs
1198 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1199
1200 self._connections: List[Connection] = []
1201 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1202 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1203 if self._event_dispatcher is None:
1204 self._event_dispatcher = EventDispatcher()
1205
1206 def __repr__(self) -> str:
1207 return (
1208 f"[host={self.host}, port={self.port}, "
1209 f"name={self.name}, server_type={self.server_type}]"
1210 )
1211
1212 def __eq__(self, obj: Any) -> bool:
1213 return isinstance(obj, ClusterNode) and obj.name == self.name
1214
1215 def __hash__(self) -> int:
1216 return hash(self.name)
1217
1218 _DEL_MESSAGE = "Unclosed ClusterNode object"
1219
1220 def __del__(
1221 self,
1222 _warn: Any = warnings.warn,
1223 _grl: Any = asyncio.get_running_loop,
1224 ) -> None:
1225 for connection in self._connections:
1226 if connection.is_connected:
1227 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1228
1229 try:
1230 context = {"client": self, "message": self._DEL_MESSAGE}
1231 _grl().call_exception_handler(context)
1232 except RuntimeError:
1233 pass
1234 break
1235
1236 async def disconnect(self) -> None:
1237 ret = await asyncio.gather(
1238 *(
1239 asyncio.create_task(connection.disconnect())
1240 for connection in self._connections
1241 ),
1242 return_exceptions=True,
1243 )
1244 exc = next((res for res in ret if isinstance(res, Exception)), None)
1245 if exc:
1246 raise exc
1247
1248 def acquire_connection(self) -> Connection:
1249 try:
1250 return self._free.popleft()
1251 except IndexError:
1252 if len(self._connections) < self.max_connections:
1253 # We are configuring the connection pool not to retry
1254 # connections on lower level clients to avoid retrying
1255 # connections to nodes that are not reachable
1256 # and to avoid blocking the connection pool.
1257 # The only error that will have some handling in the lower
1258 # level clients is ConnectionError which will trigger disconnection
1259 # of the socket.
1260 # The retries will be handled on cluster client level
1261 # where we will have proper handling of the cluster topology
1262 retry = Retry(
1263 backoff=NoBackoff(),
1264 retries=0,
1265 supported_errors=(ConnectionError,),
1266 )
1267 connection_kwargs = self.connection_kwargs.copy()
1268 connection_kwargs["retry"] = retry
1269 connection = self.connection_class(**connection_kwargs)
1270 self._connections.append(connection)
1271 return connection
1272
1273 raise MaxConnectionsError()
1274
1275 async def disconnect_if_needed(self, connection: Connection) -> None:
1276 """
1277 Disconnect a connection if it's marked for reconnect.
1278 This implements lazy disconnection to avoid race conditions.
1279 The connection will auto-reconnect on next use.
1280 """
1281 if connection.should_reconnect():
1282 await connection.disconnect()
1283
1284 def release(self, connection: Connection) -> None:
1285 """
1286 Release connection back to free queue.
1287 If the connection is marked for reconnect, it will be disconnected
1288 lazily when next acquired via disconnect_if_needed().
1289 """
1290 self._free.append(connection)
1291
1292 def update_active_connections_for_reconnect(self) -> None:
1293 """
1294 Mark all in-use (active) connections for reconnect.
1295 In-use connections are those in _connections but not currently in _free.
1296 They will be disconnected when released back to the pool.
1297 """
1298 free_set = set(self._free)
1299 for connection in self._connections:
1300 if connection not in free_set:
1301 connection.mark_for_reconnect()
1302
1303 async def disconnect_free_connections(self) -> None:
1304 """
1305 Disconnect all free/idle connections in the pool.
1306 This is useful after topology changes (e.g., failover) to clear
1307 stale connection state like READONLY mode.
1308 The connections remain in the pool and will reconnect on next use.
1309 """
1310 if self._free:
1311 # Take a snapshot to avoid issues if _free changes during await
1312 await asyncio.gather(
1313 *(connection.disconnect() for connection in tuple(self._free)),
1314 return_exceptions=True,
1315 )
1316
1317 async def parse_response(
1318 self, connection: Connection, command: str, **kwargs: Any
1319 ) -> Any:
1320 try:
1321 if NEVER_DECODE in kwargs:
1322 response = await connection.read_response(disable_decoding=True)
1323 kwargs.pop(NEVER_DECODE)
1324 else:
1325 response = await connection.read_response()
1326 except ResponseError:
1327 if EMPTY_RESPONSE in kwargs:
1328 return kwargs[EMPTY_RESPONSE]
1329 raise
1330
1331 if EMPTY_RESPONSE in kwargs:
1332 kwargs.pop(EMPTY_RESPONSE)
1333
1334 # Remove keys entry, it needs only for cache.
1335 kwargs.pop("keys", None)
1336
1337 # Return response
1338 if command in self.response_callbacks:
1339 return self.response_callbacks[command](response, **kwargs)
1340
1341 return response
1342
1343 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1344 # Acquire connection
1345 connection = self.acquire_connection()
1346 # Handle lazy disconnect for connections marked for reconnect
1347 await self.disconnect_if_needed(connection)
1348
1349 # Execute command
1350 await connection.send_packed_command(connection.pack_command(*args), False)
1351
1352 # Read response
1353 try:
1354 return await self.parse_response(connection, args[0], **kwargs)
1355 finally:
1356 await self.disconnect_if_needed(connection)
1357 # Release connection
1358 self._free.append(connection)
1359
1360 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1361 # Acquire connection
1362 connection = self.acquire_connection()
1363 # Handle lazy disconnect for connections marked for reconnect
1364 await self.disconnect_if_needed(connection)
1365
1366 # Execute command
1367 await connection.send_packed_command(
1368 connection.pack_commands(cmd.args for cmd in commands), False
1369 )
1370
1371 # Read responses
1372 ret = False
1373 for cmd in commands:
1374 try:
1375 cmd.result = await self.parse_response(
1376 connection, cmd.args[0], **cmd.kwargs
1377 )
1378 except Exception as e:
1379 cmd.result = e
1380 ret = True
1381
1382 # Release connection
1383 await self.disconnect_if_needed(connection)
1384 self._free.append(connection)
1385
1386 return ret
1387
1388 async def re_auth_callback(self, token: TokenInterface):
1389 tmp_queue = collections.deque()
1390 while self._free:
1391 conn = self._free.popleft()
1392 await conn.retry.call_with_retry(
1393 lambda: conn.send_command(
1394 "AUTH", token.try_get("oid"), token.get_value()
1395 ),
1396 lambda error: self._mock(error),
1397 )
1398 await conn.retry.call_with_retry(
1399 lambda: conn.read_response(), lambda error: self._mock(error)
1400 )
1401 tmp_queue.append(conn)
1402
1403 while tmp_queue:
1404 conn = tmp_queue.popleft()
1405 self._free.append(conn)
1406
1407 async def _mock(self, error: RedisError):
1408 """
1409 Dummy functions, needs to be passed as error callback to retry object.
1410 :param error:
1411 :return:
1412 """
1413 pass
1414
1415
1416class NodesManager:
1417 __slots__ = (
1418 "_dynamic_startup_nodes",
1419 "_event_dispatcher",
1420 "connection_kwargs",
1421 "default_node",
1422 "nodes_cache",
1423 "_epoch",
1424 "read_load_balancer",
1425 "_initialize_lock",
1426 "require_full_coverage",
1427 "slots_cache",
1428 "startup_nodes",
1429 "address_remap",
1430 )
1431
1432 def __init__(
1433 self,
1434 startup_nodes: List["ClusterNode"],
1435 require_full_coverage: bool,
1436 connection_kwargs: Dict[str, Any],
1437 dynamic_startup_nodes: bool = True,
1438 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1439 event_dispatcher: Optional[EventDispatcher] = None,
1440 ) -> None:
1441 self.startup_nodes = {node.name: node for node in startup_nodes}
1442 self.require_full_coverage = require_full_coverage
1443 self.connection_kwargs = connection_kwargs
1444 self.address_remap = address_remap
1445
1446 self.default_node: "ClusterNode" = None
1447 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1448 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1449 self._epoch: int = 0
1450 self.read_load_balancer = LoadBalancer()
1451 self._initialize_lock: asyncio.Lock = asyncio.Lock()
1452
1453 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1454 if event_dispatcher is None:
1455 self._event_dispatcher = EventDispatcher()
1456 else:
1457 self._event_dispatcher = event_dispatcher
1458
1459 def get_node(
1460 self,
1461 host: Optional[str] = None,
1462 port: Optional[int] = None,
1463 node_name: Optional[str] = None,
1464 ) -> Optional["ClusterNode"]:
1465 if host and port:
1466 # the user passed host and port
1467 if host == "localhost":
1468 host = socket.gethostbyname(host)
1469 return self.nodes_cache.get(get_node_name(host=host, port=port))
1470 elif node_name:
1471 return self.nodes_cache.get(node_name)
1472 else:
1473 raise DataError(
1474 "get_node requires one of the following: 1. node name 2. host and port"
1475 )
1476
1477 def set_nodes(
1478 self,
1479 old: Dict[str, "ClusterNode"],
1480 new: Dict[str, "ClusterNode"],
1481 remove_old: bool = False,
1482 ) -> None:
1483 if remove_old:
1484 for name in list(old.keys()):
1485 if name not in new:
1486 # Node is removed from cache before disconnect starts,
1487 # so it won't be found in lookups during disconnect
1488 # Mark active connections for reconnect so they get disconnected after current command completes
1489 # and disconnect free connections immediately
1490 # the node is removed from the cache before the connections changes so it won't be used and should be safe
1491 # not to wait for the disconnects
1492 removed_node = old.pop(name)
1493 removed_node.update_active_connections_for_reconnect()
1494 asyncio.create_task(removed_node.disconnect_free_connections()) # noqa
1495
1496 for name, node in new.items():
1497 if name in old:
1498 # Preserve the existing node but mark connections for reconnect.
1499 # This method is sync so we can't call disconnect_free_connections()
1500 # which is async. Instead, we mark free connections for reconnect
1501 # and they will be lazily disconnected when acquired via
1502 # disconnect_if_needed() to avoid race conditions.
1503 # TODO: Make this method async in the next major release to allow
1504 # immediate disconnection of free connections.
1505 existing_node = old[name]
1506 existing_node.update_active_connections_for_reconnect()
1507 for conn in existing_node._free:
1508 conn.mark_for_reconnect()
1509 continue
1510 # New node is detected and should be added to the pool
1511 old[name] = node
1512
1513 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
1514 """
1515 Move a failing node to the end of startup_nodes and nodes_cache so it's
1516 tried last during reinitialization and when selecting the default node.
1517 If the node is not in the respective list, nothing is done.
1518 """
1519 # Move in startup_nodes
1520 if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
1521 node = self.startup_nodes.pop(node_name)
1522 self.startup_nodes[node_name] = node # Re-insert at end
1523
1524 # Move in nodes_cache - this affects get_nodes_by_server_type ordering
1525 # which is used to select the default_node during initialize()
1526 if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
1527 node = self.nodes_cache.pop(node_name)
1528 self.nodes_cache[node_name] = node # Re-insert at end
1529
1530 def move_slot(self, e: AskError | MovedError):
1531 redirected_node = self.get_node(host=e.host, port=e.port)
1532 if redirected_node:
1533 # The node already exists
1534 if redirected_node.server_type != PRIMARY:
1535 # Update the node's server type
1536 redirected_node.server_type = PRIMARY
1537 else:
1538 # This is a new node, we will add it to the nodes cache
1539 redirected_node = ClusterNode(
1540 e.host, e.port, PRIMARY, **self.connection_kwargs
1541 )
1542 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1543 slot_nodes = self.slots_cache[e.slot_id]
1544 if redirected_node not in slot_nodes:
1545 # The new slot owner is a new server, or a server from a different
1546 # shard. We need to remove all current nodes from the slot's list
1547 # (including replications) and add just the new node.
1548 self.slots_cache[e.slot_id] = [redirected_node]
1549 elif redirected_node is not slot_nodes[0]:
1550 # The MOVED error resulted from a failover, and the new slot owner
1551 # had previously been a replica.
1552 old_primary = slot_nodes[0]
1553 # Update the old primary to be a replica and add it to the end of
1554 # the slot's node list
1555 old_primary.server_type = REPLICA
1556 slot_nodes.append(old_primary)
1557 # Remove the old replica, which is now a primary, from the slot's
1558 # node list
1559 slot_nodes.remove(redirected_node)
1560 # Override the old primary with the new one
1561 slot_nodes[0] = redirected_node
1562 if self.default_node == old_primary:
1563 # Update the default node with the new primary
1564 self.default_node = redirected_node
1565 # else: circular MOVED to current primary -> no-op
1566
1567 def get_node_from_slot(
1568 self,
1569 slot: int,
1570 read_from_replicas: bool = False,
1571 load_balancing_strategy=None,
1572 ) -> "ClusterNode":
1573 if read_from_replicas is True and load_balancing_strategy is None:
1574 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1575
1576 try:
1577 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1578 # get the server index using the strategy defined in load_balancing_strategy
1579 primary_name = self.slots_cache[slot][0].name
1580 node_idx = self.read_load_balancer.get_server_index(
1581 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1582 )
1583 return self.slots_cache[slot][node_idx]
1584 return self.slots_cache[slot][0]
1585 except (IndexError, TypeError):
1586 raise SlotNotCoveredError(
1587 f'Slot "{slot}" not covered by the cluster. '
1588 f'"require_full_coverage={self.require_full_coverage}"'
1589 )
1590
1591 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1592 return [
1593 node
1594 for node in self.nodes_cache.values()
1595 if node.server_type == server_type
1596 ]
1597
1598 async def initialize(self) -> None:
1599 self.read_load_balancer.reset()
1600 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1601 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1602 disagreements = []
1603 startup_nodes_reachable = False
1604 fully_covered = False
1605 exception = None
1606 epoch = self._epoch
1607
1608 async with self._initialize_lock:
1609 if self._epoch != epoch:
1610 # another initialize call has already reinitialized the
1611 # nodes since we started waiting for the lock;
1612 # we don't need to do it again.
1613 return
1614
1615 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1616 # is modified during iteration
1617 for startup_node in tuple(self.startup_nodes.values()):
1618 try:
1619 # Make sure cluster mode is enabled on this node
1620 try:
1621 self._event_dispatcher.dispatch(
1622 AfterAsyncClusterInstantiationEvent(
1623 self.nodes_cache,
1624 self.connection_kwargs.get("credential_provider", None),
1625 )
1626 )
1627 cluster_slots = await startup_node.execute_command(
1628 "CLUSTER SLOTS"
1629 )
1630 except ResponseError:
1631 raise RedisClusterException(
1632 "Cluster mode is not enabled on this node"
1633 )
1634 startup_nodes_reachable = True
1635 except Exception as e:
1636 # Try the next startup node.
1637 # The exception is saved and raised only if we have no more nodes.
1638 exception = e
1639 continue
1640
1641 # CLUSTER SLOTS command results in the following output:
1642 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1643 # where each node contains the following list: [IP, port, node_id]
1644 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1645 # primary node of the first slot section.
1646 # If there's only one server in the cluster, its ``host`` is ''
1647 # Fix it to the host in startup_nodes
1648 if (
1649 len(cluster_slots) == 1
1650 and not cluster_slots[0][2][0]
1651 and len(self.startup_nodes) == 1
1652 ):
1653 cluster_slots[0][2][0] = startup_node.host
1654
1655 for slot in cluster_slots:
1656 for i in range(2, len(slot)):
1657 slot[i] = [str_if_bytes(val) for val in slot[i]]
1658 primary_node = slot[2]
1659 host = primary_node[0]
1660 if host == "":
1661 host = startup_node.host
1662 port = int(primary_node[1])
1663 host, port = self.remap_host_port(host, port)
1664
1665 nodes_for_slot = []
1666
1667 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1668 if not target_node:
1669 target_node = ClusterNode(
1670 host, port, PRIMARY, **self.connection_kwargs
1671 )
1672 # add this node to the nodes cache
1673 tmp_nodes_cache[target_node.name] = target_node
1674 nodes_for_slot.append(target_node)
1675
1676 replica_nodes = slot[3:]
1677 for replica_node in replica_nodes:
1678 host = replica_node[0]
1679 port = replica_node[1]
1680 host, port = self.remap_host_port(host, port)
1681
1682 target_replica_node = tmp_nodes_cache.get(
1683 get_node_name(host, port)
1684 )
1685 if not target_replica_node:
1686 target_replica_node = ClusterNode(
1687 host, port, REPLICA, **self.connection_kwargs
1688 )
1689 # add this node to the nodes cache
1690 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1691 nodes_for_slot.append(target_replica_node)
1692
1693 for i in range(int(slot[0]), int(slot[1]) + 1):
1694 if i not in tmp_slots:
1695 tmp_slots[i] = nodes_for_slot
1696 else:
1697 # Validate that 2 nodes want to use the same slot cache
1698 # setup
1699 tmp_slot = tmp_slots[i][0]
1700 if tmp_slot.name != target_node.name:
1701 disagreements.append(
1702 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1703 )
1704
1705 if len(disagreements) > 5:
1706 raise RedisClusterException(
1707 f"startup_nodes could not agree on a valid "
1708 f"slots cache: {', '.join(disagreements)}"
1709 )
1710
1711 # Validate if all slots are covered or if we should try next startup node
1712 fully_covered = True
1713 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1714 if i not in tmp_slots:
1715 fully_covered = False
1716 break
1717 if fully_covered:
1718 break
1719
1720 if not startup_nodes_reachable:
1721 raise RedisClusterException(
1722 f"Redis Cluster cannot be connected. Please provide at least "
1723 f"one reachable node: {str(exception)}"
1724 ) from exception
1725
1726 # Check if the slots are not fully covered
1727 if not fully_covered and self.require_full_coverage:
1728 # Despite the requirement that the slots be covered, there
1729 # isn't a full coverage
1730 raise RedisClusterException(
1731 f"All slots are not covered after query all startup_nodes. "
1732 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1733 f"covered..."
1734 )
1735
1736 # Set the tmp variables to the real variables
1737 self.slots_cache = tmp_slots
1738 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1739
1740 if self._dynamic_startup_nodes:
1741 # Populate the startup nodes with all discovered nodes
1742 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1743
1744 # Set the default node
1745 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
1746 self._epoch += 1
1747
1748 async def aclose(self, attr: str = "nodes_cache") -> None:
1749 self.default_node = None
1750 await asyncio.gather(
1751 *(
1752 asyncio.create_task(node.disconnect())
1753 for node in getattr(self, attr).values()
1754 )
1755 )
1756
1757 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1758 """
1759 Remap the host and port returned from the cluster to a different
1760 internal value. Useful if the client is not connecting directly
1761 to the cluster.
1762 """
1763 if self.address_remap:
1764 return self.address_remap((host, port))
1765 return host, port
1766
1767
1768class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1769 """
1770 Create a new ClusterPipeline object.
1771
1772 Usage::
1773
1774 result = await (
1775 rc.pipeline()
1776 .set("A", 1)
1777 .get("A")
1778 .hset("K", "F", "V")
1779 .hgetall("K")
1780 .mset_nonatomic({"A": 2, "B": 3})
1781 .get("A")
1782 .get("B")
1783 .delete("A", "B", "K")
1784 .execute()
1785 )
1786 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
1787
1788 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
1789 are split across multiple nodes, you'll get multiple results for them in the array.
1790
1791 Retryable errors:
1792 - :class:`~.ClusterDownError`
1793 - :class:`~.ConnectionError`
1794 - :class:`~.TimeoutError`
1795
1796 Redirection errors:
1797 - :class:`~.TryAgainError`
1798 - :class:`~.MovedError`
1799 - :class:`~.AskError`
1800
1801 :param client:
1802 | Existing :class:`~.RedisCluster` client
1803 """
1804
1805 __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1806
1807 def __init__(
1808 self, client: RedisCluster, transaction: Optional[bool] = None
1809 ) -> None:
1810 self.cluster_client = client
1811 self._transaction = transaction
1812 self._execution_strategy: ExecutionStrategy = (
1813 PipelineStrategy(self)
1814 if not self._transaction
1815 else TransactionStrategy(self)
1816 )
1817
1818 async def initialize(self) -> "ClusterPipeline":
1819 await self._execution_strategy.initialize()
1820 return self
1821
1822 async def __aenter__(self) -> "ClusterPipeline":
1823 return await self.initialize()
1824
1825 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1826 await self.reset()
1827
1828 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1829 return self.initialize().__await__()
1830
1831 def __bool__(self) -> bool:
1832 "Pipeline instances should always evaluate to True on Python 3+"
1833 return True
1834
1835 def __len__(self) -> int:
1836 return len(self._execution_strategy)
1837
1838 def execute_command(
1839 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1840 ) -> "ClusterPipeline":
1841 """
1842 Append a raw command to the pipeline.
1843
1844 :param args:
1845 | Raw command args
1846 :param kwargs:
1847
1848 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
1849 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1850 - Rest of the kwargs are passed to the Redis connection
1851 """
1852 return self._execution_strategy.execute_command(*args, **kwargs)
1853
1854 async def execute(
1855 self, raise_on_error: bool = True, allow_redirections: bool = True
1856 ) -> List[Any]:
1857 """
1858 Execute the pipeline.
1859
1860 It will retry the commands as specified by retries specified in :attr:`retry`
1861 & then raise an exception.
1862
1863 :param raise_on_error:
1864 | Raise the first error if there are any errors
1865 :param allow_redirections:
1866 | Whether to retry each failed command individually in case of redirection
1867 errors
1868
1869 :raises RedisClusterException: if target_nodes is not provided & the command
1870 can't be mapped to a slot
1871 """
1872 try:
1873 return await self._execution_strategy.execute(
1874 raise_on_error, allow_redirections
1875 )
1876 finally:
1877 await self.reset()
1878
1879 def _split_command_across_slots(
1880 self, command: str, *keys: KeyT
1881 ) -> "ClusterPipeline":
1882 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1883 self.execute_command(command, *slot_keys)
1884
1885 return self
1886
1887 async def reset(self):
1888 """
1889 Reset back to empty pipeline.
1890 """
1891 await self._execution_strategy.reset()
1892
1893 def multi(self):
1894 """
1895 Start a transactional block of the pipeline after WATCH commands
1896 are issued. End the transactional block with `execute`.
1897 """
1898 self._execution_strategy.multi()
1899
1900 async def discard(self):
1901 """ """
1902 await self._execution_strategy.discard()
1903
1904 async def watch(self, *names):
1905 """Watches the values at keys ``names``"""
1906 await self._execution_strategy.watch(*names)
1907
1908 async def unwatch(self):
1909 """Unwatches all previously specified keys"""
1910 await self._execution_strategy.unwatch()
1911
1912 async def unlink(self, *names):
1913 await self._execution_strategy.unlink(*names)
1914
1915 def mset_nonatomic(
1916 self, mapping: Mapping[AnyKeyT, EncodableT]
1917 ) -> "ClusterPipeline":
1918 return self._execution_strategy.mset_nonatomic(mapping)
1919
1920
1921for command in PIPELINE_BLOCKED_COMMANDS:
1922 command = command.replace(" ", "_").lower()
1923 if command == "mset_nonatomic":
1924 continue
1925
1926 setattr(ClusterPipeline, command, block_pipeline_command(command))
1927
1928
1929class PipelineCommand:
1930 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1931 self.args = args
1932 self.kwargs = kwargs
1933 self.position = position
1934 self.result: Union[Any, Exception] = None
1935 self.command_policies: Optional[CommandPolicies] = None
1936
1937 def __repr__(self) -> str:
1938 return f"[{self.position}] {self.args} ({self.kwargs})"
1939
1940
1941class ExecutionStrategy(ABC):
1942 @abstractmethod
1943 async def initialize(self) -> "ClusterPipeline":
1944 """
1945 Initialize the execution strategy.
1946
1947 See ClusterPipeline.initialize()
1948 """
1949 pass
1950
1951 @abstractmethod
1952 def execute_command(
1953 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1954 ) -> "ClusterPipeline":
1955 """
1956 Append a raw command to the pipeline.
1957
1958 See ClusterPipeline.execute_command()
1959 """
1960 pass
1961
1962 @abstractmethod
1963 async def execute(
1964 self, raise_on_error: bool = True, allow_redirections: bool = True
1965 ) -> List[Any]:
1966 """
1967 Execute the pipeline.
1968
1969 It will retry the commands as specified by retries specified in :attr:`retry`
1970 & then raise an exception.
1971
1972 See ClusterPipeline.execute()
1973 """
1974 pass
1975
1976 @abstractmethod
1977 def mset_nonatomic(
1978 self, mapping: Mapping[AnyKeyT, EncodableT]
1979 ) -> "ClusterPipeline":
1980 """
1981 Executes multiple MSET commands according to the provided slot/pairs mapping.
1982
1983 See ClusterPipeline.mset_nonatomic()
1984 """
1985 pass
1986
1987 @abstractmethod
1988 async def reset(self):
1989 """
1990 Resets current execution strategy.
1991
1992 See: ClusterPipeline.reset()
1993 """
1994 pass
1995
1996 @abstractmethod
1997 def multi(self):
1998 """
1999 Starts transactional context.
2000
2001 See: ClusterPipeline.multi()
2002 """
2003 pass
2004
2005 @abstractmethod
2006 async def watch(self, *names):
2007 """
2008 Watch given keys.
2009
2010 See: ClusterPipeline.watch()
2011 """
2012 pass
2013
2014 @abstractmethod
2015 async def unwatch(self):
2016 """
2017 Unwatches all previously specified keys
2018
2019 See: ClusterPipeline.unwatch()
2020 """
2021 pass
2022
2023 @abstractmethod
2024 async def discard(self):
2025 pass
2026
2027 @abstractmethod
2028 async def unlink(self, *names):
2029 """
2030 "Unlink a key specified by ``names``"
2031
2032 See: ClusterPipeline.unlink()
2033 """
2034 pass
2035
2036 @abstractmethod
2037 def __len__(self) -> int:
2038 pass
2039
2040
2041class AbstractStrategy(ExecutionStrategy):
2042 def __init__(self, pipe: ClusterPipeline) -> None:
2043 self._pipe: ClusterPipeline = pipe
2044 self._command_queue: List["PipelineCommand"] = []
2045
2046 async def initialize(self) -> "ClusterPipeline":
2047 if self._pipe.cluster_client._initialize:
2048 await self._pipe.cluster_client.initialize()
2049 self._command_queue = []
2050 return self._pipe
2051
2052 def execute_command(
2053 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2054 ) -> "ClusterPipeline":
2055 self._command_queue.append(
2056 PipelineCommand(len(self._command_queue), *args, **kwargs)
2057 )
2058 return self._pipe
2059
2060 def _annotate_exception(self, exception, number, command):
2061 """
2062 Provides extra context to the exception prior to it being handled
2063 """
2064 cmd = " ".join(map(safe_str, command))
2065 msg = (
2066 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
2067 f"caused error: {exception.args[0]}"
2068 )
2069 exception.args = (msg,) + exception.args[1:]
2070
2071 @abstractmethod
2072 def mset_nonatomic(
2073 self, mapping: Mapping[AnyKeyT, EncodableT]
2074 ) -> "ClusterPipeline":
2075 pass
2076
2077 @abstractmethod
2078 async def execute(
2079 self, raise_on_error: bool = True, allow_redirections: bool = True
2080 ) -> List[Any]:
2081 pass
2082
2083 @abstractmethod
2084 async def reset(self):
2085 pass
2086
2087 @abstractmethod
2088 def multi(self):
2089 pass
2090
2091 @abstractmethod
2092 async def watch(self, *names):
2093 pass
2094
2095 @abstractmethod
2096 async def unwatch(self):
2097 pass
2098
2099 @abstractmethod
2100 async def discard(self):
2101 pass
2102
2103 @abstractmethod
2104 async def unlink(self, *names):
2105 pass
2106
2107 def __len__(self) -> int:
2108 return len(self._command_queue)
2109
2110
2111class PipelineStrategy(AbstractStrategy):
2112 def __init__(self, pipe: ClusterPipeline) -> None:
2113 super().__init__(pipe)
2114
2115 def mset_nonatomic(
2116 self, mapping: Mapping[AnyKeyT, EncodableT]
2117 ) -> "ClusterPipeline":
2118 encoder = self._pipe.cluster_client.encoder
2119
2120 slots_pairs = {}
2121 for pair in mapping.items():
2122 slot = key_slot(encoder.encode(pair[0]))
2123 slots_pairs.setdefault(slot, []).extend(pair)
2124
2125 for pairs in slots_pairs.values():
2126 self.execute_command("MSET", *pairs)
2127
2128 return self._pipe
2129
2130 async def execute(
2131 self, raise_on_error: bool = True, allow_redirections: bool = True
2132 ) -> List[Any]:
2133 if not self._command_queue:
2134 return []
2135
2136 try:
2137 retry_attempts = self._pipe.cluster_client.retry.get_retries()
2138 while True:
2139 try:
2140 if self._pipe.cluster_client._initialize:
2141 await self._pipe.cluster_client.initialize()
2142 return await self._execute(
2143 self._pipe.cluster_client,
2144 self._command_queue,
2145 raise_on_error=raise_on_error,
2146 allow_redirections=allow_redirections,
2147 )
2148
2149 except RedisCluster.ERRORS_ALLOW_RETRY as e:
2150 if retry_attempts > 0:
2151 # Try again with the new cluster setup. All other errors
2152 # should be raised.
2153 retry_attempts -= 1
2154 await self._pipe.cluster_client.aclose()
2155 await asyncio.sleep(0.25)
2156 else:
2157 # All other errors should be raised.
2158 raise e
2159 finally:
2160 await self.reset()
2161
2162 async def _execute(
2163 self,
2164 client: "RedisCluster",
2165 stack: List["PipelineCommand"],
2166 raise_on_error: bool = True,
2167 allow_redirections: bool = True,
2168 ) -> List[Any]:
2169 todo = [
2170 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
2171 ]
2172
2173 nodes = {}
2174 for cmd in todo:
2175 passed_targets = cmd.kwargs.pop("target_nodes", None)
2176 command_policies = await client._policy_resolver.resolve(
2177 cmd.args[0].lower()
2178 )
2179
2180 if passed_targets and not client._is_node_flag(passed_targets):
2181 target_nodes = client._parse_target_nodes(passed_targets)
2182
2183 if not command_policies:
2184 command_policies = CommandPolicies()
2185 else:
2186 if not command_policies:
2187 command_flag = client.command_flags.get(cmd.args[0])
2188 if not command_flag:
2189 # Fallback to default policy
2190 if not client.get_default_node():
2191 slot = None
2192 else:
2193 slot = await client._determine_slot(*cmd.args)
2194 if slot is None:
2195 command_policies = CommandPolicies()
2196 else:
2197 command_policies = CommandPolicies(
2198 request_policy=RequestPolicy.DEFAULT_KEYED,
2199 response_policy=ResponsePolicy.DEFAULT_KEYED,
2200 )
2201 else:
2202 if command_flag in client._command_flags_mapping:
2203 command_policies = CommandPolicies(
2204 request_policy=client._command_flags_mapping[
2205 command_flag
2206 ]
2207 )
2208 else:
2209 command_policies = CommandPolicies()
2210
2211 target_nodes = await client._determine_nodes(
2212 *cmd.args,
2213 request_policy=command_policies.request_policy,
2214 node_flag=passed_targets,
2215 )
2216 if not target_nodes:
2217 raise RedisClusterException(
2218 f"No targets were found to execute {cmd.args} command on"
2219 )
2220 cmd.command_policies = command_policies
2221 if len(target_nodes) > 1:
2222 raise RedisClusterException(f"Too many targets for command {cmd.args}")
2223 node = target_nodes[0]
2224 if node.name not in nodes:
2225 nodes[node.name] = (node, [])
2226 nodes[node.name][1].append(cmd)
2227
2228 errors = await asyncio.gather(
2229 *(
2230 asyncio.create_task(node[0].execute_pipeline(node[1]))
2231 for node in nodes.values()
2232 )
2233 )
2234
2235 if any(errors):
2236 if allow_redirections:
2237 # send each errored command individually
2238 for cmd in todo:
2239 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2240 try:
2241 cmd.result = client._policies_callback_mapping[
2242 cmd.command_policies.response_policy
2243 ](await client.execute_command(*cmd.args, **cmd.kwargs))
2244 except Exception as e:
2245 cmd.result = e
2246
2247 if raise_on_error:
2248 for cmd in todo:
2249 result = cmd.result
2250 if isinstance(result, Exception):
2251 command = " ".join(map(safe_str, cmd.args))
2252 msg = (
2253 f"Command # {cmd.position + 1} "
2254 f"({truncate_text(command)}) "
2255 f"of pipeline caused error: {result.args}"
2256 )
2257 result.args = (msg,) + result.args[1:]
2258 raise result
2259
2260 default_cluster_node = client.get_default_node()
2261
2262 # Check whether the default node was used. In some cases,
2263 # 'client.get_default_node()' may return None. The check below
2264 # prevents a potential AttributeError.
2265 if default_cluster_node is not None:
2266 default_node = nodes.get(default_cluster_node.name)
2267 if default_node is not None:
2268 # This pipeline execution used the default node, check if we need
2269 # to replace it.
2270 # Note: when the error is raised we'll reset the default node in the
2271 # caller function.
2272 for cmd in default_node[1]:
2273 # Check if it has a command that failed with a relevant
2274 # exception
2275 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2276 client.replace_default_node()
2277 break
2278
2279 return [cmd.result for cmd in stack]
2280
2281 async def reset(self):
2282 """
2283 Reset back to empty pipeline.
2284 """
2285 self._command_queue = []
2286
2287 def multi(self):
2288 raise RedisClusterException(
2289 "method multi() is not supported outside of transactional context"
2290 )
2291
2292 async def watch(self, *names):
2293 raise RedisClusterException(
2294 "method watch() is not supported outside of transactional context"
2295 )
2296
2297 async def unwatch(self):
2298 raise RedisClusterException(
2299 "method unwatch() is not supported outside of transactional context"
2300 )
2301
2302 async def discard(self):
2303 raise RedisClusterException(
2304 "method discard() is not supported outside of transactional context"
2305 )
2306
2307 async def unlink(self, *names):
2308 if len(names) != 1:
2309 raise RedisClusterException(
2310 "unlinking multiple keys is not implemented in pipeline command"
2311 )
2312
2313 return self.execute_command("UNLINK", names[0])
2314
2315
2316class TransactionStrategy(AbstractStrategy):
2317 NO_SLOTS_COMMANDS = {"UNWATCH"}
2318 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2319 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2320 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2321 CONNECTION_ERRORS = (
2322 ConnectionError,
2323 OSError,
2324 ClusterDownError,
2325 SlotNotCoveredError,
2326 )
2327
2328 def __init__(self, pipe: ClusterPipeline) -> None:
2329 super().__init__(pipe)
2330 self._explicit_transaction = False
2331 self._watching = False
2332 self._pipeline_slots: Set[int] = set()
2333 self._transaction_node: Optional[ClusterNode] = None
2334 self._transaction_connection: Optional[Connection] = None
2335 self._executing = False
2336 self._retry = copy(self._pipe.cluster_client.retry)
2337 self._retry.update_supported_errors(
2338 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2339 )
2340
2341 def _get_client_and_connection_for_transaction(
2342 self,
2343 ) -> Tuple[ClusterNode, Connection]:
2344 """
2345 Find a connection for a pipeline transaction.
2346
2347 For running an atomic transaction, watch keys ensure that contents have not been
2348 altered as long as the watch commands for those keys were sent over the same
2349 connection. So once we start watching a key, we fetch a connection to the
2350 node that owns that slot and reuse it.
2351 """
2352 if not self._pipeline_slots:
2353 raise RedisClusterException(
2354 "At least a command with a key is needed to identify a node"
2355 )
2356
2357 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2358 list(self._pipeline_slots)[0], False
2359 )
2360 self._transaction_node = node
2361
2362 if not self._transaction_connection:
2363 connection: Connection = self._transaction_node.acquire_connection()
2364 self._transaction_connection = connection
2365
2366 return self._transaction_node, self._transaction_connection
2367
2368 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2369 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2370 response = None
2371 error = None
2372
2373 def runner():
2374 nonlocal response
2375 nonlocal error
2376 try:
2377 response = asyncio.run(self._execute_command(*args, **kwargs))
2378 except Exception as e:
2379 error = e
2380
2381 thread = threading.Thread(target=runner)
2382 thread.start()
2383 thread.join()
2384
2385 if error:
2386 raise error
2387
2388 return response
2389
2390 async def _execute_command(
2391 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2392 ) -> Any:
2393 if self._pipe.cluster_client._initialize:
2394 await self._pipe.cluster_client.initialize()
2395
2396 slot_number: Optional[int] = None
2397 if args[0] not in self.NO_SLOTS_COMMANDS:
2398 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2399
2400 if (
2401 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2402 ) and not self._explicit_transaction:
2403 if args[0] == "WATCH":
2404 self._validate_watch()
2405
2406 if slot_number is not None:
2407 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2408 raise CrossSlotTransactionError(
2409 "Cannot watch or send commands on different slots"
2410 )
2411
2412 self._pipeline_slots.add(slot_number)
2413 elif args[0] not in self.NO_SLOTS_COMMANDS:
2414 raise RedisClusterException(
2415 f"Cannot identify slot number for command: {args[0]},"
2416 "it cannot be triggered in a transaction"
2417 )
2418
2419 return self._immediate_execute_command(*args, **kwargs)
2420 else:
2421 if slot_number is not None:
2422 self._pipeline_slots.add(slot_number)
2423
2424 return super().execute_command(*args, **kwargs)
2425
2426 def _validate_watch(self):
2427 if self._explicit_transaction:
2428 raise RedisError("Cannot issue a WATCH after a MULTI")
2429
2430 self._watching = True
2431
2432 async def _immediate_execute_command(self, *args, **options):
2433 return await self._retry.call_with_retry(
2434 lambda: self._get_connection_and_send_command(*args, **options),
2435 self._reinitialize_on_error,
2436 )
2437
2438 async def _get_connection_and_send_command(self, *args, **options):
2439 redis_node, connection = self._get_client_and_connection_for_transaction()
2440 # Only disconnect if not watching - disconnecting would lose WATCH state
2441 if not self._watching:
2442 await redis_node.disconnect_if_needed(connection)
2443 return await self._send_command_parse_response(
2444 connection, redis_node, args[0], *args, **options
2445 )
2446
2447 async def _send_command_parse_response(
2448 self,
2449 connection: Connection,
2450 redis_node: ClusterNode,
2451 command_name,
2452 *args,
2453 **options,
2454 ):
2455 """
2456 Send a command and parse the response
2457 """
2458
2459 await connection.send_command(*args)
2460 output = await redis_node.parse_response(connection, command_name, **options)
2461
2462 if command_name in self.UNWATCH_COMMANDS:
2463 self._watching = False
2464 return output
2465
2466 async def _reinitialize_on_error(self, error):
2467 if self._watching:
2468 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2469 raise WatchError("Slot rebalancing occurred while watching keys")
2470
2471 if (
2472 type(error) in self.SLOT_REDIRECT_ERRORS
2473 or type(error) in self.CONNECTION_ERRORS
2474 ):
2475 if self._transaction_connection and self._transaction_node:
2476 # Disconnect and release back to pool
2477 await self._transaction_connection.disconnect()
2478 self._transaction_node.release(self._transaction_connection)
2479 self._transaction_connection = None
2480
2481 self._pipe.cluster_client.reinitialize_counter += 1
2482 if (
2483 self._pipe.cluster_client.reinitialize_steps
2484 and self._pipe.cluster_client.reinitialize_counter
2485 % self._pipe.cluster_client.reinitialize_steps
2486 == 0
2487 ):
2488 await self._pipe.cluster_client.nodes_manager.initialize()
2489 self.reinitialize_counter = 0
2490 else:
2491 if isinstance(error, AskError):
2492 self._pipe.cluster_client.nodes_manager.move_slot(error)
2493
2494 self._executing = False
2495
2496 def _raise_first_error(self, responses, stack):
2497 """
2498 Raise the first exception on the stack
2499 """
2500 for r, cmd in zip(responses, stack):
2501 if isinstance(r, Exception):
2502 self._annotate_exception(r, cmd.position + 1, cmd.args)
2503 raise r
2504
2505 def mset_nonatomic(
2506 self, mapping: Mapping[AnyKeyT, EncodableT]
2507 ) -> "ClusterPipeline":
2508 raise NotImplementedError("Method is not supported in transactional context.")
2509
2510 async def execute(
2511 self, raise_on_error: bool = True, allow_redirections: bool = True
2512 ) -> List[Any]:
2513 stack = self._command_queue
2514 if not stack and (not self._watching or not self._pipeline_slots):
2515 return []
2516
2517 return await self._execute_transaction_with_retries(stack, raise_on_error)
2518
2519 async def _execute_transaction_with_retries(
2520 self, stack: List["PipelineCommand"], raise_on_error: bool
2521 ):
2522 return await self._retry.call_with_retry(
2523 lambda: self._execute_transaction(stack, raise_on_error),
2524 self._reinitialize_on_error,
2525 )
2526
2527 async def _execute_transaction(
2528 self, stack: List["PipelineCommand"], raise_on_error: bool
2529 ):
2530 if len(self._pipeline_slots) > 1:
2531 raise CrossSlotTransactionError(
2532 "All keys involved in a cluster transaction must map to the same slot"
2533 )
2534
2535 self._executing = True
2536
2537 redis_node, connection = self._get_client_and_connection_for_transaction()
2538 # Only disconnect if not watching - disconnecting would lose WATCH state
2539 if not self._watching:
2540 await redis_node.disconnect_if_needed(connection)
2541
2542 stack = chain(
2543 [PipelineCommand(0, "MULTI")],
2544 stack,
2545 [PipelineCommand(0, "EXEC")],
2546 )
2547 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2548 packed_commands = connection.pack_commands(commands)
2549 await connection.send_packed_command(packed_commands)
2550 errors = []
2551
2552 # parse off the response for MULTI
2553 # NOTE: we need to handle ResponseErrors here and continue
2554 # so that we read all the additional command messages from
2555 # the socket
2556 try:
2557 await redis_node.parse_response(connection, "MULTI")
2558 except ResponseError as e:
2559 self._annotate_exception(e, 0, "MULTI")
2560 errors.append(e)
2561 except self.CONNECTION_ERRORS as cluster_error:
2562 self._annotate_exception(cluster_error, 0, "MULTI")
2563 raise
2564
2565 # and all the other commands
2566 for i, command in enumerate(self._command_queue):
2567 if EMPTY_RESPONSE in command.kwargs:
2568 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2569 else:
2570 try:
2571 _ = await redis_node.parse_response(connection, "_")
2572 except self.SLOT_REDIRECT_ERRORS as slot_error:
2573 self._annotate_exception(slot_error, i + 1, command.args)
2574 errors.append(slot_error)
2575 except self.CONNECTION_ERRORS as cluster_error:
2576 self._annotate_exception(cluster_error, i + 1, command.args)
2577 raise
2578 except ResponseError as e:
2579 self._annotate_exception(e, i + 1, command.args)
2580 errors.append(e)
2581
2582 response = None
2583 # parse the EXEC.
2584 try:
2585 response = await redis_node.parse_response(connection, "EXEC")
2586 except ExecAbortError:
2587 if errors:
2588 raise errors[0]
2589 raise
2590
2591 self._executing = False
2592
2593 # EXEC clears any watched keys
2594 self._watching = False
2595
2596 if response is None:
2597 raise WatchError("Watched variable changed.")
2598
2599 # put any parse errors into the response
2600 for i, e in errors:
2601 response.insert(i, e)
2602
2603 if len(response) != len(self._command_queue):
2604 raise InvalidPipelineStack(
2605 "Unexpected response length for cluster pipeline EXEC."
2606 " Command stack was {} but response had length {}".format(
2607 [c.args[0] for c in self._command_queue], len(response)
2608 )
2609 )
2610
2611 # find any errors in the response and raise if necessary
2612 if raise_on_error or len(errors) > 0:
2613 self._raise_first_error(
2614 response,
2615 self._command_queue,
2616 )
2617
2618 # We have to run response callbacks manually
2619 data = []
2620 for r, cmd in zip(response, self._command_queue):
2621 if not isinstance(r, Exception):
2622 command_name = cmd.args[0]
2623 if command_name in self._pipe.cluster_client.response_callbacks:
2624 r = self._pipe.cluster_client.response_callbacks[command_name](
2625 r, **cmd.kwargs
2626 )
2627 data.append(r)
2628 return data
2629
2630 async def reset(self):
2631 self._command_queue = []
2632
2633 # make sure to reset the connection state in the event that we were
2634 # watching something
2635 if self._transaction_connection:
2636 try:
2637 if self._watching:
2638 # call this manually since our unwatch or
2639 # immediate_execute_command methods can call reset()
2640 await self._transaction_connection.send_command("UNWATCH")
2641 await self._transaction_connection.read_response()
2642 # we can safely return the connection to the pool here since we're
2643 # sure we're no longer WATCHing anything
2644 self._transaction_node.release(self._transaction_connection)
2645 self._transaction_connection = None
2646 except self.CONNECTION_ERRORS:
2647 # disconnect will also remove any previous WATCHes
2648 if self._transaction_connection and self._transaction_node:
2649 await self._transaction_connection.disconnect()
2650 self._transaction_node.release(self._transaction_connection)
2651 self._transaction_connection = None
2652
2653 # clean up the other instance attributes
2654 self._transaction_node = None
2655 self._watching = False
2656 self._explicit_transaction = False
2657 self._pipeline_slots = set()
2658 self._executing = False
2659
2660 def multi(self):
2661 if self._explicit_transaction:
2662 raise RedisError("Cannot issue nested calls to MULTI")
2663 if self._command_queue:
2664 raise RedisError(
2665 "Commands without an initial WATCH have already been issued"
2666 )
2667 self._explicit_transaction = True
2668
2669 async def watch(self, *names):
2670 if self._explicit_transaction:
2671 raise RedisError("Cannot issue a WATCH after a MULTI")
2672
2673 return await self.execute_command("WATCH", *names)
2674
2675 async def unwatch(self):
2676 if self._watching:
2677 return await self.execute_command("UNWATCH")
2678
2679 return True
2680
2681 async def discard(self):
2682 await self.reset()
2683
2684 async def unlink(self, *names):
2685 return self.execute_command("UNLINK", *names)