1import asyncio
2import collections
3import random
4import socket
5import threading
6import time
7import warnings
8from abc import ABC, abstractmethod
9from copy import copy
10from itertools import chain
11from typing import (
12 TYPE_CHECKING,
13 Any,
14 Callable,
15 Coroutine,
16 Deque,
17 Dict,
18 Generator,
19 List,
20 Literal,
21 Mapping,
22 Optional,
23 Set,
24 Tuple,
25 Type,
26 TypeVar,
27 Union,
28)
29
30if TYPE_CHECKING:
31 from redis.asyncio.keyspace_notifications import (
32 AsyncClusterKeyspaceNotifications,
33 )
34
35from redis._parsers import AsyncCommandsParser, Encoder
36from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
37from redis._parsers.helpers import (
38 _RedisCallbacks,
39 _RedisCallbacksRESP2,
40 _RedisCallbacksRESP3,
41)
42from redis.asyncio.client import PubSub, ResponseCallbackT
43from redis.asyncio.connection import (
44 AbstractConnection,
45 Connection,
46 ConnectionPool,
47 SSLConnection,
48 parse_url,
49)
50from redis.asyncio.lock import Lock
51from redis.asyncio.observability.recorder import (
52 record_error_count,
53 record_operation_duration,
54)
55from redis.asyncio.retry import Retry
56from redis.auth.token import TokenInterface
57from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
58from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
59from redis.cluster import (
60 PIPELINE_BLOCKED_COMMANDS,
61 PRIMARY,
62 REPLICA,
63 SLOT_ID,
64 AbstractRedisCluster,
65 LoadBalancer,
66 LoadBalancingStrategy,
67 block_pipeline_command,
68 get_node_name,
69 parse_cluster_slots,
70)
71from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
72from redis.commands.helpers import list_or_args
73from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
74from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
75from redis.credentials import CredentialProvider
76from redis.driver_info import DriverInfo, resolve_driver_info
77from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
78from redis.exceptions import (
79 AskError,
80 BusyLoadingError,
81 ClusterDownError,
82 ClusterError,
83 ConnectionError,
84 CrossSlotTransactionError,
85 DataError,
86 ExecAbortError,
87 InvalidPipelineStack,
88 MaxConnectionsError,
89 MovedError,
90 RedisClusterException,
91 RedisError,
92 ResponseError,
93 SlotNotCoveredError,
94 TimeoutError,
95 TryAgainError,
96 WatchError,
97)
98from redis.typing import AnyKeyT, EncodableT, KeyT
99from redis.utils import (
100 SSL_AVAILABLE,
101 deprecated_args,
102 deprecated_function,
103 safe_str,
104 str_if_bytes,
105 truncate_text,
106)
107
108if SSL_AVAILABLE:
109 from ssl import TLSVersion, VerifyFlags, VerifyMode
110else:
111 TLSVersion = None
112 VerifyMode = None
113 VerifyFlags = None
114
115TargetNodesT = TypeVar(
116 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
117)
118
119
120class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
121 """
122 Create a new RedisCluster client.
123
124 Pass one of parameters:
125
126 - `host` & `port`
127 - `startup_nodes`
128
129 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
130 | Use ``await`` :meth:`close` to disconnect connections & close client.
131
132 Many commands support the target_nodes kwarg. It can be one of the
133 :attr:`NODE_FLAGS`:
134
135 - :attr:`PRIMARIES`
136 - :attr:`REPLICAS`
137 - :attr:`ALL_NODES`
138 - :attr:`RANDOM`
139 - :attr:`DEFAULT_NODE`
140
141 Note: This client is not thread/process/fork safe.
142
143 :param host:
144 | Can be used to point to a startup node
145 :param port:
146 | Port used if **host** is provided
147 :param startup_nodes:
148 | :class:`~.ClusterNode` to used as a startup node
149 :param require_full_coverage:
150 | When set to ``False``: the client will not require a full coverage of
151 the slots. However, if not all slots are covered, and at least one node
152 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
153 a :class:`~.ClusterDownError` for some key-based commands.
154 | When set to ``True``: all slots must be covered to construct the cluster
155 client. If not all slots are covered, :class:`~.RedisClusterException` will be
156 thrown.
157 | See:
158 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
159 :param read_from_replicas:
160 | @deprecated - please use load_balancing_strategy instead
161 | Enable read from replicas in READONLY mode.
162 When set to true, read commands will be assigned between the primary and
163 its replications in a Round-Robin manner.
164 The data read from replicas is eventually consistent with the data in primary nodes.
165 :param load_balancing_strategy:
166 | Enable read from replicas in READONLY mode and defines the load balancing
167 strategy that will be used for cluster node selection.
168 The data read from replicas is eventually consistent with the data in primary nodes.
169 :param dynamic_startup_nodes:
170 | Set the RedisCluster's startup nodes to all the discovered nodes.
171 If true (default value), the cluster's discovered nodes will be used to
172 determine the cluster nodes-slots mapping in the next topology refresh.
173 It will remove the initial passed startup nodes if their endpoints aren't
174 listed in the CLUSTER SLOTS output.
175 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
176 specific IP addresses, it is best to set it to false.
177 :param reinitialize_steps:
178 | Specifies the number of MOVED errors that need to occur before reinitializing
179 the whole cluster topology. If a MOVED error occurs and the cluster does not
180 need to be reinitialized on this current error handling, only the MOVED slot
181 will be patched with the redirected node.
182 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
183 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
184 0.
185 :param cluster_error_retry_attempts:
186 | @deprecated - Please configure the 'retry' object instead
187 In case 'retry' object is set - this argument is ignored!
188
189 Number of times to retry before raising an error when :class:`~.TimeoutError`,
190 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
191 or :class:`~.ClusterDownError` are encountered
192 :param retry:
193 | A retry object that defines the retry strategy and the number of
194 retries for the cluster client.
195 In current implementation for the cluster client (starting form redis-py version 6.0.0)
196 the retry object is not yet fully utilized, instead it is used just to determine
197 the number of retries for the cluster client.
198 In the future releases the retry object will be used to handle the cluster client retries!
199 :param max_connections:
200 | Maximum number of connections per node. If there are no free connections & the
201 maximum number of connections are already created, a
202 :class:`~.MaxConnectionsError` is raised.
203 :param address_remap:
204 | An optional callable which, when provided with an internal network
205 address of a node, e.g. a `(host, port)` tuple, will return the address
206 where the node is reachable. This can be used to map the addresses at
207 which the nodes _think_ they are, to addresses at which a client may
208 reach them, such as when they sit behind a proxy.
209
210 | Rest of the arguments will be passed to the
211 :class:`~redis.asyncio.connection.Connection` instances when created
212
213 :raises RedisClusterException:
214 if any arguments are invalid or unknown. Eg:
215
216 - `db` != 0 or None
217 - `path` argument for unix socket connection
218 - none of the `host`/`port` & `startup_nodes` were provided
219
220 """
221
222 @classmethod
223 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
224 """
225 Return a Redis client object configured from the given URL.
226
227 For example::
228
229 redis://[[username]:[password]]@localhost:6379/0
230 rediss://[[username]:[password]]@localhost:6379/0
231
232 Three URL schemes are supported:
233
234 - `redis://` creates a TCP socket connection. See more at:
235 <https://www.iana.org/assignments/uri-schemes/prov/redis>
236 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
237 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
238
239 The username, password, hostname, path and all querystring values are passed
240 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
241 with their corresponding characters.
242
243 All querystring options are cast to their appropriate Python types. Boolean
244 arguments can be specified with string values "True"/"False" or "Yes"/"No".
245 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
246 parsed, the querystring arguments and keyword arguments are passed to
247 :class:`~redis.asyncio.connection.Connection` when created.
248 In the case of conflicting arguments, querystring arguments are used.
249 """
250 kwargs.update(parse_url(url))
251 if kwargs.pop("connection_class", None) is SSLConnection:
252 kwargs["ssl"] = True
253 return cls(**kwargs)
254
255 # Type discrimination marker for @overload self-type pattern
256 _is_async_client: Literal[True] = True
257
258 __slots__ = (
259 "_initialize",
260 "_lock",
261 "retry",
262 "command_flags",
263 "commands_parser",
264 "connection_kwargs",
265 "encoder",
266 "node_flags",
267 "nodes_manager",
268 "read_from_replicas",
269 "reinitialize_counter",
270 "reinitialize_steps",
271 "response_callbacks",
272 "result_callbacks",
273 )
274
275 @deprecated_args(
276 args_to_warn=["read_from_replicas"],
277 reason="Please configure the 'load_balancing_strategy' instead",
278 version="5.3.0",
279 )
280 @deprecated_args(
281 args_to_warn=[
282 "cluster_error_retry_attempts",
283 ],
284 reason="Please configure the 'retry' object instead",
285 version="6.0.0",
286 )
287 @deprecated_args(
288 args_to_warn=["lib_name", "lib_version"],
289 reason="Use 'driver_info' parameter instead. "
290 "lib_name and lib_version will be removed in a future version.",
291 )
292 def __init__(
293 self,
294 host: Optional[str] = None,
295 port: Union[str, int] = 6379,
296 # Cluster related kwargs
297 startup_nodes: Optional[List["ClusterNode"]] = None,
298 require_full_coverage: bool = True,
299 read_from_replicas: bool = False,
300 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
301 dynamic_startup_nodes: bool = True,
302 reinitialize_steps: int = 5,
303 cluster_error_retry_attempts: int = 3,
304 max_connections: int = 2**31,
305 retry: Optional["Retry"] = None,
306 retry_on_error: Optional[List[Type[Exception]]] = None,
307 # Client related kwargs
308 db: Union[str, int] = 0,
309 path: Optional[str] = None,
310 credential_provider: Optional[CredentialProvider] = None,
311 username: Optional[str] = None,
312 password: Optional[str] = None,
313 client_name: Optional[str] = None,
314 lib_name: Optional[str] = None,
315 lib_version: Optional[str] = None,
316 driver_info: Optional["DriverInfo"] = None,
317 # Encoding related kwargs
318 encoding: str = "utf-8",
319 encoding_errors: str = "strict",
320 decode_responses: bool = False,
321 # Connection related kwargs
322 health_check_interval: float = 0,
323 socket_connect_timeout: Optional[float] = None,
324 socket_keepalive: bool = False,
325 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
326 socket_timeout: Optional[float] = None,
327 # SSL related kwargs
328 ssl: bool = False,
329 ssl_ca_certs: Optional[str] = None,
330 ssl_ca_data: Optional[str] = None,
331 ssl_cert_reqs: Union[str, VerifyMode] = "required",
332 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
333 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
334 ssl_certfile: Optional[str] = None,
335 ssl_check_hostname: bool = True,
336 ssl_keyfile: Optional[str] = None,
337 ssl_min_version: Optional[TLSVersion] = None,
338 ssl_ciphers: Optional[str] = None,
339 protocol: Optional[int] = 2,
340 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
341 event_dispatcher: Optional[EventDispatcher] = None,
342 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
343 ) -> None:
344 if db:
345 raise RedisClusterException(
346 "Argument 'db' must be 0 or None in cluster mode"
347 )
348
349 if path:
350 raise RedisClusterException(
351 "Unix domain socket is not supported in cluster mode"
352 )
353
354 if (not host or not port) and not startup_nodes:
355 raise RedisClusterException(
356 "RedisCluster requires at least one node to discover the cluster.\n"
357 "Please provide one of the following or use RedisCluster.from_url:\n"
358 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
359 " - startup_nodes: RedisCluster(startup_nodes=["
360 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
361 )
362
363 computed_driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
364
365 kwargs: Dict[str, Any] = {
366 "max_connections": max_connections,
367 "connection_class": Connection,
368 # Client related kwargs
369 "credential_provider": credential_provider,
370 "username": username,
371 "password": password,
372 "client_name": client_name,
373 "driver_info": computed_driver_info,
374 # Encoding related kwargs
375 "encoding": encoding,
376 "encoding_errors": encoding_errors,
377 "decode_responses": decode_responses,
378 # Connection related kwargs
379 "health_check_interval": health_check_interval,
380 "socket_connect_timeout": socket_connect_timeout,
381 "socket_keepalive": socket_keepalive,
382 "socket_keepalive_options": socket_keepalive_options,
383 "socket_timeout": socket_timeout,
384 "protocol": protocol,
385 }
386
387 if ssl:
388 # SSL related kwargs
389 kwargs.update(
390 {
391 "connection_class": SSLConnection,
392 "ssl_ca_certs": ssl_ca_certs,
393 "ssl_ca_data": ssl_ca_data,
394 "ssl_cert_reqs": ssl_cert_reqs,
395 "ssl_include_verify_flags": ssl_include_verify_flags,
396 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
397 "ssl_certfile": ssl_certfile,
398 "ssl_check_hostname": ssl_check_hostname,
399 "ssl_keyfile": ssl_keyfile,
400 "ssl_min_version": ssl_min_version,
401 "ssl_ciphers": ssl_ciphers,
402 }
403 )
404
405 if read_from_replicas or load_balancing_strategy:
406 # Call our on_connect function to configure READONLY mode
407 kwargs["redis_connect_func"] = self.on_connect
408
409 if retry:
410 self.retry = retry
411 else:
412 self.retry = Retry(
413 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
414 retries=cluster_error_retry_attempts,
415 )
416 if retry_on_error:
417 self.retry.update_supported_errors(retry_on_error)
418
419 kwargs["response_callbacks"] = _RedisCallbacks.copy()
420 if kwargs.get("protocol") in ["3", 3]:
421 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
422 else:
423 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
424 self.connection_kwargs = kwargs
425
426 if startup_nodes:
427 passed_nodes = []
428 for node in startup_nodes:
429 passed_nodes.append(
430 ClusterNode(node.host, node.port, **self.connection_kwargs)
431 )
432 startup_nodes = passed_nodes
433 else:
434 startup_nodes = []
435 if host and port:
436 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
437
438 if event_dispatcher is None:
439 self._event_dispatcher = EventDispatcher()
440 else:
441 self._event_dispatcher = event_dispatcher
442
443 self.startup_nodes = startup_nodes
444 self.nodes_manager = NodesManager(
445 startup_nodes,
446 require_full_coverage,
447 kwargs,
448 dynamic_startup_nodes=dynamic_startup_nodes,
449 address_remap=address_remap,
450 event_dispatcher=self._event_dispatcher,
451 )
452 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
453 self.read_from_replicas = read_from_replicas
454 self.load_balancing_strategy = load_balancing_strategy
455 self.reinitialize_steps = reinitialize_steps
456 self.reinitialize_counter = 0
457
458 # For backward compatibility, mapping from existing policies to new one
459 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
460 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
461 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
462 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
463 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
464 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
465 SLOT_ID: RequestPolicy.DEFAULT_KEYED,
466 }
467
468 self._policies_callback_mapping: dict[
469 Union[RequestPolicy, ResponsePolicy], Callable
470 ] = {
471 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
472 self.get_random_primary_or_all_nodes(command_name)
473 ],
474 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
475 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
476 RequestPolicy.ALL_SHARDS: self.get_primaries,
477 RequestPolicy.ALL_NODES: self.get_nodes,
478 RequestPolicy.ALL_REPLICAS: self.get_replicas,
479 RequestPolicy.SPECIAL: self.get_special_nodes,
480 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
481 ResponsePolicy.DEFAULT_KEYED: lambda res: res,
482 }
483
484 self._policy_resolver = policy_resolver
485 self.commands_parser = AsyncCommandsParser()
486 self._aggregate_nodes = None
487 self.node_flags = self.__class__.NODE_FLAGS.copy()
488 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
489 self.response_callbacks = kwargs["response_callbacks"]
490 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
491 self.result_callbacks["CLUSTER SLOTS"] = (
492 lambda cmd, res, **kwargs: parse_cluster_slots(
493 list(res.values())[0], **kwargs
494 )
495 )
496
497 self._initialize = True
498 self._lock: Optional[asyncio.Lock] = None
499
500 # When used as an async context manager, we need to increment and decrement
501 # a usage counter so that we can close the connection pool when no one is
502 # using the client.
503 self._usage_counter = 0
504 self._usage_lock = asyncio.Lock()
505
506 async def initialize(self) -> "RedisCluster":
507 """Get all nodes from startup nodes & creates connections if not initialized."""
508 if self._initialize:
509 if not self._lock:
510 self._lock = asyncio.Lock()
511 async with self._lock:
512 if self._initialize:
513 try:
514 await self.nodes_manager.initialize()
515 await self.commands_parser.initialize(
516 self.nodes_manager.default_node
517 )
518 self._initialize = False
519 except BaseException:
520 await self.nodes_manager.aclose()
521 await self.nodes_manager.aclose("startup_nodes")
522 raise
523 return self
524
525 async def aclose(self) -> None:
526 """Close all connections & client if initialized."""
527 if not self._initialize:
528 if not self._lock:
529 self._lock = asyncio.Lock()
530 async with self._lock:
531 if not self._initialize:
532 self._initialize = True
533 await self.nodes_manager.aclose()
534 await self.nodes_manager.aclose("startup_nodes")
535
536 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
537 async def close(self) -> None:
538 """alias for aclose() for backwards compatibility"""
539 await self.aclose()
540
541 async def __aenter__(self) -> "RedisCluster":
542 """
543 Async context manager entry. Increments a usage counter so that the
544 connection pool is only closed (via aclose()) when no context is using
545 the client.
546 """
547 await self._increment_usage()
548 try:
549 # Initialize the client (i.e. establish connection, etc.)
550 return await self.initialize()
551 except Exception:
552 # If initialization fails, decrement the counter to keep it in sync
553 await self._decrement_usage()
554 raise
555
556 async def _increment_usage(self) -> int:
557 """
558 Helper coroutine to increment the usage counter while holding the lock.
559 Returns the new value of the usage counter.
560 """
561 async with self._usage_lock:
562 self._usage_counter += 1
563 return self._usage_counter
564
565 async def _decrement_usage(self) -> int:
566 """
567 Helper coroutine to decrement the usage counter while holding the lock.
568 Returns the new value of the usage counter.
569 """
570 async with self._usage_lock:
571 self._usage_counter -= 1
572 return self._usage_counter
573
574 async def __aexit__(self, exc_type, exc_value, traceback):
575 """
576 Async context manager exit. Decrements a usage counter. If this is the
577 last exit (counter becomes zero), the client closes its connection pool.
578 """
579 current_usage = await asyncio.shield(self._decrement_usage())
580 if current_usage == 0:
581 # This was the last active context, so disconnect the pool.
582 await asyncio.shield(self.aclose())
583
584 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
585 return self.initialize().__await__()
586
587 _DEL_MESSAGE = "Unclosed RedisCluster client"
588
589 def __del__(
590 self,
591 _warn: Any = warnings.warn,
592 _grl: Any = asyncio.get_running_loop,
593 ) -> None:
594 if hasattr(self, "_initialize") and not self._initialize:
595 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
596 try:
597 context = {"client": self, "message": self._DEL_MESSAGE}
598 _grl().call_exception_handler(context)
599 except RuntimeError:
600 pass
601
602 async def on_connect(self, connection: Connection) -> None:
603 await connection.on_connect()
604
605 # Sending READONLY command to server to configure connection as
606 # readonly. Since each cluster node may change its server type due
607 # to a failover, we should establish a READONLY connection
608 # regardless of the server type. If this is a primary connection,
609 # READONLY would not affect executing write commands.
610 await connection.send_command("READONLY")
611 if str_if_bytes(await connection.read_response()) != "OK":
612 raise ConnectionError("READONLY command failed")
613
614 def get_nodes(self) -> List["ClusterNode"]:
615 """Get all nodes of the cluster."""
616 return list(self.nodes_manager.nodes_cache.values())
617
618 def get_primaries(self) -> List["ClusterNode"]:
619 """Get the primary nodes of the cluster."""
620 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
621
622 def get_replicas(self) -> List["ClusterNode"]:
623 """Get the replica nodes of the cluster."""
624 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
625
626 def get_random_node(self) -> "ClusterNode":
627 """Get a random node of the cluster."""
628 return random.choice(list(self.nodes_manager.nodes_cache.values()))
629
630 def get_default_node(self) -> "ClusterNode":
631 """Get the default node of the client."""
632 return self.nodes_manager.default_node
633
634 def set_default_node(self, node: "ClusterNode") -> None:
635 """
636 Set the default node of the client.
637
638 :raises DataError: if None is passed or node does not exist in cluster.
639 """
640 if not node or not self.get_node(node_name=node.name):
641 raise DataError("The requested node does not exist in the cluster.")
642
643 self.nodes_manager.default_node = node
644
645 def get_node(
646 self,
647 host: Optional[str] = None,
648 port: Optional[int] = None,
649 node_name: Optional[str] = None,
650 ) -> Optional["ClusterNode"]:
651 """Get node by (host, port) or node_name."""
652 return self.nodes_manager.get_node(host, port, node_name)
653
654 def get_node_from_key(
655 self, key: str, replica: bool = False
656 ) -> Optional["ClusterNode"]:
657 """
658 Get the cluster node corresponding to the provided key.
659
660 :param key:
661 :param replica:
662 | Indicates if a replica should be returned
663 |
664 None will returned if no replica holds this key
665
666 :raises SlotNotCoveredError: if the key is not covered by any slot.
667 """
668 slot = self.keyslot(key)
669 slot_cache = self.nodes_manager.slots_cache.get(slot)
670 if not slot_cache:
671 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
672
673 if replica:
674 if len(self.nodes_manager.slots_cache[slot]) < 2:
675 return None
676 node_idx = 1
677 else:
678 node_idx = 0
679
680 return slot_cache[node_idx]
681
682 def get_random_primary_or_all_nodes(self, command_name):
683 """
684 Returns random primary or all nodes depends on READONLY mode.
685 """
686 if self.read_from_replicas and command_name in READ_COMMANDS:
687 return self.get_random_node()
688
689 return self.get_random_primary_node()
690
691 def get_random_primary_node(self) -> "ClusterNode":
692 """
693 Returns a random primary node
694 """
695 return random.choice(self.get_primaries())
696
697 async def get_nodes_from_slot(self, command: str, *args):
698 """
699 Returns a list of nodes that hold the specified keys' slots.
700 """
701 # get the node that holds the key's slot
702 return [
703 self.nodes_manager.get_node_from_slot(
704 await self._determine_slot(command, *args),
705 self.read_from_replicas and command in READ_COMMANDS,
706 self.load_balancing_strategy if command in READ_COMMANDS else None,
707 )
708 ]
709
710 def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
711 """
712 Returns a list of nodes for commands with a special policy.
713 """
714 if not self._aggregate_nodes:
715 raise RedisClusterException(
716 "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
717 )
718
719 return self._aggregate_nodes
720
721 def keyslot(self, key: EncodableT) -> int:
722 """
723 Find the keyslot for a given key.
724
725 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
726 """
727 return key_slot(self.encoder.encode(key))
728
729 def get_encoder(self) -> Encoder:
730 """Get the encoder object of the client."""
731 return self.encoder
732
733 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
734 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
735 return self.connection_kwargs
736
737 def set_retry(self, retry: Retry) -> None:
738 self.retry = retry
739
740 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
741 """Set a custom response callback."""
742 self.response_callbacks[command] = callback
743
744 async def _determine_nodes(
745 self,
746 command: str,
747 *args: Any,
748 request_policy: RequestPolicy,
749 node_flag: Optional[str] = None,
750 ) -> List["ClusterNode"]:
751 # Determine which nodes should be executed the command on.
752 # Returns a list of target nodes.
753 if not node_flag:
754 # get the nodes group for this command if it was predefined
755 node_flag = self.command_flags.get(command)
756
757 if node_flag in self._command_flags_mapping:
758 request_policy = self._command_flags_mapping[node_flag]
759
760 policy_callback = self._policies_callback_mapping[request_policy]
761
762 if request_policy == RequestPolicy.DEFAULT_KEYED:
763 nodes = await policy_callback(command, *args)
764 elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
765 nodes = policy_callback(command)
766 else:
767 nodes = policy_callback()
768
769 if command.lower() == "ft.aggregate":
770 self._aggregate_nodes = nodes
771
772 return nodes
773
774 async def _determine_slot(self, command: str, *args: Any) -> int:
775 if self.command_flags.get(command) == SLOT_ID:
776 # The command contains the slot ID
777 return int(args[0])
778
779 # Get the keys in the command
780
781 # EVAL and EVALSHA are common enough that it's wasteful to go to the
782 # redis server to parse the keys. Besides, there is a bug in redis<7.0
783 # where `self._get_command_keys()` fails anyway. So, we special case
784 # EVAL/EVALSHA.
785 # - issue: https://github.com/redis/redis/issues/9493
786 # - fix: https://github.com/redis/redis/pull/9733
787 if command.upper() in ("EVAL", "EVALSHA"):
788 # command syntax: EVAL "script body" num_keys ...
789 if len(args) < 2:
790 raise RedisClusterException(
791 f"Invalid args in command: {command, *args}"
792 )
793 keys = args[2 : 2 + int(args[1])]
794 # if there are 0 keys, that means the script can be run on any node
795 # so we can just return a random slot
796 if not keys:
797 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
798 else:
799 keys = await self.commands_parser.get_keys(command, *args)
800 if not keys:
801 # FCALL can call a function with 0 keys, that means the function
802 # can be run on any node so we can just return a random slot
803 if command.upper() in ("FCALL", "FCALL_RO"):
804 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
805 raise RedisClusterException(
806 "No way to dispatch this command to Redis Cluster. "
807 "Missing key.\nYou can execute the command by specifying "
808 f"target nodes.\nCommand: {args}"
809 )
810
811 # single key command
812 if len(keys) == 1:
813 return self.keyslot(keys[0])
814
815 # multi-key command; we need to make sure all keys are mapped to
816 # the same slot
817 slots = {self.keyslot(key) for key in keys}
818 if len(slots) != 1:
819 raise RedisClusterException(
820 f"{command} - all keys must map to the same key slot"
821 )
822
823 return slots.pop()
824
825 def _is_node_flag(self, target_nodes: Any) -> bool:
826 return isinstance(target_nodes, str) and target_nodes in self.node_flags
827
828 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
829 if isinstance(target_nodes, list):
830 nodes = target_nodes
831 elif isinstance(target_nodes, ClusterNode):
832 # Supports passing a single ClusterNode as a variable
833 nodes = [target_nodes]
834 elif isinstance(target_nodes, dict):
835 # Supports dictionaries of the format {node_name: node}.
836 # It enables to execute commands with multi nodes as follows:
837 # rc.cluster_save_config(rc.get_primaries())
838 nodes = list(target_nodes.values())
839 else:
840 raise TypeError(
841 "target_nodes type can be one of the following: "
842 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
843 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
844 f"The passed type is {type(target_nodes)}"
845 )
846 return nodes
847
848 async def _record_error_metric(
849 self,
850 error: Exception,
851 connection: Union[Connection, "ClusterNode"],
852 is_internal: bool = True,
853 retry_attempts: Optional[int] = None,
854 ):
855 """
856 Records error count metric directly.
857 Accepts either a Connection or ClusterNode object.
858 """
859 await record_error_count(
860 server_address=connection.host,
861 server_port=connection.port,
862 network_peer_address=connection.host,
863 network_peer_port=connection.port,
864 error_type=error,
865 retry_attempts=retry_attempts if retry_attempts is not None else 0,
866 is_internal=is_internal,
867 )
868
869 async def _record_command_metric(
870 self,
871 command_name: str,
872 duration_seconds: float,
873 connection: Union[Connection, "ClusterNode"],
874 error: Optional[Exception] = None,
875 ):
876 """
877 Records operation duration metric directly.
878 Accepts either a Connection or ClusterNode object.
879 """
880 # Connection has db attribute, ClusterNode has connection_kwargs
881 if hasattr(connection, "db"):
882 db = connection.db
883 else:
884 db = connection.connection_kwargs.get("db", 0)
885 await record_operation_duration(
886 command_name=command_name,
887 duration_seconds=duration_seconds,
888 server_address=connection.host,
889 server_port=connection.port,
890 db_namespace=str(db) if db is not None else None,
891 error=error,
892 )
893
894 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
895 """
896 Execute a raw command on the appropriate cluster node or target_nodes.
897
898 It will retry the command as specified by the retries property of
899 the :attr:`retry` & then raise an exception.
900
901 :param args:
902 | Raw command args
903 :param kwargs:
904
905 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
906 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
907 - Rest of the kwargs are passed to the Redis connection
908
909 :raises RedisClusterException: if target_nodes is not provided & the command
910 can't be mapped to a slot
911 """
912 command = args[0]
913 target_nodes = []
914 target_nodes_specified = False
915 retry_attempts = self.retry.get_retries()
916
917 passed_targets = kwargs.pop("target_nodes", None)
918 if passed_targets and not self._is_node_flag(passed_targets):
919 target_nodes = self._parse_target_nodes(passed_targets)
920 target_nodes_specified = True
921 retry_attempts = 0
922
923 command_policies = await self._policy_resolver.resolve(args[0].lower())
924
925 if not command_policies and not target_nodes_specified:
926 command_flag = self.command_flags.get(command)
927 if not command_flag:
928 # Fallback to default policy
929 if not self.get_default_node():
930 slot = None
931 else:
932 slot = await self._determine_slot(*args)
933 if slot is None:
934 command_policies = CommandPolicies()
935 else:
936 command_policies = CommandPolicies(
937 request_policy=RequestPolicy.DEFAULT_KEYED,
938 response_policy=ResponsePolicy.DEFAULT_KEYED,
939 )
940 else:
941 if command_flag in self._command_flags_mapping:
942 command_policies = CommandPolicies(
943 request_policy=self._command_flags_mapping[command_flag]
944 )
945 else:
946 command_policies = CommandPolicies()
947 elif not command_policies and target_nodes_specified:
948 command_policies = CommandPolicies()
949
950 # Add one for the first execution
951 execute_attempts = 1 + retry_attempts
952 failure_count = 0
953
954 # Start timing for observability
955 start_time = time.monotonic()
956
957 for _ in range(execute_attempts):
958 if self._initialize:
959 await self.initialize()
960 if (
961 len(target_nodes) == 1
962 and target_nodes[0] == self.get_default_node()
963 ):
964 # Replace the default cluster node
965 self.replace_default_node()
966 try:
967 if not target_nodes_specified:
968 # Determine the nodes to execute the command on
969 target_nodes = await self._determine_nodes(
970 *args,
971 request_policy=command_policies.request_policy,
972 node_flag=passed_targets,
973 )
974 if not target_nodes:
975 raise RedisClusterException(
976 f"No targets were found to execute {args} command on"
977 )
978
979 if len(target_nodes) == 1:
980 # Return the processed result
981 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
982 if command in self.result_callbacks:
983 ret = self.result_callbacks[command](
984 command, {target_nodes[0].name: ret}, **kwargs
985 )
986 return self._policies_callback_mapping[
987 command_policies.response_policy
988 ](ret)
989 else:
990 keys = [node.name for node in target_nodes]
991 values = await asyncio.gather(
992 *(
993 asyncio.create_task(
994 self._execute_command(node, *args, **kwargs)
995 )
996 for node in target_nodes
997 )
998 )
999 if command in self.result_callbacks:
1000 return self.result_callbacks[command](
1001 command, dict(zip(keys, values)), **kwargs
1002 )
1003 return self._policies_callback_mapping[
1004 command_policies.response_policy
1005 ](dict(zip(keys, values)))
1006 except Exception as e:
1007 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
1008 # The nodes and slots cache were should be reinitialized.
1009 # Try again with the new cluster setup.
1010 retry_attempts -= 1
1011 failure_count += 1
1012
1013 if hasattr(e, "connection"):
1014 await self._record_command_metric(
1015 command_name=command,
1016 duration_seconds=time.monotonic() - start_time,
1017 connection=e.connection,
1018 error=e,
1019 )
1020 await self._record_error_metric(
1021 error=e,
1022 connection=e.connection,
1023 retry_attempts=failure_count,
1024 )
1025 continue
1026 else:
1027 # raise the exception
1028 if hasattr(e, "connection"):
1029 await self._record_error_metric(
1030 error=e,
1031 connection=e.connection,
1032 retry_attempts=failure_count,
1033 is_internal=False,
1034 )
1035 raise e
1036
1037 async def _execute_command(
1038 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
1039 ) -> Any:
1040 asking = moved = False
1041 redirect_addr = None
1042 ttl = self.RedisClusterRequestTTL
1043 command = args[0]
1044 start_time = time.monotonic()
1045
1046 while ttl > 0:
1047 ttl -= 1
1048 try:
1049 if asking:
1050 target_node = self.get_node(node_name=redirect_addr)
1051 await target_node.execute_command("ASKING")
1052 asking = False
1053 elif moved:
1054 # MOVED occurred and the slots cache was updated,
1055 # refresh the target node
1056 slot = await self._determine_slot(*args)
1057 target_node = self.nodes_manager.get_node_from_slot(
1058 slot,
1059 self.read_from_replicas and args[0] in READ_COMMANDS,
1060 self.load_balancing_strategy
1061 if args[0] in READ_COMMANDS
1062 else None,
1063 )
1064 moved = False
1065
1066 response = await target_node.execute_command(*args, **kwargs)
1067 await self._record_command_metric(
1068 command_name=command,
1069 duration_seconds=time.monotonic() - start_time,
1070 connection=target_node,
1071 )
1072 return response
1073 except BusyLoadingError as e:
1074 e.connection = target_node
1075 await self._record_command_metric(
1076 command_name=command,
1077 duration_seconds=time.monotonic() - start_time,
1078 connection=target_node,
1079 error=e,
1080 )
1081 raise
1082 except MaxConnectionsError as e:
1083 # MaxConnectionsError indicates client-side resource exhaustion
1084 # (too many connections in the pool), not a node failure.
1085 # Don't treat this as a node failure - just re-raise the error
1086 # without reinitializing the cluster.
1087 e.connection = target_node
1088 await self._record_command_metric(
1089 command_name=command,
1090 duration_seconds=time.monotonic() - start_time,
1091 connection=target_node,
1092 error=e,
1093 )
1094 raise
1095 except (ConnectionError, TimeoutError) as e:
1096 # Connection retries are being handled in the node's
1097 # Retry object.
1098 # Mark active connections for reconnect and disconnect free ones
1099 # This handles connection state (like READONLY) that may be stale
1100 target_node.update_active_connections_for_reconnect()
1101 await target_node.disconnect_free_connections()
1102
1103 # Move the failed node to the end of the cached nodes list
1104 # so it's tried last during reinitialization
1105 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)
1106
1107 # Signal that reinitialization is needed
1108 # The retry loop will handle initialize() AND replace_default_node()
1109 self._initialize = True
1110 e.connection = target_node
1111 await self._record_command_metric(
1112 command_name=command,
1113 duration_seconds=time.monotonic() - start_time,
1114 connection=target_node,
1115 error=e,
1116 )
1117 raise
1118 except (ClusterDownError, SlotNotCoveredError) as e:
1119 # ClusterDownError can occur during a failover and to get
1120 # self-healed, we will try to reinitialize the cluster layout
1121 # and retry executing the command
1122
1123 # SlotNotCoveredError can occur when the cluster is not fully
1124 # initialized or can be temporary issue.
1125 # We will try to reinitialize the cluster topology
1126 # and retry executing the command
1127
1128 await self.aclose()
1129 await asyncio.sleep(0.25)
1130 e.connection = target_node
1131 await self._record_command_metric(
1132 command_name=command,
1133 duration_seconds=time.monotonic() - start_time,
1134 connection=target_node,
1135 error=e,
1136 )
1137 raise
1138 except MovedError as e:
1139 # First, we will try to patch the slots/nodes cache with the
1140 # redirected node output and try again. If MovedError exceeds
1141 # 'reinitialize_steps' number of times, we will force
1142 # reinitializing the tables, and then try again.
1143 # 'reinitialize_steps' counter will increase faster when
1144 # the same client object is shared between multiple threads. To
1145 # reduce the frequency you can set this variable in the
1146 # RedisCluster constructor.
1147 self.reinitialize_counter += 1
1148 if (
1149 self.reinitialize_steps
1150 and self.reinitialize_counter % self.reinitialize_steps == 0
1151 ):
1152 await self.aclose()
1153 # Reset the counter
1154 self.reinitialize_counter = 0
1155 else:
1156 self.nodes_manager.move_slot(e)
1157 moved = True
1158 await self._record_command_metric(
1159 command_name=command,
1160 duration_seconds=time.monotonic() - start_time,
1161 connection=target_node,
1162 error=e,
1163 )
1164 await self._record_error_metric(
1165 error=e,
1166 connection=target_node,
1167 )
1168 except AskError as e:
1169 redirect_addr = get_node_name(host=e.host, port=e.port)
1170 asking = True
1171 await self._record_command_metric(
1172 command_name=command,
1173 duration_seconds=time.monotonic() - start_time,
1174 connection=target_node,
1175 error=e,
1176 )
1177 await self._record_error_metric(
1178 error=e,
1179 connection=target_node,
1180 )
1181 except TryAgainError as e:
1182 if ttl < self.RedisClusterRequestTTL / 2:
1183 await asyncio.sleep(0.05)
1184 await self._record_command_metric(
1185 command_name=command,
1186 duration_seconds=time.monotonic() - start_time,
1187 connection=target_node,
1188 error=e,
1189 )
1190 await self._record_error_metric(
1191 error=e,
1192 connection=target_node,
1193 )
1194 except ResponseError as e:
1195 e.connection = target_node
1196 await self._record_command_metric(
1197 command_name=command,
1198 duration_seconds=time.monotonic() - start_time,
1199 connection=target_node,
1200 error=e,
1201 )
1202 raise
1203 except Exception as e:
1204 e.connection = target_node
1205 await self._record_command_metric(
1206 command_name=command,
1207 duration_seconds=time.monotonic() - start_time,
1208 connection=target_node,
1209 error=e,
1210 )
1211 raise
1212
1213 e = ClusterError("TTL exhausted.")
1214 e.connection = target_node
1215 await self._record_command_metric(
1216 command_name=command,
1217 duration_seconds=time.monotonic() - start_time,
1218 connection=target_node,
1219 error=e,
1220 )
1221 raise e
1222
1223 def pipeline(
1224 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
1225 ) -> "ClusterPipeline":
1226 """
1227 Create & return a new :class:`~.ClusterPipeline` object.
1228
1229 Cluster implementation of pipeline does not support transaction or shard_hint.
1230
1231 :raises RedisClusterException: if transaction or shard_hint are truthy values
1232 """
1233 if shard_hint:
1234 raise RedisClusterException("shard_hint is deprecated in cluster mode")
1235
1236 return ClusterPipeline(self, transaction)
1237
1238 def pubsub(
1239 self,
1240 node: Optional["ClusterNode"] = None,
1241 host: Optional[str] = None,
1242 port: Optional[int] = None,
1243 **kwargs: Any,
1244 ) -> "ClusterPubSub":
1245 """
1246 Create and return a ClusterPubSub instance.
1247
1248 Allows passing a ClusterNode, or host&port, to get a pubsub instance
1249 connected to the specified node
1250
1251 :param node: ClusterNode to connect to
1252 :param host: Host of the node to connect to
1253 :param port: Port of the node to connect to
1254 :param kwargs: Additional keyword arguments
1255 :return: ClusterPubSub instance
1256 """
1257 return ClusterPubSub(self, node=node, host=host, port=port, **kwargs)
1258
1259 def keyspace_notifications(
1260 self,
1261 key_prefix: Union[str, bytes, None] = None,
1262 ignore_subscribe_messages: bool = True,
1263 ) -> "AsyncClusterKeyspaceNotifications":
1264 """
1265 Return an
1266 :class:`~redis.asyncio.keyspace_notifications.AsyncClusterKeyspaceNotifications`
1267 object for subscribing to keyspace and keyevent notifications across
1268 all primary nodes in the cluster.
1269
1270 Note: Keyspace notifications must be enabled on all Redis cluster nodes
1271 via the ``notify-keyspace-events`` configuration option.
1272
1273 Args:
1274 key_prefix: Optional prefix to filter and strip from keys in
1275 notifications.
1276 ignore_subscribe_messages: If True, subscribe/unsubscribe
1277 confirmations are not returned by
1278 get_message/listen.
1279 """
1280 from redis.asyncio.keyspace_notifications import (
1281 AsyncClusterKeyspaceNotifications,
1282 )
1283
1284 return AsyncClusterKeyspaceNotifications(
1285 self,
1286 key_prefix=key_prefix,
1287 ignore_subscribe_messages=ignore_subscribe_messages,
1288 )
1289
1290 def lock(
1291 self,
1292 name: KeyT,
1293 timeout: Optional[float] = None,
1294 sleep: float = 0.1,
1295 blocking: bool = True,
1296 blocking_timeout: Optional[float] = None,
1297 lock_class: Optional[Type[Lock]] = None,
1298 thread_local: bool = True,
1299 raise_on_release_error: bool = True,
1300 ) -> Lock:
1301 """
1302 Return a new Lock object using key ``name`` that mimics
1303 the behavior of threading.Lock.
1304
1305 If specified, ``timeout`` indicates a maximum life for the lock.
1306 By default, it will remain locked until release() is called.
1307
1308 ``sleep`` indicates the amount of time to sleep per loop iteration
1309 when the lock is in blocking mode and another client is currently
1310 holding the lock.
1311
1312 ``blocking`` indicates whether calling ``acquire`` should block until
1313 the lock has been acquired or to fail immediately, causing ``acquire``
1314 to return False and the lock not being acquired. Defaults to True.
1315 Note this value can be overridden by passing a ``blocking``
1316 argument to ``acquire``.
1317
1318 ``blocking_timeout`` indicates the maximum amount of time in seconds to
1319 spend trying to acquire the lock. A value of ``None`` indicates
1320 continue trying forever. ``blocking_timeout`` can be specified as a
1321 float or integer, both representing the number of seconds to wait.
1322
1323 ``lock_class`` forces the specified lock implementation. Note that as
1324 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
1325 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
1326 you have created your own custom lock class.
1327
1328 ``thread_local`` indicates whether the lock token is placed in
1329 thread-local storage. By default, the token is placed in thread local
1330 storage so that a thread only sees its token, not a token set by
1331 another thread. Consider the following timeline:
1332
1333 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
1334 thread-1 sets the token to "abc"
1335 time: 1, thread-2 blocks trying to acquire `my-lock` using the
1336 Lock instance.
1337 time: 5, thread-1 has not yet completed. redis expires the lock
1338 key.
1339 time: 5, thread-2 acquired `my-lock` now that it's available.
1340 thread-2 sets the token to "xyz"
1341 time: 6, thread-1 finishes its work and calls release(). if the
1342 token is *not* stored in thread local storage, then
1343 thread-1 would see the token value as "xyz" and would be
1344 able to successfully release the thread-2's lock.
1345
1346 ``raise_on_release_error`` indicates whether to raise an exception when
1347 the lock is no longer owned when exiting the context manager. By default,
1348 this is True, meaning an exception will be raised. If False, the warning
1349 will be logged and the exception will be suppressed.
1350
1351 In some use cases it's necessary to disable thread local storage. For
1352 example, if you have code where one thread acquires a lock and passes
1353 that lock instance to a worker thread to release later. If thread
1354 local storage isn't disabled in this case, the worker thread won't see
1355 the token set by the thread that acquired the lock. Our assumption
1356 is that these cases aren't common and as such default to using
1357 thread local storage."""
1358 if lock_class is None:
1359 lock_class = Lock
1360 return lock_class(
1361 self,
1362 name,
1363 timeout=timeout,
1364 sleep=sleep,
1365 blocking=blocking,
1366 blocking_timeout=blocking_timeout,
1367 thread_local=thread_local,
1368 raise_on_release_error=raise_on_release_error,
1369 )
1370
1371 async def transaction(
1372 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1373 ):
1374 """
1375 Convenience method for executing the callable `func` as a transaction
1376 while watching all keys specified in `watches`. The 'func' callable
1377 should expect a single argument which is a Pipeline object.
1378 """
1379 shard_hint = kwargs.pop("shard_hint", None)
1380 value_from_callable = kwargs.pop("value_from_callable", False)
1381 watch_delay = kwargs.pop("watch_delay", None)
1382 async with self.pipeline(True, shard_hint) as pipe:
1383 while True:
1384 try:
1385 if watches:
1386 await pipe.watch(*watches)
1387 func_value = await func(pipe)
1388 exec_value = await pipe.execute()
1389 return func_value if value_from_callable else exec_value
1390 except WatchError:
1391 if watch_delay is not None and watch_delay > 0:
1392 time.sleep(watch_delay)
1393 continue
1394
1395
1396class ClusterNode:
1397 """
1398 Create a new ClusterNode.
1399
1400 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1401 objects for the (host, port).
1402 """
1403
1404 __slots__ = (
1405 "_connections",
1406 "_free",
1407 "_lock",
1408 "_event_dispatcher",
1409 "connection_class",
1410 "connection_kwargs",
1411 "host",
1412 "max_connections",
1413 "name",
1414 "port",
1415 "response_callbacks",
1416 "server_type",
1417 )
1418
1419 def __init__(
1420 self,
1421 host: str,
1422 port: Union[str, int],
1423 server_type: Optional[str] = None,
1424 *,
1425 max_connections: int = 2**31,
1426 connection_class: Type[Connection] = Connection,
1427 **connection_kwargs: Any,
1428 ) -> None:
1429 if host == "localhost":
1430 host = socket.gethostbyname(host)
1431
1432 connection_kwargs["host"] = host
1433 connection_kwargs["port"] = port
1434 self.host = host
1435 self.port = port
1436 self.name = get_node_name(host, port)
1437 self.server_type = server_type
1438
1439 self.max_connections = max_connections
1440 self.connection_class = connection_class
1441 self.connection_kwargs = connection_kwargs
1442 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1443
1444 self._connections: List[Connection] = []
1445 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1446 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1447 if self._event_dispatcher is None:
1448 self._event_dispatcher = EventDispatcher()
1449
1450 def __repr__(self) -> str:
1451 return (
1452 f"[host={self.host}, port={self.port}, "
1453 f"name={self.name}, server_type={self.server_type}]"
1454 )
1455
1456 def __eq__(self, obj: Any) -> bool:
1457 return isinstance(obj, ClusterNode) and obj.name == self.name
1458
1459 def __hash__(self) -> int:
1460 return hash(self.name)
1461
1462 _DEL_MESSAGE = "Unclosed ClusterNode object"
1463
1464 def __del__(
1465 self,
1466 _warn: Any = warnings.warn,
1467 _grl: Any = asyncio.get_running_loop,
1468 ) -> None:
1469 for connection in self._connections:
1470 if connection.is_connected:
1471 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1472
1473 try:
1474 context = {"client": self, "message": self._DEL_MESSAGE}
1475 _grl().call_exception_handler(context)
1476 except RuntimeError:
1477 pass
1478 break
1479
1480 async def disconnect(self) -> None:
1481 ret = await asyncio.gather(
1482 *(
1483 asyncio.create_task(connection.disconnect())
1484 for connection in self._connections
1485 ),
1486 return_exceptions=True,
1487 )
1488 exc = next((res for res in ret if isinstance(res, Exception)), None)
1489 if exc:
1490 raise exc
1491
1492 def acquire_connection(self) -> Connection:
1493 try:
1494 return self._free.popleft()
1495 except IndexError:
1496 if len(self._connections) < self.max_connections:
1497 # We are configuring the connection pool not to retry
1498 # connections on lower level clients to avoid retrying
1499 # connections to nodes that are not reachable
1500 # and to avoid blocking the connection pool.
1501 # The only error that will have some handling in the lower
1502 # level clients is ConnectionError which will trigger disconnection
1503 # of the socket.
1504 # The retries will be handled on cluster client level
1505 # where we will have proper handling of the cluster topology
1506 retry = Retry(
1507 backoff=NoBackoff(),
1508 retries=0,
1509 supported_errors=(ConnectionError,),
1510 )
1511 connection_kwargs = self.connection_kwargs.copy()
1512 connection_kwargs["retry"] = retry
1513 connection = self.connection_class(**connection_kwargs)
1514 self._connections.append(connection)
1515 return connection
1516
1517 raise MaxConnectionsError()
1518
1519 async def disconnect_if_needed(self, connection: Connection) -> None:
1520 """
1521 Disconnect a connection if it's marked for reconnect.
1522 This implements lazy disconnection to avoid race conditions.
1523 The connection will auto-reconnect on next use.
1524 """
1525 if connection.should_reconnect():
1526 await connection.disconnect()
1527
1528 def release(self, connection: Connection) -> None:
1529 """
1530 Release connection back to free queue.
1531 If the connection is marked for reconnect, it will be disconnected
1532 lazily when next acquired via disconnect_if_needed().
1533 """
1534 self._free.append(connection)
1535
1536 def get_encoder(self) -> Encoder:
1537 """Return an :class:`Encoder` derived from this node's connection kwargs."""
1538 kwargs = self.connection_kwargs
1539 encoder_class = kwargs.get("encoder_class", Encoder)
1540 return encoder_class(
1541 encoding=kwargs.get("encoding", "utf-8"),
1542 encoding_errors=kwargs.get("encoding_errors", "strict"),
1543 decode_responses=kwargs.get("decode_responses", False),
1544 )
1545
1546 def update_active_connections_for_reconnect(self) -> None:
1547 """
1548 Mark all in-use (active) connections for reconnect.
1549 In-use connections are those in _connections but not currently in _free.
1550 They will be disconnected when released back to the pool.
1551 """
1552 free_set = set(self._free)
1553 for connection in self._connections:
1554 if connection not in free_set:
1555 connection.mark_for_reconnect()
1556
1557 async def disconnect_free_connections(self) -> None:
1558 """
1559 Disconnect all free/idle connections in the pool.
1560 This is useful after topology changes (e.g., failover) to clear
1561 stale connection state like READONLY mode.
1562 The connections remain in the pool and will reconnect on next use.
1563 """
1564 if self._free:
1565 # Take a snapshot to avoid issues if _free changes during await
1566 await asyncio.gather(
1567 *(connection.disconnect() for connection in tuple(self._free)),
1568 return_exceptions=True,
1569 )
1570
1571 async def parse_response(
1572 self, connection: Connection, command: str, **kwargs: Any
1573 ) -> Any:
1574 try:
1575 if NEVER_DECODE in kwargs:
1576 response = await connection.read_response(disable_decoding=True)
1577 kwargs.pop(NEVER_DECODE)
1578 else:
1579 response = await connection.read_response()
1580 except ResponseError:
1581 if EMPTY_RESPONSE in kwargs:
1582 return kwargs[EMPTY_RESPONSE]
1583 raise
1584
1585 if EMPTY_RESPONSE in kwargs:
1586 kwargs.pop(EMPTY_RESPONSE)
1587
1588 # Remove keys entry, it needs only for cache.
1589 kwargs.pop("keys", None)
1590
1591 # Return response
1592 if command in self.response_callbacks:
1593 return self.response_callbacks[command](response, **kwargs)
1594
1595 return response
1596
1597 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1598 # Acquire connection
1599 connection = self.acquire_connection()
1600 # Handle lazy disconnect for connections marked for reconnect
1601 await self.disconnect_if_needed(connection)
1602
1603 # Execute command
1604 await connection.send_packed_command(connection.pack_command(*args), False)
1605
1606 # Read response
1607 try:
1608 return await self.parse_response(connection, args[0], **kwargs)
1609 finally:
1610 await self.disconnect_if_needed(connection)
1611 # Release connection
1612 self._free.append(connection)
1613
1614 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1615 # Acquire connection
1616 connection = self.acquire_connection()
1617 # Handle lazy disconnect for connections marked for reconnect
1618 await self.disconnect_if_needed(connection)
1619
1620 # Execute command
1621 await connection.send_packed_command(
1622 connection.pack_commands(cmd.args for cmd in commands), False
1623 )
1624
1625 # Read responses
1626 ret = False
1627 for cmd in commands:
1628 try:
1629 cmd.result = await self.parse_response(
1630 connection, cmd.args[0], **cmd.kwargs
1631 )
1632 except Exception as e:
1633 cmd.result = e
1634 ret = True
1635
1636 # Release connection
1637 await self.disconnect_if_needed(connection)
1638 self._free.append(connection)
1639
1640 return ret
1641
1642 async def re_auth_callback(self, token: TokenInterface):
1643 tmp_queue = collections.deque()
1644 while self._free:
1645 conn = self._free.popleft()
1646 await conn.retry.call_with_retry(
1647 lambda: conn.send_command(
1648 "AUTH", token.try_get("oid"), token.get_value()
1649 ),
1650 lambda error: self._mock(error),
1651 )
1652 await conn.retry.call_with_retry(
1653 lambda: conn.read_response(), lambda error: self._mock(error)
1654 )
1655 tmp_queue.append(conn)
1656
1657 while tmp_queue:
1658 conn = tmp_queue.popleft()
1659 self._free.append(conn)
1660
1661 async def _mock(self, error: RedisError):
1662 """
1663 Dummy functions, needs to be passed as error callback to retry object.
1664 :param error:
1665 :return:
1666 """
1667 pass
1668
1669
1670class NodesManager:
1671 __slots__ = (
1672 "_dynamic_startup_nodes",
1673 "_event_dispatcher",
1674 "_background_tasks",
1675 "connection_kwargs",
1676 "default_node",
1677 "nodes_cache",
1678 "_epoch",
1679 "read_load_balancer",
1680 "_initialize_lock",
1681 "require_full_coverage",
1682 "slots_cache",
1683 "startup_nodes",
1684 "address_remap",
1685 )
1686
1687 def __init__(
1688 self,
1689 startup_nodes: List["ClusterNode"],
1690 require_full_coverage: bool,
1691 connection_kwargs: Dict[str, Any],
1692 dynamic_startup_nodes: bool = True,
1693 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1694 event_dispatcher: Optional[EventDispatcher] = None,
1695 ) -> None:
1696 self.startup_nodes = {node.name: node for node in startup_nodes}
1697 self.require_full_coverage = require_full_coverage
1698 self.connection_kwargs = connection_kwargs
1699 self.address_remap = address_remap
1700
1701 self.default_node: "ClusterNode" = None
1702 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1703 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1704 self._epoch: int = 0
1705 self.read_load_balancer = LoadBalancer()
1706 self._initialize_lock: asyncio.Lock = asyncio.Lock()
1707
1708 self._background_tasks: Set[asyncio.Task] = set()
1709 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1710 if event_dispatcher is None:
1711 self._event_dispatcher = EventDispatcher()
1712 else:
1713 self._event_dispatcher = event_dispatcher
1714
1715 def get_node(
1716 self,
1717 host: Optional[str] = None,
1718 port: Optional[int] = None,
1719 node_name: Optional[str] = None,
1720 ) -> Optional["ClusterNode"]:
1721 if host and port:
1722 # the user passed host and port
1723 if host == "localhost":
1724 host = socket.gethostbyname(host)
1725 return self.nodes_cache.get(get_node_name(host=host, port=port))
1726 elif node_name:
1727 return self.nodes_cache.get(node_name)
1728 else:
1729 raise DataError(
1730 "get_node requires one of the following: 1. node name 2. host and port"
1731 )
1732
1733 def set_nodes(
1734 self,
1735 old: Dict[str, "ClusterNode"],
1736 new: Dict[str, "ClusterNode"],
1737 remove_old: bool = False,
1738 ) -> None:
1739 if remove_old:
1740 for name in list(old.keys()):
1741 if name not in new:
1742 # Node is removed from cache before disconnect starts,
1743 # so it won't be found in lookups during disconnect
1744 # Mark active connections for reconnect so they get disconnected after current command completes
1745 # and disconnect free connections immediately
1746 # the node is removed from the cache before the connections changes so it won't be used and should be safe
1747 # not to wait for the disconnects
1748 removed_node = old.pop(name)
1749 removed_node.update_active_connections_for_reconnect()
1750 task = asyncio.create_task(
1751 removed_node.disconnect_free_connections()
1752 )
1753 self._background_tasks.add(task)
1754 task.add_done_callback(self._background_tasks.discard)
1755
1756 for name, node in new.items():
1757 if name in old:
1758 # Preserve the existing node but mark connections for reconnect.
1759 # This method is sync so we can't call disconnect_free_connections()
1760 # which is async. Instead, we mark free connections for reconnect
1761 # and they will be lazily disconnected when acquired via
1762 # disconnect_if_needed() to avoid race conditions.
1763 # TODO: Make this method async in the next major release to allow
1764 # immediate disconnection of free connections.
1765 existing_node = old[name]
1766 existing_node.update_active_connections_for_reconnect()
1767 for conn in existing_node._free:
1768 conn.mark_for_reconnect()
1769 continue
1770 # New node is detected and should be added to the pool
1771 old[name] = node
1772
1773 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
1774 """
1775 Move a failing node to the end of startup_nodes and nodes_cache so it's
1776 tried last during reinitialization and when selecting the default node.
1777 If the node is not in the respective list, nothing is done.
1778 """
1779 # Move in startup_nodes
1780 if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
1781 node = self.startup_nodes.pop(node_name)
1782 self.startup_nodes[node_name] = node # Re-insert at end
1783
1784 # Move in nodes_cache - this affects get_nodes_by_server_type ordering
1785 # which is used to select the default_node during initialize()
1786 if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
1787 node = self.nodes_cache.pop(node_name)
1788 self.nodes_cache[node_name] = node # Re-insert at end
1789
1790 def move_slot(self, e: AskError | MovedError):
1791 redirected_node = self.get_node(host=e.host, port=e.port)
1792 if redirected_node:
1793 # The node already exists
1794 if redirected_node.server_type != PRIMARY:
1795 # Update the node's server type
1796 redirected_node.server_type = PRIMARY
1797 else:
1798 # This is a new node, we will add it to the nodes cache
1799 redirected_node = ClusterNode(
1800 e.host, e.port, PRIMARY, **self.connection_kwargs
1801 )
1802 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1803 slot_nodes = self.slots_cache[e.slot_id]
1804 if redirected_node not in slot_nodes:
1805 # The new slot owner is a new server, or a server from a different
1806 # shard. We need to remove all current nodes from the slot's list
1807 # (including replications) and add just the new node.
1808 self.slots_cache[e.slot_id] = [redirected_node]
1809 elif redirected_node is not slot_nodes[0]:
1810 # The MOVED error resulted from a failover, and the new slot owner
1811 # had previously been a replica.
1812 old_primary = slot_nodes[0]
1813 # Update the old primary to be a replica and add it to the end of
1814 # the slot's node list
1815 old_primary.server_type = REPLICA
1816 slot_nodes.append(old_primary)
1817 # Remove the old replica, which is now a primary, from the slot's
1818 # node list
1819 slot_nodes.remove(redirected_node)
1820 # Override the old primary with the new one
1821 slot_nodes[0] = redirected_node
1822 if self.default_node == old_primary:
1823 # Update the default node with the new primary
1824 self.default_node = redirected_node
1825 # else: circular MOVED to current primary -> no-op
1826
1827 def get_node_from_slot(
1828 self,
1829 slot: int,
1830 read_from_replicas: bool = False,
1831 load_balancing_strategy=None,
1832 ) -> "ClusterNode":
1833 if read_from_replicas is True and load_balancing_strategy is None:
1834 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1835
1836 try:
1837 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1838 # get the server index using the strategy defined in load_balancing_strategy
1839 primary_name = self.slots_cache[slot][0].name
1840 node_idx = self.read_load_balancer.get_server_index(
1841 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1842 )
1843 return self.slots_cache[slot][node_idx]
1844 return self.slots_cache[slot][0]
1845 except (IndexError, TypeError):
1846 raise SlotNotCoveredError(
1847 f'Slot "{slot}" not covered by the cluster. '
1848 f'"require_full_coverage={self.require_full_coverage}"'
1849 )
1850
1851 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1852 return [
1853 node
1854 for node in self.nodes_cache.values()
1855 if node.server_type == server_type
1856 ]
1857
1858 async def initialize(self) -> None:
1859 self.read_load_balancer.reset()
1860 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1861 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1862 disagreements = []
1863 startup_nodes_reachable = False
1864 fully_covered = False
1865 exception = None
1866 epoch = self._epoch
1867
1868 async with self._initialize_lock:
1869 if self._epoch != epoch:
1870 # another initialize call has already reinitialized the
1871 # nodes since we started waiting for the lock;
1872 # we don't need to do it again.
1873 return
1874
1875 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1876 # is modified during iteration
1877 for startup_node in tuple(self.startup_nodes.values()):
1878 try:
1879 # Make sure cluster mode is enabled on this node
1880 try:
1881 self._event_dispatcher.dispatch(
1882 AfterAsyncClusterInstantiationEvent(
1883 self.nodes_cache,
1884 self.connection_kwargs.get("credential_provider", None),
1885 )
1886 )
1887 cluster_slots = await startup_node.execute_command(
1888 "CLUSTER SLOTS"
1889 )
1890 except ResponseError:
1891 raise RedisClusterException(
1892 "Cluster mode is not enabled on this node"
1893 )
1894 startup_nodes_reachable = True
1895 except Exception as e:
1896 # Try the next startup node.
1897 # The exception is saved and raised only if we have no more nodes.
1898 exception = e
1899 continue
1900
1901 # CLUSTER SLOTS command results in the following output:
1902 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1903 # where each node contains the following list: [IP, port, node_id]
1904 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1905 # primary node of the first slot section.
1906 # If there's only one server in the cluster, its ``host`` is ''
1907 # Fix it to the host in startup_nodes
1908 if (
1909 len(cluster_slots) == 1
1910 and not cluster_slots[0][2][0]
1911 and len(self.startup_nodes) == 1
1912 ):
1913 cluster_slots[0][2][0] = startup_node.host
1914
1915 for slot in cluster_slots:
1916 for i in range(2, len(slot)):
1917 slot[i] = [str_if_bytes(val) for val in slot[i]]
1918 primary_node = slot[2]
1919 host = primary_node[0]
1920 if host == "":
1921 host = startup_node.host
1922 port = int(primary_node[1])
1923 host, port = self.remap_host_port(host, port)
1924
1925 nodes_for_slot = []
1926
1927 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1928 if not target_node:
1929 target_node = ClusterNode(
1930 host, port, PRIMARY, **self.connection_kwargs
1931 )
1932 # add this node to the nodes cache
1933 tmp_nodes_cache[target_node.name] = target_node
1934 nodes_for_slot.append(target_node)
1935
1936 replica_nodes = slot[3:]
1937 for replica_node in replica_nodes:
1938 host = replica_node[0]
1939 port = replica_node[1]
1940 host, port = self.remap_host_port(host, port)
1941
1942 target_replica_node = tmp_nodes_cache.get(
1943 get_node_name(host, port)
1944 )
1945 if not target_replica_node:
1946 target_replica_node = ClusterNode(
1947 host, port, REPLICA, **self.connection_kwargs
1948 )
1949 # add this node to the nodes cache
1950 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1951 nodes_for_slot.append(target_replica_node)
1952
1953 for i in range(int(slot[0]), int(slot[1]) + 1):
1954 if i not in tmp_slots:
1955 tmp_slots[i] = nodes_for_slot
1956 else:
1957 # Validate that 2 nodes want to use the same slot cache
1958 # setup
1959 tmp_slot = tmp_slots[i][0]
1960 if tmp_slot.name != target_node.name:
1961 disagreements.append(
1962 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1963 )
1964
1965 if len(disagreements) > 5:
1966 raise RedisClusterException(
1967 f"startup_nodes could not agree on a valid "
1968 f"slots cache: {', '.join(disagreements)}"
1969 )
1970
1971 # Validate if all slots are covered or if we should try next startup node
1972 fully_covered = True
1973 for i in range(REDIS_CLUSTER_HASH_SLOTS):
1974 if i not in tmp_slots:
1975 fully_covered = False
1976 break
1977 if fully_covered:
1978 break
1979
1980 if not startup_nodes_reachable:
1981 raise RedisClusterException(
1982 f"Redis Cluster cannot be connected. Please provide at least "
1983 f"one reachable node: {str(exception)}"
1984 ) from exception
1985
1986 # Check if the slots are not fully covered
1987 if not fully_covered and self.require_full_coverage:
1988 # Despite the requirement that the slots be covered, there
1989 # isn't a full coverage
1990 raise RedisClusterException(
1991 f"All slots are not covered after query all startup_nodes. "
1992 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
1993 f"covered..."
1994 )
1995
1996 # Set the tmp variables to the real variables
1997 self.slots_cache = tmp_slots
1998 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1999
2000 if self._dynamic_startup_nodes:
2001 # Populate the startup nodes with all discovered nodes
2002 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
2003
2004 # Set the default node
2005 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
2006 self._epoch += 1
2007
2008 async def aclose(self, attr: str = "nodes_cache") -> None:
2009 self.default_node = None
2010 await asyncio.gather(
2011 *(
2012 asyncio.create_task(node.disconnect())
2013 for node in getattr(self, attr).values()
2014 )
2015 )
2016
2017 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
2018 """
2019 Remap the host and port returned from the cluster to a different
2020 internal value. Useful if the client is not connecting directly
2021 to the cluster.
2022 """
2023 if self.address_remap:
2024 return self.address_remap((host, port))
2025 return host, port
2026
2027
2028class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
2029 """
2030 Create a new ClusterPipeline object.
2031
2032 Usage::
2033
2034 result = await (
2035 rc.pipeline()
2036 .set("A", 1)
2037 .get("A")
2038 .hset("K", "F", "V")
2039 .hgetall("K")
2040 .mset_nonatomic({"A": 2, "B": 3})
2041 .get("A")
2042 .get("B")
2043 .delete("A", "B", "K")
2044 .execute()
2045 )
2046 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
2047
2048 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
2049 are split across multiple nodes, you'll get multiple results for them in the array.
2050
2051 Retryable errors:
2052 - :class:`~.ClusterDownError`
2053 - :class:`~.ConnectionError`
2054 - :class:`~.TimeoutError`
2055
2056 Redirection errors:
2057 - :class:`~.TryAgainError`
2058 - :class:`~.MovedError`
2059 - :class:`~.AskError`
2060
2061 :param client:
2062 | Existing :class:`~.RedisCluster` client
2063 """
2064
2065 __slots__ = (
2066 "cluster_client",
2067 "_transaction",
2068 "_execution_strategy",
2069 )
2070
2071 # Type discrimination marker for @overload self-type pattern
2072 _is_async_client: Literal[True] = True
2073
2074 def __init__(
2075 self, client: RedisCluster, transaction: Optional[bool] = None
2076 ) -> None:
2077 self.cluster_client = client
2078 self._transaction = transaction
2079 self._execution_strategy: ExecutionStrategy = (
2080 PipelineStrategy(self)
2081 if not self._transaction
2082 else TransactionStrategy(self)
2083 )
2084
2085 @property
2086 def nodes_manager(self) -> "NodesManager":
2087 """Get the nodes manager from the cluster client."""
2088 return self.cluster_client.nodes_manager
2089
2090 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
2091 """Set a custom response callback on the cluster client."""
2092 self.cluster_client.set_response_callback(command, callback)
2093
2094 async def initialize(self) -> "ClusterPipeline":
2095 await self._execution_strategy.initialize()
2096 return self
2097
2098 async def __aenter__(self) -> "ClusterPipeline":
2099 return await self.initialize()
2100
2101 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
2102 await self.reset()
2103
2104 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
2105 return self.initialize().__await__()
2106
2107 def __bool__(self) -> bool:
2108 "Pipeline instances should always evaluate to True on Python 3+"
2109 return True
2110
2111 def __len__(self) -> int:
2112 return len(self._execution_strategy)
2113
2114 def execute_command(
2115 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2116 ) -> "ClusterPipeline":
2117 """
2118 Append a raw command to the pipeline.
2119
2120 :param args:
2121 | Raw command args
2122 :param kwargs:
2123
2124 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
2125 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
2126 - Rest of the kwargs are passed to the Redis connection
2127 """
2128 return self._execution_strategy.execute_command(*args, **kwargs)
2129
2130 async def execute(
2131 self, raise_on_error: bool = True, allow_redirections: bool = True
2132 ) -> List[Any]:
2133 """
2134 Execute the pipeline.
2135
2136 It will retry the commands as specified by retries specified in :attr:`retry`
2137 & then raise an exception.
2138
2139 :param raise_on_error:
2140 | Raise the first error if there are any errors
2141 :param allow_redirections:
2142 | Whether to retry each failed command individually in case of redirection
2143 errors
2144
2145 :raises RedisClusterException: if target_nodes is not provided & the command
2146 can't be mapped to a slot
2147 """
2148 try:
2149 return await self._execution_strategy.execute(
2150 raise_on_error, allow_redirections
2151 )
2152 finally:
2153 await self.reset()
2154
2155 def _split_command_across_slots(
2156 self, command: str, *keys: KeyT
2157 ) -> "ClusterPipeline":
2158 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
2159 self.execute_command(command, *slot_keys)
2160
2161 return self
2162
2163 async def reset(self):
2164 """
2165 Reset back to empty pipeline.
2166 """
2167 await self._execution_strategy.reset()
2168
2169 def multi(self):
2170 """
2171 Start a transactional block of the pipeline after WATCH commands
2172 are issued. End the transactional block with `execute`.
2173 """
2174 self._execution_strategy.multi()
2175
2176 async def discard(self):
2177 """ """
2178 await self._execution_strategy.discard()
2179
2180 async def watch(self, *names):
2181 """Watches the values at keys ``names``"""
2182 await self._execution_strategy.watch(*names)
2183
2184 async def unwatch(self):
2185 """Unwatches all previously specified keys"""
2186 await self._execution_strategy.unwatch()
2187
2188 async def unlink(self, *names):
2189 await self._execution_strategy.unlink(*names)
2190
2191 def mset_nonatomic(
2192 self, mapping: Mapping[AnyKeyT, EncodableT]
2193 ) -> "ClusterPipeline":
2194 return self._execution_strategy.mset_nonatomic(mapping)
2195
2196
2197for command in PIPELINE_BLOCKED_COMMANDS:
2198 command = command.replace(" ", "_").lower()
2199 if command == "mset_nonatomic":
2200 continue
2201
2202 setattr(ClusterPipeline, command, block_pipeline_command(command))
2203
2204
2205class PipelineCommand:
2206 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
2207 self.args = args
2208 self.kwargs = kwargs
2209 self.position = position
2210 self.result: Union[Any, Exception] = None
2211 self.command_policies: Optional[CommandPolicies] = None
2212
2213 def __repr__(self) -> str:
2214 return f"[{self.position}] {self.args} ({self.kwargs})"
2215
2216
2217class ExecutionStrategy(ABC):
2218 @abstractmethod
2219 async def initialize(self) -> "ClusterPipeline":
2220 """
2221 Initialize the execution strategy.
2222
2223 See ClusterPipeline.initialize()
2224 """
2225 pass
2226
2227 @abstractmethod
2228 def execute_command(
2229 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2230 ) -> "ClusterPipeline":
2231 """
2232 Append a raw command to the pipeline.
2233
2234 See ClusterPipeline.execute_command()
2235 """
2236 pass
2237
2238 @abstractmethod
2239 async def execute(
2240 self, raise_on_error: bool = True, allow_redirections: bool = True
2241 ) -> List[Any]:
2242 """
2243 Execute the pipeline.
2244
2245 It will retry the commands as specified by retries specified in :attr:`retry`
2246 & then raise an exception.
2247
2248 See ClusterPipeline.execute()
2249 """
2250 pass
2251
2252 @abstractmethod
2253 def mset_nonatomic(
2254 self, mapping: Mapping[AnyKeyT, EncodableT]
2255 ) -> "ClusterPipeline":
2256 """
2257 Executes multiple MSET commands according to the provided slot/pairs mapping.
2258
2259 See ClusterPipeline.mset_nonatomic()
2260 """
2261 pass
2262
2263 @abstractmethod
2264 async def reset(self):
2265 """
2266 Resets current execution strategy.
2267
2268 See: ClusterPipeline.reset()
2269 """
2270 pass
2271
2272 @abstractmethod
2273 def multi(self):
2274 """
2275 Starts transactional context.
2276
2277 See: ClusterPipeline.multi()
2278 """
2279 pass
2280
2281 @abstractmethod
2282 async def watch(self, *names):
2283 """
2284 Watch given keys.
2285
2286 See: ClusterPipeline.watch()
2287 """
2288 pass
2289
2290 @abstractmethod
2291 async def unwatch(self):
2292 """
2293 Unwatches all previously specified keys
2294
2295 See: ClusterPipeline.unwatch()
2296 """
2297 pass
2298
2299 @abstractmethod
2300 async def discard(self):
2301 pass
2302
2303 @abstractmethod
2304 async def unlink(self, *names):
2305 """
2306 "Unlink a key specified by ``names``"
2307
2308 See: ClusterPipeline.unlink()
2309 """
2310 pass
2311
2312 @abstractmethod
2313 def __len__(self) -> int:
2314 pass
2315
2316
2317class AbstractStrategy(ExecutionStrategy):
2318 def __init__(self, pipe: ClusterPipeline) -> None:
2319 self._pipe: ClusterPipeline = pipe
2320 self._command_queue: List["PipelineCommand"] = []
2321
2322 async def initialize(self) -> "ClusterPipeline":
2323 if self._pipe.cluster_client._initialize:
2324 await self._pipe.cluster_client.initialize()
2325 self._command_queue = []
2326 return self._pipe
2327
2328 def execute_command(
2329 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2330 ) -> "ClusterPipeline":
2331 self._command_queue.append(
2332 PipelineCommand(len(self._command_queue), *args, **kwargs)
2333 )
2334 return self._pipe
2335
2336 def _annotate_exception(self, exception, number, command):
2337 """
2338 Provides extra context to the exception prior to it being handled
2339 """
2340 cmd = " ".join(map(safe_str, command))
2341 msg = (
2342 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
2343 f"caused error: {exception.args[0]}"
2344 )
2345 exception.args = (msg,) + exception.args[1:]
2346
2347 @abstractmethod
2348 def mset_nonatomic(
2349 self, mapping: Mapping[AnyKeyT, EncodableT]
2350 ) -> "ClusterPipeline":
2351 pass
2352
2353 @abstractmethod
2354 async def execute(
2355 self, raise_on_error: bool = True, allow_redirections: bool = True
2356 ) -> List[Any]:
2357 pass
2358
2359 @abstractmethod
2360 async def reset(self):
2361 pass
2362
2363 @abstractmethod
2364 def multi(self):
2365 pass
2366
2367 @abstractmethod
2368 async def watch(self, *names):
2369 pass
2370
2371 @abstractmethod
2372 async def unwatch(self):
2373 pass
2374
2375 @abstractmethod
2376 async def discard(self):
2377 pass
2378
2379 @abstractmethod
2380 async def unlink(self, *names):
2381 pass
2382
2383 def __len__(self) -> int:
2384 return len(self._command_queue)
2385
2386
2387class PipelineStrategy(AbstractStrategy):
2388 def __init__(self, pipe: ClusterPipeline) -> None:
2389 super().__init__(pipe)
2390
2391 def mset_nonatomic(
2392 self, mapping: Mapping[AnyKeyT, EncodableT]
2393 ) -> "ClusterPipeline":
2394 encoder = self._pipe.cluster_client.encoder
2395
2396 slots_pairs = {}
2397 for pair in mapping.items():
2398 slot = key_slot(encoder.encode(pair[0]))
2399 slots_pairs.setdefault(slot, []).extend(pair)
2400
2401 for pairs in slots_pairs.values():
2402 self.execute_command("MSET", *pairs)
2403
2404 return self._pipe
2405
2406 async def execute(
2407 self, raise_on_error: bool = True, allow_redirections: bool = True
2408 ) -> List[Any]:
2409 if not self._command_queue:
2410 return []
2411
2412 try:
2413 retry_attempts = self._pipe.cluster_client.retry.get_retries()
2414 while True:
2415 try:
2416 if self._pipe.cluster_client._initialize:
2417 await self._pipe.cluster_client.initialize()
2418 return await self._execute(
2419 self._pipe.cluster_client,
2420 self._command_queue,
2421 raise_on_error=raise_on_error,
2422 allow_redirections=allow_redirections,
2423 )
2424
2425 except RedisCluster.ERRORS_ALLOW_RETRY as e:
2426 if retry_attempts > 0:
2427 # Try again with the new cluster setup. All other errors
2428 # should be raised.
2429 retry_attempts -= 1
2430 await self._pipe.cluster_client.aclose()
2431 await asyncio.sleep(0.25)
2432 else:
2433 # All other errors should be raised.
2434 raise e
2435 finally:
2436 await self.reset()
2437
2438 async def _execute(
2439 self,
2440 client: "RedisCluster",
2441 stack: List["PipelineCommand"],
2442 raise_on_error: bool = True,
2443 allow_redirections: bool = True,
2444 ) -> List[Any]:
2445 todo = [
2446 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
2447 ]
2448
2449 nodes = {}
2450 for cmd in todo:
2451 passed_targets = cmd.kwargs.pop("target_nodes", None)
2452 command_policies = await client._policy_resolver.resolve(
2453 cmd.args[0].lower()
2454 )
2455
2456 if passed_targets and not client._is_node_flag(passed_targets):
2457 target_nodes = client._parse_target_nodes(passed_targets)
2458
2459 if not command_policies:
2460 command_policies = CommandPolicies()
2461 else:
2462 if not command_policies:
2463 command_flag = client.command_flags.get(cmd.args[0])
2464 if not command_flag:
2465 # Fallback to default policy
2466 if not client.get_default_node():
2467 slot = None
2468 else:
2469 slot = await client._determine_slot(*cmd.args)
2470 if slot is None:
2471 command_policies = CommandPolicies()
2472 else:
2473 command_policies = CommandPolicies(
2474 request_policy=RequestPolicy.DEFAULT_KEYED,
2475 response_policy=ResponsePolicy.DEFAULT_KEYED,
2476 )
2477 else:
2478 if command_flag in client._command_flags_mapping:
2479 command_policies = CommandPolicies(
2480 request_policy=client._command_flags_mapping[
2481 command_flag
2482 ]
2483 )
2484 else:
2485 command_policies = CommandPolicies()
2486
2487 target_nodes = await client._determine_nodes(
2488 *cmd.args,
2489 request_policy=command_policies.request_policy,
2490 node_flag=passed_targets,
2491 )
2492 if not target_nodes:
2493 raise RedisClusterException(
2494 f"No targets were found to execute {cmd.args} command on"
2495 )
2496 cmd.command_policies = command_policies
2497 if len(target_nodes) > 1:
2498 raise RedisClusterException(f"Too many targets for command {cmd.args}")
2499 node = target_nodes[0]
2500 if node.name not in nodes:
2501 nodes[node.name] = (node, [])
2502 nodes[node.name][1].append(cmd)
2503
2504 # Start timing for observability
2505 start_time = time.monotonic()
2506
2507 errors = await asyncio.gather(
2508 *(
2509 asyncio.create_task(node[0].execute_pipeline(node[1]))
2510 for node in nodes.values()
2511 )
2512 )
2513
2514 # Record operation duration for each node
2515 for node_name, (node, commands) in nodes.items():
2516 # Find the first error in this node's commands, if any
2517 node_error = None
2518 for cmd in commands:
2519 if isinstance(cmd.result, Exception):
2520 node_error = cmd.result
2521 break
2522
2523 db = node.connection_kwargs.get("db", 0)
2524 await record_operation_duration(
2525 command_name="PIPELINE",
2526 duration_seconds=time.monotonic() - start_time,
2527 server_address=node.host,
2528 server_port=node.port,
2529 db_namespace=str(db) if db is not None else None,
2530 error=node_error,
2531 )
2532
2533 if any(errors):
2534 if allow_redirections:
2535 # send each errored command individually
2536 for cmd in todo:
2537 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2538 try:
2539 cmd.result = client._policies_callback_mapping[
2540 cmd.command_policies.response_policy
2541 ](await client.execute_command(*cmd.args, **cmd.kwargs))
2542 except Exception as e:
2543 cmd.result = e
2544
2545 if raise_on_error:
2546 for cmd in todo:
2547 result = cmd.result
2548 if isinstance(result, Exception):
2549 command = " ".join(map(safe_str, cmd.args))
2550 msg = (
2551 f"Command # {cmd.position + 1} "
2552 f"({truncate_text(command)}) "
2553 f"of pipeline caused error: {result.args}"
2554 )
2555 result.args = (msg,) + result.args[1:]
2556 raise result
2557
2558 default_cluster_node = client.get_default_node()
2559
2560 # Check whether the default node was used. In some cases,
2561 # 'client.get_default_node()' may return None. The check below
2562 # prevents a potential AttributeError.
2563 if default_cluster_node is not None:
2564 default_node = nodes.get(default_cluster_node.name)
2565 if default_node is not None:
2566 # This pipeline execution used the default node, check if we need
2567 # to replace it.
2568 # Note: when the error is raised we'll reset the default node in the
2569 # caller function.
2570 for cmd in default_node[1]:
2571 # Check if it has a command that failed with a relevant
2572 # exception
2573 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2574 client.replace_default_node()
2575 break
2576
2577 return [cmd.result for cmd in stack]
2578
2579 async def reset(self):
2580 """
2581 Reset back to empty pipeline.
2582 """
2583 self._command_queue = []
2584
2585 def multi(self):
2586 raise RedisClusterException(
2587 "method multi() is not supported outside of transactional context"
2588 )
2589
2590 async def watch(self, *names):
2591 raise RedisClusterException(
2592 "method watch() is not supported outside of transactional context"
2593 )
2594
2595 async def unwatch(self):
2596 raise RedisClusterException(
2597 "method unwatch() is not supported outside of transactional context"
2598 )
2599
2600 async def discard(self):
2601 raise RedisClusterException(
2602 "method discard() is not supported outside of transactional context"
2603 )
2604
2605 async def unlink(self, *names):
2606 if len(names) != 1:
2607 raise RedisClusterException(
2608 "unlinking multiple keys is not implemented in pipeline command"
2609 )
2610
2611 return self.execute_command("UNLINK", names[0])
2612
2613
2614class TransactionStrategy(AbstractStrategy):
2615 NO_SLOTS_COMMANDS = {"UNWATCH"}
2616 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2617 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2618 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2619 CONNECTION_ERRORS = (
2620 ConnectionError,
2621 OSError,
2622 ClusterDownError,
2623 SlotNotCoveredError,
2624 )
2625
2626 def __init__(self, pipe: ClusterPipeline) -> None:
2627 super().__init__(pipe)
2628 self._explicit_transaction = False
2629 self._watching = False
2630 self._pipeline_slots: Set[int] = set()
2631 self._transaction_node: Optional[ClusterNode] = None
2632 self._transaction_connection: Optional[Connection] = None
2633 self._executing = False
2634 self._retry = copy(self._pipe.cluster_client.retry)
2635 self._retry.update_supported_errors(
2636 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2637 )
2638
2639 def _get_client_and_connection_for_transaction(
2640 self,
2641 ) -> Tuple[ClusterNode, Connection]:
2642 """
2643 Find a connection for a pipeline transaction.
2644
2645 For running an atomic transaction, watch keys ensure that contents have not been
2646 altered as long as the watch commands for those keys were sent over the same
2647 connection. So once we start watching a key, we fetch a connection to the
2648 node that owns that slot and reuse it.
2649 """
2650 if not self._pipeline_slots:
2651 raise RedisClusterException(
2652 "At least a command with a key is needed to identify a node"
2653 )
2654
2655 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2656 list(self._pipeline_slots)[0], False
2657 )
2658 self._transaction_node = node
2659
2660 if not self._transaction_connection:
2661 connection: Connection = self._transaction_node.acquire_connection()
2662 self._transaction_connection = connection
2663
2664 return self._transaction_node, self._transaction_connection
2665
2666 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2667 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2668 response = None
2669 error = None
2670
2671 def runner():
2672 nonlocal response
2673 nonlocal error
2674 try:
2675 response = asyncio.run(self._execute_command(*args, **kwargs))
2676 except Exception as e:
2677 error = e
2678
2679 thread = threading.Thread(target=runner)
2680 thread.start()
2681 thread.join()
2682
2683 if error:
2684 raise error
2685
2686 return response
2687
2688 async def _execute_command(
2689 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2690 ) -> Any:
2691 if self._pipe.cluster_client._initialize:
2692 await self._pipe.cluster_client.initialize()
2693
2694 slot_number: Optional[int] = None
2695 if args[0] not in self.NO_SLOTS_COMMANDS:
2696 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2697
2698 if (
2699 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2700 ) and not self._explicit_transaction:
2701 if args[0] == "WATCH":
2702 self._validate_watch()
2703
2704 if slot_number is not None:
2705 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2706 raise CrossSlotTransactionError(
2707 "Cannot watch or send commands on different slots"
2708 )
2709
2710 self._pipeline_slots.add(slot_number)
2711 elif args[0] not in self.NO_SLOTS_COMMANDS:
2712 raise RedisClusterException(
2713 f"Cannot identify slot number for command: {args[0]},"
2714 "it cannot be triggered in a transaction"
2715 )
2716
2717 return self._immediate_execute_command(*args, **kwargs)
2718 else:
2719 if slot_number is not None:
2720 self._pipeline_slots.add(slot_number)
2721
2722 return super().execute_command(*args, **kwargs)
2723
2724 def _validate_watch(self):
2725 if self._explicit_transaction:
2726 raise RedisError("Cannot issue a WATCH after a MULTI")
2727
2728 self._watching = True
2729
2730 async def _immediate_execute_command(self, *args, **options):
2731 return await self._retry.call_with_retry(
2732 lambda: self._get_connection_and_send_command(*args, **options),
2733 self._reinitialize_on_error,
2734 with_failure_count=True,
2735 )
2736
2737 async def _get_connection_and_send_command(self, *args, **options):
2738 redis_node, connection = self._get_client_and_connection_for_transaction()
2739 # Only disconnect if not watching - disconnecting would lose WATCH state
2740 if not self._watching:
2741 await redis_node.disconnect_if_needed(connection)
2742
2743 # Start timing for observability
2744 start_time = time.monotonic()
2745
2746 try:
2747 response = await self._send_command_parse_response(
2748 connection, redis_node, args[0], *args, **options
2749 )
2750
2751 await record_operation_duration(
2752 command_name=args[0],
2753 duration_seconds=time.monotonic() - start_time,
2754 server_address=connection.host,
2755 server_port=connection.port,
2756 db_namespace=str(connection.db),
2757 )
2758
2759 return response
2760 except Exception as e:
2761 e.connection = connection
2762 await record_operation_duration(
2763 command_name=args[0],
2764 duration_seconds=time.monotonic() - start_time,
2765 server_address=connection.host,
2766 server_port=connection.port,
2767 db_namespace=str(connection.db),
2768 error=e,
2769 )
2770 raise
2771
2772 async def _send_command_parse_response(
2773 self,
2774 connection: Connection,
2775 redis_node: ClusterNode,
2776 command_name,
2777 *args,
2778 **options,
2779 ):
2780 """
2781 Send a command and parse the response
2782 """
2783
2784 await connection.send_command(*args)
2785 output = await redis_node.parse_response(connection, command_name, **options)
2786
2787 if command_name in self.UNWATCH_COMMANDS:
2788 self._watching = False
2789 return output
2790
2791 async def _reinitialize_on_error(self, error, failure_count):
2792 if hasattr(error, "connection"):
2793 await record_error_count(
2794 server_address=error.connection.host,
2795 server_port=error.connection.port,
2796 network_peer_address=error.connection.host,
2797 network_peer_port=error.connection.port,
2798 error_type=error,
2799 retry_attempts=failure_count,
2800 is_internal=True,
2801 )
2802
2803 if self._watching:
2804 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2805 raise WatchError("Slot rebalancing occurred while watching keys")
2806
2807 if (
2808 type(error) in self.SLOT_REDIRECT_ERRORS
2809 or type(error) in self.CONNECTION_ERRORS
2810 ):
2811 if self._transaction_connection and self._transaction_node:
2812 # Disconnect and release back to pool
2813 await self._transaction_connection.disconnect()
2814 self._transaction_node.release(self._transaction_connection)
2815 self._transaction_connection = None
2816
2817 self._pipe.cluster_client.reinitialize_counter += 1
2818 if (
2819 self._pipe.cluster_client.reinitialize_steps
2820 and self._pipe.cluster_client.reinitialize_counter
2821 % self._pipe.cluster_client.reinitialize_steps
2822 == 0
2823 ):
2824 await self._pipe.cluster_client.nodes_manager.initialize()
2825 self.reinitialize_counter = 0
2826 else:
2827 if isinstance(error, AskError):
2828 self._pipe.cluster_client.nodes_manager.move_slot(error)
2829
2830 self._executing = False
2831
2832 async def _raise_first_error(self, responses, stack, start_time):
2833 """
2834 Raise the first exception on the stack
2835 """
2836 for r, cmd in zip(responses, stack):
2837 if isinstance(r, Exception):
2838 self._annotate_exception(r, cmd.position + 1, cmd.args)
2839
2840 await record_operation_duration(
2841 command_name="TRANSACTION",
2842 duration_seconds=time.monotonic() - start_time,
2843 server_address=self._transaction_connection.host,
2844 server_port=self._transaction_connection.port,
2845 db_namespace=str(self._transaction_connection.db),
2846 error=r,
2847 )
2848
2849 raise r
2850
2851 def mset_nonatomic(
2852 self, mapping: Mapping[AnyKeyT, EncodableT]
2853 ) -> "ClusterPipeline":
2854 raise NotImplementedError("Method is not supported in transactional context.")
2855
2856 async def execute(
2857 self, raise_on_error: bool = True, allow_redirections: bool = True
2858 ) -> List[Any]:
2859 stack = self._command_queue
2860 if not stack and (not self._watching or not self._pipeline_slots):
2861 return []
2862
2863 return await self._execute_transaction_with_retries(stack, raise_on_error)
2864
2865 async def _execute_transaction_with_retries(
2866 self, stack: List["PipelineCommand"], raise_on_error: bool
2867 ):
2868 return await self._retry.call_with_retry(
2869 lambda: self._execute_transaction(stack, raise_on_error),
2870 lambda error, failure_count: self._reinitialize_on_error(
2871 error, failure_count
2872 ),
2873 with_failure_count=True,
2874 )
2875
2876 async def _execute_transaction(
2877 self, stack: List["PipelineCommand"], raise_on_error: bool
2878 ):
2879 if len(self._pipeline_slots) > 1:
2880 raise CrossSlotTransactionError(
2881 "All keys involved in a cluster transaction must map to the same slot"
2882 )
2883
2884 self._executing = True
2885
2886 redis_node, connection = self._get_client_and_connection_for_transaction()
2887 # Only disconnect if not watching - disconnecting would lose WATCH state
2888 if not self._watching:
2889 await redis_node.disconnect_if_needed(connection)
2890
2891 stack = chain(
2892 [PipelineCommand(0, "MULTI")],
2893 stack,
2894 [PipelineCommand(0, "EXEC")],
2895 )
2896 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2897 packed_commands = connection.pack_commands(commands)
2898
2899 # Start timing for observability
2900 start_time = time.monotonic()
2901
2902 await connection.send_packed_command(packed_commands)
2903 errors = []
2904
2905 # parse off the response for MULTI
2906 # NOTE: we need to handle ResponseErrors here and continue
2907 # so that we read all the additional command messages from
2908 # the socket
2909 try:
2910 await redis_node.parse_response(connection, "MULTI")
2911 except ResponseError as e:
2912 self._annotate_exception(e, 0, "MULTI")
2913 errors.append(e)
2914 except self.CONNECTION_ERRORS as cluster_error:
2915 self._annotate_exception(cluster_error, 0, "MULTI")
2916 cluster_error.connection = connection
2917 raise
2918
2919 # and all the other commands
2920 for i, command in enumerate(self._command_queue):
2921 if EMPTY_RESPONSE in command.kwargs:
2922 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2923 else:
2924 try:
2925 _ = await redis_node.parse_response(connection, "_")
2926 except self.SLOT_REDIRECT_ERRORS as slot_error:
2927 self._annotate_exception(slot_error, i + 1, command.args)
2928 errors.append(slot_error)
2929 except self.CONNECTION_ERRORS as cluster_error:
2930 self._annotate_exception(cluster_error, i + 1, command.args)
2931 cluster_error.connection = connection
2932 raise
2933 except ResponseError as e:
2934 self._annotate_exception(e, i + 1, command.args)
2935 errors.append(e)
2936
2937 response = None
2938 # parse the EXEC.
2939 try:
2940 response = await redis_node.parse_response(connection, "EXEC")
2941 except ExecAbortError:
2942 if errors:
2943 raise errors[0]
2944 raise
2945
2946 self._executing = False
2947
2948 # EXEC clears any watched keys
2949 self._watching = False
2950
2951 if response is None:
2952 raise WatchError("Watched variable changed.")
2953
2954 # put any parse errors into the response
2955 for i, e in errors:
2956 response.insert(i, e)
2957
2958 if len(response) != len(self._command_queue):
2959 raise InvalidPipelineStack(
2960 "Unexpected response length for cluster pipeline EXEC."
2961 " Command stack was {} but response had length {}".format(
2962 [c.args[0] for c in self._command_queue], len(response)
2963 )
2964 )
2965
2966 # find any errors in the response and raise if necessary
2967 if raise_on_error or len(errors) > 0:
2968 await self._raise_first_error(
2969 response,
2970 self._command_queue,
2971 start_time,
2972 )
2973
2974 # We have to run response callbacks manually
2975 data = []
2976 for r, cmd in zip(response, self._command_queue):
2977 if not isinstance(r, Exception):
2978 command_name = cmd.args[0]
2979 if command_name in self._pipe.cluster_client.response_callbacks:
2980 r = self._pipe.cluster_client.response_callbacks[command_name](
2981 r, **cmd.kwargs
2982 )
2983 data.append(r)
2984
2985 await record_operation_duration(
2986 command_name="TRANSACTION",
2987 duration_seconds=time.monotonic() - start_time,
2988 server_address=connection.host,
2989 server_port=connection.port,
2990 db_namespace=str(connection.db),
2991 )
2992
2993 return data
2994
2995 async def reset(self):
2996 self._command_queue = []
2997
2998 # make sure to reset the connection state in the event that we were
2999 # watching something
3000 if self._transaction_connection:
3001 try:
3002 if self._watching:
3003 # call this manually since our unwatch or
3004 # immediate_execute_command methods can call reset()
3005 await self._transaction_connection.send_command("UNWATCH")
3006 await self._transaction_connection.read_response()
3007 # we can safely return the connection to the pool here since we're
3008 # sure we're no longer WATCHing anything
3009 self._transaction_node.release(self._transaction_connection)
3010 self._transaction_connection = None
3011 except self.CONNECTION_ERRORS:
3012 # disconnect will also remove any previous WATCHes
3013 if self._transaction_connection and self._transaction_node:
3014 await self._transaction_connection.disconnect()
3015 self._transaction_node.release(self._transaction_connection)
3016 self._transaction_connection = None
3017
3018 # clean up the other instance attributes
3019 self._transaction_node = None
3020 self._watching = False
3021 self._explicit_transaction = False
3022 self._pipeline_slots = set()
3023 self._executing = False
3024
3025 def multi(self):
3026 if self._explicit_transaction:
3027 raise RedisError("Cannot issue nested calls to MULTI")
3028 if self._command_queue:
3029 raise RedisError(
3030 "Commands without an initial WATCH have already been issued"
3031 )
3032 self._explicit_transaction = True
3033
3034 async def watch(self, *names):
3035 if self._explicit_transaction:
3036 raise RedisError("Cannot issue a WATCH after a MULTI")
3037
3038 return await self.execute_command("WATCH", *names)
3039
3040 async def unwatch(self):
3041 if self._watching:
3042 return await self.execute_command("UNWATCH")
3043
3044 return True
3045
3046 async def discard(self):
3047 await self.reset()
3048
3049 async def unlink(self, *names):
3050 return self.execute_command("UNLINK", *names)
3051
3052
3053class ClusterPubSub(PubSub):
3054 """
3055 Async cluster implementation for pub/sub.
3056
3057 IMPORTANT: before using ClusterPubSub, read about the known limitations
3058 with pubsub in Cluster mode and learn how to workaround them:
3059 https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations
3060 """
3061
3062 def __init__(
3063 self,
3064 redis_cluster: "RedisCluster",
3065 node: Optional["ClusterNode"] = None,
3066 host: Optional[str] = None,
3067 port: Optional[int] = None,
3068 push_handler_func: Optional[Callable] = None,
3069 event_dispatcher: Optional[EventDispatcher] = None,
3070 **kwargs: Any,
3071 ) -> None:
3072 """
3073 When a pubsub instance is created without specifying a node, a single
3074 node will be transparently chosen for the pubsub connection on the
3075 first command execution. The node will be determined by:
3076 1. Hashing the channel name in the request to find its keyslot
3077 2. Selecting a node that handles the keyslot: If read_from_replicas is
3078 set to true or load_balancing_strategy is set, a replica can be selected.
3079
3080 :param redis_cluster: RedisCluster instance
3081 :param node: ClusterNode to connect to
3082 :param host: Host of the node to connect to
3083 :param port: Port of the node to connect to
3084 :param push_handler_func: Optional push handler function
3085 :param event_dispatcher: Optional event dispatcher
3086 :param kwargs: Additional keyword arguments
3087 """
3088 self.node = None
3089 self.set_pubsub_node(redis_cluster, node, host, port)
3090
3091 # Create connection pool if node is specified
3092 if self.node is not None:
3093 connection_pool = ConnectionPool(
3094 connection_class=self.node.connection_class,
3095 **self.node.connection_kwargs,
3096 )
3097 else:
3098 connection_pool = None
3099
3100 self.cluster = redis_cluster
3101 self.node_pubsub_mapping: Dict[str, PubSub] = {}
3102 self._pubsubs_generator = self._pubsubs_generator()
3103 if event_dispatcher is None:
3104 self._event_dispatcher = EventDispatcher()
3105 else:
3106 self._event_dispatcher = event_dispatcher
3107 super().__init__(
3108 connection_pool=connection_pool,
3109 encoder=redis_cluster.encoder,
3110 push_handler_func=push_handler_func,
3111 event_dispatcher=self._event_dispatcher,
3112 **kwargs,
3113 )
3114
3115 def set_pubsub_node(
3116 self,
3117 cluster: "RedisCluster",
3118 node: Optional["ClusterNode"] = None,
3119 host: Optional[str] = None,
3120 port: Optional[int] = None,
3121 ) -> None:
3122 """
3123 The pubsub node will be set according to the passed node, host and port
3124 When none of the node, host, or port are specified - the node is set
3125 to None and will be determined by the keyslot of the channel in the
3126 first command to be executed.
3127 RedisClusterException will be thrown if the passed node does not exist
3128 in the cluster.
3129 If host is passed without port, or vice versa, a DataError will be
3130 thrown.
3131 """
3132 if node is not None:
3133 # node is passed by the user
3134 self._raise_on_invalid_node(cluster, node, node.host, node.port)
3135 pubsub_node = node
3136 elif host is not None and port is not None:
3137 # host and port passed by the user
3138 node = cluster.get_node(host=host, port=port)
3139 self._raise_on_invalid_node(cluster, node, host, port)
3140 pubsub_node = node
3141 elif host is not None or port is not None:
3142 # only one of host and port is specified
3143 raise DataError("Specify both host and port")
3144 else:
3145 # nothing specified by the user
3146 pubsub_node = None
3147 self.node = pubsub_node
3148
3149 def get_pubsub_node(self) -> Optional["ClusterNode"]:
3150 """
3151 Get the node that is being used as the pubsub connection.
3152
3153 :return: The ClusterNode being used for pubsub, or None if not yet determined
3154 """
3155 return self.node
3156
3157 def _get_node_pubsub(self, node: "ClusterNode") -> PubSub:
3158 """Get or create a PubSub instance for the given node."""
3159 try:
3160 return self.node_pubsub_mapping[node.name]
3161 except KeyError:
3162 # Create a minimal connection pool for this node
3163 connection_pool = ConnectionPool(
3164 connection_class=node.connection_class, **node.connection_kwargs
3165 )
3166
3167 pubsub = PubSub(
3168 connection_pool=connection_pool,
3169 encoder=self.cluster.encoder,
3170 push_handler_func=self.push_handler_func,
3171 event_dispatcher=self._event_dispatcher,
3172 )
3173 self.node_pubsub_mapping[node.name] = pubsub
3174 return pubsub
3175
3176 async def _sharded_message_generator(
3177 self, timeout: float = 0.0
3178 ) -> Optional[Dict[str, Any]]:
3179 """Generate messages from shard channels across all nodes."""
3180 for _ in range(len(self.node_pubsub_mapping)):
3181 pubsub = next(self._pubsubs_generator)
3182 # Don't pass ignore_subscribe_messages here - let get_sharded_message
3183 # handle the filtering after processing subscription state changes
3184 message = await pubsub.get_message(
3185 ignore_subscribe_messages=False, timeout=timeout
3186 )
3187 if message is not None:
3188 return message
3189 return None
3190
3191 def _pubsubs_generator(self) -> Generator[PubSub, None, None]:
3192 """Generator that yields PubSub instances in round-robin fashion."""
3193 while True:
3194 current_nodes = list(self.node_pubsub_mapping.values())
3195 if not current_nodes:
3196 return # Avoid infinite loop when no subscriptions exist
3197 yield from current_nodes
3198
3199 async def get_sharded_message(
3200 self,
3201 ignore_subscribe_messages: bool = False,
3202 timeout: float = 0.0,
3203 target_node: Optional["ClusterNode"] = None,
3204 ) -> Optional[Dict[str, Any]]:
3205 """
3206 Get a message from shard channels.
3207
3208 :param ignore_subscribe_messages: Whether to ignore subscribe messages
3209 :param timeout: Timeout for message retrieval
3210 :param target_node: Specific node to get message from
3211 :return: Message dictionary or None
3212 """
3213 if target_node:
3214 pubsub = self.node_pubsub_mapping.get(target_node.name)
3215 if pubsub:
3216 # Don't pass ignore_subscribe_messages here - let get_sharded_message
3217 # handle the filtering after processing subscription state changes
3218 message = await pubsub.get_message(
3219 ignore_subscribe_messages=False, timeout=timeout
3220 )
3221 else:
3222 message = None
3223 else:
3224 message = await self._sharded_message_generator(timeout=timeout)
3225
3226 if message is None:
3227 return None
3228 elif str_if_bytes(message["type"]) == "sunsubscribe":
3229 if message["channel"] in self.pending_unsubscribe_shard_channels:
3230 self.pending_unsubscribe_shard_channels.remove(message["channel"])
3231 self.shard_channels.pop(message["channel"], None)
3232 node = self.cluster.get_node_from_key(message["channel"])
3233 if node and node.name in self.node_pubsub_mapping:
3234 pubsub = self.node_pubsub_mapping[node.name]
3235 if not pubsub.subscribed:
3236 self.node_pubsub_mapping.pop(node.name)
3237
3238 # Only suppress subscribe/unsubscribe messages, not data messages (smessage)
3239 if str_if_bytes(message["type"]) in ("ssubscribe", "sunsubscribe"):
3240 if self.ignore_subscribe_messages or ignore_subscribe_messages:
3241 return None
3242 return message
3243
3244 async def ssubscribe(self, *args: Any, **kwargs: Any) -> None:
3245 """
3246 Subscribe to shard channels.
3247
3248 :param args: Channel names
3249 :param kwargs: Channel names with handlers
3250 """
3251 if args:
3252 args = list_or_args(args[0], args[1:])
3253 s_channels = dict.fromkeys(args)
3254 s_channels.update(kwargs)
3255
3256 for s_channel, handler in s_channels.items():
3257 node = self.cluster.get_node_from_key(s_channel)
3258 if node:
3259 pubsub = self._get_node_pubsub(node)
3260 if handler:
3261 await pubsub.ssubscribe(**{s_channel: handler})
3262 else:
3263 await pubsub.ssubscribe(s_channel)
3264 self.shard_channels.update(pubsub.shard_channels)
3265 self.pending_unsubscribe_shard_channels.difference_update(
3266 self._normalize_keys({s_channel: None})
3267 )
3268
3269 async def sunsubscribe(self, *args: Any) -> None:
3270 """
3271 Unsubscribe from shard channels.
3272
3273 :param args: Channel names to unsubscribe from. If empty, unsubscribe from all.
3274 """
3275 if args:
3276 args = list_or_args(args[0], args[1:])
3277 else:
3278 args = list(self.shard_channels.keys())
3279
3280 for s_channel in args:
3281 node = self.cluster.get_node_from_key(s_channel)
3282 if node and node.name in self.node_pubsub_mapping:
3283 pubsub = self.node_pubsub_mapping[node.name]
3284 await pubsub.sunsubscribe(s_channel)
3285 self.pending_unsubscribe_shard_channels.update(
3286 pubsub.pending_unsubscribe_shard_channels
3287 )
3288
3289 def get_redis_connection(self) -> Optional["AbstractConnection"]:
3290 """
3291 Get the Redis connection of the pubsub connected node.
3292
3293 Returns the pubsub's dedicated connection (acquired from its own
3294 connection pool), not from the ClusterNode's connection pool.
3295 This avoids the connection pool resource leak that would occur
3296 if we called node.acquire_connection() without releasing.
3297 """
3298 # Return the pubsub's own dedicated connection, which is acquired
3299 # from self.connection_pool when executing pubsub commands.
3300 # This is safe because it's the connection dedicated to this pubsub
3301 # instance, not a shared pool connection from the ClusterNode.
3302 return self.connection
3303
3304 async def aclose(self) -> None:
3305 """
3306 Disconnect the pubsub connection.
3307 """
3308 # Close all shard pubsub instances first
3309 for pubsub in self.node_pubsub_mapping.values():
3310 await pubsub.aclose()
3311 # Let parent handle self.connection disconnect under the lock
3312 # (includes disconnect, release to pool, and clearing self.connection)
3313 await super().aclose()
3314
3315 def _raise_on_invalid_node(
3316 self,
3317 redis_cluster: "RedisCluster",
3318 node: Optional["ClusterNode"],
3319 host: Optional[str],
3320 port: Optional[int],
3321 ) -> None:
3322 """
3323 Raise a RedisClusterException if the node is None or doesn't exist in
3324 the cluster.
3325 """
3326 if node is None or redis_cluster.get_node(node_name=node.name) is None:
3327 raise RedisClusterException(
3328 f"Node {host}:{port} doesn't exist in the cluster"
3329 )
3330
3331 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
3332 """
3333 Execute a command on the appropriate cluster node.
3334
3335 Taken code from redis-py and tweaked to make it work within a cluster.
3336 """
3337 # NOTE: don't parse the response in this function -- it could pull a
3338 # legitimate message off the stack if the connection is already
3339 # subscribed to one or more channels
3340
3341 # For shard commands, route to appropriate node
3342 command = args[0].upper() if args else ""
3343 if command in ("SSUBSCRIBE", "SUNSUBSCRIBE", "SPUBLISH"):
3344 if len(args) > 1:
3345 channel = args[1]
3346 node = self.cluster.get_node_from_key(channel)
3347 if node:
3348 pubsub = self._get_node_pubsub(node)
3349 return await pubsub.execute_command(*args, **kwargs)
3350
3351 # For other commands, use the set node or lazily discover one
3352 if self.connection is None:
3353 if self.connection_pool is None:
3354 if len(args) > 1:
3355 # Hash the first channel and get one of the nodes holding
3356 # this slot
3357 channel = args[1]
3358 slot = self.cluster.keyslot(channel)
3359 node = self.cluster.nodes_manager.get_node_from_slot(
3360 slot,
3361 self.cluster.read_from_replicas,
3362 self.cluster.load_balancing_strategy,
3363 )
3364 else:
3365 # Get a random node
3366 node = self.cluster.get_random_node()
3367 self.node = node
3368 self.connection_pool = ConnectionPool(
3369 connection_class=node.connection_class,
3370 **node.connection_kwargs,
3371 )
3372
3373 # Now we have a connection_pool, use parent's execute_command
3374 return await super().execute_command(*args, **kwargs)