1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 Any,
13 Callable,
14 Coroutine,
15 Deque,
16 Dict,
17 Generator,
18 List,
19 Mapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26)
27
28from redis._parsers import AsyncCommandsParser, Encoder
29from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
30from redis._parsers.helpers import (
31 _RedisCallbacks,
32 _RedisCallbacksRESP2,
33 _RedisCallbacksRESP3,
34)
35from redis.asyncio.client import ResponseCallbackT
36from redis.asyncio.connection import Connection, SSLConnection, parse_url
37from redis.asyncio.lock import Lock
38from redis.asyncio.retry import Retry
39from redis.auth.token import TokenInterface
40from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
41from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
42from redis.cluster import (
43 PIPELINE_BLOCKED_COMMANDS,
44 PRIMARY,
45 REPLICA,
46 SLOT_ID,
47 AbstractRedisCluster,
48 LoadBalancer,
49 LoadBalancingStrategy,
50 block_pipeline_command,
51 get_node_name,
52 parse_cluster_slots,
53)
54from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
55from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
56from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
57from redis.credentials import CredentialProvider
58from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
59from redis.exceptions import (
60 AskError,
61 BusyLoadingError,
62 ClusterDownError,
63 ClusterError,
64 ConnectionError,
65 CrossSlotTransactionError,
66 DataError,
67 ExecAbortError,
68 InvalidPipelineStack,
69 MaxConnectionsError,
70 MovedError,
71 RedisClusterException,
72 RedisError,
73 ResponseError,
74 SlotNotCoveredError,
75 TimeoutError,
76 TryAgainError,
77 WatchError,
78)
79from redis.typing import AnyKeyT, EncodableT, KeyT
80from redis.utils import (
81 SSL_AVAILABLE,
82 deprecated_args,
83 deprecated_function,
84 get_lib_version,
85 safe_str,
86 str_if_bytes,
87 truncate_text,
88)
89
90if SSL_AVAILABLE:
91 from ssl import TLSVersion, VerifyFlags, VerifyMode
92else:
93 TLSVersion = None
94 VerifyMode = None
95 VerifyFlags = None
96
97TargetNodesT = TypeVar(
98 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
99)
100
101
102class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
103 """
104 Create a new RedisCluster client.
105
106 Pass one of parameters:
107
108 - `host` & `port`
109 - `startup_nodes`
110
111 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
112 | Use ``await`` :meth:`close` to disconnect connections & close client.
113
114 Many commands support the target_nodes kwarg. It can be one of the
115 :attr:`NODE_FLAGS`:
116
117 - :attr:`PRIMARIES`
118 - :attr:`REPLICAS`
119 - :attr:`ALL_NODES`
120 - :attr:`RANDOM`
121 - :attr:`DEFAULT_NODE`
122
123 Note: This client is not thread/process/fork safe.
124
125 :param host:
126 | Can be used to point to a startup node
127 :param port:
128 | Port used if **host** is provided
129 :param startup_nodes:
130 | :class:`~.ClusterNode` to used as a startup node
131 :param require_full_coverage:
132 | When set to ``False``: the client will not require a full coverage of
133 the slots. However, if not all slots are covered, and at least one node
134 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
135 a :class:`~.ClusterDownError` for some key-based commands.
136 | When set to ``True``: all slots must be covered to construct the cluster
137 client. If not all slots are covered, :class:`~.RedisClusterException` will be
138 thrown.
139 | See:
140 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
141 :param read_from_replicas:
142 | @deprecated - please use load_balancing_strategy instead
143 | Enable read from replicas in READONLY mode.
144 When set to true, read commands will be assigned between the primary and
145 its replications in a Round-Robin manner.
146 The data read from replicas is eventually consistent with the data in primary nodes.
147 :param load_balancing_strategy:
148 | Enable read from replicas in READONLY mode and defines the load balancing
149 strategy that will be used for cluster node selection.
150 The data read from replicas is eventually consistent with the data in primary nodes.
151 :param dynamic_startup_nodes:
152 | Set the RedisCluster's startup nodes to all the discovered nodes.
153 If true (default value), the cluster's discovered nodes will be used to
154 determine the cluster nodes-slots mapping in the next topology refresh.
155 It will remove the initial passed startup nodes if their endpoints aren't
156 listed in the CLUSTER SLOTS output.
157 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
158 specific IP addresses, it is best to set it to false.
159 :param reinitialize_steps:
160 | Specifies the number of MOVED errors that need to occur before reinitializing
161 the whole cluster topology. If a MOVED error occurs and the cluster does not
162 need to be reinitialized on this current error handling, only the MOVED slot
163 will be patched with the redirected node.
164 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
165 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
166 0.
167 :param cluster_error_retry_attempts:
168 | @deprecated - Please configure the 'retry' object instead
169 In case 'retry' object is set - this argument is ignored!
170
171 Number of times to retry before raising an error when :class:`~.TimeoutError`,
172 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
173 or :class:`~.ClusterDownError` are encountered
174 :param retry:
175 | A retry object that defines the retry strategy and the number of
176 retries for the cluster client.
177 In current implementation for the cluster client (starting form redis-py version 6.0.0)
178 the retry object is not yet fully utilized, instead it is used just to determine
179 the number of retries for the cluster client.
180 In the future releases the retry object will be used to handle the cluster client retries!
181 :param max_connections:
182 | Maximum number of connections per node. If there are no free connections & the
183 maximum number of connections are already created, a
184 :class:`~.MaxConnectionsError` is raised.
185 :param address_remap:
186 | An optional callable which, when provided with an internal network
187 address of a node, e.g. a `(host, port)` tuple, will return the address
188 where the node is reachable. This can be used to map the addresses at
189 which the nodes _think_ they are, to addresses at which a client may
190 reach them, such as when they sit behind a proxy.
191
192 | Rest of the arguments will be passed to the
193 :class:`~redis.asyncio.connection.Connection` instances when created
194
195 :raises RedisClusterException:
196 if any arguments are invalid or unknown. Eg:
197
198 - `db` != 0 or None
199 - `path` argument for unix socket connection
200 - none of the `host`/`port` & `startup_nodes` were provided
201
202 """
203
204 @classmethod
205 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
206 """
207 Return a Redis client object configured from the given URL.
208
209 For example::
210
211 redis://[[username]:[password]]@localhost:6379/0
212 rediss://[[username]:[password]]@localhost:6379/0
213
214 Three URL schemes are supported:
215
216 - `redis://` creates a TCP socket connection. See more at:
217 <https://www.iana.org/assignments/uri-schemes/prov/redis>
218 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
219 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
220
221 The username, password, hostname, path and all querystring values are passed
222 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
223 with their corresponding characters.
224
225 All querystring options are cast to their appropriate Python types. Boolean
226 arguments can be specified with string values "True"/"False" or "Yes"/"No".
227 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
228 parsed, the querystring arguments and keyword arguments are passed to
229 :class:`~redis.asyncio.connection.Connection` when created.
230 In the case of conflicting arguments, querystring arguments are used.
231 """
232 kwargs.update(parse_url(url))
233 if kwargs.pop("connection_class", None) is SSLConnection:
234 kwargs["ssl"] = True
235 return cls(**kwargs)
236
237 __slots__ = (
238 "_initialize",
239 "_lock",
240 "retry",
241 "command_flags",
242 "commands_parser",
243 "connection_kwargs",
244 "encoder",
245 "node_flags",
246 "nodes_manager",
247 "read_from_replicas",
248 "reinitialize_counter",
249 "reinitialize_steps",
250 "response_callbacks",
251 "result_callbacks",
252 )
253
254 @deprecated_args(
255 args_to_warn=["read_from_replicas"],
256 reason="Please configure the 'load_balancing_strategy' instead",
257 version="5.3.0",
258 )
259 @deprecated_args(
260 args_to_warn=[
261 "cluster_error_retry_attempts",
262 ],
263 reason="Please configure the 'retry' object instead",
264 version="6.0.0",
265 )
266 def __init__(
267 self,
268 host: Optional[str] = None,
269 port: Union[str, int] = 6379,
270 # Cluster related kwargs
271 startup_nodes: Optional[List["ClusterNode"]] = None,
272 require_full_coverage: bool = True,
273 read_from_replicas: bool = False,
274 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
275 dynamic_startup_nodes: bool = True,
276 reinitialize_steps: int = 5,
277 cluster_error_retry_attempts: int = 3,
278 max_connections: int = 2**31,
279 retry: Optional["Retry"] = None,
280 retry_on_error: Optional[List[Type[Exception]]] = None,
281 # Client related kwargs
282 db: Union[str, int] = 0,
283 path: Optional[str] = None,
284 credential_provider: Optional[CredentialProvider] = None,
285 username: Optional[str] = None,
286 password: Optional[str] = None,
287 client_name: Optional[str] = None,
288 lib_name: Optional[str] = "redis-py",
289 lib_version: Optional[str] = get_lib_version(),
290 # Encoding related kwargs
291 encoding: str = "utf-8",
292 encoding_errors: str = "strict",
293 decode_responses: bool = False,
294 # Connection related kwargs
295 health_check_interval: float = 0,
296 socket_connect_timeout: Optional[float] = None,
297 socket_keepalive: bool = False,
298 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
299 socket_timeout: Optional[float] = None,
300 # SSL related kwargs
301 ssl: bool = False,
302 ssl_ca_certs: Optional[str] = None,
303 ssl_ca_data: Optional[str] = None,
304 ssl_cert_reqs: Union[str, VerifyMode] = "required",
305 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
306 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
307 ssl_certfile: Optional[str] = None,
308 ssl_check_hostname: bool = True,
309 ssl_keyfile: Optional[str] = None,
310 ssl_min_version: Optional[TLSVersion] = None,
311 ssl_ciphers: Optional[str] = None,
312 protocol: Optional[int] = 2,
313 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
314 event_dispatcher: Optional[EventDispatcher] = None,
315 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
316 ) -> None:
317 if db:
318 raise RedisClusterException(
319 "Argument 'db' must be 0 or None in cluster mode"
320 )
321
322 if path:
323 raise RedisClusterException(
324 "Unix domain socket is not supported in cluster mode"
325 )
326
327 if (not host or not port) and not startup_nodes:
328 raise RedisClusterException(
329 "RedisCluster requires at least one node to discover the cluster.\n"
330 "Please provide one of the following or use RedisCluster.from_url:\n"
331 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
332 " - startup_nodes: RedisCluster(startup_nodes=["
333 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
334 )
335
336 kwargs: Dict[str, Any] = {
337 "max_connections": max_connections,
338 "connection_class": Connection,
339 # Client related kwargs
340 "credential_provider": credential_provider,
341 "username": username,
342 "password": password,
343 "client_name": client_name,
344 "lib_name": lib_name,
345 "lib_version": lib_version,
346 # Encoding related kwargs
347 "encoding": encoding,
348 "encoding_errors": encoding_errors,
349 "decode_responses": decode_responses,
350 # Connection related kwargs
351 "health_check_interval": health_check_interval,
352 "socket_connect_timeout": socket_connect_timeout,
353 "socket_keepalive": socket_keepalive,
354 "socket_keepalive_options": socket_keepalive_options,
355 "socket_timeout": socket_timeout,
356 "protocol": protocol,
357 }
358
359 if ssl:
360 # SSL related kwargs
361 kwargs.update(
362 {
363 "connection_class": SSLConnection,
364 "ssl_ca_certs": ssl_ca_certs,
365 "ssl_ca_data": ssl_ca_data,
366 "ssl_cert_reqs": ssl_cert_reqs,
367 "ssl_include_verify_flags": ssl_include_verify_flags,
368 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
369 "ssl_certfile": ssl_certfile,
370 "ssl_check_hostname": ssl_check_hostname,
371 "ssl_keyfile": ssl_keyfile,
372 "ssl_min_version": ssl_min_version,
373 "ssl_ciphers": ssl_ciphers,
374 }
375 )
376
377 if read_from_replicas or load_balancing_strategy:
378 # Call our on_connect function to configure READONLY mode
379 kwargs["redis_connect_func"] = self.on_connect
380
381 if retry:
382 self.retry = retry
383 else:
384 self.retry = Retry(
385 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
386 retries=cluster_error_retry_attempts,
387 )
388 if retry_on_error:
389 self.retry.update_supported_errors(retry_on_error)
390
391 kwargs["response_callbacks"] = _RedisCallbacks.copy()
392 if kwargs.get("protocol") in ["3", 3]:
393 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
394 else:
395 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
396 self.connection_kwargs = kwargs
397
398 if startup_nodes:
399 passed_nodes = []
400 for node in startup_nodes:
401 passed_nodes.append(
402 ClusterNode(node.host, node.port, **self.connection_kwargs)
403 )
404 startup_nodes = passed_nodes
405 else:
406 startup_nodes = []
407 if host and port:
408 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
409
410 if event_dispatcher is None:
411 self._event_dispatcher = EventDispatcher()
412 else:
413 self._event_dispatcher = event_dispatcher
414
415 self.startup_nodes = startup_nodes
416 self.nodes_manager = NodesManager(
417 startup_nodes,
418 require_full_coverage,
419 kwargs,
420 dynamic_startup_nodes=dynamic_startup_nodes,
421 address_remap=address_remap,
422 event_dispatcher=self._event_dispatcher,
423 )
424 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
425 self.read_from_replicas = read_from_replicas
426 self.load_balancing_strategy = load_balancing_strategy
427 self.reinitialize_steps = reinitialize_steps
428 self.reinitialize_counter = 0
429
430 # For backward compatibility, mapping from existing policies to new one
431 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
432 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
433 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
434 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
435 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
436 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
437 SLOT_ID: RequestPolicy.DEFAULT_KEYED,
438 }
439
440 self._policies_callback_mapping: dict[
441 Union[RequestPolicy, ResponsePolicy], Callable
442 ] = {
443 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
444 self.get_random_primary_or_all_nodes(command_name)
445 ],
446 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
447 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
448 RequestPolicy.ALL_SHARDS: self.get_primaries,
449 RequestPolicy.ALL_NODES: self.get_nodes,
450 RequestPolicy.ALL_REPLICAS: self.get_replicas,
451 RequestPolicy.SPECIAL: self.get_special_nodes,
452 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
453 ResponsePolicy.DEFAULT_KEYED: lambda res: res,
454 }
455
456 self._policy_resolver = policy_resolver
457 self.commands_parser = AsyncCommandsParser()
458 self._aggregate_nodes = None
459 self.node_flags = self.__class__.NODE_FLAGS.copy()
460 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
461 self.response_callbacks = kwargs["response_callbacks"]
462 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
463 self.result_callbacks["CLUSTER SLOTS"] = (
464 lambda cmd, res, **kwargs: parse_cluster_slots(
465 list(res.values())[0], **kwargs
466 )
467 )
468
469 self._initialize = True
470 self._lock: Optional[asyncio.Lock] = None
471
472 # When used as an async context manager, we need to increment and decrement
473 # a usage counter so that we can close the connection pool when no one is
474 # using the client.
475 self._usage_counter = 0
476 self._usage_lock = asyncio.Lock()
477
478 async def initialize(self) -> "RedisCluster":
479 """Get all nodes from startup nodes & creates connections if not initialized."""
480 if self._initialize:
481 if not self._lock:
482 self._lock = asyncio.Lock()
483 async with self._lock:
484 if self._initialize:
485 try:
486 await self.nodes_manager.initialize()
487 await self.commands_parser.initialize(
488 self.nodes_manager.default_node
489 )
490 self._initialize = False
491 except BaseException:
492 await self.nodes_manager.aclose()
493 await self.nodes_manager.aclose("startup_nodes")
494 raise
495 return self
496
497 async def aclose(self) -> None:
498 """Close all connections & client if initialized."""
499 if not self._initialize:
500 if not self._lock:
501 self._lock = asyncio.Lock()
502 async with self._lock:
503 if not self._initialize:
504 self._initialize = True
505 await self.nodes_manager.aclose()
506 await self.nodes_manager.aclose("startup_nodes")
507
508 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
509 async def close(self) -> None:
510 """alias for aclose() for backwards compatibility"""
511 await self.aclose()
512
513 async def __aenter__(self) -> "RedisCluster":
514 """
515 Async context manager entry. Increments a usage counter so that the
516 connection pool is only closed (via aclose()) when no context is using
517 the client.
518 """
519 await self._increment_usage()
520 try:
521 # Initialize the client (i.e. establish connection, etc.)
522 return await self.initialize()
523 except Exception:
524 # If initialization fails, decrement the counter to keep it in sync
525 await self._decrement_usage()
526 raise
527
528 async def _increment_usage(self) -> int:
529 """
530 Helper coroutine to increment the usage counter while holding the lock.
531 Returns the new value of the usage counter.
532 """
533 async with self._usage_lock:
534 self._usage_counter += 1
535 return self._usage_counter
536
537 async def _decrement_usage(self) -> int:
538 """
539 Helper coroutine to decrement the usage counter while holding the lock.
540 Returns the new value of the usage counter.
541 """
542 async with self._usage_lock:
543 self._usage_counter -= 1
544 return self._usage_counter
545
546 async def __aexit__(self, exc_type, exc_value, traceback):
547 """
548 Async context manager exit. Decrements a usage counter. If this is the
549 last exit (counter becomes zero), the client closes its connection pool.
550 """
551 current_usage = await asyncio.shield(self._decrement_usage())
552 if current_usage == 0:
553 # This was the last active context, so disconnect the pool.
554 await asyncio.shield(self.aclose())
555
556 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
557 return self.initialize().__await__()
558
559 _DEL_MESSAGE = "Unclosed RedisCluster client"
560
561 def __del__(
562 self,
563 _warn: Any = warnings.warn,
564 _grl: Any = asyncio.get_running_loop,
565 ) -> None:
566 if hasattr(self, "_initialize") and not self._initialize:
567 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
568 try:
569 context = {"client": self, "message": self._DEL_MESSAGE}
570 _grl().call_exception_handler(context)
571 except RuntimeError:
572 pass
573
574 async def on_connect(self, connection: Connection) -> None:
575 await connection.on_connect()
576
577 # Sending READONLY command to server to configure connection as
578 # readonly. Since each cluster node may change its server type due
579 # to a failover, we should establish a READONLY connection
580 # regardless of the server type. If this is a primary connection,
581 # READONLY would not affect executing write commands.
582 await connection.send_command("READONLY")
583 if str_if_bytes(await connection.read_response()) != "OK":
584 raise ConnectionError("READONLY command failed")
585
586 def get_nodes(self) -> List["ClusterNode"]:
587 """Get all nodes of the cluster."""
588 return list(self.nodes_manager.nodes_cache.values())
589
590 def get_primaries(self) -> List["ClusterNode"]:
591 """Get the primary nodes of the cluster."""
592 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
593
594 def get_replicas(self) -> List["ClusterNode"]:
595 """Get the replica nodes of the cluster."""
596 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
597
598 def get_random_node(self) -> "ClusterNode":
599 """Get a random node of the cluster."""
600 return random.choice(list(self.nodes_manager.nodes_cache.values()))
601
602 def get_default_node(self) -> "ClusterNode":
603 """Get the default node of the client."""
604 return self.nodes_manager.default_node
605
606 def set_default_node(self, node: "ClusterNode") -> None:
607 """
608 Set the default node of the client.
609
610 :raises DataError: if None is passed or node does not exist in cluster.
611 """
612 if not node or not self.get_node(node_name=node.name):
613 raise DataError("The requested node does not exist in the cluster.")
614
615 self.nodes_manager.default_node = node
616
617 def get_node(
618 self,
619 host: Optional[str] = None,
620 port: Optional[int] = None,
621 node_name: Optional[str] = None,
622 ) -> Optional["ClusterNode"]:
623 """Get node by (host, port) or node_name."""
624 return self.nodes_manager.get_node(host, port, node_name)
625
626 def get_node_from_key(
627 self, key: str, replica: bool = False
628 ) -> Optional["ClusterNode"]:
629 """
630 Get the cluster node corresponding to the provided key.
631
632 :param key:
633 :param replica:
634 | Indicates if a replica should be returned
635 |
636 None will returned if no replica holds this key
637
638 :raises SlotNotCoveredError: if the key is not covered by any slot.
639 """
640 slot = self.keyslot(key)
641 slot_cache = self.nodes_manager.slots_cache.get(slot)
642 if not slot_cache:
643 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
644
645 if replica:
646 if len(self.nodes_manager.slots_cache[slot]) < 2:
647 return None
648 node_idx = 1
649 else:
650 node_idx = 0
651
652 return slot_cache[node_idx]
653
654 def get_random_primary_or_all_nodes(self, command_name):
655 """
656 Returns random primary or all nodes depends on READONLY mode.
657 """
658 if self.read_from_replicas and command_name in READ_COMMANDS:
659 return self.get_random_node()
660
661 return self.get_random_primary_node()
662
663 def get_random_primary_node(self) -> "ClusterNode":
664 """
665 Returns a random primary node
666 """
667 return random.choice(self.get_primaries())
668
669 async def get_nodes_from_slot(self, command: str, *args):
670 """
671 Returns a list of nodes that hold the specified keys' slots.
672 """
673 # get the node that holds the key's slot
674 return [
675 self.nodes_manager.get_node_from_slot(
676 await self._determine_slot(command, *args),
677 self.read_from_replicas and command in READ_COMMANDS,
678 self.load_balancing_strategy if command in READ_COMMANDS else None,
679 )
680 ]
681
682 def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
683 """
684 Returns a list of nodes for commands with a special policy.
685 """
686 if not self._aggregate_nodes:
687 raise RedisClusterException(
688 "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
689 )
690
691 return self._aggregate_nodes
692
693 def keyslot(self, key: EncodableT) -> int:
694 """
695 Find the keyslot for a given key.
696
697 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
698 """
699 return key_slot(self.encoder.encode(key))
700
701 def get_encoder(self) -> Encoder:
702 """Get the encoder object of the client."""
703 return self.encoder
704
705 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
706 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
707 return self.connection_kwargs
708
709 def set_retry(self, retry: Retry) -> None:
710 self.retry = retry
711
712 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
713 """Set a custom response callback."""
714 self.response_callbacks[command] = callback
715
716 async def _determine_nodes(
717 self,
718 command: str,
719 *args: Any,
720 request_policy: RequestPolicy,
721 node_flag: Optional[str] = None,
722 ) -> List["ClusterNode"]:
723 # Determine which nodes should be executed the command on.
724 # Returns a list of target nodes.
725 if not node_flag:
726 # get the nodes group for this command if it was predefined
727 node_flag = self.command_flags.get(command)
728
729 if node_flag in self._command_flags_mapping:
730 request_policy = self._command_flags_mapping[node_flag]
731
732 policy_callback = self._policies_callback_mapping[request_policy]
733
734 if request_policy == RequestPolicy.DEFAULT_KEYED:
735 nodes = await policy_callback(command, *args)
736 elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
737 nodes = policy_callback(command)
738 else:
739 nodes = policy_callback()
740
741 if command.lower() == "ft.aggregate":
742 self._aggregate_nodes = nodes
743
744 return nodes
745
746 async def _determine_slot(self, command: str, *args: Any) -> int:
747 if self.command_flags.get(command) == SLOT_ID:
748 # The command contains the slot ID
749 return int(args[0])
750
751 # Get the keys in the command
752
753 # EVAL and EVALSHA are common enough that it's wasteful to go to the
754 # redis server to parse the keys. Besides, there is a bug in redis<7.0
755 # where `self._get_command_keys()` fails anyway. So, we special case
756 # EVAL/EVALSHA.
757 # - issue: https://github.com/redis/redis/issues/9493
758 # - fix: https://github.com/redis/redis/pull/9733
759 if command.upper() in ("EVAL", "EVALSHA"):
760 # command syntax: EVAL "script body" num_keys ...
761 if len(args) < 2:
762 raise RedisClusterException(
763 f"Invalid args in command: {command, *args}"
764 )
765 keys = args[2 : 2 + int(args[1])]
766 # if there are 0 keys, that means the script can be run on any node
767 # so we can just return a random slot
768 if not keys:
769 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
770 else:
771 keys = await self.commands_parser.get_keys(command, *args)
772 if not keys:
773 # FCALL can call a function with 0 keys, that means the function
774 # can be run on any node so we can just return a random slot
775 if command.upper() in ("FCALL", "FCALL_RO"):
776 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
777 raise RedisClusterException(
778 "No way to dispatch this command to Redis Cluster. "
779 "Missing key.\nYou can execute the command by specifying "
780 f"target nodes.\nCommand: {args}"
781 )
782
783 # single key command
784 if len(keys) == 1:
785 return self.keyslot(keys[0])
786
787 # multi-key command; we need to make sure all keys are mapped to
788 # the same slot
789 slots = {self.keyslot(key) for key in keys}
790 if len(slots) != 1:
791 raise RedisClusterException(
792 f"{command} - all keys must map to the same key slot"
793 )
794
795 return slots.pop()
796
797 def _is_node_flag(self, target_nodes: Any) -> bool:
798 return isinstance(target_nodes, str) and target_nodes in self.node_flags
799
800 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
801 if isinstance(target_nodes, list):
802 nodes = target_nodes
803 elif isinstance(target_nodes, ClusterNode):
804 # Supports passing a single ClusterNode as a variable
805 nodes = [target_nodes]
806 elif isinstance(target_nodes, dict):
807 # Supports dictionaries of the format {node_name: node}.
808 # It enables to execute commands with multi nodes as follows:
809 # rc.cluster_save_config(rc.get_primaries())
810 nodes = list(target_nodes.values())
811 else:
812 raise TypeError(
813 "target_nodes type can be one of the following: "
814 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
815 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
816 f"The passed type is {type(target_nodes)}"
817 )
818 return nodes
819
820 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
821 """
822 Execute a raw command on the appropriate cluster node or target_nodes.
823
824 It will retry the command as specified by the retries property of
825 the :attr:`retry` & then raise an exception.
826
827 :param args:
828 | Raw command args
829 :param kwargs:
830
831 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
832 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
833 - Rest of the kwargs are passed to the Redis connection
834
835 :raises RedisClusterException: if target_nodes is not provided & the command
836 can't be mapped to a slot
837 """
838 command = args[0]
839 target_nodes = []
840 target_nodes_specified = False
841 retry_attempts = self.retry.get_retries()
842
843 passed_targets = kwargs.pop("target_nodes", None)
844 if passed_targets and not self._is_node_flag(passed_targets):
845 target_nodes = self._parse_target_nodes(passed_targets)
846 target_nodes_specified = True
847 retry_attempts = 0
848
849 command_policies = await self._policy_resolver.resolve(args[0].lower())
850
851 if not command_policies and not target_nodes_specified:
852 command_flag = self.command_flags.get(command)
853 if not command_flag:
854 # Fallback to default policy
855 if not self.get_default_node():
856 slot = None
857 else:
858 slot = await self._determine_slot(*args)
859 if not slot:
860 command_policies = CommandPolicies()
861 else:
862 command_policies = CommandPolicies(
863 request_policy=RequestPolicy.DEFAULT_KEYED,
864 response_policy=ResponsePolicy.DEFAULT_KEYED,
865 )
866 else:
867 if command_flag in self._command_flags_mapping:
868 command_policies = CommandPolicies(
869 request_policy=self._command_flags_mapping[command_flag]
870 )
871 else:
872 command_policies = CommandPolicies()
873 elif not command_policies and target_nodes_specified:
874 command_policies = CommandPolicies()
875
876 # Add one for the first execution
877 execute_attempts = 1 + retry_attempts
878 for _ in range(execute_attempts):
879 if self._initialize:
880 await self.initialize()
881 if (
882 len(target_nodes) == 1
883 and target_nodes[0] == self.get_default_node()
884 ):
885 # Replace the default cluster node
886 self.replace_default_node()
887 try:
888 if not target_nodes_specified:
889 # Determine the nodes to execute the command on
890 target_nodes = await self._determine_nodes(
891 *args,
892 request_policy=command_policies.request_policy,
893 node_flag=passed_targets,
894 )
895 if not target_nodes:
896 raise RedisClusterException(
897 f"No targets were found to execute {args} command on"
898 )
899
900 if len(target_nodes) == 1:
901 # Return the processed result
902 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
903 if command in self.result_callbacks:
904 ret = self.result_callbacks[command](
905 command, {target_nodes[0].name: ret}, **kwargs
906 )
907 return self._policies_callback_mapping[
908 command_policies.response_policy
909 ](ret)
910 else:
911 keys = [node.name for node in target_nodes]
912 values = await asyncio.gather(
913 *(
914 asyncio.create_task(
915 self._execute_command(node, *args, **kwargs)
916 )
917 for node in target_nodes
918 )
919 )
920 if command in self.result_callbacks:
921 return self.result_callbacks[command](
922 command, dict(zip(keys, values)), **kwargs
923 )
924 return self._policies_callback_mapping[
925 command_policies.response_policy
926 ](dict(zip(keys, values)))
927 except Exception as e:
928 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
929 # The nodes and slots cache were should be reinitialized.
930 # Try again with the new cluster setup.
931 retry_attempts -= 1
932 continue
933 else:
934 # raise the exception
935 raise e
936
937 async def _execute_command(
938 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
939 ) -> Any:
940 asking = moved = False
941 redirect_addr = None
942 ttl = self.RedisClusterRequestTTL
943
944 while ttl > 0:
945 ttl -= 1
946 try:
947 if asking:
948 target_node = self.get_node(node_name=redirect_addr)
949 await target_node.execute_command("ASKING")
950 asking = False
951 elif moved:
952 # MOVED occurred and the slots cache was updated,
953 # refresh the target node
954 slot = await self._determine_slot(*args)
955 target_node = self.nodes_manager.get_node_from_slot(
956 slot,
957 self.read_from_replicas and args[0] in READ_COMMANDS,
958 self.load_balancing_strategy
959 if args[0] in READ_COMMANDS
960 else None,
961 )
962 moved = False
963
964 return await target_node.execute_command(*args, **kwargs)
965 except BusyLoadingError:
966 raise
967 except MaxConnectionsError:
968 # MaxConnectionsError indicates client-side resource exhaustion
969 # (too many connections in the pool), not a node failure.
970 # Don't treat this as a node failure - just re-raise the error
971 # without reinitializing the cluster.
972 raise
973 except (ConnectionError, TimeoutError):
974 # Connection retries are being handled in the node's
975 # Retry object.
976 # Remove the failed node from the startup nodes before we try
977 # to reinitialize the cluster
978 self.nodes_manager.startup_nodes.pop(target_node.name, None)
979 # Hard force of reinitialize of the node/slots setup
980 # and try again with the new setup
981 await self.aclose()
982 raise
983 except (ClusterDownError, SlotNotCoveredError):
984 # ClusterDownError can occur during a failover and to get
985 # self-healed, we will try to reinitialize the cluster layout
986 # and retry executing the command
987
988 # SlotNotCoveredError can occur when the cluster is not fully
989 # initialized or can be temporary issue.
990 # We will try to reinitialize the cluster topology
991 # and retry executing the command
992
993 await self.aclose()
994 await asyncio.sleep(0.25)
995 raise
996 except MovedError as e:
997 # First, we will try to patch the slots/nodes cache with the
998 # redirected node output and try again. If MovedError exceeds
999 # 'reinitialize_steps' number of times, we will force
1000 # reinitializing the tables, and then try again.
1001 # 'reinitialize_steps' counter will increase faster when
1002 # the same client object is shared between multiple threads. To
1003 # reduce the frequency you can set this variable in the
1004 # RedisCluster constructor.
1005 self.reinitialize_counter += 1
1006 if (
1007 self.reinitialize_steps
1008 and self.reinitialize_counter % self.reinitialize_steps == 0
1009 ):
1010 await self.aclose()
1011 # Reset the counter
1012 self.reinitialize_counter = 0
1013 else:
1014 self.nodes_manager._moved_exception = e
1015 moved = True
1016 except AskError as e:
1017 redirect_addr = get_node_name(host=e.host, port=e.port)
1018 asking = True
1019 except TryAgainError:
1020 if ttl < self.RedisClusterRequestTTL / 2:
1021 await asyncio.sleep(0.05)
1022
1023 raise ClusterError("TTL exhausted.")
1024
1025 def pipeline(
1026 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
1027 ) -> "ClusterPipeline":
1028 """
1029 Create & return a new :class:`~.ClusterPipeline` object.
1030
1031 Cluster implementation of pipeline does not support transaction or shard_hint.
1032
1033 :raises RedisClusterException: if transaction or shard_hint are truthy values
1034 """
1035 if shard_hint:
1036 raise RedisClusterException("shard_hint is deprecated in cluster mode")
1037
1038 return ClusterPipeline(self, transaction)
1039
1040 def lock(
1041 self,
1042 name: KeyT,
1043 timeout: Optional[float] = None,
1044 sleep: float = 0.1,
1045 blocking: bool = True,
1046 blocking_timeout: Optional[float] = None,
1047 lock_class: Optional[Type[Lock]] = None,
1048 thread_local: bool = True,
1049 raise_on_release_error: bool = True,
1050 ) -> Lock:
1051 """
1052 Return a new Lock object using key ``name`` that mimics
1053 the behavior of threading.Lock.
1054
1055 If specified, ``timeout`` indicates a maximum life for the lock.
1056 By default, it will remain locked until release() is called.
1057
1058 ``sleep`` indicates the amount of time to sleep per loop iteration
1059 when the lock is in blocking mode and another client is currently
1060 holding the lock.
1061
1062 ``blocking`` indicates whether calling ``acquire`` should block until
1063 the lock has been acquired or to fail immediately, causing ``acquire``
1064 to return False and the lock not being acquired. Defaults to True.
1065 Note this value can be overridden by passing a ``blocking``
1066 argument to ``acquire``.
1067
1068 ``blocking_timeout`` indicates the maximum amount of time in seconds to
1069 spend trying to acquire the lock. A value of ``None`` indicates
1070 continue trying forever. ``blocking_timeout`` can be specified as a
1071 float or integer, both representing the number of seconds to wait.
1072
1073 ``lock_class`` forces the specified lock implementation. Note that as
1074 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
1075 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
1076 you have created your own custom lock class.
1077
1078 ``thread_local`` indicates whether the lock token is placed in
1079 thread-local storage. By default, the token is placed in thread local
1080 storage so that a thread only sees its token, not a token set by
1081 another thread. Consider the following timeline:
1082
1083 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
1084 thread-1 sets the token to "abc"
1085 time: 1, thread-2 blocks trying to acquire `my-lock` using the
1086 Lock instance.
1087 time: 5, thread-1 has not yet completed. redis expires the lock
1088 key.
1089 time: 5, thread-2 acquired `my-lock` now that it's available.
1090 thread-2 sets the token to "xyz"
1091 time: 6, thread-1 finishes its work and calls release(). if the
1092 token is *not* stored in thread local storage, then
1093 thread-1 would see the token value as "xyz" and would be
1094 able to successfully release the thread-2's lock.
1095
1096 ``raise_on_release_error`` indicates whether to raise an exception when
1097 the lock is no longer owned when exiting the context manager. By default,
1098 this is True, meaning an exception will be raised. If False, the warning
1099 will be logged and the exception will be suppressed.
1100
1101 In some use cases it's necessary to disable thread local storage. For
1102 example, if you have code where one thread acquires a lock and passes
1103 that lock instance to a worker thread to release later. If thread
1104 local storage isn't disabled in this case, the worker thread won't see
1105 the token set by the thread that acquired the lock. Our assumption
1106 is that these cases aren't common and as such default to using
1107 thread local storage."""
1108 if lock_class is None:
1109 lock_class = Lock
1110 return lock_class(
1111 self,
1112 name,
1113 timeout=timeout,
1114 sleep=sleep,
1115 blocking=blocking,
1116 blocking_timeout=blocking_timeout,
1117 thread_local=thread_local,
1118 raise_on_release_error=raise_on_release_error,
1119 )
1120
1121 async def transaction(
1122 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1123 ):
1124 """
1125 Convenience method for executing the callable `func` as a transaction
1126 while watching all keys specified in `watches`. The 'func' callable
1127 should expect a single argument which is a Pipeline object.
1128 """
1129 shard_hint = kwargs.pop("shard_hint", None)
1130 value_from_callable = kwargs.pop("value_from_callable", False)
1131 watch_delay = kwargs.pop("watch_delay", None)
1132 async with self.pipeline(True, shard_hint) as pipe:
1133 while True:
1134 try:
1135 if watches:
1136 await pipe.watch(*watches)
1137 func_value = await func(pipe)
1138 exec_value = await pipe.execute()
1139 return func_value if value_from_callable else exec_value
1140 except WatchError:
1141 if watch_delay is not None and watch_delay > 0:
1142 time.sleep(watch_delay)
1143 continue
1144
1145
1146class ClusterNode:
1147 """
1148 Create a new ClusterNode.
1149
1150 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1151 objects for the (host, port).
1152 """
1153
1154 __slots__ = (
1155 "_connections",
1156 "_free",
1157 "_lock",
1158 "_event_dispatcher",
1159 "connection_class",
1160 "connection_kwargs",
1161 "host",
1162 "max_connections",
1163 "name",
1164 "port",
1165 "response_callbacks",
1166 "server_type",
1167 )
1168
1169 def __init__(
1170 self,
1171 host: str,
1172 port: Union[str, int],
1173 server_type: Optional[str] = None,
1174 *,
1175 max_connections: int = 2**31,
1176 connection_class: Type[Connection] = Connection,
1177 **connection_kwargs: Any,
1178 ) -> None:
1179 if host == "localhost":
1180 host = socket.gethostbyname(host)
1181
1182 connection_kwargs["host"] = host
1183 connection_kwargs["port"] = port
1184 self.host = host
1185 self.port = port
1186 self.name = get_node_name(host, port)
1187 self.server_type = server_type
1188
1189 self.max_connections = max_connections
1190 self.connection_class = connection_class
1191 self.connection_kwargs = connection_kwargs
1192 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1193
1194 self._connections: List[Connection] = []
1195 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1196 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1197 if self._event_dispatcher is None:
1198 self._event_dispatcher = EventDispatcher()
1199
1200 def __repr__(self) -> str:
1201 return (
1202 f"[host={self.host}, port={self.port}, "
1203 f"name={self.name}, server_type={self.server_type}]"
1204 )
1205
1206 def __eq__(self, obj: Any) -> bool:
1207 return isinstance(obj, ClusterNode) and obj.name == self.name
1208
1209 _DEL_MESSAGE = "Unclosed ClusterNode object"
1210
1211 def __del__(
1212 self,
1213 _warn: Any = warnings.warn,
1214 _grl: Any = asyncio.get_running_loop,
1215 ) -> None:
1216 for connection in self._connections:
1217 if connection.is_connected:
1218 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1219
1220 try:
1221 context = {"client": self, "message": self._DEL_MESSAGE}
1222 _grl().call_exception_handler(context)
1223 except RuntimeError:
1224 pass
1225 break
1226
1227 async def disconnect(self) -> None:
1228 ret = await asyncio.gather(
1229 *(
1230 asyncio.create_task(connection.disconnect())
1231 for connection in self._connections
1232 ),
1233 return_exceptions=True,
1234 )
1235 exc = next((res for res in ret if isinstance(res, Exception)), None)
1236 if exc:
1237 raise exc
1238
1239 def acquire_connection(self) -> Connection:
1240 try:
1241 return self._free.popleft()
1242 except IndexError:
1243 if len(self._connections) < self.max_connections:
1244 # We are configuring the connection pool not to retry
1245 # connections on lower level clients to avoid retrying
1246 # connections to nodes that are not reachable
1247 # and to avoid blocking the connection pool.
1248 # The only error that will have some handling in the lower
1249 # level clients is ConnectionError which will trigger disconnection
1250 # of the socket.
1251 # The retries will be handled on cluster client level
1252 # where we will have proper handling of the cluster topology
1253 retry = Retry(
1254 backoff=NoBackoff(),
1255 retries=0,
1256 supported_errors=(ConnectionError,),
1257 )
1258 connection_kwargs = self.connection_kwargs.copy()
1259 connection_kwargs["retry"] = retry
1260 connection = self.connection_class(**connection_kwargs)
1261 self._connections.append(connection)
1262 return connection
1263
1264 raise MaxConnectionsError()
1265
1266 def release(self, connection: Connection) -> None:
1267 """
1268 Release connection back to free queue.
1269 """
1270 self._free.append(connection)
1271
1272 async def parse_response(
1273 self, connection: Connection, command: str, **kwargs: Any
1274 ) -> Any:
1275 try:
1276 if NEVER_DECODE in kwargs:
1277 response = await connection.read_response(disable_decoding=True)
1278 kwargs.pop(NEVER_DECODE)
1279 else:
1280 response = await connection.read_response()
1281 except ResponseError:
1282 if EMPTY_RESPONSE in kwargs:
1283 return kwargs[EMPTY_RESPONSE]
1284 raise
1285
1286 if EMPTY_RESPONSE in kwargs:
1287 kwargs.pop(EMPTY_RESPONSE)
1288
1289 # Remove keys entry, it needs only for cache.
1290 kwargs.pop("keys", None)
1291
1292 # Return response
1293 if command in self.response_callbacks:
1294 return self.response_callbacks[command](response, **kwargs)
1295
1296 return response
1297
1298 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1299 # Acquire connection
1300 connection = self.acquire_connection()
1301
1302 # Execute command
1303 await connection.send_packed_command(connection.pack_command(*args), False)
1304
1305 # Read response
1306 try:
1307 return await self.parse_response(connection, args[0], **kwargs)
1308 finally:
1309 # Release connection
1310 self._free.append(connection)
1311
1312 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1313 # Acquire connection
1314 connection = self.acquire_connection()
1315
1316 # Execute command
1317 await connection.send_packed_command(
1318 connection.pack_commands(cmd.args for cmd in commands), False
1319 )
1320
1321 # Read responses
1322 ret = False
1323 for cmd in commands:
1324 try:
1325 cmd.result = await self.parse_response(
1326 connection, cmd.args[0], **cmd.kwargs
1327 )
1328 except Exception as e:
1329 cmd.result = e
1330 ret = True
1331
1332 # Release connection
1333 self._free.append(connection)
1334
1335 return ret
1336
1337 async def re_auth_callback(self, token: TokenInterface):
1338 tmp_queue = collections.deque()
1339 while self._free:
1340 conn = self._free.popleft()
1341 await conn.retry.call_with_retry(
1342 lambda: conn.send_command(
1343 "AUTH", token.try_get("oid"), token.get_value()
1344 ),
1345 lambda error: self._mock(error),
1346 )
1347 await conn.retry.call_with_retry(
1348 lambda: conn.read_response(), lambda error: self._mock(error)
1349 )
1350 tmp_queue.append(conn)
1351
1352 while tmp_queue:
1353 conn = tmp_queue.popleft()
1354 self._free.append(conn)
1355
1356 async def _mock(self, error: RedisError):
1357 """
1358 Dummy functions, needs to be passed as error callback to retry object.
1359 :param error:
1360 :return:
1361 """
1362 pass
1363
1364
1365class NodesManager:
1366 __slots__ = (
1367 "_dynamic_startup_nodes",
1368 "_moved_exception",
1369 "_event_dispatcher",
1370 "connection_kwargs",
1371 "default_node",
1372 "nodes_cache",
1373 "read_load_balancer",
1374 "require_full_coverage",
1375 "slots_cache",
1376 "startup_nodes",
1377 "address_remap",
1378 )
1379
1380 def __init__(
1381 self,
1382 startup_nodes: List["ClusterNode"],
1383 require_full_coverage: bool,
1384 connection_kwargs: Dict[str, Any],
1385 dynamic_startup_nodes: bool = True,
1386 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1387 event_dispatcher: Optional[EventDispatcher] = None,
1388 ) -> None:
1389 self.startup_nodes = {node.name: node for node in startup_nodes}
1390 self.require_full_coverage = require_full_coverage
1391 self.connection_kwargs = connection_kwargs
1392 self.address_remap = address_remap
1393
1394 self.default_node: "ClusterNode" = None
1395 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1396 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1397 self.read_load_balancer = LoadBalancer()
1398
1399 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1400 self._moved_exception: MovedError = None
1401 if event_dispatcher is None:
1402 self._event_dispatcher = EventDispatcher()
1403 else:
1404 self._event_dispatcher = event_dispatcher
1405
1406 def get_node(
1407 self,
1408 host: Optional[str] = None,
1409 port: Optional[int] = None,
1410 node_name: Optional[str] = None,
1411 ) -> Optional["ClusterNode"]:
1412 if host and port:
1413 # the user passed host and port
1414 if host == "localhost":
1415 host = socket.gethostbyname(host)
1416 return self.nodes_cache.get(get_node_name(host=host, port=port))
1417 elif node_name:
1418 return self.nodes_cache.get(node_name)
1419 else:
1420 raise DataError(
1421 "get_node requires one of the following: 1. node name 2. host and port"
1422 )
1423
1424 def set_nodes(
1425 self,
1426 old: Dict[str, "ClusterNode"],
1427 new: Dict[str, "ClusterNode"],
1428 remove_old: bool = False,
1429 ) -> None:
1430 if remove_old:
1431 for name in list(old.keys()):
1432 if name not in new:
1433 task = asyncio.create_task(old.pop(name).disconnect()) # noqa
1434
1435 for name, node in new.items():
1436 if name in old:
1437 if old[name] is node:
1438 continue
1439 task = asyncio.create_task(old[name].disconnect()) # noqa
1440 old[name] = node
1441
1442 def update_moved_exception(self, exception):
1443 self._moved_exception = exception
1444
1445 def _update_moved_slots(self) -> None:
1446 e = self._moved_exception
1447 redirected_node = self.get_node(host=e.host, port=e.port)
1448 if redirected_node:
1449 # The node already exists
1450 if redirected_node.server_type != PRIMARY:
1451 # Update the node's server type
1452 redirected_node.server_type = PRIMARY
1453 else:
1454 # This is a new node, we will add it to the nodes cache
1455 redirected_node = ClusterNode(
1456 e.host, e.port, PRIMARY, **self.connection_kwargs
1457 )
1458 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1459 slot_nodes = self.slots_cache[e.slot_id]
1460 if redirected_node not in slot_nodes:
1461 # The new slot owner is a new server, or a server from a different
1462 # shard. We need to remove all current nodes from the slot's list
1463 # (including replications) and add just the new node.
1464 self.slots_cache[e.slot_id] = [redirected_node]
1465 elif redirected_node is not slot_nodes[0]:
1466 # The MOVED error resulted from a failover, and the new slot owner
1467 # had previously been a replica.
1468 old_primary = slot_nodes[0]
1469 # Update the old primary to be a replica and add it to the end of
1470 # the slot's node list
1471 old_primary.server_type = REPLICA
1472 slot_nodes.append(old_primary)
1473 # Remove the old replica, which is now a primary, from the slot's
1474 # node list
1475 slot_nodes.remove(redirected_node)
1476 # Override the old primary with the new one
1477 slot_nodes[0] = redirected_node
1478 if self.default_node == old_primary:
1479 # Update the default node with the new primary
1480 self.default_node = redirected_node
1481 # else: circular MOVED to current primary -> no-op
1482
1483 # Reset moved_exception
1484 self._moved_exception = None
1485
1486 def get_node_from_slot(
1487 self,
1488 slot: int,
1489 read_from_replicas: bool = False,
1490 load_balancing_strategy=None,
1491 ) -> "ClusterNode":
1492 if self._moved_exception:
1493 self._update_moved_slots()
1494
1495 if read_from_replicas is True and load_balancing_strategy is None:
1496 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1497
1498 try:
1499 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1500 # get the server index using the strategy defined in load_balancing_strategy
1501 primary_name = self.slots_cache[slot][0].name
1502 node_idx = self.read_load_balancer.get_server_index(
1503 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1504 )
1505 return self.slots_cache[slot][node_idx]
1506 return self.slots_cache[slot][0]
1507 except (IndexError, TypeError):
1508 raise SlotNotCoveredError(
1509 f'Slot "{slot}" not covered by the cluster. '
1510 f'"require_full_coverage={self.require_full_coverage}"'
1511 )
1512
1513 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1514 return [
1515 node
1516 for node in self.nodes_cache.values()
1517 if node.server_type == server_type
1518 ]
1519
1520 async def initialize(self) -> None:
1521 self.read_load_balancer.reset()
1522 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1523 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1524 disagreements = []
1525 startup_nodes_reachable = False
1526 fully_covered = False
1527 exception = None
1528 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1529 # is modified during iteration
1530 for startup_node in tuple(self.startup_nodes.values()):
1531 try:
1532 # Make sure cluster mode is enabled on this node
1533 try:
1534 self._event_dispatcher.dispatch(
1535 AfterAsyncClusterInstantiationEvent(
1536 self.nodes_cache,
1537 self.connection_kwargs.get("credential_provider", None),
1538 )
1539 )
1540 cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
1541 except ResponseError:
1542 raise RedisClusterException(
1543 "Cluster mode is not enabled on this node"
1544 )
1545 startup_nodes_reachable = True
1546 except Exception as e:
1547 # Try the next startup node.
1548 # The exception is saved and raised only if we have no more nodes.
1549 exception = e
1550 continue
1551
1552 # CLUSTER SLOTS command results in the following output:
1553 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1554 # where each node contains the following list: [IP, port, node_id]
1555 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1556 # primary node of the first slot section.
1557 # If there's only one server in the cluster, its ``host`` is ''
1558 # Fix it to the host in startup_nodes
1559 if (
1560 len(cluster_slots) == 1
1561 and not cluster_slots[0][2][0]
1562 and len(self.startup_nodes) == 1
1563 ):
1564 cluster_slots[0][2][0] = startup_node.host
1565
1566 for slot in cluster_slots:
1567 for i in range(2, len(slot)):
1568 slot[i] = [str_if_bytes(val) for val in slot[i]]
1569 primary_node = slot[2]
1570 host = primary_node[0]
1571 if host == "":
1572 host = startup_node.host
1573 port = int(primary_node[1])
1574 host, port = self.remap_host_port(host, port)
1575
1576 nodes_for_slot = []
1577
1578 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1579 if not target_node:
1580 target_node = ClusterNode(
1581 host, port, PRIMARY, **self.connection_kwargs
1582 )
1583 # add this node to the nodes cache
1584 tmp_nodes_cache[target_node.name] = target_node
1585 nodes_for_slot.append(target_node)
1586
1587 replica_nodes = slot[3:]
1588 for replica_node in replica_nodes:
1589 host = replica_node[0]
1590 port = replica_node[1]
1591 host, port = self.remap_host_port(host, port)
1592
1593 target_replica_node = tmp_nodes_cache.get(get_node_name(host, port))
1594 if not target_replica_node:
1595 target_replica_node = ClusterNode(
1596 host, port, REPLICA, **self.connection_kwargs
1597 )
1598 # add this node to the nodes cache
1599 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1600 nodes_for_slot.append(target_replica_node)
1601
1602 for i in range(int(slot[0]), int(slot[1]) + 1):
1603 if i not in tmp_slots:
1604 tmp_slots[i] = nodes_for_slot
1605 else:
1606 # Validate that 2 nodes want to use the same slot cache
1607 # setup
1608 tmp_slot = tmp_slots[i][0]
1609 if tmp_slot.name != target_node.name:
1610 disagreements.append(
1611 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1612 )
1613
1614 if len(disagreements) > 5:
1615 raise RedisClusterException(
1616 f"startup_nodes could not agree on a valid "
1617 f"slots cache: {', '.join(disagreements)}"
1618 )
1619
1620 # Validate if all slots are covered or if we should try next startup node
1621 fully_covered = True
1622 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1623 if i not in tmp_slots:
1624 fully_covered = False
1625 break
1626 if fully_covered:
1627 break
1628
1629 if not startup_nodes_reachable:
1630 raise RedisClusterException(
1631 f"Redis Cluster cannot be connected. Please provide at least "
1632 f"one reachable node: {str(exception)}"
1633 ) from exception
1634
1635 # Check if the slots are not fully covered
1636 if not fully_covered and self.require_full_coverage:
1637 # Despite the requirement that the slots be covered, there
1638 # isn't a full coverage
1639 raise RedisClusterException(
1640 f"All slots are not covered after query all startup_nodes. "
1641 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1642 f"covered..."
1643 )
1644
1645 # Set the tmp variables to the real variables
1646 self.slots_cache = tmp_slots
1647 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1648
1649 if self._dynamic_startup_nodes:
1650 # Populate the startup nodes with all discovered nodes
1651 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1652
1653 # Set the default node
1654 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
1655 # If initialize was called after a MovedError, clear it
1656 self._moved_exception = None
1657
1658 async def aclose(self, attr: str = "nodes_cache") -> None:
1659 self.default_node = None
1660 await asyncio.gather(
1661 *(
1662 asyncio.create_task(node.disconnect())
1663 for node in getattr(self, attr).values()
1664 )
1665 )
1666
1667 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1668 """
1669 Remap the host and port returned from the cluster to a different
1670 internal value. Useful if the client is not connecting directly
1671 to the cluster.
1672 """
1673 if self.address_remap:
1674 return self.address_remap((host, port))
1675 return host, port
1676
1677
1678class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1679 """
1680 Create a new ClusterPipeline object.
1681
1682 Usage::
1683
1684 result = await (
1685 rc.pipeline()
1686 .set("A", 1)
1687 .get("A")
1688 .hset("K", "F", "V")
1689 .hgetall("K")
1690 .mset_nonatomic({"A": 2, "B": 3})
1691 .get("A")
1692 .get("B")
1693 .delete("A", "B", "K")
1694 .execute()
1695 )
1696 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
1697
1698 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
1699 are split across multiple nodes, you'll get multiple results for them in the array.
1700
1701 Retryable errors:
1702 - :class:`~.ClusterDownError`
1703 - :class:`~.ConnectionError`
1704 - :class:`~.TimeoutError`
1705
1706 Redirection errors:
1707 - :class:`~.TryAgainError`
1708 - :class:`~.MovedError`
1709 - :class:`~.AskError`
1710
1711 :param client:
1712 | Existing :class:`~.RedisCluster` client
1713 """
1714
1715 __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1716
1717 def __init__(
1718 self, client: RedisCluster, transaction: Optional[bool] = None
1719 ) -> None:
1720 self.cluster_client = client
1721 self._transaction = transaction
1722 self._execution_strategy: ExecutionStrategy = (
1723 PipelineStrategy(self)
1724 if not self._transaction
1725 else TransactionStrategy(self)
1726 )
1727
1728 async def initialize(self) -> "ClusterPipeline":
1729 await self._execution_strategy.initialize()
1730 return self
1731
1732 async def __aenter__(self) -> "ClusterPipeline":
1733 return await self.initialize()
1734
1735 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1736 await self.reset()
1737
1738 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1739 return self.initialize().__await__()
1740
1741 def __bool__(self) -> bool:
1742 "Pipeline instances should always evaluate to True on Python 3+"
1743 return True
1744
1745 def __len__(self) -> int:
1746 return len(self._execution_strategy)
1747
1748 def execute_command(
1749 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1750 ) -> "ClusterPipeline":
1751 """
1752 Append a raw command to the pipeline.
1753
1754 :param args:
1755 | Raw command args
1756 :param kwargs:
1757
1758 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
1759 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1760 - Rest of the kwargs are passed to the Redis connection
1761 """
1762 return self._execution_strategy.execute_command(*args, **kwargs)
1763
1764 async def execute(
1765 self, raise_on_error: bool = True, allow_redirections: bool = True
1766 ) -> List[Any]:
1767 """
1768 Execute the pipeline.
1769
1770 It will retry the commands as specified by retries specified in :attr:`retry`
1771 & then raise an exception.
1772
1773 :param raise_on_error:
1774 | Raise the first error if there are any errors
1775 :param allow_redirections:
1776 | Whether to retry each failed command individually in case of redirection
1777 errors
1778
1779 :raises RedisClusterException: if target_nodes is not provided & the command
1780 can't be mapped to a slot
1781 """
1782 try:
1783 return await self._execution_strategy.execute(
1784 raise_on_error, allow_redirections
1785 )
1786 finally:
1787 await self.reset()
1788
1789 def _split_command_across_slots(
1790 self, command: str, *keys: KeyT
1791 ) -> "ClusterPipeline":
1792 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1793 self.execute_command(command, *slot_keys)
1794
1795 return self
1796
1797 async def reset(self):
1798 """
1799 Reset back to empty pipeline.
1800 """
1801 await self._execution_strategy.reset()
1802
1803 def multi(self):
1804 """
1805 Start a transactional block of the pipeline after WATCH commands
1806 are issued. End the transactional block with `execute`.
1807 """
1808 self._execution_strategy.multi()
1809
1810 async def discard(self):
1811 """ """
1812 await self._execution_strategy.discard()
1813
1814 async def watch(self, *names):
1815 """Watches the values at keys ``names``"""
1816 await self._execution_strategy.watch(*names)
1817
1818 async def unwatch(self):
1819 """Unwatches all previously specified keys"""
1820 await self._execution_strategy.unwatch()
1821
1822 async def unlink(self, *names):
1823 await self._execution_strategy.unlink(*names)
1824
1825 def mset_nonatomic(
1826 self, mapping: Mapping[AnyKeyT, EncodableT]
1827 ) -> "ClusterPipeline":
1828 return self._execution_strategy.mset_nonatomic(mapping)
1829
1830
1831for command in PIPELINE_BLOCKED_COMMANDS:
1832 command = command.replace(" ", "_").lower()
1833 if command == "mset_nonatomic":
1834 continue
1835
1836 setattr(ClusterPipeline, command, block_pipeline_command(command))
1837
1838
1839class PipelineCommand:
1840 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1841 self.args = args
1842 self.kwargs = kwargs
1843 self.position = position
1844 self.result: Union[Any, Exception] = None
1845 self.command_policies: Optional[CommandPolicies] = None
1846
1847 def __repr__(self) -> str:
1848 return f"[{self.position}] {self.args} ({self.kwargs})"
1849
1850
1851class ExecutionStrategy(ABC):
1852 @abstractmethod
1853 async def initialize(self) -> "ClusterPipeline":
1854 """
1855 Initialize the execution strategy.
1856
1857 See ClusterPipeline.initialize()
1858 """
1859 pass
1860
1861 @abstractmethod
1862 def execute_command(
1863 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1864 ) -> "ClusterPipeline":
1865 """
1866 Append a raw command to the pipeline.
1867
1868 See ClusterPipeline.execute_command()
1869 """
1870 pass
1871
1872 @abstractmethod
1873 async def execute(
1874 self, raise_on_error: bool = True, allow_redirections: bool = True
1875 ) -> List[Any]:
1876 """
1877 Execute the pipeline.
1878
1879 It will retry the commands as specified by retries specified in :attr:`retry`
1880 & then raise an exception.
1881
1882 See ClusterPipeline.execute()
1883 """
1884 pass
1885
1886 @abstractmethod
1887 def mset_nonatomic(
1888 self, mapping: Mapping[AnyKeyT, EncodableT]
1889 ) -> "ClusterPipeline":
1890 """
1891 Executes multiple MSET commands according to the provided slot/pairs mapping.
1892
1893 See ClusterPipeline.mset_nonatomic()
1894 """
1895 pass
1896
1897 @abstractmethod
1898 async def reset(self):
1899 """
1900 Resets current execution strategy.
1901
1902 See: ClusterPipeline.reset()
1903 """
1904 pass
1905
1906 @abstractmethod
1907 def multi(self):
1908 """
1909 Starts transactional context.
1910
1911 See: ClusterPipeline.multi()
1912 """
1913 pass
1914
1915 @abstractmethod
1916 async def watch(self, *names):
1917 """
1918 Watch given keys.
1919
1920 See: ClusterPipeline.watch()
1921 """
1922 pass
1923
1924 @abstractmethod
1925 async def unwatch(self):
1926 """
1927 Unwatches all previously specified keys
1928
1929 See: ClusterPipeline.unwatch()
1930 """
1931 pass
1932
1933 @abstractmethod
1934 async def discard(self):
1935 pass
1936
1937 @abstractmethod
1938 async def unlink(self, *names):
1939 """
1940 "Unlink a key specified by ``names``"
1941
1942 See: ClusterPipeline.unlink()
1943 """
1944 pass
1945
1946 @abstractmethod
1947 def __len__(self) -> int:
1948 pass
1949
1950
1951class AbstractStrategy(ExecutionStrategy):
1952 def __init__(self, pipe: ClusterPipeline) -> None:
1953 self._pipe: ClusterPipeline = pipe
1954 self._command_queue: List["PipelineCommand"] = []
1955
1956 async def initialize(self) -> "ClusterPipeline":
1957 if self._pipe.cluster_client._initialize:
1958 await self._pipe.cluster_client.initialize()
1959 self._command_queue = []
1960 return self._pipe
1961
1962 def execute_command(
1963 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1964 ) -> "ClusterPipeline":
1965 self._command_queue.append(
1966 PipelineCommand(len(self._command_queue), *args, **kwargs)
1967 )
1968 return self._pipe
1969
1970 def _annotate_exception(self, exception, number, command):
1971 """
1972 Provides extra context to the exception prior to it being handled
1973 """
1974 cmd = " ".join(map(safe_str, command))
1975 msg = (
1976 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
1977 f"caused error: {exception.args[0]}"
1978 )
1979 exception.args = (msg,) + exception.args[1:]
1980
1981 @abstractmethod
1982 def mset_nonatomic(
1983 self, mapping: Mapping[AnyKeyT, EncodableT]
1984 ) -> "ClusterPipeline":
1985 pass
1986
1987 @abstractmethod
1988 async def execute(
1989 self, raise_on_error: bool = True, allow_redirections: bool = True
1990 ) -> List[Any]:
1991 pass
1992
1993 @abstractmethod
1994 async def reset(self):
1995 pass
1996
1997 @abstractmethod
1998 def multi(self):
1999 pass
2000
2001 @abstractmethod
2002 async def watch(self, *names):
2003 pass
2004
2005 @abstractmethod
2006 async def unwatch(self):
2007 pass
2008
2009 @abstractmethod
2010 async def discard(self):
2011 pass
2012
2013 @abstractmethod
2014 async def unlink(self, *names):
2015 pass
2016
2017 def __len__(self) -> int:
2018 return len(self._command_queue)
2019
2020
2021class PipelineStrategy(AbstractStrategy):
2022 def __init__(self, pipe: ClusterPipeline) -> None:
2023 super().__init__(pipe)
2024
2025 def mset_nonatomic(
2026 self, mapping: Mapping[AnyKeyT, EncodableT]
2027 ) -> "ClusterPipeline":
2028 encoder = self._pipe.cluster_client.encoder
2029
2030 slots_pairs = {}
2031 for pair in mapping.items():
2032 slot = key_slot(encoder.encode(pair[0]))
2033 slots_pairs.setdefault(slot, []).extend(pair)
2034
2035 for pairs in slots_pairs.values():
2036 self.execute_command("MSET", *pairs)
2037
2038 return self._pipe
2039
2040 async def execute(
2041 self, raise_on_error: bool = True, allow_redirections: bool = True
2042 ) -> List[Any]:
2043 if not self._command_queue:
2044 return []
2045
2046 try:
2047 retry_attempts = self._pipe.cluster_client.retry.get_retries()
2048 while True:
2049 try:
2050 if self._pipe.cluster_client._initialize:
2051 await self._pipe.cluster_client.initialize()
2052 return await self._execute(
2053 self._pipe.cluster_client,
2054 self._command_queue,
2055 raise_on_error=raise_on_error,
2056 allow_redirections=allow_redirections,
2057 )
2058
2059 except RedisCluster.ERRORS_ALLOW_RETRY as e:
2060 if retry_attempts > 0:
2061 # Try again with the new cluster setup. All other errors
2062 # should be raised.
2063 retry_attempts -= 1
2064 await self._pipe.cluster_client.aclose()
2065 await asyncio.sleep(0.25)
2066 else:
2067 # All other errors should be raised.
2068 raise e
2069 finally:
2070 await self.reset()
2071
2072 async def _execute(
2073 self,
2074 client: "RedisCluster",
2075 stack: List["PipelineCommand"],
2076 raise_on_error: bool = True,
2077 allow_redirections: bool = True,
2078 ) -> List[Any]:
2079 todo = [
2080 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
2081 ]
2082
2083 nodes = {}
2084 for cmd in todo:
2085 passed_targets = cmd.kwargs.pop("target_nodes", None)
2086 command_policies = await client._policy_resolver.resolve(
2087 cmd.args[0].lower()
2088 )
2089
2090 if passed_targets and not client._is_node_flag(passed_targets):
2091 target_nodes = client._parse_target_nodes(passed_targets)
2092
2093 if not command_policies:
2094 command_policies = CommandPolicies()
2095 else:
2096 if not command_policies:
2097 command_flag = client.command_flags.get(cmd.args[0])
2098 if not command_flag:
2099 # Fallback to default policy
2100 if not client.get_default_node():
2101 slot = None
2102 else:
2103 slot = await client._determine_slot(*cmd.args)
2104 if not slot:
2105 command_policies = CommandPolicies()
2106 else:
2107 command_policies = CommandPolicies(
2108 request_policy=RequestPolicy.DEFAULT_KEYED,
2109 response_policy=ResponsePolicy.DEFAULT_KEYED,
2110 )
2111 else:
2112 if command_flag in client._command_flags_mapping:
2113 command_policies = CommandPolicies(
2114 request_policy=client._command_flags_mapping[
2115 command_flag
2116 ]
2117 )
2118 else:
2119 command_policies = CommandPolicies()
2120
2121 target_nodes = await client._determine_nodes(
2122 *cmd.args,
2123 request_policy=command_policies.request_policy,
2124 node_flag=passed_targets,
2125 )
2126 if not target_nodes:
2127 raise RedisClusterException(
2128 f"No targets were found to execute {cmd.args} command on"
2129 )
2130 cmd.command_policies = command_policies
2131 if len(target_nodes) > 1:
2132 raise RedisClusterException(f"Too many targets for command {cmd.args}")
2133 node = target_nodes[0]
2134 if node.name not in nodes:
2135 nodes[node.name] = (node, [])
2136 nodes[node.name][1].append(cmd)
2137
2138 errors = await asyncio.gather(
2139 *(
2140 asyncio.create_task(node[0].execute_pipeline(node[1]))
2141 for node in nodes.values()
2142 )
2143 )
2144
2145 if any(errors):
2146 if allow_redirections:
2147 # send each errored command individually
2148 for cmd in todo:
2149 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2150 try:
2151 cmd.result = client._policies_callback_mapping[
2152 cmd.command_policies.response_policy
2153 ](await client.execute_command(*cmd.args, **cmd.kwargs))
2154 except Exception as e:
2155 cmd.result = e
2156
2157 if raise_on_error:
2158 for cmd in todo:
2159 result = cmd.result
2160 if isinstance(result, Exception):
2161 command = " ".join(map(safe_str, cmd.args))
2162 msg = (
2163 f"Command # {cmd.position + 1} "
2164 f"({truncate_text(command)}) "
2165 f"of pipeline caused error: {result.args}"
2166 )
2167 result.args = (msg,) + result.args[1:]
2168 raise result
2169
2170 default_cluster_node = client.get_default_node()
2171
2172 # Check whether the default node was used. In some cases,
2173 # 'client.get_default_node()' may return None. The check below
2174 # prevents a potential AttributeError.
2175 if default_cluster_node is not None:
2176 default_node = nodes.get(default_cluster_node.name)
2177 if default_node is not None:
2178 # This pipeline execution used the default node, check if we need
2179 # to replace it.
2180 # Note: when the error is raised we'll reset the default node in the
2181 # caller function.
2182 for cmd in default_node[1]:
2183 # Check if it has a command that failed with a relevant
2184 # exception
2185 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2186 client.replace_default_node()
2187 break
2188
2189 return [cmd.result for cmd in stack]
2190
2191 async def reset(self):
2192 """
2193 Reset back to empty pipeline.
2194 """
2195 self._command_queue = []
2196
2197 def multi(self):
2198 raise RedisClusterException(
2199 "method multi() is not supported outside of transactional context"
2200 )
2201
2202 async def watch(self, *names):
2203 raise RedisClusterException(
2204 "method watch() is not supported outside of transactional context"
2205 )
2206
2207 async def unwatch(self):
2208 raise RedisClusterException(
2209 "method unwatch() is not supported outside of transactional context"
2210 )
2211
2212 async def discard(self):
2213 raise RedisClusterException(
2214 "method discard() is not supported outside of transactional context"
2215 )
2216
2217 async def unlink(self, *names):
2218 if len(names) != 1:
2219 raise RedisClusterException(
2220 "unlinking multiple keys is not implemented in pipeline command"
2221 )
2222
2223 return self.execute_command("UNLINK", names[0])
2224
2225
2226class TransactionStrategy(AbstractStrategy):
2227 NO_SLOTS_COMMANDS = {"UNWATCH"}
2228 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2229 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2230 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2231 CONNECTION_ERRORS = (
2232 ConnectionError,
2233 OSError,
2234 ClusterDownError,
2235 SlotNotCoveredError,
2236 )
2237
2238 def __init__(self, pipe: ClusterPipeline) -> None:
2239 super().__init__(pipe)
2240 self._explicit_transaction = False
2241 self._watching = False
2242 self._pipeline_slots: Set[int] = set()
2243 self._transaction_node: Optional[ClusterNode] = None
2244 self._transaction_connection: Optional[Connection] = None
2245 self._executing = False
2246 self._retry = copy(self._pipe.cluster_client.retry)
2247 self._retry.update_supported_errors(
2248 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2249 )
2250
2251 def _get_client_and_connection_for_transaction(
2252 self,
2253 ) -> Tuple[ClusterNode, Connection]:
2254 """
2255 Find a connection for a pipeline transaction.
2256
2257 For running an atomic transaction, watch keys ensure that contents have not been
2258 altered as long as the watch commands for those keys were sent over the same
2259 connection. So once we start watching a key, we fetch a connection to the
2260 node that owns that slot and reuse it.
2261 """
2262 if not self._pipeline_slots:
2263 raise RedisClusterException(
2264 "At least a command with a key is needed to identify a node"
2265 )
2266
2267 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2268 list(self._pipeline_slots)[0], False
2269 )
2270 self._transaction_node = node
2271
2272 if not self._transaction_connection:
2273 connection: Connection = self._transaction_node.acquire_connection()
2274 self._transaction_connection = connection
2275
2276 return self._transaction_node, self._transaction_connection
2277
2278 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2279 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2280 response = None
2281 error = None
2282
2283 def runner():
2284 nonlocal response
2285 nonlocal error
2286 try:
2287 response = asyncio.run(self._execute_command(*args, **kwargs))
2288 except Exception as e:
2289 error = e
2290
2291 thread = threading.Thread(target=runner)
2292 thread.start()
2293 thread.join()
2294
2295 if error:
2296 raise error
2297
2298 return response
2299
2300 async def _execute_command(
2301 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2302 ) -> Any:
2303 if self._pipe.cluster_client._initialize:
2304 await self._pipe.cluster_client.initialize()
2305
2306 slot_number: Optional[int] = None
2307 if args[0] not in self.NO_SLOTS_COMMANDS:
2308 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2309
2310 if (
2311 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2312 ) and not self._explicit_transaction:
2313 if args[0] == "WATCH":
2314 self._validate_watch()
2315
2316 if slot_number is not None:
2317 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2318 raise CrossSlotTransactionError(
2319 "Cannot watch or send commands on different slots"
2320 )
2321
2322 self._pipeline_slots.add(slot_number)
2323 elif args[0] not in self.NO_SLOTS_COMMANDS:
2324 raise RedisClusterException(
2325 f"Cannot identify slot number for command: {args[0]},"
2326 "it cannot be triggered in a transaction"
2327 )
2328
2329 return self._immediate_execute_command(*args, **kwargs)
2330 else:
2331 if slot_number is not None:
2332 self._pipeline_slots.add(slot_number)
2333
2334 return super().execute_command(*args, **kwargs)
2335
2336 def _validate_watch(self):
2337 if self._explicit_transaction:
2338 raise RedisError("Cannot issue a WATCH after a MULTI")
2339
2340 self._watching = True
2341
2342 async def _immediate_execute_command(self, *args, **options):
2343 return await self._retry.call_with_retry(
2344 lambda: self._get_connection_and_send_command(*args, **options),
2345 self._reinitialize_on_error,
2346 )
2347
2348 async def _get_connection_and_send_command(self, *args, **options):
2349 redis_node, connection = self._get_client_and_connection_for_transaction()
2350 return await self._send_command_parse_response(
2351 connection, redis_node, args[0], *args, **options
2352 )
2353
2354 async def _send_command_parse_response(
2355 self,
2356 connection: Connection,
2357 redis_node: ClusterNode,
2358 command_name,
2359 *args,
2360 **options,
2361 ):
2362 """
2363 Send a command and parse the response
2364 """
2365
2366 await connection.send_command(*args)
2367 output = await redis_node.parse_response(connection, command_name, **options)
2368
2369 if command_name in self.UNWATCH_COMMANDS:
2370 self._watching = False
2371 return output
2372
2373 async def _reinitialize_on_error(self, error):
2374 if self._watching:
2375 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2376 raise WatchError("Slot rebalancing occurred while watching keys")
2377
2378 if (
2379 type(error) in self.SLOT_REDIRECT_ERRORS
2380 or type(error) in self.CONNECTION_ERRORS
2381 ):
2382 if self._transaction_connection:
2383 self._transaction_connection = None
2384
2385 self._pipe.cluster_client.reinitialize_counter += 1
2386 if (
2387 self._pipe.cluster_client.reinitialize_steps
2388 and self._pipe.cluster_client.reinitialize_counter
2389 % self._pipe.cluster_client.reinitialize_steps
2390 == 0
2391 ):
2392 await self._pipe.cluster_client.nodes_manager.initialize()
2393 self.reinitialize_counter = 0
2394 else:
2395 if isinstance(error, AskError):
2396 self._pipe.cluster_client.nodes_manager.update_moved_exception(
2397 error
2398 )
2399
2400 self._executing = False
2401
2402 def _raise_first_error(self, responses, stack):
2403 """
2404 Raise the first exception on the stack
2405 """
2406 for r, cmd in zip(responses, stack):
2407 if isinstance(r, Exception):
2408 self._annotate_exception(r, cmd.position + 1, cmd.args)
2409 raise r
2410
2411 def mset_nonatomic(
2412 self, mapping: Mapping[AnyKeyT, EncodableT]
2413 ) -> "ClusterPipeline":
2414 raise NotImplementedError("Method is not supported in transactional context.")
2415
2416 async def execute(
2417 self, raise_on_error: bool = True, allow_redirections: bool = True
2418 ) -> List[Any]:
2419 stack = self._command_queue
2420 if not stack and (not self._watching or not self._pipeline_slots):
2421 return []
2422
2423 return await self._execute_transaction_with_retries(stack, raise_on_error)
2424
2425 async def _execute_transaction_with_retries(
2426 self, stack: List["PipelineCommand"], raise_on_error: bool
2427 ):
2428 return await self._retry.call_with_retry(
2429 lambda: self._execute_transaction(stack, raise_on_error),
2430 self._reinitialize_on_error,
2431 )
2432
2433 async def _execute_transaction(
2434 self, stack: List["PipelineCommand"], raise_on_error: bool
2435 ):
2436 if len(self._pipeline_slots) > 1:
2437 raise CrossSlotTransactionError(
2438 "All keys involved in a cluster transaction must map to the same slot"
2439 )
2440
2441 self._executing = True
2442
2443 redis_node, connection = self._get_client_and_connection_for_transaction()
2444
2445 stack = chain(
2446 [PipelineCommand(0, "MULTI")],
2447 stack,
2448 [PipelineCommand(0, "EXEC")],
2449 )
2450 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2451 packed_commands = connection.pack_commands(commands)
2452 await connection.send_packed_command(packed_commands)
2453 errors = []
2454
2455 # parse off the response for MULTI
2456 # NOTE: we need to handle ResponseErrors here and continue
2457 # so that we read all the additional command messages from
2458 # the socket
2459 try:
2460 await redis_node.parse_response(connection, "MULTI")
2461 except ResponseError as e:
2462 self._annotate_exception(e, 0, "MULTI")
2463 errors.append(e)
2464 except self.CONNECTION_ERRORS as cluster_error:
2465 self._annotate_exception(cluster_error, 0, "MULTI")
2466 raise
2467
2468 # and all the other commands
2469 for i, command in enumerate(self._command_queue):
2470 if EMPTY_RESPONSE in command.kwargs:
2471 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2472 else:
2473 try:
2474 _ = await redis_node.parse_response(connection, "_")
2475 except self.SLOT_REDIRECT_ERRORS as slot_error:
2476 self._annotate_exception(slot_error, i + 1, command.args)
2477 errors.append(slot_error)
2478 except self.CONNECTION_ERRORS as cluster_error:
2479 self._annotate_exception(cluster_error, i + 1, command.args)
2480 raise
2481 except ResponseError as e:
2482 self._annotate_exception(e, i + 1, command.args)
2483 errors.append(e)
2484
2485 response = None
2486 # parse the EXEC.
2487 try:
2488 response = await redis_node.parse_response(connection, "EXEC")
2489 except ExecAbortError:
2490 if errors:
2491 raise errors[0]
2492 raise
2493
2494 self._executing = False
2495
2496 # EXEC clears any watched keys
2497 self._watching = False
2498
2499 if response is None:
2500 raise WatchError("Watched variable changed.")
2501
2502 # put any parse errors into the response
2503 for i, e in errors:
2504 response.insert(i, e)
2505
2506 if len(response) != len(self._command_queue):
2507 raise InvalidPipelineStack(
2508 "Unexpected response length for cluster pipeline EXEC."
2509 " Command stack was {} but response had length {}".format(
2510 [c.args[0] for c in self._command_queue], len(response)
2511 )
2512 )
2513
2514 # find any errors in the response and raise if necessary
2515 if raise_on_error or len(errors) > 0:
2516 self._raise_first_error(
2517 response,
2518 self._command_queue,
2519 )
2520
2521 # We have to run response callbacks manually
2522 data = []
2523 for r, cmd in zip(response, self._command_queue):
2524 if not isinstance(r, Exception):
2525 command_name = cmd.args[0]
2526 if command_name in self._pipe.cluster_client.response_callbacks:
2527 r = self._pipe.cluster_client.response_callbacks[command_name](
2528 r, **cmd.kwargs
2529 )
2530 data.append(r)
2531 return data
2532
2533 async def reset(self):
2534 self._command_queue = []
2535
2536 # make sure to reset the connection state in the event that we were
2537 # watching something
2538 if self._transaction_connection:
2539 try:
2540 if self._watching:
2541 # call this manually since our unwatch or
2542 # immediate_execute_command methods can call reset()
2543 await self._transaction_connection.send_command("UNWATCH")
2544 await self._transaction_connection.read_response()
2545 # we can safely return the connection to the pool here since we're
2546 # sure we're no longer WATCHing anything
2547 self._transaction_node.release(self._transaction_connection)
2548 self._transaction_connection = None
2549 except self.CONNECTION_ERRORS:
2550 # disconnect will also remove any previous WATCHes
2551 if self._transaction_connection:
2552 await self._transaction_connection.disconnect()
2553
2554 # clean up the other instance attributes
2555 self._transaction_node = None
2556 self._watching = False
2557 self._explicit_transaction = False
2558 self._pipeline_slots = set()
2559 self._executing = False
2560
2561 def multi(self):
2562 if self._explicit_transaction:
2563 raise RedisError("Cannot issue nested calls to MULTI")
2564 if self._command_queue:
2565 raise RedisError(
2566 "Commands without an initial WATCH have already been issued"
2567 )
2568 self._explicit_transaction = True
2569
2570 async def watch(self, *names):
2571 if self._explicit_transaction:
2572 raise RedisError("Cannot issue a WATCH after a MULTI")
2573
2574 return await self.execute_command("WATCH", *names)
2575
2576 async def unwatch(self):
2577 if self._watching:
2578 return await self.execute_command("UNWATCH")
2579
2580 return True
2581
2582 async def discard(self):
2583 await self.reset()
2584
2585 async def unlink(self, *names):
2586 return self.execute_command("UNLINK", *names)