1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 Any,
13 Callable,
14 Coroutine,
15 Deque,
16 Dict,
17 Generator,
18 List,
19 Mapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26)
27
28from redis._parsers import AsyncCommandsParser, Encoder
29from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
30from redis._parsers.helpers import (
31 _RedisCallbacks,
32 _RedisCallbacksRESP2,
33 _RedisCallbacksRESP3,
34)
35from redis.asyncio.client import ResponseCallbackT
36from redis.asyncio.connection import Connection, SSLConnection, parse_url
37from redis.asyncio.lock import Lock
38from redis.asyncio.retry import Retry
39from redis.auth.token import TokenInterface
40from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
41from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
42from redis.cluster import (
43 PIPELINE_BLOCKED_COMMANDS,
44 PRIMARY,
45 REPLICA,
46 SLOT_ID,
47 AbstractRedisCluster,
48 LoadBalancer,
49 LoadBalancingStrategy,
50 block_pipeline_command,
51 get_node_name,
52 parse_cluster_slots,
53)
54from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
55from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
56from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
57from redis.credentials import CredentialProvider
58from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
59from redis.exceptions import (
60 AskError,
61 BusyLoadingError,
62 ClusterDownError,
63 ClusterError,
64 ConnectionError,
65 CrossSlotTransactionError,
66 DataError,
67 ExecAbortError,
68 InvalidPipelineStack,
69 MaxConnectionsError,
70 MovedError,
71 RedisClusterException,
72 RedisError,
73 ResponseError,
74 SlotNotCoveredError,
75 TimeoutError,
76 TryAgainError,
77 WatchError,
78)
79from redis.typing import AnyKeyT, EncodableT, KeyT
80from redis.utils import (
81 SSL_AVAILABLE,
82 deprecated_args,
83 deprecated_function,
84 get_lib_version,
85 safe_str,
86 str_if_bytes,
87 truncate_text,
88)
89
90if SSL_AVAILABLE:
91 from ssl import TLSVersion, VerifyFlags, VerifyMode
92else:
93 TLSVersion = None
94 VerifyMode = None
95 VerifyFlags = None
96
97TargetNodesT = TypeVar(
98 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
99)
100
101
102class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
103 """
104 Create a new RedisCluster client.
105
106 Pass one of parameters:
107
108 - `host` & `port`
109 - `startup_nodes`
110
111 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
112 | Use ``await`` :meth:`close` to disconnect connections & close client.
113
114 Many commands support the target_nodes kwarg. It can be one of the
115 :attr:`NODE_FLAGS`:
116
117 - :attr:`PRIMARIES`
118 - :attr:`REPLICAS`
119 - :attr:`ALL_NODES`
120 - :attr:`RANDOM`
121 - :attr:`DEFAULT_NODE`
122
123 Note: This client is not thread/process/fork safe.
124
125 :param host:
126 | Can be used to point to a startup node
127 :param port:
128 | Port used if **host** is provided
129 :param startup_nodes:
130 | :class:`~.ClusterNode` to used as a startup node
131 :param require_full_coverage:
132 | When set to ``False``: the client will not require a full coverage of
133 the slots. However, if not all slots are covered, and at least one node
134 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
135 a :class:`~.ClusterDownError` for some key-based commands.
136 | When set to ``True``: all slots must be covered to construct the cluster
137 client. If not all slots are covered, :class:`~.RedisClusterException` will be
138 thrown.
139 | See:
140 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
141 :param read_from_replicas:
142 | @deprecated - please use load_balancing_strategy instead
143 | Enable read from replicas in READONLY mode.
144 When set to true, read commands will be assigned between the primary and
145 its replications in a Round-Robin manner.
146 The data read from replicas is eventually consistent with the data in primary nodes.
147 :param load_balancing_strategy:
148 | Enable read from replicas in READONLY mode and defines the load balancing
149 strategy that will be used for cluster node selection.
150 The data read from replicas is eventually consistent with the data in primary nodes.
151 :param dynamic_startup_nodes:
152 | Set the RedisCluster's startup nodes to all the discovered nodes.
153 If true (default value), the cluster's discovered nodes will be used to
154 determine the cluster nodes-slots mapping in the next topology refresh.
155 It will remove the initial passed startup nodes if their endpoints aren't
156 listed in the CLUSTER SLOTS output.
157 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
158 specific IP addresses, it is best to set it to false.
159 :param reinitialize_steps:
160 | Specifies the number of MOVED errors that need to occur before reinitializing
161 the whole cluster topology. If a MOVED error occurs and the cluster does not
162 need to be reinitialized on this current error handling, only the MOVED slot
163 will be patched with the redirected node.
164 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
165 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
166 0.
167 :param cluster_error_retry_attempts:
168 | @deprecated - Please configure the 'retry' object instead
169 In case 'retry' object is set - this argument is ignored!
170
171 Number of times to retry before raising an error when :class:`~.TimeoutError`,
172 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
173 or :class:`~.ClusterDownError` are encountered
174 :param retry:
175 | A retry object that defines the retry strategy and the number of
176 retries for the cluster client.
177 In current implementation for the cluster client (starting form redis-py version 6.0.0)
178 the retry object is not yet fully utilized, instead it is used just to determine
179 the number of retries for the cluster client.
180 In the future releases the retry object will be used to handle the cluster client retries!
181 :param max_connections:
182 | Maximum number of connections per node. If there are no free connections & the
183 maximum number of connections are already created, a
184 :class:`~.MaxConnectionsError` is raised.
185 :param address_remap:
186 | An optional callable which, when provided with an internal network
187 address of a node, e.g. a `(host, port)` tuple, will return the address
188 where the node is reachable. This can be used to map the addresses at
189 which the nodes _think_ they are, to addresses at which a client may
190 reach them, such as when they sit behind a proxy.
191
192 | Rest of the arguments will be passed to the
193 :class:`~redis.asyncio.connection.Connection` instances when created
194
195 :raises RedisClusterException:
196 if any arguments are invalid or unknown. Eg:
197
198 - `db` != 0 or None
199 - `path` argument for unix socket connection
200 - none of the `host`/`port` & `startup_nodes` were provided
201
202 """
203
204 @classmethod
205 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
206 """
207 Return a Redis client object configured from the given URL.
208
209 For example::
210
211 redis://[[username]:[password]]@localhost:6379/0
212 rediss://[[username]:[password]]@localhost:6379/0
213
214 Three URL schemes are supported:
215
216 - `redis://` creates a TCP socket connection. See more at:
217 <https://www.iana.org/assignments/uri-schemes/prov/redis>
218 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
219 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
220
221 The username, password, hostname, path and all querystring values are passed
222 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
223 with their corresponding characters.
224
225 All querystring options are cast to their appropriate Python types. Boolean
226 arguments can be specified with string values "True"/"False" or "Yes"/"No".
227 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
228 parsed, the querystring arguments and keyword arguments are passed to
229 :class:`~redis.asyncio.connection.Connection` when created.
230 In the case of conflicting arguments, querystring arguments are used.
231 """
232 kwargs.update(parse_url(url))
233 if kwargs.pop("connection_class", None) is SSLConnection:
234 kwargs["ssl"] = True
235 return cls(**kwargs)
236
237 __slots__ = (
238 "_initialize",
239 "_lock",
240 "retry",
241 "command_flags",
242 "commands_parser",
243 "connection_kwargs",
244 "encoder",
245 "node_flags",
246 "nodes_manager",
247 "read_from_replicas",
248 "reinitialize_counter",
249 "reinitialize_steps",
250 "response_callbacks",
251 "result_callbacks",
252 )
253
254 @deprecated_args(
255 args_to_warn=["read_from_replicas"],
256 reason="Please configure the 'load_balancing_strategy' instead",
257 version="5.3.0",
258 )
259 @deprecated_args(
260 args_to_warn=[
261 "cluster_error_retry_attempts",
262 ],
263 reason="Please configure the 'retry' object instead",
264 version="6.0.0",
265 )
266 def __init__(
267 self,
268 host: Optional[str] = None,
269 port: Union[str, int] = 6379,
270 # Cluster related kwargs
271 startup_nodes: Optional[List["ClusterNode"]] = None,
272 require_full_coverage: bool = True,
273 read_from_replicas: bool = False,
274 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
275 dynamic_startup_nodes: bool = True,
276 reinitialize_steps: int = 5,
277 cluster_error_retry_attempts: int = 3,
278 max_connections: int = 2**31,
279 retry: Optional["Retry"] = None,
280 retry_on_error: Optional[List[Type[Exception]]] = None,
281 # Client related kwargs
282 db: Union[str, int] = 0,
283 path: Optional[str] = None,
284 credential_provider: Optional[CredentialProvider] = None,
285 username: Optional[str] = None,
286 password: Optional[str] = None,
287 client_name: Optional[str] = None,
288 lib_name: Optional[str] = "redis-py",
289 lib_version: Optional[str] = get_lib_version(),
290 # Encoding related kwargs
291 encoding: str = "utf-8",
292 encoding_errors: str = "strict",
293 decode_responses: bool = False,
294 # Connection related kwargs
295 health_check_interval: float = 0,
296 socket_connect_timeout: Optional[float] = None,
297 socket_keepalive: bool = False,
298 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
299 socket_timeout: Optional[float] = None,
300 # SSL related kwargs
301 ssl: bool = False,
302 ssl_ca_certs: Optional[str] = None,
303 ssl_ca_data: Optional[str] = None,
304 ssl_cert_reqs: Union[str, VerifyMode] = "required",
305 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
306 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
307 ssl_certfile: Optional[str] = None,
308 ssl_check_hostname: bool = True,
309 ssl_keyfile: Optional[str] = None,
310 ssl_min_version: Optional[TLSVersion] = None,
311 ssl_ciphers: Optional[str] = None,
312 protocol: Optional[int] = 2,
313 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
314 event_dispatcher: Optional[EventDispatcher] = None,
315 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
316 ) -> None:
317 if db:
318 raise RedisClusterException(
319 "Argument 'db' must be 0 or None in cluster mode"
320 )
321
322 if path:
323 raise RedisClusterException(
324 "Unix domain socket is not supported in cluster mode"
325 )
326
327 if (not host or not port) and not startup_nodes:
328 raise RedisClusterException(
329 "RedisCluster requires at least one node to discover the cluster.\n"
330 "Please provide one of the following or use RedisCluster.from_url:\n"
331 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
332 " - startup_nodes: RedisCluster(startup_nodes=["
333 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
334 )
335
336 kwargs: Dict[str, Any] = {
337 "max_connections": max_connections,
338 "connection_class": Connection,
339 # Client related kwargs
340 "credential_provider": credential_provider,
341 "username": username,
342 "password": password,
343 "client_name": client_name,
344 "lib_name": lib_name,
345 "lib_version": lib_version,
346 # Encoding related kwargs
347 "encoding": encoding,
348 "encoding_errors": encoding_errors,
349 "decode_responses": decode_responses,
350 # Connection related kwargs
351 "health_check_interval": health_check_interval,
352 "socket_connect_timeout": socket_connect_timeout,
353 "socket_keepalive": socket_keepalive,
354 "socket_keepalive_options": socket_keepalive_options,
355 "socket_timeout": socket_timeout,
356 "protocol": protocol,
357 }
358
359 if ssl:
360 # SSL related kwargs
361 kwargs.update(
362 {
363 "connection_class": SSLConnection,
364 "ssl_ca_certs": ssl_ca_certs,
365 "ssl_ca_data": ssl_ca_data,
366 "ssl_cert_reqs": ssl_cert_reqs,
367 "ssl_include_verify_flags": ssl_include_verify_flags,
368 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
369 "ssl_certfile": ssl_certfile,
370 "ssl_check_hostname": ssl_check_hostname,
371 "ssl_keyfile": ssl_keyfile,
372 "ssl_min_version": ssl_min_version,
373 "ssl_ciphers": ssl_ciphers,
374 }
375 )
376
377 if read_from_replicas or load_balancing_strategy:
378 # Call our on_connect function to configure READONLY mode
379 kwargs["redis_connect_func"] = self.on_connect
380
381 if retry:
382 self.retry = retry
383 else:
384 self.retry = Retry(
385 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
386 retries=cluster_error_retry_attempts,
387 )
388 if retry_on_error:
389 self.retry.update_supported_errors(retry_on_error)
390
391 kwargs["response_callbacks"] = _RedisCallbacks.copy()
392 if kwargs.get("protocol") in ["3", 3]:
393 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
394 else:
395 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
396 self.connection_kwargs = kwargs
397
398 if startup_nodes:
399 passed_nodes = []
400 for node in startup_nodes:
401 passed_nodes.append(
402 ClusterNode(node.host, node.port, **self.connection_kwargs)
403 )
404 startup_nodes = passed_nodes
405 else:
406 startup_nodes = []
407 if host and port:
408 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
409
410 if event_dispatcher is None:
411 self._event_dispatcher = EventDispatcher()
412 else:
413 self._event_dispatcher = event_dispatcher
414
415 self.startup_nodes = startup_nodes
416 self.nodes_manager = NodesManager(
417 startup_nodes,
418 require_full_coverage,
419 kwargs,
420 dynamic_startup_nodes=dynamic_startup_nodes,
421 address_remap=address_remap,
422 event_dispatcher=self._event_dispatcher,
423 )
424 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
425 self.read_from_replicas = read_from_replicas
426 self.load_balancing_strategy = load_balancing_strategy
427 self.reinitialize_steps = reinitialize_steps
428 self.reinitialize_counter = 0
429
430 # For backward compatibility, mapping from existing policies to new one
431 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
432 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
433 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
434 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
435 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
436 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
437 SLOT_ID: RequestPolicy.DEFAULT_KEYED,
438 }
439
440 self._policies_callback_mapping: dict[
441 Union[RequestPolicy, ResponsePolicy], Callable
442 ] = {
443 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
444 self.get_random_primary_or_all_nodes(command_name)
445 ],
446 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
447 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
448 RequestPolicy.ALL_SHARDS: self.get_primaries,
449 RequestPolicy.ALL_NODES: self.get_nodes,
450 RequestPolicy.ALL_REPLICAS: self.get_replicas,
451 RequestPolicy.SPECIAL: self.get_special_nodes,
452 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
453 ResponsePolicy.DEFAULT_KEYED: lambda res: res,
454 }
455
456 self._policy_resolver = policy_resolver
457 self.commands_parser = AsyncCommandsParser()
458 self._aggregate_nodes = None
459 self.node_flags = self.__class__.NODE_FLAGS.copy()
460 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
461 self.response_callbacks = kwargs["response_callbacks"]
462 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
463 self.result_callbacks["CLUSTER SLOTS"] = (
464 lambda cmd, res, **kwargs: parse_cluster_slots(
465 list(res.values())[0], **kwargs
466 )
467 )
468
469 self._initialize = True
470 self._lock: Optional[asyncio.Lock] = None
471
472 # When used as an async context manager, we need to increment and decrement
473 # a usage counter so that we can close the connection pool when no one is
474 # using the client.
475 self._usage_counter = 0
476 self._usage_lock = asyncio.Lock()
477
478 async def initialize(self) -> "RedisCluster":
479 """Get all nodes from startup nodes & creates connections if not initialized."""
480 if self._initialize:
481 if not self._lock:
482 self._lock = asyncio.Lock()
483 async with self._lock:
484 if self._initialize:
485 try:
486 await self.nodes_manager.initialize()
487 await self.commands_parser.initialize(
488 self.nodes_manager.default_node
489 )
490 self._initialize = False
491 except BaseException:
492 await self.nodes_manager.aclose()
493 await self.nodes_manager.aclose("startup_nodes")
494 raise
495 return self
496
497 async def aclose(self) -> None:
498 """Close all connections & client if initialized."""
499 if not self._initialize:
500 if not self._lock:
501 self._lock = asyncio.Lock()
502 async with self._lock:
503 if not self._initialize:
504 self._initialize = True
505 await self.nodes_manager.aclose()
506 await self.nodes_manager.aclose("startup_nodes")
507
508 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
509 async def close(self) -> None:
510 """alias for aclose() for backwards compatibility"""
511 await self.aclose()
512
513 async def __aenter__(self) -> "RedisCluster":
514 """
515 Async context manager entry. Increments a usage counter so that the
516 connection pool is only closed (via aclose()) when no context is using
517 the client.
518 """
519 await self._increment_usage()
520 try:
521 # Initialize the client (i.e. establish connection, etc.)
522 return await self.initialize()
523 except Exception:
524 # If initialization fails, decrement the counter to keep it in sync
525 await self._decrement_usage()
526 raise
527
528 async def _increment_usage(self) -> int:
529 """
530 Helper coroutine to increment the usage counter while holding the lock.
531 Returns the new value of the usage counter.
532 """
533 async with self._usage_lock:
534 self._usage_counter += 1
535 return self._usage_counter
536
537 async def _decrement_usage(self) -> int:
538 """
539 Helper coroutine to decrement the usage counter while holding the lock.
540 Returns the new value of the usage counter.
541 """
542 async with self._usage_lock:
543 self._usage_counter -= 1
544 return self._usage_counter
545
546 async def __aexit__(self, exc_type, exc_value, traceback):
547 """
548 Async context manager exit. Decrements a usage counter. If this is the
549 last exit (counter becomes zero), the client closes its connection pool.
550 """
551 current_usage = await asyncio.shield(self._decrement_usage())
552 if current_usage == 0:
553 # This was the last active context, so disconnect the pool.
554 await asyncio.shield(self.aclose())
555
556 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
557 return self.initialize().__await__()
558
559 _DEL_MESSAGE = "Unclosed RedisCluster client"
560
561 def __del__(
562 self,
563 _warn: Any = warnings.warn,
564 _grl: Any = asyncio.get_running_loop,
565 ) -> None:
566 if hasattr(self, "_initialize") and not self._initialize:
567 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
568 try:
569 context = {"client": self, "message": self._DEL_MESSAGE}
570 _grl().call_exception_handler(context)
571 except RuntimeError:
572 pass
573
574 async def on_connect(self, connection: Connection) -> None:
575 await connection.on_connect()
576
577 # Sending READONLY command to server to configure connection as
578 # readonly. Since each cluster node may change its server type due
579 # to a failover, we should establish a READONLY connection
580 # regardless of the server type. If this is a primary connection,
581 # READONLY would not affect executing write commands.
582 await connection.send_command("READONLY")
583 if str_if_bytes(await connection.read_response()) != "OK":
584 raise ConnectionError("READONLY command failed")
585
586 def get_nodes(self) -> List["ClusterNode"]:
587 """Get all nodes of the cluster."""
588 return list(self.nodes_manager.nodes_cache.values())
589
590 def get_primaries(self) -> List["ClusterNode"]:
591 """Get the primary nodes of the cluster."""
592 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
593
594 def get_replicas(self) -> List["ClusterNode"]:
595 """Get the replica nodes of the cluster."""
596 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
597
598 def get_random_node(self) -> "ClusterNode":
599 """Get a random node of the cluster."""
600 return random.choice(list(self.nodes_manager.nodes_cache.values()))
601
602 def get_default_node(self) -> "ClusterNode":
603 """Get the default node of the client."""
604 return self.nodes_manager.default_node
605
606 def set_default_node(self, node: "ClusterNode") -> None:
607 """
608 Set the default node of the client.
609
610 :raises DataError: if None is passed or node does not exist in cluster.
611 """
612 if not node or not self.get_node(node_name=node.name):
613 raise DataError("The requested node does not exist in the cluster.")
614
615 self.nodes_manager.default_node = node
616
617 def get_node(
618 self,
619 host: Optional[str] = None,
620 port: Optional[int] = None,
621 node_name: Optional[str] = None,
622 ) -> Optional["ClusterNode"]:
623 """Get node by (host, port) or node_name."""
624 return self.nodes_manager.get_node(host, port, node_name)
625
626 def get_node_from_key(
627 self, key: str, replica: bool = False
628 ) -> Optional["ClusterNode"]:
629 """
630 Get the cluster node corresponding to the provided key.
631
632 :param key:
633 :param replica:
634 | Indicates if a replica should be returned
635 |
636 None will returned if no replica holds this key
637
638 :raises SlotNotCoveredError: if the key is not covered by any slot.
639 """
640 slot = self.keyslot(key)
641 slot_cache = self.nodes_manager.slots_cache.get(slot)
642 if not slot_cache:
643 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
644
645 if replica:
646 if len(self.nodes_manager.slots_cache[slot]) < 2:
647 return None
648 node_idx = 1
649 else:
650 node_idx = 0
651
652 return slot_cache[node_idx]
653
654 def get_random_primary_or_all_nodes(self, command_name):
655 """
656 Returns random primary or all nodes depends on READONLY mode.
657 """
658 if self.read_from_replicas and command_name in READ_COMMANDS:
659 return self.get_random_node()
660
661 return self.get_random_primary_node()
662
663 def get_random_primary_node(self) -> "ClusterNode":
664 """
665 Returns a random primary node
666 """
667 return random.choice(self.get_primaries())
668
669 async def get_nodes_from_slot(self, command: str, *args):
670 """
671 Returns a list of nodes that hold the specified keys' slots.
672 """
673 # get the node that holds the key's slot
674 return [
675 self.nodes_manager.get_node_from_slot(
676 await self._determine_slot(command, *args),
677 self.read_from_replicas and command in READ_COMMANDS,
678 self.load_balancing_strategy if command in READ_COMMANDS else None,
679 )
680 ]
681
682 def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
683 """
684 Returns a list of nodes for commands with a special policy.
685 """
686 if not self._aggregate_nodes:
687 raise RedisClusterException(
688 "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
689 )
690
691 return self._aggregate_nodes
692
693 def keyslot(self, key: EncodableT) -> int:
694 """
695 Find the keyslot for a given key.
696
697 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
698 """
699 return key_slot(self.encoder.encode(key))
700
701 def get_encoder(self) -> Encoder:
702 """Get the encoder object of the client."""
703 return self.encoder
704
705 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
706 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
707 return self.connection_kwargs
708
709 def set_retry(self, retry: Retry) -> None:
710 self.retry = retry
711
712 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
713 """Set a custom response callback."""
714 self.response_callbacks[command] = callback
715
716 async def _determine_nodes(
717 self,
718 command: str,
719 *args: Any,
720 request_policy: RequestPolicy,
721 node_flag: Optional[str] = None,
722 ) -> List["ClusterNode"]:
723 # Determine which nodes should be executed the command on.
724 # Returns a list of target nodes.
725 if not node_flag:
726 # get the nodes group for this command if it was predefined
727 node_flag = self.command_flags.get(command)
728
729 if node_flag in self._command_flags_mapping:
730 request_policy = self._command_flags_mapping[node_flag]
731
732 policy_callback = self._policies_callback_mapping[request_policy]
733
734 if request_policy == RequestPolicy.DEFAULT_KEYED:
735 nodes = await policy_callback(command, *args)
736 elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
737 nodes = policy_callback(command)
738 else:
739 nodes = policy_callback()
740
741 if command.lower() == "ft.aggregate":
742 self._aggregate_nodes = nodes
743
744 return nodes
745
746 async def _determine_slot(self, command: str, *args: Any) -> int:
747 if self.command_flags.get(command) == SLOT_ID:
748 # The command contains the slot ID
749 return int(args[0])
750
751 # Get the keys in the command
752
753 # EVAL and EVALSHA are common enough that it's wasteful to go to the
754 # redis server to parse the keys. Besides, there is a bug in redis<7.0
755 # where `self._get_command_keys()` fails anyway. So, we special case
756 # EVAL/EVALSHA.
757 # - issue: https://github.com/redis/redis/issues/9493
758 # - fix: https://github.com/redis/redis/pull/9733
759 if command.upper() in ("EVAL", "EVALSHA"):
760 # command syntax: EVAL "script body" num_keys ...
761 if len(args) < 2:
762 raise RedisClusterException(
763 f"Invalid args in command: {command, *args}"
764 )
765 keys = args[2 : 2 + int(args[1])]
766 # if there are 0 keys, that means the script can be run on any node
767 # so we can just return a random slot
768 if not keys:
769 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
770 else:
771 keys = await self.commands_parser.get_keys(command, *args)
772 if not keys:
773 # FCALL can call a function with 0 keys, that means the function
774 # can be run on any node so we can just return a random slot
775 if command.upper() in ("FCALL", "FCALL_RO"):
776 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
777 raise RedisClusterException(
778 "No way to dispatch this command to Redis Cluster. "
779 "Missing key.\nYou can execute the command by specifying "
780 f"target nodes.\nCommand: {args}"
781 )
782
783 # single key command
784 if len(keys) == 1:
785 return self.keyslot(keys[0])
786
787 # multi-key command; we need to make sure all keys are mapped to
788 # the same slot
789 slots = {self.keyslot(key) for key in keys}
790 if len(slots) != 1:
791 raise RedisClusterException(
792 f"{command} - all keys must map to the same key slot"
793 )
794
795 return slots.pop()
796
797 def _is_node_flag(self, target_nodes: Any) -> bool:
798 return isinstance(target_nodes, str) and target_nodes in self.node_flags
799
800 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
801 if isinstance(target_nodes, list):
802 nodes = target_nodes
803 elif isinstance(target_nodes, ClusterNode):
804 # Supports passing a single ClusterNode as a variable
805 nodes = [target_nodes]
806 elif isinstance(target_nodes, dict):
807 # Supports dictionaries of the format {node_name: node}.
808 # It enables to execute commands with multi nodes as follows:
809 # rc.cluster_save_config(rc.get_primaries())
810 nodes = list(target_nodes.values())
811 else:
812 raise TypeError(
813 "target_nodes type can be one of the following: "
814 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
815 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
816 f"The passed type is {type(target_nodes)}"
817 )
818 return nodes
819
820 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
821 """
822 Execute a raw command on the appropriate cluster node or target_nodes.
823
824 It will retry the command as specified by the retries property of
825 the :attr:`retry` & then raise an exception.
826
827 :param args:
828 | Raw command args
829 :param kwargs:
830
831 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
832 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
833 - Rest of the kwargs are passed to the Redis connection
834
835 :raises RedisClusterException: if target_nodes is not provided & the command
836 can't be mapped to a slot
837 """
838 command = args[0]
839 target_nodes = []
840 target_nodes_specified = False
841 retry_attempts = self.retry.get_retries()
842
843 passed_targets = kwargs.pop("target_nodes", None)
844 if passed_targets and not self._is_node_flag(passed_targets):
845 target_nodes = self._parse_target_nodes(passed_targets)
846 target_nodes_specified = True
847 retry_attempts = 0
848
849 command_policies = await self._policy_resolver.resolve(args[0].lower())
850
851 if not command_policies and not target_nodes_specified:
852 command_flag = self.command_flags.get(command)
853 if not command_flag:
854 # Fallback to default policy
855 if not self.get_default_node():
856 slot = None
857 else:
858 slot = await self._determine_slot(*args)
859 if not slot:
860 command_policies = CommandPolicies()
861 else:
862 command_policies = CommandPolicies(
863 request_policy=RequestPolicy.DEFAULT_KEYED,
864 response_policy=ResponsePolicy.DEFAULT_KEYED,
865 )
866 else:
867 if command_flag in self._command_flags_mapping:
868 command_policies = CommandPolicies(
869 request_policy=self._command_flags_mapping[command_flag]
870 )
871 else:
872 command_policies = CommandPolicies()
873 elif not command_policies and target_nodes_specified:
874 command_policies = CommandPolicies()
875
876 # Add one for the first execution
877 execute_attempts = 1 + retry_attempts
878 for _ in range(execute_attempts):
879 if self._initialize:
880 await self.initialize()
881 if (
882 len(target_nodes) == 1
883 and target_nodes[0] == self.get_default_node()
884 ):
885 # Replace the default cluster node
886 self.replace_default_node()
887 try:
888 if not target_nodes_specified:
889 # Determine the nodes to execute the command on
890 target_nodes = await self._determine_nodes(
891 *args,
892 request_policy=command_policies.request_policy,
893 node_flag=passed_targets,
894 )
895 if not target_nodes:
896 raise RedisClusterException(
897 f"No targets were found to execute {args} command on"
898 )
899
900 if len(target_nodes) == 1:
901 # Return the processed result
902 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
903 if command in self.result_callbacks:
904 ret = self.result_callbacks[command](
905 command, {target_nodes[0].name: ret}, **kwargs
906 )
907 return self._policies_callback_mapping[
908 command_policies.response_policy
909 ](ret)
910 else:
911 keys = [node.name for node in target_nodes]
912 values = await asyncio.gather(
913 *(
914 asyncio.create_task(
915 self._execute_command(node, *args, **kwargs)
916 )
917 for node in target_nodes
918 )
919 )
920 if command in self.result_callbacks:
921 return self.result_callbacks[command](
922 command, dict(zip(keys, values)), **kwargs
923 )
924 return self._policies_callback_mapping[
925 command_policies.response_policy
926 ](dict(zip(keys, values)))
927 except Exception as e:
928 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
929 # The nodes and slots cache were should be reinitialized.
930 # Try again with the new cluster setup.
931 retry_attempts -= 1
932 continue
933 else:
934 # raise the exception
935 raise e
936
937 async def _execute_command(
938 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
939 ) -> Any:
940 asking = moved = False
941 redirect_addr = None
942 ttl = self.RedisClusterRequestTTL
943
944 while ttl > 0:
945 ttl -= 1
946 try:
947 if asking:
948 target_node = self.get_node(node_name=redirect_addr)
949 await target_node.execute_command("ASKING")
950 asking = False
951 elif moved:
952 # MOVED occurred and the slots cache was updated,
953 # refresh the target node
954 slot = await self._determine_slot(*args)
955 target_node = self.nodes_manager.get_node_from_slot(
956 slot,
957 self.read_from_replicas and args[0] in READ_COMMANDS,
958 self.load_balancing_strategy
959 if args[0] in READ_COMMANDS
960 else None,
961 )
962 moved = False
963
964 return await target_node.execute_command(*args, **kwargs)
965 except BusyLoadingError:
966 raise
967 except MaxConnectionsError:
968 # MaxConnectionsError indicates client-side resource exhaustion
969 # (too many connections in the pool), not a node failure.
970 # Don't treat this as a node failure - just re-raise the error
971 # without reinitializing the cluster.
972 raise
973 except (ConnectionError, TimeoutError):
974 # Connection retries are being handled in the node's
975 # Retry object.
976 # Remove the failed node from the startup nodes before we try
977 # to reinitialize the cluster
978 self.nodes_manager.startup_nodes.pop(target_node.name, None)
979 # Hard force of reinitialize of the node/slots setup
980 # and try again with the new setup
981 await self.aclose()
982 raise
983 except (ClusterDownError, SlotNotCoveredError):
984 # ClusterDownError can occur during a failover and to get
985 # self-healed, we will try to reinitialize the cluster layout
986 # and retry executing the command
987
988 # SlotNotCoveredError can occur when the cluster is not fully
989 # initialized or can be temporary issue.
990 # We will try to reinitialize the cluster topology
991 # and retry executing the command
992
993 await self.aclose()
994 await asyncio.sleep(0.25)
995 raise
996 except MovedError as e:
997 # First, we will try to patch the slots/nodes cache with the
998 # redirected node output and try again. If MovedError exceeds
999 # 'reinitialize_steps' number of times, we will force
1000 # reinitializing the tables, and then try again.
1001 # 'reinitialize_steps' counter will increase faster when
1002 # the same client object is shared between multiple threads. To
1003 # reduce the frequency you can set this variable in the
1004 # RedisCluster constructor.
1005 self.reinitialize_counter += 1
1006 if (
1007 self.reinitialize_steps
1008 and self.reinitialize_counter % self.reinitialize_steps == 0
1009 ):
1010 await self.aclose()
1011 # Reset the counter
1012 self.reinitialize_counter = 0
1013 else:
1014 self.nodes_manager._moved_exception = e
1015 moved = True
1016 except AskError as e:
1017 redirect_addr = get_node_name(host=e.host, port=e.port)
1018 asking = True
1019 except TryAgainError:
1020 if ttl < self.RedisClusterRequestTTL / 2:
1021 await asyncio.sleep(0.05)
1022
1023 raise ClusterError("TTL exhausted.")
1024
1025 def pipeline(
1026 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
1027 ) -> "ClusterPipeline":
1028 """
1029 Create & return a new :class:`~.ClusterPipeline` object.
1030
1031 Cluster implementation of pipeline does not support transaction or shard_hint.
1032
1033 :raises RedisClusterException: if transaction or shard_hint are truthy values
1034 """
1035 if shard_hint:
1036 raise RedisClusterException("shard_hint is deprecated in cluster mode")
1037
1038 return ClusterPipeline(self, transaction)
1039
1040 def lock(
1041 self,
1042 name: KeyT,
1043 timeout: Optional[float] = None,
1044 sleep: float = 0.1,
1045 blocking: bool = True,
1046 blocking_timeout: Optional[float] = None,
1047 lock_class: Optional[Type[Lock]] = None,
1048 thread_local: bool = True,
1049 raise_on_release_error: bool = True,
1050 ) -> Lock:
1051 """
1052 Return a new Lock object using key ``name`` that mimics
1053 the behavior of threading.Lock.
1054
1055 If specified, ``timeout`` indicates a maximum life for the lock.
1056 By default, it will remain locked until release() is called.
1057
1058 ``sleep`` indicates the amount of time to sleep per loop iteration
1059 when the lock is in blocking mode and another client is currently
1060 holding the lock.
1061
1062 ``blocking`` indicates whether calling ``acquire`` should block until
1063 the lock has been acquired or to fail immediately, causing ``acquire``
1064 to return False and the lock not being acquired. Defaults to True.
1065 Note this value can be overridden by passing a ``blocking``
1066 argument to ``acquire``.
1067
1068 ``blocking_timeout`` indicates the maximum amount of time in seconds to
1069 spend trying to acquire the lock. A value of ``None`` indicates
1070 continue trying forever. ``blocking_timeout`` can be specified as a
1071 float or integer, both representing the number of seconds to wait.
1072
1073 ``lock_class`` forces the specified lock implementation. Note that as
1074 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
1075 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
1076 you have created your own custom lock class.
1077
1078 ``thread_local`` indicates whether the lock token is placed in
1079 thread-local storage. By default, the token is placed in thread local
1080 storage so that a thread only sees its token, not a token set by
1081 another thread. Consider the following timeline:
1082
1083 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
1084 thread-1 sets the token to "abc"
1085 time: 1, thread-2 blocks trying to acquire `my-lock` using the
1086 Lock instance.
1087 time: 5, thread-1 has not yet completed. redis expires the lock
1088 key.
1089 time: 5, thread-2 acquired `my-lock` now that it's available.
1090 thread-2 sets the token to "xyz"
1091 time: 6, thread-1 finishes its work and calls release(). if the
1092 token is *not* stored in thread local storage, then
1093 thread-1 would see the token value as "xyz" and would be
1094 able to successfully release the thread-2's lock.
1095
1096 ``raise_on_release_error`` indicates whether to raise an exception when
1097 the lock is no longer owned when exiting the context manager. By default,
1098 this is True, meaning an exception will be raised. If False, the warning
1099 will be logged and the exception will be suppressed.
1100
1101 In some use cases it's necessary to disable thread local storage. For
1102 example, if you have code where one thread acquires a lock and passes
1103 that lock instance to a worker thread to release later. If thread
1104 local storage isn't disabled in this case, the worker thread won't see
1105 the token set by the thread that acquired the lock. Our assumption
1106 is that these cases aren't common and as such default to using
1107 thread local storage."""
1108 if lock_class is None:
1109 lock_class = Lock
1110 return lock_class(
1111 self,
1112 name,
1113 timeout=timeout,
1114 sleep=sleep,
1115 blocking=blocking,
1116 blocking_timeout=blocking_timeout,
1117 thread_local=thread_local,
1118 raise_on_release_error=raise_on_release_error,
1119 )
1120
1121 async def transaction(
1122 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1123 ):
1124 """
1125 Convenience method for executing the callable `func` as a transaction
1126 while watching all keys specified in `watches`. The 'func' callable
1127 should expect a single argument which is a Pipeline object.
1128 """
1129 shard_hint = kwargs.pop("shard_hint", None)
1130 value_from_callable = kwargs.pop("value_from_callable", False)
1131 watch_delay = kwargs.pop("watch_delay", None)
1132 async with self.pipeline(True, shard_hint) as pipe:
1133 while True:
1134 try:
1135 if watches:
1136 await pipe.watch(*watches)
1137 func_value = await func(pipe)
1138 exec_value = await pipe.execute()
1139 return func_value if value_from_callable else exec_value
1140 except WatchError:
1141 if watch_delay is not None and watch_delay > 0:
1142 time.sleep(watch_delay)
1143 continue
1144
1145
1146class ClusterNode:
1147 """
1148 Create a new ClusterNode.
1149
1150 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1151 objects for the (host, port).
1152 """
1153
1154 __slots__ = (
1155 "_connections",
1156 "_free",
1157 "_lock",
1158 "_event_dispatcher",
1159 "connection_class",
1160 "connection_kwargs",
1161 "host",
1162 "max_connections",
1163 "name",
1164 "port",
1165 "response_callbacks",
1166 "server_type",
1167 )
1168
1169 def __init__(
1170 self,
1171 host: str,
1172 port: Union[str, int],
1173 server_type: Optional[str] = None,
1174 *,
1175 max_connections: int = 2**31,
1176 connection_class: Type[Connection] = Connection,
1177 **connection_kwargs: Any,
1178 ) -> None:
1179 if host == "localhost":
1180 host = socket.gethostbyname(host)
1181
1182 connection_kwargs["host"] = host
1183 connection_kwargs["port"] = port
1184 self.host = host
1185 self.port = port
1186 self.name = get_node_name(host, port)
1187 self.server_type = server_type
1188
1189 self.max_connections = max_connections
1190 self.connection_class = connection_class
1191 self.connection_kwargs = connection_kwargs
1192 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1193
1194 self._connections: List[Connection] = []
1195 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1196 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1197 if self._event_dispatcher is None:
1198 self._event_dispatcher = EventDispatcher()
1199
1200 def __repr__(self) -> str:
1201 return (
1202 f"[host={self.host}, port={self.port}, "
1203 f"name={self.name}, server_type={self.server_type}]"
1204 )
1205
1206 def __eq__(self, obj: Any) -> bool:
1207 return isinstance(obj, ClusterNode) and obj.name == self.name
1208
1209 _DEL_MESSAGE = "Unclosed ClusterNode object"
1210
1211 def __del__(
1212 self,
1213 _warn: Any = warnings.warn,
1214 _grl: Any = asyncio.get_running_loop,
1215 ) -> None:
1216 for connection in self._connections:
1217 if connection.is_connected:
1218 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1219
1220 try:
1221 context = {"client": self, "message": self._DEL_MESSAGE}
1222 _grl().call_exception_handler(context)
1223 except RuntimeError:
1224 pass
1225 break
1226
1227 async def disconnect(self) -> None:
1228 ret = await asyncio.gather(
1229 *(
1230 asyncio.create_task(connection.disconnect())
1231 for connection in self._connections
1232 ),
1233 return_exceptions=True,
1234 )
1235 exc = next((res for res in ret if isinstance(res, Exception)), None)
1236 if exc:
1237 raise exc
1238
1239 def acquire_connection(self) -> Connection:
1240 try:
1241 return self._free.popleft()
1242 except IndexError:
1243 if len(self._connections) < self.max_connections:
1244 # We are configuring the connection pool not to retry
1245 # connections on lower level clients to avoid retrying
1246 # connections to nodes that are not reachable
1247 # and to avoid blocking the connection pool.
1248 # The only error that will have some handling in the lower
1249 # level clients is ConnectionError which will trigger disconnection
1250 # of the socket.
1251 # The retries will be handled on cluster client level
1252 # where we will have proper handling of the cluster topology
1253 retry = Retry(
1254 backoff=NoBackoff(),
1255 retries=0,
1256 supported_errors=(ConnectionError,),
1257 )
1258 connection_kwargs = self.connection_kwargs.copy()
1259 connection_kwargs["retry"] = retry
1260 connection = self.connection_class(**connection_kwargs)
1261 self._connections.append(connection)
1262 return connection
1263
1264 raise MaxConnectionsError()
1265
1266 def release(self, connection: Connection) -> None:
1267 """
1268 Release connection back to free queue.
1269 """
1270 self._free.append(connection)
1271
1272 async def parse_response(
1273 self, connection: Connection, command: str, **kwargs: Any
1274 ) -> Any:
1275 try:
1276 if NEVER_DECODE in kwargs:
1277 response = await connection.read_response(disable_decoding=True)
1278 kwargs.pop(NEVER_DECODE)
1279 else:
1280 response = await connection.read_response()
1281 except ResponseError:
1282 if EMPTY_RESPONSE in kwargs:
1283 return kwargs[EMPTY_RESPONSE]
1284 raise
1285
1286 if EMPTY_RESPONSE in kwargs:
1287 kwargs.pop(EMPTY_RESPONSE)
1288
1289 # Remove keys entry, it needs only for cache.
1290 kwargs.pop("keys", None)
1291
1292 # Return response
1293 if command in self.response_callbacks:
1294 return self.response_callbacks[command](response, **kwargs)
1295
1296 return response
1297
1298 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1299 # Acquire connection
1300 connection = self.acquire_connection()
1301
1302 # Execute command
1303 await connection.send_packed_command(connection.pack_command(*args), False)
1304
1305 # Read response
1306 try:
1307 return await self.parse_response(connection, args[0], **kwargs)
1308 finally:
1309 # Release connection
1310 self._free.append(connection)
1311
1312 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1313 # Acquire connection
1314 connection = self.acquire_connection()
1315
1316 # Execute command
1317 await connection.send_packed_command(
1318 connection.pack_commands(cmd.args for cmd in commands), False
1319 )
1320
1321 # Read responses
1322 ret = False
1323 for cmd in commands:
1324 try:
1325 cmd.result = await self.parse_response(
1326 connection, cmd.args[0], **cmd.kwargs
1327 )
1328 except Exception as e:
1329 cmd.result = e
1330 ret = True
1331
1332 # Release connection
1333 self._free.append(connection)
1334
1335 return ret
1336
1337 async def re_auth_callback(self, token: TokenInterface):
1338 tmp_queue = collections.deque()
1339 while self._free:
1340 conn = self._free.popleft()
1341 await conn.retry.call_with_retry(
1342 lambda: conn.send_command(
1343 "AUTH", token.try_get("oid"), token.get_value()
1344 ),
1345 lambda error: self._mock(error),
1346 )
1347 await conn.retry.call_with_retry(
1348 lambda: conn.read_response(), lambda error: self._mock(error)
1349 )
1350 tmp_queue.append(conn)
1351
1352 while tmp_queue:
1353 conn = tmp_queue.popleft()
1354 self._free.append(conn)
1355
1356 async def _mock(self, error: RedisError):
1357 """
1358 Dummy functions, needs to be passed as error callback to retry object.
1359 :param error:
1360 :return:
1361 """
1362 pass
1363
1364
1365class NodesManager:
1366 __slots__ = (
1367 "_dynamic_startup_nodes",
1368 "_moved_exception",
1369 "_event_dispatcher",
1370 "connection_kwargs",
1371 "default_node",
1372 "nodes_cache",
1373 "read_load_balancer",
1374 "require_full_coverage",
1375 "slots_cache",
1376 "startup_nodes",
1377 "address_remap",
1378 )
1379
1380 def __init__(
1381 self,
1382 startup_nodes: List["ClusterNode"],
1383 require_full_coverage: bool,
1384 connection_kwargs: Dict[str, Any],
1385 dynamic_startup_nodes: bool = True,
1386 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1387 event_dispatcher: Optional[EventDispatcher] = None,
1388 ) -> None:
1389 self.startup_nodes = {node.name: node for node in startup_nodes}
1390 self.require_full_coverage = require_full_coverage
1391 self.connection_kwargs = connection_kwargs
1392 self.address_remap = address_remap
1393
1394 self.default_node: "ClusterNode" = None
1395 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1396 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1397 self.read_load_balancer = LoadBalancer()
1398
1399 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1400 self._moved_exception: MovedError = None
1401 if event_dispatcher is None:
1402 self._event_dispatcher = EventDispatcher()
1403 else:
1404 self._event_dispatcher = event_dispatcher
1405
1406 def get_node(
1407 self,
1408 host: Optional[str] = None,
1409 port: Optional[int] = None,
1410 node_name: Optional[str] = None,
1411 ) -> Optional["ClusterNode"]:
1412 if host and port:
1413 # the user passed host and port
1414 if host == "localhost":
1415 host = socket.gethostbyname(host)
1416 return self.nodes_cache.get(get_node_name(host=host, port=port))
1417 elif node_name:
1418 return self.nodes_cache.get(node_name)
1419 else:
1420 raise DataError(
1421 "get_node requires one of the following: 1. node name 2. host and port"
1422 )
1423
1424 def set_nodes(
1425 self,
1426 old: Dict[str, "ClusterNode"],
1427 new: Dict[str, "ClusterNode"],
1428 remove_old: bool = False,
1429 ) -> None:
1430 if remove_old:
1431 for name in list(old.keys()):
1432 if name not in new:
1433 task = asyncio.create_task(old.pop(name).disconnect()) # noqa
1434
1435 for name, node in new.items():
1436 if name in old:
1437 if old[name] is node:
1438 continue
1439 task = asyncio.create_task(old[name].disconnect()) # noqa
1440 old[name] = node
1441
1442 def update_moved_exception(self, exception):
1443 self._moved_exception = exception
1444
1445 def _update_moved_slots(self) -> None:
1446 e = self._moved_exception
1447 redirected_node = self.get_node(host=e.host, port=e.port)
1448 if redirected_node:
1449 # The node already exists
1450 if redirected_node.server_type != PRIMARY:
1451 # Update the node's server type
1452 redirected_node.server_type = PRIMARY
1453 else:
1454 # This is a new node, we will add it to the nodes cache
1455 redirected_node = ClusterNode(
1456 e.host, e.port, PRIMARY, **self.connection_kwargs
1457 )
1458 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1459 if redirected_node in self.slots_cache[e.slot_id]:
1460 # The MOVED error resulted from a failover, and the new slot owner
1461 # had previously been a replica.
1462 old_primary = self.slots_cache[e.slot_id][0]
1463 # Update the old primary to be a replica and add it to the end of
1464 # the slot's node list
1465 old_primary.server_type = REPLICA
1466 self.slots_cache[e.slot_id].append(old_primary)
1467 # Remove the old replica, which is now a primary, from the slot's
1468 # node list
1469 self.slots_cache[e.slot_id].remove(redirected_node)
1470 # Override the old primary with the new one
1471 self.slots_cache[e.slot_id][0] = redirected_node
1472 if self.default_node == old_primary:
1473 # Update the default node with the new primary
1474 self.default_node = redirected_node
1475 else:
1476 # The new slot owner is a new server, or a server from a different
1477 # shard. We need to remove all current nodes from the slot's list
1478 # (including replications) and add just the new node.
1479 self.slots_cache[e.slot_id] = [redirected_node]
1480 # Reset moved_exception
1481 self._moved_exception = None
1482
1483 def get_node_from_slot(
1484 self,
1485 slot: int,
1486 read_from_replicas: bool = False,
1487 load_balancing_strategy=None,
1488 ) -> "ClusterNode":
1489 if self._moved_exception:
1490 self._update_moved_slots()
1491
1492 if read_from_replicas is True and load_balancing_strategy is None:
1493 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1494
1495 try:
1496 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1497 # get the server index using the strategy defined in load_balancing_strategy
1498 primary_name = self.slots_cache[slot][0].name
1499 node_idx = self.read_load_balancer.get_server_index(
1500 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1501 )
1502 return self.slots_cache[slot][node_idx]
1503 return self.slots_cache[slot][0]
1504 except (IndexError, TypeError):
1505 raise SlotNotCoveredError(
1506 f'Slot "{slot}" not covered by the cluster. '
1507 f'"require_full_coverage={self.require_full_coverage}"'
1508 )
1509
1510 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1511 return [
1512 node
1513 for node in self.nodes_cache.values()
1514 if node.server_type == server_type
1515 ]
1516
1517 async def initialize(self) -> None:
1518 self.read_load_balancer.reset()
1519 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1520 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1521 disagreements = []
1522 startup_nodes_reachable = False
1523 fully_covered = False
1524 exception = None
1525 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1526 # is modified during iteration
1527 for startup_node in tuple(self.startup_nodes.values()):
1528 try:
1529 # Make sure cluster mode is enabled on this node
1530 try:
1531 self._event_dispatcher.dispatch(
1532 AfterAsyncClusterInstantiationEvent(
1533 self.nodes_cache,
1534 self.connection_kwargs.get("credential_provider", None),
1535 )
1536 )
1537 cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
1538 except ResponseError:
1539 raise RedisClusterException(
1540 "Cluster mode is not enabled on this node"
1541 )
1542 startup_nodes_reachable = True
1543 except Exception as e:
1544 # Try the next startup node.
1545 # The exception is saved and raised only if we have no more nodes.
1546 exception = e
1547 continue
1548
1549 # CLUSTER SLOTS command results in the following output:
1550 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1551 # where each node contains the following list: [IP, port, node_id]
1552 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1553 # primary node of the first slot section.
1554 # If there's only one server in the cluster, its ``host`` is ''
1555 # Fix it to the host in startup_nodes
1556 if (
1557 len(cluster_slots) == 1
1558 and not cluster_slots[0][2][0]
1559 and len(self.startup_nodes) == 1
1560 ):
1561 cluster_slots[0][2][0] = startup_node.host
1562
1563 for slot in cluster_slots:
1564 for i in range(2, len(slot)):
1565 slot[i] = [str_if_bytes(val) for val in slot[i]]
1566 primary_node = slot[2]
1567 host = primary_node[0]
1568 if host == "":
1569 host = startup_node.host
1570 port = int(primary_node[1])
1571 host, port = self.remap_host_port(host, port)
1572
1573 nodes_for_slot = []
1574
1575 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1576 if not target_node:
1577 target_node = ClusterNode(
1578 host, port, PRIMARY, **self.connection_kwargs
1579 )
1580 # add this node to the nodes cache
1581 tmp_nodes_cache[target_node.name] = target_node
1582 nodes_for_slot.append(target_node)
1583
1584 replica_nodes = slot[3:]
1585 for replica_node in replica_nodes:
1586 host = replica_node[0]
1587 port = replica_node[1]
1588 host, port = self.remap_host_port(host, port)
1589
1590 target_replica_node = tmp_nodes_cache.get(get_node_name(host, port))
1591 if not target_replica_node:
1592 target_replica_node = ClusterNode(
1593 host, port, REPLICA, **self.connection_kwargs
1594 )
1595 # add this node to the nodes cache
1596 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1597 nodes_for_slot.append(target_replica_node)
1598
1599 for i in range(int(slot[0]), int(slot[1]) + 1):
1600 if i not in tmp_slots:
1601 tmp_slots[i] = nodes_for_slot
1602 else:
1603 # Validate that 2 nodes want to use the same slot cache
1604 # setup
1605 tmp_slot = tmp_slots[i][0]
1606 if tmp_slot.name != target_node.name:
1607 disagreements.append(
1608 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1609 )
1610
1611 if len(disagreements) > 5:
1612 raise RedisClusterException(
1613 f"startup_nodes could not agree on a valid "
1614 f"slots cache: {', '.join(disagreements)}"
1615 )
1616
1617 # Validate if all slots are covered or if we should try next startup node
1618 fully_covered = True
1619 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1620 if i not in tmp_slots:
1621 fully_covered = False
1622 break
1623 if fully_covered:
1624 break
1625
1626 if not startup_nodes_reachable:
1627 raise RedisClusterException(
1628 f"Redis Cluster cannot be connected. Please provide at least "
1629 f"one reachable node: {str(exception)}"
1630 ) from exception
1631
1632 # Check if the slots are not fully covered
1633 if not fully_covered and self.require_full_coverage:
1634 # Despite the requirement that the slots be covered, there
1635 # isn't a full coverage
1636 raise RedisClusterException(
1637 f"All slots are not covered after query all startup_nodes. "
1638 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1639 f"covered..."
1640 )
1641
1642 # Set the tmp variables to the real variables
1643 self.slots_cache = tmp_slots
1644 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1645
1646 if self._dynamic_startup_nodes:
1647 # Populate the startup nodes with all discovered nodes
1648 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1649
1650 # Set the default node
1651 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
1652 # If initialize was called after a MovedError, clear it
1653 self._moved_exception = None
1654
1655 async def aclose(self, attr: str = "nodes_cache") -> None:
1656 self.default_node = None
1657 await asyncio.gather(
1658 *(
1659 asyncio.create_task(node.disconnect())
1660 for node in getattr(self, attr).values()
1661 )
1662 )
1663
1664 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1665 """
1666 Remap the host and port returned from the cluster to a different
1667 internal value. Useful if the client is not connecting directly
1668 to the cluster.
1669 """
1670 if self.address_remap:
1671 return self.address_remap((host, port))
1672 return host, port
1673
1674
1675class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1676 """
1677 Create a new ClusterPipeline object.
1678
1679 Usage::
1680
1681 result = await (
1682 rc.pipeline()
1683 .set("A", 1)
1684 .get("A")
1685 .hset("K", "F", "V")
1686 .hgetall("K")
1687 .mset_nonatomic({"A": 2, "B": 3})
1688 .get("A")
1689 .get("B")
1690 .delete("A", "B", "K")
1691 .execute()
1692 )
1693 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
1694
1695 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
1696 are split across multiple nodes, you'll get multiple results for them in the array.
1697
1698 Retryable errors:
1699 - :class:`~.ClusterDownError`
1700 - :class:`~.ConnectionError`
1701 - :class:`~.TimeoutError`
1702
1703 Redirection errors:
1704 - :class:`~.TryAgainError`
1705 - :class:`~.MovedError`
1706 - :class:`~.AskError`
1707
1708 :param client:
1709 | Existing :class:`~.RedisCluster` client
1710 """
1711
1712 __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1713
1714 def __init__(
1715 self, client: RedisCluster, transaction: Optional[bool] = None
1716 ) -> None:
1717 self.cluster_client = client
1718 self._transaction = transaction
1719 self._execution_strategy: ExecutionStrategy = (
1720 PipelineStrategy(self)
1721 if not self._transaction
1722 else TransactionStrategy(self)
1723 )
1724
1725 async def initialize(self) -> "ClusterPipeline":
1726 await self._execution_strategy.initialize()
1727 return self
1728
1729 async def __aenter__(self) -> "ClusterPipeline":
1730 return await self.initialize()
1731
1732 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1733 await self.reset()
1734
1735 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1736 return self.initialize().__await__()
1737
1738 def __bool__(self) -> bool:
1739 "Pipeline instances should always evaluate to True on Python 3+"
1740 return True
1741
1742 def __len__(self) -> int:
1743 return len(self._execution_strategy)
1744
1745 def execute_command(
1746 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1747 ) -> "ClusterPipeline":
1748 """
1749 Append a raw command to the pipeline.
1750
1751 :param args:
1752 | Raw command args
1753 :param kwargs:
1754
1755 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
1756 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1757 - Rest of the kwargs are passed to the Redis connection
1758 """
1759 return self._execution_strategy.execute_command(*args, **kwargs)
1760
1761 async def execute(
1762 self, raise_on_error: bool = True, allow_redirections: bool = True
1763 ) -> List[Any]:
1764 """
1765 Execute the pipeline.
1766
1767 It will retry the commands as specified by retries specified in :attr:`retry`
1768 & then raise an exception.
1769
1770 :param raise_on_error:
1771 | Raise the first error if there are any errors
1772 :param allow_redirections:
1773 | Whether to retry each failed command individually in case of redirection
1774 errors
1775
1776 :raises RedisClusterException: if target_nodes is not provided & the command
1777 can't be mapped to a slot
1778 """
1779 try:
1780 return await self._execution_strategy.execute(
1781 raise_on_error, allow_redirections
1782 )
1783 finally:
1784 await self.reset()
1785
1786 def _split_command_across_slots(
1787 self, command: str, *keys: KeyT
1788 ) -> "ClusterPipeline":
1789 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1790 self.execute_command(command, *slot_keys)
1791
1792 return self
1793
1794 async def reset(self):
1795 """
1796 Reset back to empty pipeline.
1797 """
1798 await self._execution_strategy.reset()
1799
1800 def multi(self):
1801 """
1802 Start a transactional block of the pipeline after WATCH commands
1803 are issued. End the transactional block with `execute`.
1804 """
1805 self._execution_strategy.multi()
1806
1807 async def discard(self):
1808 """ """
1809 await self._execution_strategy.discard()
1810
1811 async def watch(self, *names):
1812 """Watches the values at keys ``names``"""
1813 await self._execution_strategy.watch(*names)
1814
1815 async def unwatch(self):
1816 """Unwatches all previously specified keys"""
1817 await self._execution_strategy.unwatch()
1818
1819 async def unlink(self, *names):
1820 await self._execution_strategy.unlink(*names)
1821
1822 def mset_nonatomic(
1823 self, mapping: Mapping[AnyKeyT, EncodableT]
1824 ) -> "ClusterPipeline":
1825 return self._execution_strategy.mset_nonatomic(mapping)
1826
1827
1828for command in PIPELINE_BLOCKED_COMMANDS:
1829 command = command.replace(" ", "_").lower()
1830 if command == "mset_nonatomic":
1831 continue
1832
1833 setattr(ClusterPipeline, command, block_pipeline_command(command))
1834
1835
1836class PipelineCommand:
1837 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1838 self.args = args
1839 self.kwargs = kwargs
1840 self.position = position
1841 self.result: Union[Any, Exception] = None
1842 self.command_policies: Optional[CommandPolicies] = None
1843
1844 def __repr__(self) -> str:
1845 return f"[{self.position}] {self.args} ({self.kwargs})"
1846
1847
1848class ExecutionStrategy(ABC):
1849 @abstractmethod
1850 async def initialize(self) -> "ClusterPipeline":
1851 """
1852 Initialize the execution strategy.
1853
1854 See ClusterPipeline.initialize()
1855 """
1856 pass
1857
1858 @abstractmethod
1859 def execute_command(
1860 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1861 ) -> "ClusterPipeline":
1862 """
1863 Append a raw command to the pipeline.
1864
1865 See ClusterPipeline.execute_command()
1866 """
1867 pass
1868
1869 @abstractmethod
1870 async def execute(
1871 self, raise_on_error: bool = True, allow_redirections: bool = True
1872 ) -> List[Any]:
1873 """
1874 Execute the pipeline.
1875
1876 It will retry the commands as specified by retries specified in :attr:`retry`
1877 & then raise an exception.
1878
1879 See ClusterPipeline.execute()
1880 """
1881 pass
1882
1883 @abstractmethod
1884 def mset_nonatomic(
1885 self, mapping: Mapping[AnyKeyT, EncodableT]
1886 ) -> "ClusterPipeline":
1887 """
1888 Executes multiple MSET commands according to the provided slot/pairs mapping.
1889
1890 See ClusterPipeline.mset_nonatomic()
1891 """
1892 pass
1893
1894 @abstractmethod
1895 async def reset(self):
1896 """
1897 Resets current execution strategy.
1898
1899 See: ClusterPipeline.reset()
1900 """
1901 pass
1902
1903 @abstractmethod
1904 def multi(self):
1905 """
1906 Starts transactional context.
1907
1908 See: ClusterPipeline.multi()
1909 """
1910 pass
1911
1912 @abstractmethod
1913 async def watch(self, *names):
1914 """
1915 Watch given keys.
1916
1917 See: ClusterPipeline.watch()
1918 """
1919 pass
1920
1921 @abstractmethod
1922 async def unwatch(self):
1923 """
1924 Unwatches all previously specified keys
1925
1926 See: ClusterPipeline.unwatch()
1927 """
1928 pass
1929
1930 @abstractmethod
1931 async def discard(self):
1932 pass
1933
1934 @abstractmethod
1935 async def unlink(self, *names):
1936 """
1937 "Unlink a key specified by ``names``"
1938
1939 See: ClusterPipeline.unlink()
1940 """
1941 pass
1942
1943 @abstractmethod
1944 def __len__(self) -> int:
1945 pass
1946
1947
1948class AbstractStrategy(ExecutionStrategy):
1949 def __init__(self, pipe: ClusterPipeline) -> None:
1950 self._pipe: ClusterPipeline = pipe
1951 self._command_queue: List["PipelineCommand"] = []
1952
1953 async def initialize(self) -> "ClusterPipeline":
1954 if self._pipe.cluster_client._initialize:
1955 await self._pipe.cluster_client.initialize()
1956 self._command_queue = []
1957 return self._pipe
1958
1959 def execute_command(
1960 self, *args: Union[KeyT, EncodableT], **kwargs: Any
1961 ) -> "ClusterPipeline":
1962 self._command_queue.append(
1963 PipelineCommand(len(self._command_queue), *args, **kwargs)
1964 )
1965 return self._pipe
1966
1967 def _annotate_exception(self, exception, number, command):
1968 """
1969 Provides extra context to the exception prior to it being handled
1970 """
1971 cmd = " ".join(map(safe_str, command))
1972 msg = (
1973 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
1974 f"caused error: {exception.args[0]}"
1975 )
1976 exception.args = (msg,) + exception.args[1:]
1977
1978 @abstractmethod
1979 def mset_nonatomic(
1980 self, mapping: Mapping[AnyKeyT, EncodableT]
1981 ) -> "ClusterPipeline":
1982 pass
1983
1984 @abstractmethod
1985 async def execute(
1986 self, raise_on_error: bool = True, allow_redirections: bool = True
1987 ) -> List[Any]:
1988 pass
1989
1990 @abstractmethod
1991 async def reset(self):
1992 pass
1993
1994 @abstractmethod
1995 def multi(self):
1996 pass
1997
1998 @abstractmethod
1999 async def watch(self, *names):
2000 pass
2001
2002 @abstractmethod
2003 async def unwatch(self):
2004 pass
2005
2006 @abstractmethod
2007 async def discard(self):
2008 pass
2009
2010 @abstractmethod
2011 async def unlink(self, *names):
2012 pass
2013
2014 def __len__(self) -> int:
2015 return len(self._command_queue)
2016
2017
2018class PipelineStrategy(AbstractStrategy):
2019 def __init__(self, pipe: ClusterPipeline) -> None:
2020 super().__init__(pipe)
2021
2022 def mset_nonatomic(
2023 self, mapping: Mapping[AnyKeyT, EncodableT]
2024 ) -> "ClusterPipeline":
2025 encoder = self._pipe.cluster_client.encoder
2026
2027 slots_pairs = {}
2028 for pair in mapping.items():
2029 slot = key_slot(encoder.encode(pair[0]))
2030 slots_pairs.setdefault(slot, []).extend(pair)
2031
2032 for pairs in slots_pairs.values():
2033 self.execute_command("MSET", *pairs)
2034
2035 return self._pipe
2036
2037 async def execute(
2038 self, raise_on_error: bool = True, allow_redirections: bool = True
2039 ) -> List[Any]:
2040 if not self._command_queue:
2041 return []
2042
2043 try:
2044 retry_attempts = self._pipe.cluster_client.retry.get_retries()
2045 while True:
2046 try:
2047 if self._pipe.cluster_client._initialize:
2048 await self._pipe.cluster_client.initialize()
2049 return await self._execute(
2050 self._pipe.cluster_client,
2051 self._command_queue,
2052 raise_on_error=raise_on_error,
2053 allow_redirections=allow_redirections,
2054 )
2055
2056 except RedisCluster.ERRORS_ALLOW_RETRY as e:
2057 if retry_attempts > 0:
2058 # Try again with the new cluster setup. All other errors
2059 # should be raised.
2060 retry_attempts -= 1
2061 await self._pipe.cluster_client.aclose()
2062 await asyncio.sleep(0.25)
2063 else:
2064 # All other errors should be raised.
2065 raise e
2066 finally:
2067 await self.reset()
2068
2069 async def _execute(
2070 self,
2071 client: "RedisCluster",
2072 stack: List["PipelineCommand"],
2073 raise_on_error: bool = True,
2074 allow_redirections: bool = True,
2075 ) -> List[Any]:
2076 todo = [
2077 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
2078 ]
2079
2080 nodes = {}
2081 for cmd in todo:
2082 passed_targets = cmd.kwargs.pop("target_nodes", None)
2083 command_policies = await client._policy_resolver.resolve(
2084 cmd.args[0].lower()
2085 )
2086
2087 if passed_targets and not client._is_node_flag(passed_targets):
2088 target_nodes = client._parse_target_nodes(passed_targets)
2089
2090 if not command_policies:
2091 command_policies = CommandPolicies()
2092 else:
2093 if not command_policies:
2094 command_flag = client.command_flags.get(cmd.args[0])
2095 if not command_flag:
2096 # Fallback to default policy
2097 if not client.get_default_node():
2098 slot = None
2099 else:
2100 slot = await client._determine_slot(*cmd.args)
2101 if not slot:
2102 command_policies = CommandPolicies()
2103 else:
2104 command_policies = CommandPolicies(
2105 request_policy=RequestPolicy.DEFAULT_KEYED,
2106 response_policy=ResponsePolicy.DEFAULT_KEYED,
2107 )
2108 else:
2109 if command_flag in client._command_flags_mapping:
2110 command_policies = CommandPolicies(
2111 request_policy=client._command_flags_mapping[
2112 command_flag
2113 ]
2114 )
2115 else:
2116 command_policies = CommandPolicies()
2117
2118 target_nodes = await client._determine_nodes(
2119 *cmd.args,
2120 request_policy=command_policies.request_policy,
2121 node_flag=passed_targets,
2122 )
2123 if not target_nodes:
2124 raise RedisClusterException(
2125 f"No targets were found to execute {cmd.args} command on"
2126 )
2127 cmd.command_policies = command_policies
2128 if len(target_nodes) > 1:
2129 raise RedisClusterException(f"Too many targets for command {cmd.args}")
2130 node = target_nodes[0]
2131 if node.name not in nodes:
2132 nodes[node.name] = (node, [])
2133 nodes[node.name][1].append(cmd)
2134
2135 errors = await asyncio.gather(
2136 *(
2137 asyncio.create_task(node[0].execute_pipeline(node[1]))
2138 for node in nodes.values()
2139 )
2140 )
2141
2142 if any(errors):
2143 if allow_redirections:
2144 # send each errored command individually
2145 for cmd in todo:
2146 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2147 try:
2148 cmd.result = client._policies_callback_mapping[
2149 cmd.command_policies.response_policy
2150 ](await client.execute_command(*cmd.args, **cmd.kwargs))
2151 except Exception as e:
2152 cmd.result = e
2153
2154 if raise_on_error:
2155 for cmd in todo:
2156 result = cmd.result
2157 if isinstance(result, Exception):
2158 command = " ".join(map(safe_str, cmd.args))
2159 msg = (
2160 f"Command # {cmd.position + 1} "
2161 f"({truncate_text(command)}) "
2162 f"of pipeline caused error: {result.args}"
2163 )
2164 result.args = (msg,) + result.args[1:]
2165 raise result
2166
2167 default_cluster_node = client.get_default_node()
2168
2169 # Check whether the default node was used. In some cases,
2170 # 'client.get_default_node()' may return None. The check below
2171 # prevents a potential AttributeError.
2172 if default_cluster_node is not None:
2173 default_node = nodes.get(default_cluster_node.name)
2174 if default_node is not None:
2175 # This pipeline execution used the default node, check if we need
2176 # to replace it.
2177 # Note: when the error is raised we'll reset the default node in the
2178 # caller function.
2179 for cmd in default_node[1]:
2180 # Check if it has a command that failed with a relevant
2181 # exception
2182 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2183 client.replace_default_node()
2184 break
2185
2186 return [cmd.result for cmd in stack]
2187
2188 async def reset(self):
2189 """
2190 Reset back to empty pipeline.
2191 """
2192 self._command_queue = []
2193
2194 def multi(self):
2195 raise RedisClusterException(
2196 "method multi() is not supported outside of transactional context"
2197 )
2198
2199 async def watch(self, *names):
2200 raise RedisClusterException(
2201 "method watch() is not supported outside of transactional context"
2202 )
2203
2204 async def unwatch(self):
2205 raise RedisClusterException(
2206 "method unwatch() is not supported outside of transactional context"
2207 )
2208
2209 async def discard(self):
2210 raise RedisClusterException(
2211 "method discard() is not supported outside of transactional context"
2212 )
2213
2214 async def unlink(self, *names):
2215 if len(names) != 1:
2216 raise RedisClusterException(
2217 "unlinking multiple keys is not implemented in pipeline command"
2218 )
2219
2220 return self.execute_command("UNLINK", names[0])
2221
2222
2223class TransactionStrategy(AbstractStrategy):
2224 NO_SLOTS_COMMANDS = {"UNWATCH"}
2225 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2226 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2227 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2228 CONNECTION_ERRORS = (
2229 ConnectionError,
2230 OSError,
2231 ClusterDownError,
2232 SlotNotCoveredError,
2233 )
2234
2235 def __init__(self, pipe: ClusterPipeline) -> None:
2236 super().__init__(pipe)
2237 self._explicit_transaction = False
2238 self._watching = False
2239 self._pipeline_slots: Set[int] = set()
2240 self._transaction_node: Optional[ClusterNode] = None
2241 self._transaction_connection: Optional[Connection] = None
2242 self._executing = False
2243 self._retry = copy(self._pipe.cluster_client.retry)
2244 self._retry.update_supported_errors(
2245 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2246 )
2247
2248 def _get_client_and_connection_for_transaction(
2249 self,
2250 ) -> Tuple[ClusterNode, Connection]:
2251 """
2252 Find a connection for a pipeline transaction.
2253
2254 For running an atomic transaction, watch keys ensure that contents have not been
2255 altered as long as the watch commands for those keys were sent over the same
2256 connection. So once we start watching a key, we fetch a connection to the
2257 node that owns that slot and reuse it.
2258 """
2259 if not self._pipeline_slots:
2260 raise RedisClusterException(
2261 "At least a command with a key is needed to identify a node"
2262 )
2263
2264 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2265 list(self._pipeline_slots)[0], False
2266 )
2267 self._transaction_node = node
2268
2269 if not self._transaction_connection:
2270 connection: Connection = self._transaction_node.acquire_connection()
2271 self._transaction_connection = connection
2272
2273 return self._transaction_node, self._transaction_connection
2274
2275 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2276 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2277 response = None
2278 error = None
2279
2280 def runner():
2281 nonlocal response
2282 nonlocal error
2283 try:
2284 response = asyncio.run(self._execute_command(*args, **kwargs))
2285 except Exception as e:
2286 error = e
2287
2288 thread = threading.Thread(target=runner)
2289 thread.start()
2290 thread.join()
2291
2292 if error:
2293 raise error
2294
2295 return response
2296
2297 async def _execute_command(
2298 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2299 ) -> Any:
2300 if self._pipe.cluster_client._initialize:
2301 await self._pipe.cluster_client.initialize()
2302
2303 slot_number: Optional[int] = None
2304 if args[0] not in self.NO_SLOTS_COMMANDS:
2305 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2306
2307 if (
2308 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2309 ) and not self._explicit_transaction:
2310 if args[0] == "WATCH":
2311 self._validate_watch()
2312
2313 if slot_number is not None:
2314 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2315 raise CrossSlotTransactionError(
2316 "Cannot watch or send commands on different slots"
2317 )
2318
2319 self._pipeline_slots.add(slot_number)
2320 elif args[0] not in self.NO_SLOTS_COMMANDS:
2321 raise RedisClusterException(
2322 f"Cannot identify slot number for command: {args[0]},"
2323 "it cannot be triggered in a transaction"
2324 )
2325
2326 return self._immediate_execute_command(*args, **kwargs)
2327 else:
2328 if slot_number is not None:
2329 self._pipeline_slots.add(slot_number)
2330
2331 return super().execute_command(*args, **kwargs)
2332
2333 def _validate_watch(self):
2334 if self._explicit_transaction:
2335 raise RedisError("Cannot issue a WATCH after a MULTI")
2336
2337 self._watching = True
2338
2339 async def _immediate_execute_command(self, *args, **options):
2340 return await self._retry.call_with_retry(
2341 lambda: self._get_connection_and_send_command(*args, **options),
2342 self._reinitialize_on_error,
2343 )
2344
2345 async def _get_connection_and_send_command(self, *args, **options):
2346 redis_node, connection = self._get_client_and_connection_for_transaction()
2347 return await self._send_command_parse_response(
2348 connection, redis_node, args[0], *args, **options
2349 )
2350
2351 async def _send_command_parse_response(
2352 self,
2353 connection: Connection,
2354 redis_node: ClusterNode,
2355 command_name,
2356 *args,
2357 **options,
2358 ):
2359 """
2360 Send a command and parse the response
2361 """
2362
2363 await connection.send_command(*args)
2364 output = await redis_node.parse_response(connection, command_name, **options)
2365
2366 if command_name in self.UNWATCH_COMMANDS:
2367 self._watching = False
2368 return output
2369
2370 async def _reinitialize_on_error(self, error):
2371 if self._watching:
2372 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2373 raise WatchError("Slot rebalancing occurred while watching keys")
2374
2375 if (
2376 type(error) in self.SLOT_REDIRECT_ERRORS
2377 or type(error) in self.CONNECTION_ERRORS
2378 ):
2379 if self._transaction_connection:
2380 self._transaction_connection = None
2381
2382 self._pipe.cluster_client.reinitialize_counter += 1
2383 if (
2384 self._pipe.cluster_client.reinitialize_steps
2385 and self._pipe.cluster_client.reinitialize_counter
2386 % self._pipe.cluster_client.reinitialize_steps
2387 == 0
2388 ):
2389 await self._pipe.cluster_client.nodes_manager.initialize()
2390 self.reinitialize_counter = 0
2391 else:
2392 if isinstance(error, AskError):
2393 self._pipe.cluster_client.nodes_manager.update_moved_exception(
2394 error
2395 )
2396
2397 self._executing = False
2398
2399 def _raise_first_error(self, responses, stack):
2400 """
2401 Raise the first exception on the stack
2402 """
2403 for r, cmd in zip(responses, stack):
2404 if isinstance(r, Exception):
2405 self._annotate_exception(r, cmd.position + 1, cmd.args)
2406 raise r
2407
2408 def mset_nonatomic(
2409 self, mapping: Mapping[AnyKeyT, EncodableT]
2410 ) -> "ClusterPipeline":
2411 raise NotImplementedError("Method is not supported in transactional context.")
2412
2413 async def execute(
2414 self, raise_on_error: bool = True, allow_redirections: bool = True
2415 ) -> List[Any]:
2416 stack = self._command_queue
2417 if not stack and (not self._watching or not self._pipeline_slots):
2418 return []
2419
2420 return await self._execute_transaction_with_retries(stack, raise_on_error)
2421
2422 async def _execute_transaction_with_retries(
2423 self, stack: List["PipelineCommand"], raise_on_error: bool
2424 ):
2425 return await self._retry.call_with_retry(
2426 lambda: self._execute_transaction(stack, raise_on_error),
2427 self._reinitialize_on_error,
2428 )
2429
2430 async def _execute_transaction(
2431 self, stack: List["PipelineCommand"], raise_on_error: bool
2432 ):
2433 if len(self._pipeline_slots) > 1:
2434 raise CrossSlotTransactionError(
2435 "All keys involved in a cluster transaction must map to the same slot"
2436 )
2437
2438 self._executing = True
2439
2440 redis_node, connection = self._get_client_and_connection_for_transaction()
2441
2442 stack = chain(
2443 [PipelineCommand(0, "MULTI")],
2444 stack,
2445 [PipelineCommand(0, "EXEC")],
2446 )
2447 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2448 packed_commands = connection.pack_commands(commands)
2449 await connection.send_packed_command(packed_commands)
2450 errors = []
2451
2452 # parse off the response for MULTI
2453 # NOTE: we need to handle ResponseErrors here and continue
2454 # so that we read all the additional command messages from
2455 # the socket
2456 try:
2457 await redis_node.parse_response(connection, "MULTI")
2458 except ResponseError as e:
2459 self._annotate_exception(e, 0, "MULTI")
2460 errors.append(e)
2461 except self.CONNECTION_ERRORS as cluster_error:
2462 self._annotate_exception(cluster_error, 0, "MULTI")
2463 raise
2464
2465 # and all the other commands
2466 for i, command in enumerate(self._command_queue):
2467 if EMPTY_RESPONSE in command.kwargs:
2468 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2469 else:
2470 try:
2471 _ = await redis_node.parse_response(connection, "_")
2472 except self.SLOT_REDIRECT_ERRORS as slot_error:
2473 self._annotate_exception(slot_error, i + 1, command.args)
2474 errors.append(slot_error)
2475 except self.CONNECTION_ERRORS as cluster_error:
2476 self._annotate_exception(cluster_error, i + 1, command.args)
2477 raise
2478 except ResponseError as e:
2479 self._annotate_exception(e, i + 1, command.args)
2480 errors.append(e)
2481
2482 response = None
2483 # parse the EXEC.
2484 try:
2485 response = await redis_node.parse_response(connection, "EXEC")
2486 except ExecAbortError:
2487 if errors:
2488 raise errors[0]
2489 raise
2490
2491 self._executing = False
2492
2493 # EXEC clears any watched keys
2494 self._watching = False
2495
2496 if response is None:
2497 raise WatchError("Watched variable changed.")
2498
2499 # put any parse errors into the response
2500 for i, e in errors:
2501 response.insert(i, e)
2502
2503 if len(response) != len(self._command_queue):
2504 raise InvalidPipelineStack(
2505 "Unexpected response length for cluster pipeline EXEC."
2506 " Command stack was {} but response had length {}".format(
2507 [c.args[0] for c in self._command_queue], len(response)
2508 )
2509 )
2510
2511 # find any errors in the response and raise if necessary
2512 if raise_on_error or len(errors) > 0:
2513 self._raise_first_error(
2514 response,
2515 self._command_queue,
2516 )
2517
2518 # We have to run response callbacks manually
2519 data = []
2520 for r, cmd in zip(response, self._command_queue):
2521 if not isinstance(r, Exception):
2522 command_name = cmd.args[0]
2523 if command_name in self._pipe.cluster_client.response_callbacks:
2524 r = self._pipe.cluster_client.response_callbacks[command_name](
2525 r, **cmd.kwargs
2526 )
2527 data.append(r)
2528 return data
2529
2530 async def reset(self):
2531 self._command_queue = []
2532
2533 # make sure to reset the connection state in the event that we were
2534 # watching something
2535 if self._transaction_connection:
2536 try:
2537 if self._watching:
2538 # call this manually since our unwatch or
2539 # immediate_execute_command methods can call reset()
2540 await self._transaction_connection.send_command("UNWATCH")
2541 await self._transaction_connection.read_response()
2542 # we can safely return the connection to the pool here since we're
2543 # sure we're no longer WATCHing anything
2544 self._transaction_node.release(self._transaction_connection)
2545 self._transaction_connection = None
2546 except self.CONNECTION_ERRORS:
2547 # disconnect will also remove any previous WATCHes
2548 if self._transaction_connection:
2549 await self._transaction_connection.disconnect()
2550
2551 # clean up the other instance attributes
2552 self._transaction_node = None
2553 self._watching = False
2554 self._explicit_transaction = False
2555 self._pipeline_slots = set()
2556 self._executing = False
2557
2558 def multi(self):
2559 if self._explicit_transaction:
2560 raise RedisError("Cannot issue nested calls to MULTI")
2561 if self._command_queue:
2562 raise RedisError(
2563 "Commands without an initial WATCH have already been issued"
2564 )
2565 self._explicit_transaction = True
2566
2567 async def watch(self, *names):
2568 if self._explicit_transaction:
2569 raise RedisError("Cannot issue a WATCH after a MULTI")
2570
2571 return await self.execute_command("WATCH", *names)
2572
2573 async def unwatch(self):
2574 if self._watching:
2575 return await self.execute_command("UNWATCH")
2576
2577 return True
2578
2579 async def discard(self):
2580 await self.reset()
2581
2582 async def unlink(self, *names):
2583 return self.execute_command("UNLINK", *names)