1import asyncio
2import collections
3import logging
4import random
5import socket
6import threading
7import time
8import warnings
9import weakref
10from abc import ABC, abstractmethod
11from collections import defaultdict
12from copy import copy
13from itertools import chain
14from types import MethodType
15from typing import (
16 TYPE_CHECKING,
17 Any,
18 Callable,
19 Coroutine,
20 Deque,
21 Dict,
22 Generator,
23 List,
24 Literal,
25 Mapping,
26 Optional,
27 Set,
28 Tuple,
29 Type,
30 TypeVar,
31 Union,
32)
33
34if TYPE_CHECKING:
35 from redis.asyncio.keyspace_notifications import (
36 AsyncClusterKeyspaceNotifications,
37 )
38
39from redis._defaults import (
40 DEFAULT_RETRY_BASE,
41 DEFAULT_RETRY_CAP,
42 DEFAULT_RETRY_COUNT,
43 DEFAULT_SOCKET_CONNECT_TIMEOUT,
44 DEFAULT_SOCKET_READ_SIZE,
45 DEFAULT_SOCKET_TIMEOUT,
46)
47from redis._parsers import AsyncCommandsParser, Encoder
48from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
49from redis._parsers.helpers import get_response_callbacks
50from redis.asyncio.client import PubSub, ResponseCallbackT
51from redis.asyncio.connection import (
52 AbstractConnection,
53 Connection,
54 ConnectionPoolInterface,
55 SSLConnection,
56 parse_url,
57)
58from redis.asyncio.lock import Lock
59from redis.asyncio.observability.recorder import (
60 record_error_count,
61 record_operation_duration,
62)
63from redis.asyncio.retry import Retry
64from redis.auth.token import TokenInterface
65from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
66from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
67from redis.cluster import (
68 PIPELINE_BLOCKED_COMMANDS,
69 PRIMARY,
70 REPLICA,
71 SLOT_ID,
72 AbstractRedisCluster,
73 LoadBalancer,
74 LoadBalancingStrategy,
75 block_pipeline_command,
76 get_node_name,
77 parse_cluster_shards,
78 parse_cluster_shards_unified,
79 parse_cluster_shards_with_str_keys,
80 parse_cluster_slots,
81)
82from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
83from redis.commands.helpers import list_or_args, parse_pubsub_subscriptions
84from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
85from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
86from redis.credentials import CredentialProvider
87from redis.driver_info import DriverInfo, resolve_driver_info
88from redis.event import (
89 AfterAsyncClusterInstantiationEvent,
90 AsyncAfterSlotsCacheRefreshEvent,
91 AsyncEventListenerInterface,
92 EventDispatcher,
93)
94from redis.exceptions import (
95 AskError,
96 BusyLoadingError,
97 ClusterDownError,
98 ClusterError,
99 ConnectionError,
100 CrossSlotTransactionError,
101 DataError,
102 ExecAbortError,
103 InvalidPipelineStack,
104 MaxConnectionsError,
105 MovedError,
106 RedisClusterException,
107 RedisError,
108 ResponseError,
109 SlotNotCoveredError,
110 TimeoutError,
111 TryAgainError,
112 WatchError,
113)
114from redis.typing import (
115 AnyKeyT,
116 ChannelT,
117 EncodableT,
118 KeyT,
119 PubSubHandler,
120 Subscription,
121)
122from redis.utils import (
123 SENTINEL,
124 SSL_AVAILABLE,
125 deprecated_args,
126 deprecated_function,
127 safe_str,
128 str_if_bytes,
129 truncate_text,
130)
131
132if SSL_AVAILABLE:
133 from ssl import TLSVersion, VerifyFlags, VerifyMode
134else:
135 TLSVersion = None
136 VerifyMode = None
137 VerifyFlags = None
138
139logger = logging.getLogger(__name__)
140
141TargetNodesT = TypeVar(
142 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
143)
144
145
146class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
147 """
148 Create a new RedisCluster client.
149
150 Pass one of parameters:
151
152 - `host` & `port`
153 - `startup_nodes`
154
155 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
156 | Use ``await`` :meth:`close` to disconnect connections & close client.
157
158 Many commands support the target_nodes kwarg. It can be one of the
159 :attr:`NODE_FLAGS`:
160
161 - :attr:`PRIMARIES`
162 - :attr:`REPLICAS`
163 - :attr:`ALL_NODES`
164 - :attr:`RANDOM`
165 - :attr:`DEFAULT_NODE`
166
167 Note: This client is not thread/process/fork safe.
168
169 :param host:
170 | Can be used to point to a startup node
171 :param port:
172 | Port used if **host** is provided
173 :param startup_nodes:
174 | :class:`~.ClusterNode` to used as a startup node
175 :param require_full_coverage:
176 | When set to ``False``: the client will not require a full coverage of
177 the slots. However, if not all slots are covered, and at least one node
178 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
179 a :class:`~.ClusterDownError` for some key-based commands.
180 | When set to ``True``: all slots must be covered to construct the cluster
181 client. If not all slots are covered, :class:`~.RedisClusterException` will be
182 thrown.
183 | See:
184 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
185 :param read_from_replicas:
186 | @deprecated - please use load_balancing_strategy instead
187 | Enable read from replicas in READONLY mode.
188 When set to true, read commands will be assigned between the primary and
189 its replications in a Round-Robin manner.
190 The data read from replicas is eventually consistent with the data in primary nodes.
191 :param load_balancing_strategy:
192 | Enable read from replicas in READONLY mode and defines the load balancing
193 strategy that will be used for cluster node selection.
194 The data read from replicas is eventually consistent with the data in primary nodes.
195 :param dynamic_startup_nodes:
196 | Set the RedisCluster's startup nodes to all the discovered nodes.
197 If true (default value), the cluster's discovered nodes will be used to
198 determine the cluster nodes-slots mapping in the next topology refresh.
199 It will remove the initial passed startup nodes if their endpoints aren't
200 listed in the CLUSTER SLOTS output.
201 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
202 specific IP addresses, it is best to set it to false.
203 :param reinitialize_steps:
204 | Specifies the number of MOVED errors that need to occur before reinitializing
205 the whole cluster topology. If a MOVED error occurs and the cluster does not
206 need to be reinitialized on this current error handling, only the MOVED slot
207 will be patched with the redirected node.
208 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
209 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
210 0.
211 :param cluster_error_retry_attempts:
212 | @deprecated - Please configure the 'retry' object instead
213 In case 'retry' object is set - this argument is ignored!
214
215 Number of times to retry before raising an error when :class:`~.TimeoutError`,
216 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
217 or :class:`~.ClusterDownError` are encountered
218 :param retry:
219 | A retry object that defines the retry strategy and the number of
220 retries for the cluster client.
221 In current implementation for the cluster client (starting form redis-py version 6.0.0)
222 the retry object is not yet fully utilized, instead it is used just to determine
223 the number of retries for the cluster client.
224 In the future releases the retry object will be used to handle the cluster client retries!
225 :param max_connections:
226 | Maximum number of connections per node. If there are no free connections & the
227 maximum number of connections are already created, a
228 :class:`~.MaxConnectionsError` is raised.
229 :param socket_keepalive:
230 | If ``True``, TCP keepalive is enabled for TCP socket connections.
231 :param socket_keepalive_options:
232 | Mapping of TCP keepalive socket option constants to values, for
233 example ``{socket.TCP_KEEPIDLE: 30}``. If left unspecified, redis-py
234 uses TCP keepalive defaults when ``socket_keepalive`` is enabled:
235 idle 30 seconds, interval 5 seconds, and 3 probes.
236 Platform-specific options that are not available are skipped.
237 Pass ``None`` or ``{}`` to avoid setting additional TCP keepalive
238 options.
239 :param address_remap:
240 | An optional callable which, when provided with an internal network
241 address of a node, e.g. a `(host, port)` tuple, will return the address
242 where the node is reachable. This can be used to map the addresses at
243 which the nodes _think_ they are, to addresses at which a client may
244 reach them, such as when they sit behind a proxy.
245
246 | Rest of the arguments will be passed to the
247 :class:`~redis.asyncio.connection.Connection` instances when created
248
249 :raises RedisClusterException:
250 if any arguments are invalid or unknown. Eg:
251
252 - `db` != 0 or None
253 - `path` argument for unix socket connection
254 - none of the `host`/`port` & `startup_nodes` were provided
255
256 """
257
258 @classmethod
259 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
260 """
261 Return a Redis client object configured from the given URL.
262
263 For example::
264
265 redis://[[username]:[password]]@localhost:6379/0
266 rediss://[[username]:[password]]@localhost:6379/0
267
268 Three URL schemes are supported:
269
270 - `redis://` creates a TCP socket connection. See more at:
271 <https://www.iana.org/assignments/uri-schemes/prov/redis>
272 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
273 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
274
275 The username, password, hostname, path and all querystring values are passed
276 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
277 with their corresponding characters.
278
279 All querystring options are cast to their appropriate Python types. Boolean
280 arguments can be specified with string values "True"/"False" or "Yes"/"No".
281 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
282 parsed, the querystring arguments and keyword arguments are passed to
283 :class:`~redis.asyncio.connection.Connection` when created.
284 In the case of conflicting arguments, querystring arguments are used.
285 """
286 kwargs.update(parse_url(url))
287 if kwargs.pop("connection_class", None) is SSLConnection:
288 kwargs["ssl"] = True
289 return cls(**kwargs)
290
291 # Type discrimination marker for @overload self-type pattern
292 _is_async_client: Literal[True] = True
293
294 __slots__ = (
295 "_initialize",
296 "_lock",
297 "retry",
298 "command_flags",
299 "commands_parser",
300 "connection_kwargs",
301 "encoder",
302 "node_flags",
303 "nodes_manager",
304 "read_from_replicas",
305 "reinitialize_counter",
306 "reinitialize_steps",
307 "response_callbacks",
308 "result_callbacks",
309 )
310
311 @deprecated_args(
312 args_to_warn=["read_from_replicas"],
313 reason="Please configure the 'load_balancing_strategy' instead",
314 version="5.3.0",
315 )
316 @deprecated_args(
317 args_to_warn=[
318 "cluster_error_retry_attempts",
319 ],
320 reason="Please configure the 'retry' object instead",
321 version="6.0.0",
322 )
323 @deprecated_args(
324 args_to_warn=["lib_name", "lib_version"],
325 reason="Use 'driver_info' parameter instead. "
326 "lib_name and lib_version will be removed in a future version.",
327 )
328 def __init__(
329 self,
330 host: str | None = None,
331 port: str | int = 6379,
332 # Cluster related kwargs
333 startup_nodes: List["ClusterNode"] | None = None,
334 require_full_coverage: bool = True,
335 read_from_replicas: bool = False,
336 load_balancing_strategy: LoadBalancingStrategy | None = None,
337 dynamic_startup_nodes: bool = True,
338 reinitialize_steps: int = 5,
339 cluster_error_retry_attempts: int = DEFAULT_RETRY_COUNT,
340 max_connections: int = 100,
341 retry: Retry | None = None,
342 retry_on_error: List[Type[Exception]] | None = None,
343 # Client related kwargs
344 db: str | int = 0,
345 path: str | None = None,
346 credential_provider: CredentialProvider | None = None,
347 username: str | None = None,
348 password: str | None = None,
349 client_name: str | None = None,
350 lib_name: str | object | None = SENTINEL,
351 lib_version: str | object | None = SENTINEL,
352 driver_info: DriverInfo | object | None = SENTINEL,
353 # Encoding related kwargs
354 encoding: str = "utf-8",
355 encoding_errors: str = "strict",
356 decode_responses: bool = False,
357 # Connection related kwargs
358 health_check_interval: float = 0,
359 socket_timeout: float | None = DEFAULT_SOCKET_TIMEOUT,
360 socket_connect_timeout: float | None = DEFAULT_SOCKET_CONNECT_TIMEOUT,
361 socket_read_size: int = DEFAULT_SOCKET_READ_SIZE,
362 socket_keepalive: bool = True,
363 socket_keepalive_options: Mapping[int, int | bytes] | object | None = SENTINEL,
364 # SSL related kwargs
365 ssl: bool = False,
366 ssl_ca_certs: str | None = None,
367 ssl_ca_data: str | None = None,
368 ssl_cert_reqs: "str | VerifyMode" = "required",
369 ssl_include_verify_flags: List["VerifyFlags"] | None = None,
370 ssl_exclude_verify_flags: List["VerifyFlags"] | None = None,
371 ssl_certfile: str | None = None,
372 ssl_check_hostname: bool = True,
373 ssl_keyfile: str | None = None,
374 ssl_min_version: "TLSVersion | None" = None,
375 ssl_ciphers: str | None = None,
376 protocol: int | None = None,
377 legacy_responses: bool = True,
378 address_remap: Callable[[Tuple[str, int]], Tuple[str, int]] | None = None,
379 event_dispatcher: EventDispatcher | None = None,
380 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
381 ) -> None:
382 if db:
383 raise RedisClusterException(
384 "Argument 'db' must be 0 or None in cluster mode"
385 )
386
387 if path:
388 raise RedisClusterException(
389 "Unix domain socket is not supported in cluster mode"
390 )
391
392 if (not host or not port) and not startup_nodes:
393 raise RedisClusterException(
394 "RedisCluster requires at least one node to discover the cluster.\n"
395 "Please provide one of the following or use RedisCluster.from_url:\n"
396 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
397 " - startup_nodes: RedisCluster(startup_nodes=["
398 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
399 )
400
401 computed_driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
402
403 kwargs: Dict[str, Any] = {
404 "max_connections": max_connections,
405 "connection_class": Connection,
406 # Client related kwargs
407 "credential_provider": credential_provider,
408 "username": username,
409 "password": password,
410 "client_name": client_name,
411 "driver_info": computed_driver_info,
412 # Encoding related kwargs
413 "encoding": encoding,
414 "encoding_errors": encoding_errors,
415 "decode_responses": decode_responses,
416 # Connection related kwargs
417 "health_check_interval": health_check_interval,
418 "socket_connect_timeout": socket_connect_timeout,
419 "socket_keepalive": socket_keepalive,
420 "socket_keepalive_options": socket_keepalive_options,
421 "socket_read_size": socket_read_size,
422 "socket_timeout": socket_timeout,
423 "protocol": protocol,
424 "legacy_responses": legacy_responses,
425 }
426
427 if ssl:
428 # SSL related kwargs
429 kwargs.update(
430 {
431 "connection_class": SSLConnection,
432 "ssl_ca_certs": ssl_ca_certs,
433 "ssl_ca_data": ssl_ca_data,
434 "ssl_cert_reqs": ssl_cert_reqs,
435 "ssl_include_verify_flags": ssl_include_verify_flags,
436 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
437 "ssl_certfile": ssl_certfile,
438 "ssl_check_hostname": ssl_check_hostname,
439 "ssl_keyfile": ssl_keyfile,
440 "ssl_min_version": ssl_min_version,
441 "ssl_ciphers": ssl_ciphers,
442 }
443 )
444
445 if read_from_replicas or load_balancing_strategy:
446 # Call our on_connect function to configure READONLY mode
447 kwargs["redis_connect_func"] = self.on_connect
448
449 if retry:
450 self.retry = retry
451 else:
452 self.retry = Retry(
453 backoff=ExponentialWithJitterBackoff(
454 base=DEFAULT_RETRY_BASE, cap=DEFAULT_RETRY_CAP
455 ),
456 retries=cluster_error_retry_attempts,
457 )
458 if retry_on_error:
459 self.retry.update_supported_errors(retry_on_error)
460
461 kwargs["response_callbacks"] = get_response_callbacks(
462 user_protocol=kwargs.get("protocol"),
463 legacy_responses=kwargs.get("legacy_responses", True),
464 )
465 if not kwargs.get("legacy_responses", True):
466 kwargs["response_callbacks"]["CLUSTER SHARDS"] = (
467 parse_cluster_shards_unified
468 )
469 elif kwargs.get("protocol") is None:
470 kwargs["response_callbacks"]["CLUSTER SHARDS"] = (
471 parse_cluster_shards_with_str_keys
472 )
473 else:
474 kwargs["response_callbacks"]["CLUSTER SHARDS"] = parse_cluster_shards
475 self.connection_kwargs = kwargs
476
477 if startup_nodes:
478 passed_nodes = []
479 for node in startup_nodes:
480 passed_nodes.append(
481 ClusterNode(node.host, node.port, **self.connection_kwargs)
482 )
483 startup_nodes = passed_nodes
484 else:
485 startup_nodes = []
486 if host and port:
487 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
488
489 if event_dispatcher is None:
490 self._event_dispatcher = EventDispatcher()
491 else:
492 self._event_dispatcher = event_dispatcher
493
494 self.startup_nodes = startup_nodes
495 self.nodes_manager = NodesManager(
496 startup_nodes,
497 require_full_coverage,
498 kwargs,
499 dynamic_startup_nodes=dynamic_startup_nodes,
500 address_remap=address_remap,
501 event_dispatcher=self._event_dispatcher,
502 )
503 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
504 self.read_from_replicas = read_from_replicas
505 self.load_balancing_strategy = load_balancing_strategy
506 self.reinitialize_steps = reinitialize_steps
507 self.reinitialize_counter = 0
508
509 # For backward compatibility, mapping from existing policies to new one
510 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
511 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
512 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
513 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
514 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
515 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
516 SLOT_ID: RequestPolicy.DEFAULT_KEYED,
517 }
518
519 self._policies_callback_mapping: dict[
520 Union[RequestPolicy, ResponsePolicy], Callable
521 ] = {
522 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
523 self.get_random_primary_or_all_nodes(command_name)
524 ],
525 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
526 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
527 RequestPolicy.ALL_SHARDS: self.get_primaries,
528 RequestPolicy.ALL_NODES: self.get_nodes,
529 RequestPolicy.ALL_REPLICAS: self.get_replicas,
530 RequestPolicy.SPECIAL: self.get_special_nodes,
531 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
532 ResponsePolicy.DEFAULT_KEYED: lambda res: res,
533 }
534
535 self._policy_resolver = policy_resolver
536 self.commands_parser = AsyncCommandsParser()
537 self._aggregate_nodes = None
538 self.node_flags = self.__class__.NODE_FLAGS.copy()
539 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
540 self.response_callbacks = kwargs["response_callbacks"]
541 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
542 self.result_callbacks["CLUSTER SLOTS"] = (
543 lambda cmd, res, **kwargs: parse_cluster_slots(
544 list(res.values())[0], **kwargs
545 )
546 )
547
548 self._initialize = True
549 self._lock: Optional[asyncio.Lock] = None
550
551 # When used as an async context manager, we need to increment and decrement
552 # a usage counter so that we can close the connection pool when no one is
553 # using the client.
554 self._usage_counter = 0
555 self._usage_lock = asyncio.Lock()
556
557 async def initialize(
558 self,
559 additional_startup_nodes_info: Optional[List[Tuple[str, int]]] = None,
560 last_failed_node_name: Optional[str] = None,
561 ) -> "RedisCluster":
562 """Get all nodes from startup nodes & creates connections if not initialized."""
563 if self._initialize:
564 if not self._lock:
565 self._lock = asyncio.Lock()
566 async with self._lock:
567 if self._initialize:
568 try:
569 await self.nodes_manager.initialize(
570 additional_startup_nodes_info=additional_startup_nodes_info,
571 last_failed_node_name=last_failed_node_name,
572 )
573 await self.commands_parser.initialize(
574 self.nodes_manager.default_node
575 )
576 self._initialize = False
577 except BaseException:
578 await self.nodes_manager.aclose()
579 await self.nodes_manager.aclose("startup_nodes")
580 raise
581 return self
582
583 async def aclose(self) -> None:
584 """Close all connections & client if initialized."""
585 if not self._initialize:
586 if not self._lock:
587 self._lock = asyncio.Lock()
588 async with self._lock:
589 if not self._initialize:
590 self._initialize = True
591 await self.nodes_manager.aclose()
592 await self.nodes_manager.aclose("startup_nodes")
593
594 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
595 async def close(self) -> None:
596 """alias for aclose() for backwards compatibility"""
597 await self.aclose()
598
599 async def __aenter__(self) -> "RedisCluster":
600 """
601 Async context manager entry. Increments a usage counter so that the
602 connection pool is only closed (via aclose()) when no context is using
603 the client.
604 """
605 await self._increment_usage()
606 try:
607 # Initialize the client (i.e. establish connection, etc.)
608 return await self.initialize()
609 except Exception:
610 # If initialization fails, decrement the counter to keep it in sync
611 await self._decrement_usage()
612 raise
613
614 async def _increment_usage(self) -> int:
615 """
616 Helper coroutine to increment the usage counter while holding the lock.
617 Returns the new value of the usage counter.
618 """
619 async with self._usage_lock:
620 self._usage_counter += 1
621 return self._usage_counter
622
623 async def _decrement_usage(self) -> int:
624 """
625 Helper coroutine to decrement the usage counter while holding the lock.
626 Returns the new value of the usage counter.
627 """
628 async with self._usage_lock:
629 self._usage_counter -= 1
630 return self._usage_counter
631
632 async def __aexit__(self, exc_type, exc_value, traceback):
633 """
634 Async context manager exit. Decrements a usage counter. If this is the
635 last exit (counter becomes zero), the client closes its connection pool.
636 """
637 current_usage = await asyncio.shield(self._decrement_usage())
638 if current_usage == 0:
639 # This was the last active context, so disconnect the pool.
640 await asyncio.shield(self.aclose())
641
642 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
643 return self.initialize().__await__()
644
645 _DEL_MESSAGE = "Unclosed RedisCluster client"
646
647 def __del__(
648 self,
649 _warn: Any = warnings.warn,
650 _grl: Any = asyncio.get_running_loop,
651 ) -> None:
652 if hasattr(self, "_initialize") and not self._initialize:
653 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
654 try:
655 context = {"client": self, "message": self._DEL_MESSAGE}
656 _grl().call_exception_handler(context)
657 except RuntimeError:
658 pass
659
660 async def on_connect(self, connection: Connection) -> None:
661 await connection.on_connect()
662
663 # Sending READONLY command to server to configure connection as
664 # readonly. Since each cluster node may change its server type due
665 # to a failover, we should establish a READONLY connection
666 # regardless of the server type. If this is a primary connection,
667 # READONLY would not affect executing write commands.
668 await connection.send_command("READONLY")
669 if str_if_bytes(await connection.read_response()) != "OK":
670 raise ConnectionError("READONLY command failed")
671
672 def get_nodes(self) -> List["ClusterNode"]:
673 """Get all nodes of the cluster."""
674 return list(self.nodes_manager.nodes_cache.values())
675
676 def get_primaries(self) -> List["ClusterNode"]:
677 """Get the primary nodes of the cluster."""
678 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
679
680 def get_replicas(self) -> List["ClusterNode"]:
681 """Get the replica nodes of the cluster."""
682 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
683
684 def get_random_node(self) -> "ClusterNode":
685 """Get a random node of the cluster."""
686 return random.choice(list(self.nodes_manager.nodes_cache.values()))
687
688 def get_default_node(self) -> "ClusterNode":
689 """Get the default node of the client."""
690 return self.nodes_manager.default_node
691
692 def set_default_node(self, node: "ClusterNode") -> None:
693 """
694 Set the default node of the client.
695
696 :raises DataError: if None is passed or node does not exist in cluster.
697 """
698 if not node or not self.get_node(node_name=node.name):
699 raise DataError("The requested node does not exist in the cluster.")
700
701 self.nodes_manager.default_node = node
702
703 def get_node(
704 self,
705 host: Optional[str] = None,
706 port: Optional[int] = None,
707 node_name: Optional[str] = None,
708 ) -> Optional["ClusterNode"]:
709 """Get node by (host, port) or node_name."""
710 return self.nodes_manager.get_node(host, port, node_name)
711
712 def get_node_from_key(
713 self, key: str, replica: bool = False
714 ) -> Optional["ClusterNode"]:
715 """
716 Get the cluster node corresponding to the provided key.
717
718 :param key:
719 :param replica:
720 | Indicates if a replica should be returned
721 |
722 None will returned if no replica holds this key
723
724 :raises SlotNotCoveredError: if the key is not covered by any slot.
725 """
726 slot = self.keyslot(key)
727 slot_cache = self.nodes_manager.slots_cache.get(slot)
728 if not slot_cache:
729 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
730
731 if replica:
732 if len(self.nodes_manager.slots_cache[slot]) < 2:
733 return None
734 node_idx = 1
735 else:
736 node_idx = 0
737
738 return slot_cache[node_idx]
739
740 def get_random_primary_or_all_nodes(self, command_name):
741 """
742 Returns random primary or all nodes depends on READONLY mode.
743 """
744 if self.read_from_replicas and command_name in READ_COMMANDS:
745 return self.get_random_node()
746
747 return self.get_random_primary_node()
748
749 def get_random_primary_node(self) -> "ClusterNode":
750 """
751 Returns a random primary node
752 """
753 return random.choice(self.get_primaries())
754
755 async def get_nodes_from_slot(self, command: str, *args):
756 """
757 Returns a list of nodes that hold the specified keys' slots.
758 """
759 # get the node that holds the key's slot
760 return [
761 self.nodes_manager.get_node_from_slot(
762 await self._determine_slot(command, *args),
763 self.read_from_replicas and command in READ_COMMANDS,
764 self.load_balancing_strategy if command in READ_COMMANDS else None,
765 )
766 ]
767
768 def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
769 """
770 Returns a list of nodes for commands with a special policy.
771 """
772 if not self._aggregate_nodes:
773 raise RedisClusterException(
774 "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
775 )
776
777 return self._aggregate_nodes
778
779 def keyslot(self, key: EncodableT) -> int:
780 """
781 Find the keyslot for a given key.
782
783 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
784 """
785 return key_slot(self.encoder.encode(key))
786
787 def get_encoder(self) -> Encoder:
788 """Get the encoder object of the client."""
789 return self.encoder
790
791 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
792 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
793 return self.connection_kwargs
794
795 def set_retry(self, retry: Retry) -> None:
796 self.retry = retry
797
798 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
799 """Set a custom response callback."""
800 self.response_callbacks[command] = callback
801
802 async def _determine_nodes(
803 self,
804 command: str,
805 *args: Any,
806 request_policy: RequestPolicy,
807 node_flag: Optional[str] = None,
808 ) -> List["ClusterNode"]:
809 # Determine which nodes should be executed the command on.
810 # Returns a list of target nodes.
811 if not node_flag:
812 # get the nodes group for this command if it was predefined
813 node_flag = self.command_flags.get(command)
814
815 if node_flag in self._command_flags_mapping:
816 request_policy = self._command_flags_mapping[node_flag]
817
818 policy_callback = self._policies_callback_mapping[request_policy]
819
820 if request_policy == RequestPolicy.DEFAULT_KEYED:
821 nodes = await policy_callback(command, *args)
822 elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
823 nodes = policy_callback(command)
824 else:
825 nodes = policy_callback()
826
827 if command.lower() == "ft.aggregate":
828 self._aggregate_nodes = nodes
829
830 return nodes
831
832 async def _determine_slot(self, command: str, *args: Any) -> int:
833 if self.command_flags.get(command) == SLOT_ID:
834 # The command contains the slot ID
835 return int(args[0])
836
837 # Get the keys in the command
838
839 # EVAL and EVALSHA are common enough that it's wasteful to go to the
840 # redis server to parse the keys. Besides, there is a bug in redis<7.0
841 # where `self._get_command_keys()` fails anyway. So, we special case
842 # EVAL/EVALSHA.
843 # - issue: https://github.com/redis/redis/issues/9493
844 # - fix: https://github.com/redis/redis/pull/9733
845 if command.upper() in ("EVAL", "EVALSHA"):
846 # command syntax: EVAL "script body" num_keys ...
847 if len(args) < 2:
848 raise RedisClusterException(
849 f"Invalid args in command: {command, *args}"
850 )
851 keys = args[2 : 2 + int(args[1])]
852 # if there are 0 keys, that means the script can be run on any node
853 # so we can just return a random slot
854 if not keys:
855 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
856 else:
857 keys = await self.commands_parser.get_keys(command, *args)
858 if not keys:
859 # FCALL can call a function with 0 keys, that means the function
860 # can be run on any node so we can just return a random slot
861 if command.upper() in ("FCALL", "FCALL_RO"):
862 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
863 raise RedisClusterException(
864 "No way to dispatch this command to Redis Cluster. "
865 "Missing key.\nYou can execute the command by specifying "
866 f"target nodes.\nCommand: {args}"
867 )
868
869 # single key command
870 if len(keys) == 1:
871 return self.keyslot(keys[0])
872
873 # multi-key command; we need to make sure all keys are mapped to
874 # the same slot
875 slots = {self.keyslot(key) for key in keys}
876 if len(slots) != 1:
877 raise RedisClusterException(
878 f"{command} - all keys must map to the same key slot"
879 )
880
881 return slots.pop()
882
883 def _is_node_flag(self, target_nodes: Any) -> bool:
884 return isinstance(target_nodes, str) and target_nodes in self.node_flags
885
886 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
887 if isinstance(target_nodes, list):
888 nodes = target_nodes
889 elif isinstance(target_nodes, ClusterNode):
890 # Supports passing a single ClusterNode as a variable
891 nodes = [target_nodes]
892 elif isinstance(target_nodes, dict):
893 # Supports dictionaries of the format {node_name: node}.
894 # It enables to execute commands with multi nodes as follows:
895 # rc.cluster_save_config(rc.get_primaries())
896 nodes = list(target_nodes.values())
897 else:
898 raise TypeError(
899 "target_nodes type can be one of the following: "
900 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
901 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
902 f"The passed type is {type(target_nodes)}"
903 )
904 return nodes
905
906 async def _record_error_metric(
907 self,
908 error: Exception,
909 connection: Union[Connection, "ClusterNode"],
910 is_internal: bool = True,
911 retry_attempts: Optional[int] = None,
912 ):
913 """
914 Records error count metric directly.
915 Accepts either a Connection or ClusterNode object.
916 """
917 await record_error_count(
918 server_address=connection.host,
919 server_port=connection.port,
920 network_peer_address=connection.host,
921 network_peer_port=connection.port,
922 error_type=error,
923 retry_attempts=retry_attempts if retry_attempts is not None else 0,
924 is_internal=is_internal,
925 )
926
927 async def _record_command_metric(
928 self,
929 command_name: str,
930 duration_seconds: float,
931 connection: Union[Connection, "ClusterNode"],
932 error: Optional[Exception] = None,
933 ):
934 """
935 Records operation duration metric directly.
936 Accepts either a Connection or ClusterNode object.
937 """
938 # Connection has db attribute, ClusterNode has connection_kwargs
939 if hasattr(connection, "db"):
940 db = connection.db
941 else:
942 db = connection.connection_kwargs.get("db", 0)
943 await record_operation_duration(
944 command_name=command_name,
945 duration_seconds=duration_seconds,
946 server_address=connection.host,
947 server_port=connection.port,
948 db_namespace=str(db) if db is not None else None,
949 error=error,
950 )
951
952 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
953 """
954 Execute a raw command on the appropriate cluster node or target_nodes.
955
956 It will retry the command as specified by the retries property of
957 the :attr:`retry` & then raise an exception.
958
959 :param args:
960 | Raw command args
961 :param kwargs:
962
963 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
964 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
965 - Rest of the kwargs are passed to the Redis connection
966
967 :raises RedisClusterException: if target_nodes is not provided & the command
968 can't be mapped to a slot
969 """
970 command = args[0]
971 target_nodes = []
972 target_nodes_specified = False
973 retry_attempts = self.retry.get_retries()
974
975 passed_targets = kwargs.pop("target_nodes", None)
976 if passed_targets and not self._is_node_flag(passed_targets):
977 target_nodes = self._parse_target_nodes(passed_targets)
978 target_nodes_specified = True
979 retry_attempts = 0
980
981 command_policies = await self._policy_resolver.resolve(args[0].lower())
982
983 if not command_policies and not target_nodes_specified:
984 command_flag = self.command_flags.get(command)
985 if not command_flag:
986 # Fallback to default policy
987 if not self.get_default_node():
988 slot = None
989 else:
990 slot = await self._determine_slot(*args)
991 if slot is None:
992 command_policies = CommandPolicies()
993 else:
994 command_policies = CommandPolicies(
995 request_policy=RequestPolicy.DEFAULT_KEYED,
996 response_policy=ResponsePolicy.DEFAULT_KEYED,
997 )
998 else:
999 if command_flag in self._command_flags_mapping:
1000 command_policies = CommandPolicies(
1001 request_policy=self._command_flags_mapping[command_flag]
1002 )
1003 else:
1004 command_policies = CommandPolicies()
1005 elif not command_policies and target_nodes_specified:
1006 command_policies = CommandPolicies()
1007
1008 # Add one for the first execution
1009 execute_attempts = 1 + retry_attempts
1010 failure_count = 0
1011
1012 # Start timing for observability
1013 start_time = time.monotonic()
1014 last_failed_node_name = None
1015
1016 for _ in range(execute_attempts):
1017 if self._initialize:
1018 await self.initialize(last_failed_node_name=last_failed_node_name)
1019 last_failed_node_name = None
1020 if (
1021 len(target_nodes) == 1
1022 and target_nodes[0] == self.get_default_node()
1023 ):
1024 # Replace the default cluster node
1025 self.replace_default_node()
1026 try:
1027 if not target_nodes_specified:
1028 # Determine the nodes to execute the command on
1029 target_nodes = await self._determine_nodes(
1030 *args,
1031 request_policy=command_policies.request_policy,
1032 node_flag=passed_targets,
1033 )
1034 if not target_nodes:
1035 raise RedisClusterException(
1036 f"No targets were found to execute {args} command on"
1037 )
1038
1039 if len(target_nodes) == 1:
1040 # Return the processed result
1041 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
1042 if command in self.result_callbacks:
1043 ret = self.result_callbacks[command](
1044 command, {target_nodes[0].name: ret}, **kwargs
1045 )
1046 return self._policies_callback_mapping[
1047 command_policies.response_policy
1048 ](ret)
1049 else:
1050 keys = [node.name for node in target_nodes]
1051 values = await asyncio.gather(
1052 *(
1053 asyncio.create_task(
1054 self._execute_command(node, *args, **kwargs)
1055 )
1056 for node in target_nodes
1057 )
1058 )
1059 if command in self.result_callbacks:
1060 return self.result_callbacks[command](
1061 command, dict(zip(keys, values)), **kwargs
1062 )
1063 return self._policies_callback_mapping[
1064 command_policies.response_policy
1065 ](dict(zip(keys, values)))
1066 except Exception as e:
1067 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
1068 # The nodes and slots cache were should be reinitialized.
1069 # Try again with the new cluster setup.
1070 retry_attempts -= 1
1071 failure_count += 1
1072 last_failed_node_name = getattr(e, "last_failed_node_name", None)
1073
1074 if hasattr(e, "connection"):
1075 await self._record_command_metric(
1076 command_name=command,
1077 duration_seconds=time.monotonic() - start_time,
1078 connection=e.connection,
1079 error=e,
1080 )
1081 await self._record_error_metric(
1082 error=e,
1083 connection=e.connection,
1084 retry_attempts=failure_count,
1085 )
1086 continue
1087 else:
1088 # raise the exception
1089 if hasattr(e, "connection"):
1090 await self._record_error_metric(
1091 error=e,
1092 connection=e.connection,
1093 retry_attempts=failure_count,
1094 is_internal=False,
1095 )
1096 raise e
1097
1098 async def _execute_command(
1099 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
1100 ) -> Any:
1101 asking = moved = False
1102 redirect_addr = None
1103 ttl = self.RedisClusterRequestTTL
1104 command = args[0]
1105 start_time = time.monotonic()
1106
1107 while ttl > 0:
1108 ttl -= 1
1109 try:
1110 if asking:
1111 target_node = self.get_node(node_name=redirect_addr)
1112 await target_node.execute_command("ASKING")
1113 asking = False
1114 elif moved:
1115 # MOVED occurred and the slots cache was updated,
1116 # refresh the target node
1117 slot = await self._determine_slot(*args)
1118 target_node = self.nodes_manager.get_node_from_slot(
1119 slot,
1120 self.read_from_replicas and args[0] in READ_COMMANDS,
1121 self.load_balancing_strategy
1122 if args[0] in READ_COMMANDS
1123 else None,
1124 )
1125 moved = False
1126
1127 response = await target_node.execute_command(*args, **kwargs)
1128 await self._record_command_metric(
1129 command_name=command,
1130 duration_seconds=time.monotonic() - start_time,
1131 connection=target_node,
1132 )
1133 return response
1134 except BusyLoadingError as e:
1135 e.connection = target_node
1136 await self._record_command_metric(
1137 command_name=command,
1138 duration_seconds=time.monotonic() - start_time,
1139 connection=target_node,
1140 error=e,
1141 )
1142 raise
1143 except MaxConnectionsError as e:
1144 # MaxConnectionsError indicates client-side resource exhaustion
1145 # (too many connections in the pool), not a node failure.
1146 # Don't treat this as a node failure - just re-raise the error
1147 # without reinitializing the cluster.
1148 e.connection = target_node
1149 await self._record_command_metric(
1150 command_name=command,
1151 duration_seconds=time.monotonic() - start_time,
1152 connection=target_node,
1153 error=e,
1154 )
1155 raise
1156 except (ConnectionError, TimeoutError) as e:
1157 # Connection retries are being handled in the node's
1158 # Retry object.
1159 # Mark active connections for reconnect and disconnect free ones
1160 # This handles connection state (like READONLY) that may be stale
1161 target_node.update_active_connections_for_reconnect()
1162 await target_node.disconnect_free_connections()
1163
1164 # Move the failed node to the end of the cached nodes list
1165 # so it's tried last during reinitialization
1166 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)
1167 e.last_failed_node_name = target_node.name
1168
1169 # Signal that reinitialization is needed
1170 # The retry loop will handle initialize() AND replace_default_node()
1171 self._initialize = True
1172 e.connection = target_node
1173 await self._record_command_metric(
1174 command_name=command,
1175 duration_seconds=time.monotonic() - start_time,
1176 connection=target_node,
1177 error=e,
1178 )
1179 raise
1180 except (ClusterDownError, SlotNotCoveredError) as e:
1181 # ClusterDownError can occur during a failover and to get
1182 # self-healed, we will try to reinitialize the cluster layout
1183 # and retry executing the command
1184
1185 # SlotNotCoveredError can occur when the cluster is not fully
1186 # initialized or can be temporary issue.
1187 # We will try to reinitialize the cluster topology
1188 # and retry executing the command
1189
1190 await self.aclose()
1191 await asyncio.sleep(0.25)
1192 e.connection = target_node
1193 await self._record_command_metric(
1194 command_name=command,
1195 duration_seconds=time.monotonic() - start_time,
1196 connection=target_node,
1197 error=e,
1198 )
1199 raise
1200 except MovedError as e:
1201 # First, we will try to patch the slots/nodes cache with the
1202 # redirected node output and try again. If MovedError exceeds
1203 # 'reinitialize_steps' number of times, we will force
1204 # reinitializing the tables, and then try again.
1205 # 'reinitialize_steps' counter will increase faster when
1206 # the same client object is shared between multiple threads. To
1207 # reduce the frequency you can set this variable in the
1208 # RedisCluster constructor.
1209 self.reinitialize_counter += 1
1210 if (
1211 self.reinitialize_steps
1212 and self.reinitialize_counter % self.reinitialize_steps == 0
1213 ):
1214 await self.aclose()
1215 # Reset the counter
1216 self.reinitialize_counter = 0
1217 else:
1218 await self.nodes_manager.move_slot(e)
1219 moved = True
1220 await self._record_command_metric(
1221 command_name=command,
1222 duration_seconds=time.monotonic() - start_time,
1223 connection=target_node,
1224 error=e,
1225 )
1226 await self._record_error_metric(
1227 error=e,
1228 connection=target_node,
1229 )
1230 except AskError as e:
1231 redirect_addr = get_node_name(host=e.host, port=e.port)
1232 asking = True
1233 await self._record_command_metric(
1234 command_name=command,
1235 duration_seconds=time.monotonic() - start_time,
1236 connection=target_node,
1237 error=e,
1238 )
1239 await self._record_error_metric(
1240 error=e,
1241 connection=target_node,
1242 )
1243 except TryAgainError as e:
1244 if ttl < self.RedisClusterRequestTTL / 2:
1245 await asyncio.sleep(0.05)
1246 await self._record_command_metric(
1247 command_name=command,
1248 duration_seconds=time.monotonic() - start_time,
1249 connection=target_node,
1250 error=e,
1251 )
1252 await self._record_error_metric(
1253 error=e,
1254 connection=target_node,
1255 )
1256 except ResponseError as e:
1257 e.connection = target_node
1258 await self._record_command_metric(
1259 command_name=command,
1260 duration_seconds=time.monotonic() - start_time,
1261 connection=target_node,
1262 error=e,
1263 )
1264 raise
1265 except Exception as e:
1266 e.connection = target_node
1267 await self._record_command_metric(
1268 command_name=command,
1269 duration_seconds=time.monotonic() - start_time,
1270 connection=target_node,
1271 error=e,
1272 )
1273 raise
1274
1275 e = ClusterError("TTL exhausted.")
1276 e.connection = target_node
1277 await self._record_command_metric(
1278 command_name=command,
1279 duration_seconds=time.monotonic() - start_time,
1280 connection=target_node,
1281 error=e,
1282 )
1283 raise e
1284
1285 def pipeline(
1286 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
1287 ) -> "ClusterPipeline":
1288 """
1289 Create & return a new :class:`~.ClusterPipeline` object.
1290
1291 Cluster implementation of pipeline does not support transaction or shard_hint.
1292
1293 :raises RedisClusterException: if transaction or shard_hint are truthy values
1294 """
1295 if shard_hint:
1296 raise RedisClusterException("shard_hint is deprecated in cluster mode")
1297
1298 return ClusterPipeline(self, transaction)
1299
1300 def pubsub(
1301 self,
1302 node: Optional["ClusterNode"] = None,
1303 host: Optional[str] = None,
1304 port: Optional[int] = None,
1305 **kwargs: Any,
1306 ) -> "ClusterPubSub":
1307 """
1308 Create and return a ClusterPubSub instance.
1309
1310 Allows passing a ClusterNode, or host&port, to get a pubsub instance
1311 connected to the specified node
1312
1313 :param node: ClusterNode to connect to
1314 :param host: Host of the node to connect to
1315 :param port: Port of the node to connect to
1316 :param kwargs: Additional keyword arguments
1317 :return: ClusterPubSub instance
1318 """
1319 return ClusterPubSub(self, node=node, host=host, port=port, **kwargs)
1320
1321 def keyspace_notifications(
1322 self,
1323 key_prefix: Union[str, bytes, None] = None,
1324 ignore_subscribe_messages: bool = True,
1325 ) -> "AsyncClusterKeyspaceNotifications":
1326 """
1327 Return an
1328 :class:`~redis.asyncio.keyspace_notifications.AsyncClusterKeyspaceNotifications`
1329 object for subscribing to keyspace and keyevent notifications across
1330 all primary nodes in the cluster.
1331
1332 Note: Keyspace notifications must be enabled on all Redis cluster nodes
1333 via the ``notify-keyspace-events`` configuration option.
1334
1335 Args:
1336 key_prefix: Optional prefix to filter and strip from keys in
1337 notifications.
1338 ignore_subscribe_messages: If True, subscribe/unsubscribe
1339 confirmations are not returned by
1340 get_message/listen.
1341 """
1342 from redis.asyncio.keyspace_notifications import (
1343 AsyncClusterKeyspaceNotifications,
1344 )
1345
1346 return AsyncClusterKeyspaceNotifications(
1347 self,
1348 key_prefix=key_prefix,
1349 ignore_subscribe_messages=ignore_subscribe_messages,
1350 )
1351
1352 def lock(
1353 self,
1354 name: KeyT,
1355 timeout: Optional[float] = None,
1356 sleep: float = 0.1,
1357 blocking: bool = True,
1358 blocking_timeout: Optional[float] = None,
1359 lock_class: Optional[Type[Lock]] = None,
1360 thread_local: bool = True,
1361 raise_on_release_error: bool = True,
1362 ) -> Lock:
1363 """
1364 Return a new Lock object using key ``name`` that mimics
1365 the behavior of threading.Lock.
1366
1367 If specified, ``timeout`` indicates a maximum life for the lock.
1368 By default, it will remain locked until release() is called.
1369
1370 ``sleep`` indicates the amount of time to sleep per loop iteration
1371 when the lock is in blocking mode and another client is currently
1372 holding the lock.
1373
1374 ``blocking`` indicates whether calling ``acquire`` should block until
1375 the lock has been acquired or to fail immediately, causing ``acquire``
1376 to return False and the lock not being acquired. Defaults to True.
1377 Note this value can be overridden by passing a ``blocking``
1378 argument to ``acquire``.
1379
1380 ``blocking_timeout`` indicates the maximum amount of time in seconds to
1381 spend trying to acquire the lock. A value of ``None`` indicates
1382 continue trying forever. ``blocking_timeout`` can be specified as a
1383 float or integer, both representing the number of seconds to wait.
1384
1385 ``lock_class`` forces the specified lock implementation. Note that as
1386 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
1387 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
1388 you have created your own custom lock class.
1389
1390 ``thread_local`` indicates whether the lock token is placed in
1391 thread-local storage. By default, the token is placed in thread local
1392 storage so that a thread only sees its token, not a token set by
1393 another thread. Consider the following timeline:
1394
1395 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
1396 thread-1 sets the token to "abc"
1397 time: 1, thread-2 blocks trying to acquire `my-lock` using the
1398 Lock instance.
1399 time: 5, thread-1 has not yet completed. redis expires the lock
1400 key.
1401 time: 5, thread-2 acquired `my-lock` now that it's available.
1402 thread-2 sets the token to "xyz"
1403 time: 6, thread-1 finishes its work and calls release(). if the
1404 token is *not* stored in thread local storage, then
1405 thread-1 would see the token value as "xyz" and would be
1406 able to successfully release the thread-2's lock.
1407
1408 ``raise_on_release_error`` indicates whether to raise an exception when
1409 the lock is no longer owned when exiting the context manager. By default,
1410 this is True, meaning an exception will be raised. If False, the warning
1411 will be logged and the exception will be suppressed.
1412
1413 In some use cases it's necessary to disable thread local storage. For
1414 example, if you have code where one thread acquires a lock and passes
1415 that lock instance to a worker thread to release later. If thread
1416 local storage isn't disabled in this case, the worker thread won't see
1417 the token set by the thread that acquired the lock. Our assumption
1418 is that these cases aren't common and as such default to using
1419 thread local storage."""
1420 if lock_class is None:
1421 lock_class = Lock
1422 return lock_class(
1423 self,
1424 name,
1425 timeout=timeout,
1426 sleep=sleep,
1427 blocking=blocking,
1428 blocking_timeout=blocking_timeout,
1429 thread_local=thread_local,
1430 raise_on_release_error=raise_on_release_error,
1431 )
1432
1433 async def transaction(
1434 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1435 ):
1436 """
1437 Convenience method for executing the callable `func` as a transaction
1438 while watching all keys specified in `watches`. The 'func' callable
1439 should expect a single argument which is a Pipeline object.
1440 """
1441 shard_hint = kwargs.pop("shard_hint", None)
1442 value_from_callable = kwargs.pop("value_from_callable", False)
1443 watch_delay = kwargs.pop("watch_delay", None)
1444 async with self.pipeline(True, shard_hint) as pipe:
1445 while True:
1446 try:
1447 if watches:
1448 await pipe.watch(*watches)
1449 func_value = await func(pipe)
1450 exec_value = await pipe.execute()
1451 return func_value if value_from_callable else exec_value
1452 except WatchError:
1453 if watch_delay is not None and watch_delay > 0:
1454 time.sleep(watch_delay)
1455 continue
1456
1457
1458class ClusterNode:
1459 """
1460 Create a new ClusterNode.
1461
1462 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1463 objects for the (host, port).
1464 """
1465
1466 __slots__ = (
1467 "_background_tasks",
1468 "_connections",
1469 "_free",
1470 "_lock",
1471 "_event_dispatcher",
1472 "connection_class",
1473 "connection_kwargs",
1474 "host",
1475 "max_connections",
1476 "name",
1477 "port",
1478 "response_callbacks",
1479 "server_type",
1480 )
1481
1482 def __init__(
1483 self,
1484 host: str,
1485 port: Union[str, int],
1486 server_type: Optional[str] = None,
1487 *,
1488 max_connections: int = 100,
1489 connection_class: Type[Connection] = Connection,
1490 **connection_kwargs: Any,
1491 ) -> None:
1492 if host == "localhost":
1493 host = socket.gethostbyname(host)
1494
1495 connection_kwargs["host"] = host
1496 connection_kwargs["port"] = port
1497 self.host = host
1498 self.port = port
1499 self.name = get_node_name(host, port)
1500 self.server_type = server_type
1501
1502 self.max_connections = max_connections
1503 self.connection_class = connection_class
1504 self.connection_kwargs = connection_kwargs
1505 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1506
1507 self._connections: List[Connection] = []
1508 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1509 self._background_tasks: Set[asyncio.Task] = set()
1510 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1511 if self._event_dispatcher is None:
1512 self._event_dispatcher = EventDispatcher()
1513
1514 def __repr__(self) -> str:
1515 return (
1516 f"[host={self.host}, port={self.port}, "
1517 f"name={self.name}, server_type={self.server_type}]"
1518 )
1519
1520 def __eq__(self, obj: Any) -> bool:
1521 return isinstance(obj, ClusterNode) and obj.name == self.name
1522
1523 def __hash__(self) -> int:
1524 return hash(self.name)
1525
1526 _DEL_MESSAGE = "Unclosed ClusterNode object"
1527
1528 def __del__(
1529 self,
1530 _warn: Any = warnings.warn,
1531 _grl: Any = asyncio.get_running_loop,
1532 ) -> None:
1533 for connection in self._connections:
1534 if connection.is_connected:
1535 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1536
1537 try:
1538 context = {"client": self, "message": self._DEL_MESSAGE}
1539 _grl().call_exception_handler(context)
1540 except RuntimeError:
1541 pass
1542 break
1543
1544 async def disconnect(self) -> None:
1545 ret = await asyncio.gather(
1546 *(
1547 asyncio.create_task(connection.disconnect())
1548 for connection in self._connections
1549 ),
1550 return_exceptions=True,
1551 )
1552 exc = next((res for res in ret if isinstance(res, Exception)), None)
1553 if exc:
1554 raise exc
1555
1556 def acquire_connection(self) -> Connection:
1557 try:
1558 return self._free.popleft()
1559 except IndexError:
1560 if len(self._connections) < self.max_connections:
1561 # We are configuring the connection pool not to retry
1562 # connections on lower level clients to avoid retrying
1563 # connections to nodes that are not reachable
1564 # and to avoid blocking the connection pool.
1565 # The only error that will have some handling in the lower
1566 # level clients is ConnectionError which will trigger disconnection
1567 # of the socket.
1568 # The retries will be handled on cluster client level
1569 # where we will have proper handling of the cluster topology
1570 retry = Retry(
1571 backoff=NoBackoff(),
1572 retries=0,
1573 supported_errors=(ConnectionError,),
1574 )
1575 connection_kwargs = self.connection_kwargs.copy()
1576 connection_kwargs["retry"] = retry
1577 connection = self.connection_class(**connection_kwargs)
1578 self._connections.append(connection)
1579 return connection
1580
1581 raise MaxConnectionsError()
1582
1583 async def disconnect_if_needed(self, connection: Connection) -> None:
1584 """
1585 Disconnect a connection if it's marked for reconnect.
1586 This implements lazy disconnection to avoid race conditions.
1587 The connection will auto-reconnect on next use.
1588 """
1589 if connection.should_reconnect():
1590 await connection.disconnect()
1591
1592 def release(self, connection: Connection) -> None:
1593 """
1594 Release connection back to free queue.
1595 If the connection is marked for reconnect, disconnect it before
1596 returning it to the free queue.
1597 """
1598 if connection.should_reconnect():
1599 task = asyncio.create_task(self._disconnect_and_release(connection))
1600 self._background_tasks.add(task)
1601 task.add_done_callback(self._background_tasks.discard)
1602 return
1603 self._free.append(connection)
1604
1605 async def _disconnect_and_release(self, connection: Connection) -> None:
1606 try:
1607 await connection.disconnect()
1608 except Exception as exc:
1609 logger.debug(
1610 "disconnecting released cluster connection failed: %r",
1611 exc,
1612 exc_info=True,
1613 )
1614 try:
1615 self._connections.remove(connection)
1616 except ValueError:
1617 pass
1618 return
1619
1620 self._free.append(connection)
1621
1622 def get_encoder(self) -> Encoder:
1623 """Return an :class:`Encoder` derived from this node's connection kwargs."""
1624 kwargs = self.connection_kwargs
1625 encoder_class = kwargs.get("encoder_class", Encoder)
1626 return encoder_class(
1627 encoding=kwargs.get("encoding", "utf-8"),
1628 encoding_errors=kwargs.get("encoding_errors", "strict"),
1629 decode_responses=kwargs.get("decode_responses", False),
1630 )
1631
1632 def update_active_connections_for_reconnect(self) -> None:
1633 """
1634 Mark all in-use (active) connections for reconnect.
1635 In-use connections are those in _connections but not currently in _free.
1636 They will be disconnected after their current operation completes.
1637 """
1638 free_set = set(self._free)
1639 for connection in self._connections:
1640 if connection not in free_set:
1641 connection.mark_for_reconnect()
1642
1643 async def disconnect_free_connections(self) -> None:
1644 """
1645 Disconnect all free/idle connections in the pool.
1646 This is useful after topology changes (e.g., failover) to clear
1647 stale connection state like READONLY mode.
1648 The connections remain in the pool and will reconnect on next use.
1649 """
1650 if self._free:
1651 # Take a snapshot to avoid issues if _free changes during await
1652 await asyncio.gather(
1653 *(connection.disconnect() for connection in tuple(self._free)),
1654 return_exceptions=True,
1655 )
1656
1657 async def parse_response(
1658 self, connection: Connection, command: str, **kwargs: Any
1659 ) -> Any:
1660 try:
1661 if NEVER_DECODE in kwargs:
1662 response = await connection.read_response(disable_decoding=True)
1663 kwargs.pop(NEVER_DECODE)
1664 else:
1665 response = await connection.read_response()
1666 except ResponseError:
1667 if EMPTY_RESPONSE in kwargs:
1668 return kwargs[EMPTY_RESPONSE]
1669 raise
1670
1671 if EMPTY_RESPONSE in kwargs:
1672 kwargs.pop(EMPTY_RESPONSE)
1673
1674 # Remove keys entry, it needs only for cache.
1675 kwargs.pop("keys", None)
1676
1677 # Return response
1678 if command in self.response_callbacks:
1679 return self.response_callbacks[command](response, **kwargs)
1680
1681 return response
1682
1683 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1684 # Acquire connection
1685 connection = self.acquire_connection()
1686 # Handle lazy disconnect for connections marked for reconnect
1687 await self.disconnect_if_needed(connection)
1688
1689 # Execute command
1690 await connection.send_packed_command(connection.pack_command(*args))
1691
1692 # Read response
1693 try:
1694 return await self.parse_response(connection, args[0], **kwargs)
1695 finally:
1696 await self.disconnect_if_needed(connection)
1697 # Release connection
1698 self.release(connection)
1699
1700 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1701 # Acquire connection
1702 connection = self.acquire_connection()
1703 # Handle lazy disconnect for connections marked for reconnect
1704 await self.disconnect_if_needed(connection)
1705
1706 # Execute command
1707 await connection.send_packed_command(
1708 connection.pack_commands(cmd.args for cmd in commands)
1709 )
1710
1711 # Read responses
1712 ret = False
1713 for cmd in commands:
1714 try:
1715 cmd.result = await self.parse_response(
1716 connection, cmd.args[0], **cmd.kwargs
1717 )
1718 except Exception as e:
1719 cmd.result = e
1720 ret = True
1721
1722 # Release connection
1723 await self.disconnect_if_needed(connection)
1724 self.release(connection)
1725
1726 return ret
1727
1728 async def re_auth_callback(self, token: TokenInterface):
1729 tmp_queue = collections.deque()
1730 while self._free:
1731 conn = self._free.popleft()
1732 await conn.retry.call_with_retry(
1733 lambda: conn.send_command(
1734 "AUTH", token.try_get("oid"), token.get_value()
1735 ),
1736 lambda error: self._mock(error),
1737 )
1738 await conn.retry.call_with_retry(
1739 lambda: conn.read_response(), lambda error: self._mock(error)
1740 )
1741 tmp_queue.append(conn)
1742
1743 while tmp_queue:
1744 conn = tmp_queue.popleft()
1745 self._free.append(conn)
1746
1747 async def _mock(self, error: RedisError):
1748 """
1749 Dummy functions, needs to be passed as error callback to retry object.
1750 :param error:
1751 :return:
1752 """
1753 pass
1754
1755
1756class NodesManager:
1757 __slots__ = (
1758 "_dynamic_startup_nodes",
1759 "_event_dispatcher",
1760 "_background_tasks",
1761 "connection_kwargs",
1762 "default_node",
1763 "nodes_cache",
1764 "_epoch",
1765 "read_load_balancer",
1766 "_initialize_lock",
1767 "require_full_coverage",
1768 "slots_cache",
1769 "startup_nodes",
1770 "address_remap",
1771 )
1772
1773 def __init__(
1774 self,
1775 startup_nodes: List["ClusterNode"],
1776 require_full_coverage: bool,
1777 connection_kwargs: Dict[str, Any],
1778 dynamic_startup_nodes: bool = True,
1779 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1780 event_dispatcher: Optional[EventDispatcher] = None,
1781 ) -> None:
1782 self.startup_nodes = {node.name: node for node in startup_nodes}
1783 self.require_full_coverage = require_full_coverage
1784 self.connection_kwargs = connection_kwargs
1785 self.address_remap = address_remap
1786
1787 self.default_node: "ClusterNode" = None
1788 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1789 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1790 self._epoch: int = 0
1791 self.read_load_balancer = LoadBalancer()
1792 self._initialize_lock: asyncio.Lock = asyncio.Lock()
1793
1794 self._background_tasks: Set[asyncio.Task] = set()
1795 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1796 if event_dispatcher is None:
1797 self._event_dispatcher = EventDispatcher()
1798 else:
1799 self._event_dispatcher = event_dispatcher
1800
1801 def get_node(
1802 self,
1803 host: Optional[str] = None,
1804 port: Optional[int] = None,
1805 node_name: Optional[str] = None,
1806 ) -> Optional["ClusterNode"]:
1807 if host and port:
1808 # the user passed host and port
1809 if host == "localhost":
1810 host = socket.gethostbyname(host)
1811 return self.nodes_cache.get(get_node_name(host=host, port=port))
1812 elif node_name:
1813 return self.nodes_cache.get(node_name)
1814 else:
1815 raise DataError(
1816 "get_node requires one of the following: 1. node name 2. host and port"
1817 )
1818
1819 def set_nodes(
1820 self,
1821 old: Dict[str, "ClusterNode"],
1822 new: Dict[str, "ClusterNode"],
1823 remove_old: bool = False,
1824 ) -> None:
1825 if remove_old:
1826 for name in list(old.keys()):
1827 if name not in new:
1828 # Node is removed from cache before disconnect starts,
1829 # so it won't be found in lookups during disconnect
1830 # Mark active connections so in-flight commands can
1831 # finish, then disconnect them when their current
1832 # operation completes. Free connections can be
1833 # disconnected immediately.
1834 removed_node = old.pop(name)
1835 removed_node.update_active_connections_for_reconnect()
1836 task = asyncio.create_task(
1837 removed_node.disconnect_free_connections()
1838 )
1839 self._background_tasks.add(task)
1840 task.add_done_callback(self._background_tasks.discard)
1841
1842 for name, node in new.items():
1843 if name in old:
1844 # Preserve the existing node but mark connections for reconnect.
1845 # This method is sync so we can't call disconnect_free_connections()
1846 # which is async. Instead, we mark free connections for reconnect
1847 # and they will be lazily disconnected when acquired via
1848 # disconnect_if_needed() to avoid race conditions.
1849 # TODO: Make this method async in the next major release to allow
1850 # immediate disconnection of free connections.
1851 existing_node = old[name]
1852 existing_node.server_type = node.server_type
1853 existing_node.update_active_connections_for_reconnect()
1854 for conn in existing_node._free:
1855 conn.mark_for_reconnect()
1856 continue
1857 # New node is detected and should be added to the pool
1858 old[name] = node
1859
1860 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
1861 """
1862 Move a failing node to the end of startup_nodes and nodes_cache so it's
1863 tried last during reinitialization and when selecting the default node.
1864 If the node is not in the respective list, nothing is done.
1865 """
1866 # Move in startup_nodes
1867 if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
1868 node = self.startup_nodes.pop(node_name)
1869 self.startup_nodes[node_name] = node # Re-insert at end
1870
1871 # Move in nodes_cache - this affects get_nodes_by_server_type ordering
1872 # which is used to select the default_node during initialize()
1873 if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
1874 node = self.nodes_cache.pop(node_name)
1875 self.nodes_cache[node_name] = node # Re-insert at end
1876
1877 async def move_slot(self, e: AskError | MovedError):
1878 node_changed = False
1879 redirected_node = self.get_node(host=e.host, port=e.port)
1880 if redirected_node:
1881 # The node already exists
1882 if redirected_node.server_type != PRIMARY:
1883 # Update the node's server type
1884 redirected_node.server_type = PRIMARY
1885 else:
1886 # This is a new node, we will add it to the nodes cache
1887 redirected_node = ClusterNode(
1888 e.host, e.port, PRIMARY, **self.connection_kwargs
1889 )
1890 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1891 slot_nodes = self.slots_cache[e.slot_id]
1892 if redirected_node not in slot_nodes:
1893 # The new slot owner is a new server, or a server from a different
1894 # shard. We need to remove all current nodes from the slot's list
1895 # (including replications) and add just the new node.
1896 self.slots_cache[e.slot_id] = [redirected_node]
1897 node_changed = True
1898 elif redirected_node is not slot_nodes[0]:
1899 # The MOVED error resulted from a failover, and the new slot owner
1900 # had previously been a replica.
1901 old_primary = slot_nodes[0]
1902 # Update the old primary to be a replica and add it to the end of
1903 # the slot's node list
1904 old_primary.server_type = REPLICA
1905 slot_nodes.append(old_primary)
1906 # Remove the old replica, which is now a primary, from the slot's
1907 # node list
1908 slot_nodes.remove(redirected_node)
1909 # Override the old primary with the new one
1910 slot_nodes[0] = redirected_node
1911 if self.default_node == old_primary:
1912 # Update the default node with the new primary
1913 self.default_node = redirected_node
1914 node_changed = True
1915 # else: circular MOVED to current primary -> no-op
1916 # Dispatch so listeners can run shard-pubsub reconciliation; skipped on
1917 # the no-op branch to avoid needless walks under MOVED storms. A
1918 # listener must not break slots-cache refresh; log and continue so a
1919 # single buggy listener cannot starve the rest.
1920 if node_changed:
1921 try:
1922 await self._event_dispatcher.dispatch_async(
1923 AsyncAfterSlotsCacheRefreshEvent()
1924 )
1925 except Exception as exc:
1926 # Don't shadow the method parameter ``e``: ``except as`` binds
1927 # the listener exception in the function scope and ``del``s
1928 # the name on block exit (PEP 3134), which would also wipe
1929 # out the original AskError/MovedError parameter.
1930 logger.exception(
1931 "listener raised during slots-cache refresh: %s: %s",
1932 type(exc).__name__,
1933 exc,
1934 )
1935
1936 def get_node_from_slot(
1937 self,
1938 slot: int,
1939 read_from_replicas: bool = False,
1940 load_balancing_strategy=None,
1941 ) -> "ClusterNode":
1942 if read_from_replicas is True and load_balancing_strategy is None:
1943 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1944
1945 try:
1946 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1947 # get the server index using the strategy defined in load_balancing_strategy
1948 primary_name = self.slots_cache[slot][0].name
1949 node_idx = self.read_load_balancer.get_server_index(
1950 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1951 )
1952 return self.slots_cache[slot][node_idx]
1953 return self.slots_cache[slot][0]
1954 except (IndexError, TypeError):
1955 raise SlotNotCoveredError(
1956 f'Slot "{slot}" not covered by the cluster. '
1957 f'"require_full_coverage={self.require_full_coverage}"'
1958 )
1959
1960 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1961 return [
1962 node
1963 for node in self.nodes_cache.values()
1964 if node.server_type == server_type
1965 ]
1966
1967 async def initialize(
1968 self,
1969 additional_startup_nodes_info: Optional[List[Tuple[str, int]]] = None,
1970 last_failed_node_name: Optional[str] = None,
1971 ) -> None:
1972 self.read_load_balancer.reset()
1973 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1974 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1975 disagreements = []
1976 startup_nodes_reachable = False
1977 fully_covered = False
1978 exception = None
1979 epoch = self._epoch
1980 if additional_startup_nodes_info is None:
1981 additional_startup_nodes_info = []
1982
1983 async with self._initialize_lock:
1984 if self._epoch != epoch:
1985 # another initialize call has already reinitialized the
1986 # nodes since we started waiting for the lock;
1987 # we don't need to do it again.
1988 return
1989
1990 # Copy to a list to prevent RuntimeError if self.startup_nodes
1991 # is modified during iteration, then shuffle the iteration order.
1992 startup_nodes = list(self.startup_nodes.values())
1993 deferred_failed_nodes = []
1994 if last_failed_node_name is not None:
1995 for index, node in enumerate(startup_nodes):
1996 if node.name == last_failed_node_name:
1997 deferred_failed_nodes.append(startup_nodes.pop(index))
1998 break
1999 if len(startup_nodes) > 1:
2000 # Vary which startup node is queried first so clients do not
2001 # all reinitialize through the same node.
2002 random.shuffle(startup_nodes)
2003 additional_startup_nodes = [
2004 ClusterNode(host, port, **self.connection_kwargs)
2005 for host, port in additional_startup_nodes_info
2006 ]
2007 if last_failed_node_name is not None:
2008 for index, node in enumerate(additional_startup_nodes):
2009 if node.name == last_failed_node_name:
2010 if not deferred_failed_nodes:
2011 deferred_failed_nodes.append(node)
2012 additional_startup_nodes.pop(index)
2013 break
2014 for startup_node in chain(
2015 startup_nodes,
2016 additional_startup_nodes,
2017 deferred_failed_nodes,
2018 ):
2019 try:
2020 # Make sure cluster mode is enabled on this node
2021 try:
2022 self._event_dispatcher.dispatch(
2023 AfterAsyncClusterInstantiationEvent(
2024 self.nodes_cache,
2025 self.connection_kwargs.get("credential_provider", None),
2026 )
2027 )
2028 cluster_slots = await startup_node.execute_command(
2029 "CLUSTER SLOTS"
2030 )
2031 except ResponseError:
2032 raise RedisClusterException(
2033 "Cluster mode is not enabled on this node"
2034 )
2035 startup_nodes_reachable = True
2036 except Exception as e:
2037 # Try the next startup node.
2038 # The exception is saved and raised only if we have no more nodes.
2039 exception = e
2040 continue
2041
2042 # CLUSTER SLOTS command results in the following output:
2043 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
2044 # where each node contains the following list: [IP, port, node_id]
2045 # Therefore, cluster_slots[0][2][0] will be the IP address of the
2046 # primary node of the first slot section.
2047 # If there's only one server in the cluster, its ``host`` is ''
2048 # Fix it to the host in startup_nodes
2049 if (
2050 len(cluster_slots) == 1
2051 and not cluster_slots[0][2][0]
2052 and len(self.startup_nodes) == 1
2053 ):
2054 cluster_slots[0][2][0] = startup_node.host
2055
2056 for slot in cluster_slots:
2057 for i in range(2, len(slot)):
2058 slot[i] = [str_if_bytes(val) for val in slot[i]]
2059 primary_node = slot[2]
2060 host = primary_node[0]
2061 if host == "":
2062 host = startup_node.host
2063 port = int(primary_node[1])
2064 host, port = self.remap_host_port(host, port)
2065
2066 nodes_for_slot = []
2067
2068 target_node = tmp_nodes_cache.get(get_node_name(host, port))
2069 if not target_node:
2070 target_node = ClusterNode(
2071 host, port, PRIMARY, **self.connection_kwargs
2072 )
2073 # add this node to the nodes cache
2074 tmp_nodes_cache[target_node.name] = target_node
2075 nodes_for_slot.append(target_node)
2076
2077 replica_nodes = slot[3:]
2078 for replica_node in replica_nodes:
2079 host = replica_node[0]
2080 port = replica_node[1]
2081 host, port = self.remap_host_port(host, port)
2082
2083 target_replica_node = tmp_nodes_cache.get(
2084 get_node_name(host, port)
2085 )
2086 if not target_replica_node:
2087 target_replica_node = ClusterNode(
2088 host, port, REPLICA, **self.connection_kwargs
2089 )
2090 # add this node to the nodes cache
2091 tmp_nodes_cache[target_replica_node.name] = target_replica_node
2092 nodes_for_slot.append(target_replica_node)
2093
2094 for i in range(int(slot[0]), int(slot[1]) + 1):
2095 if i not in tmp_slots:
2096 tmp_slots[i] = nodes_for_slot
2097 else:
2098 # Validate that 2 nodes want to use the same slot cache
2099 # setup
2100 tmp_slot = tmp_slots[i][0]
2101 if tmp_slot.name != target_node.name:
2102 disagreements.append(
2103 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
2104 )
2105
2106 if len(disagreements) > 5:
2107 raise RedisClusterException(
2108 f"startup_nodes could not agree on a valid "
2109 f"slots cache: {', '.join(disagreements)}"
2110 )
2111
2112 # Validate if all slots are covered or if we should try next startup node
2113 fully_covered = True
2114 for i in range(REDIS_CLUSTER_HASH_SLOTS):
2115 if i not in tmp_slots:
2116 fully_covered = False
2117 break
2118 if fully_covered:
2119 break
2120
2121 if not startup_nodes_reachable:
2122 raise RedisClusterException(
2123 f"Redis Cluster cannot be connected. Please provide at least "
2124 f"one reachable node: {str(exception)}"
2125 ) from exception
2126
2127 # Check if the slots are not fully covered
2128 if not fully_covered and self.require_full_coverage:
2129 # Despite the requirement that the slots be covered, there
2130 # isn't a full coverage
2131 raise RedisClusterException(
2132 f"All slots are not covered after query all startup_nodes. "
2133 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
2134 f"covered..."
2135 )
2136
2137 # Set the tmp variables to the real variables
2138 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
2139 # tmp_slots was built from CLUSTER SLOTS responses and can contain
2140 # newly-created ClusterNode objects for nodes we already know about.
2141 # Rebuild the slots cache with the preserved nodes_cache instances
2142 # so existing per-node connection pools stay in use after refresh.
2143 # Keep the shared node-list-per-slot-range shape from tmp_slots to
2144 # avoid allocating a separate list for every slot.
2145 node_lists_by_id: Dict[int, List["ClusterNode"]] = {}
2146 new_slots_cache: Dict[int, List["ClusterNode"]] = {}
2147 for slot, nodes in tmp_slots.items():
2148 node_list_id = id(nodes)
2149 slot_nodes = node_lists_by_id.get(node_list_id)
2150 if slot_nodes is None:
2151 slot_nodes = [self.nodes_cache[node.name] for node in nodes]
2152 node_lists_by_id[node_list_id] = slot_nodes
2153 new_slots_cache[slot] = slot_nodes
2154 self.slots_cache = new_slots_cache
2155
2156 if self._dynamic_startup_nodes:
2157 # Populate the startup nodes with all discovered nodes
2158 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
2159
2160 # Set the default node
2161 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
2162 self._epoch += 1
2163 # Dispatch so listeners (e.g. ClusterPubSub) can reconcile per-node
2164 # state after slot ownership may have changed. A listener must not
2165 # break slots-cache refresh; log and continue so a single buggy
2166 # listener cannot starve the rest.
2167 try:
2168 await self._event_dispatcher.dispatch_async(
2169 AsyncAfterSlotsCacheRefreshEvent()
2170 )
2171 except Exception as e:
2172 logger.exception(
2173 "listener raised during slots-cache refresh: %s: %s",
2174 type(e).__name__,
2175 e,
2176 )
2177
2178 async def aclose(self, attr: str = "nodes_cache") -> None:
2179 self.default_node = None
2180 await asyncio.gather(
2181 *(
2182 asyncio.create_task(node.disconnect())
2183 for node in getattr(self, attr).values()
2184 )
2185 )
2186
2187 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
2188 """
2189 Remap the host and port returned from the cluster to a different
2190 internal value. Useful if the client is not connecting directly
2191 to the cluster.
2192 """
2193 if self.address_remap:
2194 return self.address_remap((host, port))
2195 return host, port
2196
2197
2198class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
2199 """
2200 Create a new ClusterPipeline object.
2201
2202 Usage::
2203
2204 result = await (
2205 rc.pipeline()
2206 .set("A", 1)
2207 .get("A")
2208 .hset("K", "F", "V")
2209 .hgetall("K")
2210 .mset_nonatomic({"A": 2, "B": 3})
2211 .get("A")
2212 .get("B")
2213 .delete("A", "B", "K")
2214 .execute()
2215 )
2216 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
2217
2218 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
2219 are split across multiple nodes, you'll get multiple results for them in the array.
2220
2221 Retryable errors:
2222 - :class:`~.ClusterDownError`
2223 - :class:`~.ConnectionError`
2224 - :class:`~.TimeoutError`
2225
2226 Redirection errors:
2227 - :class:`~.TryAgainError`
2228 - :class:`~.MovedError`
2229 - :class:`~.AskError`
2230
2231 :param client:
2232 | Existing :class:`~.RedisCluster` client
2233 """
2234
2235 __slots__ = (
2236 "cluster_client",
2237 "_transaction",
2238 "_execution_strategy",
2239 )
2240
2241 # Type discrimination marker for @overload self-type pattern
2242 _is_async_client: Literal[True] = True
2243
2244 def __init__(
2245 self, client: RedisCluster, transaction: Optional[bool] = None
2246 ) -> None:
2247 self.cluster_client = client
2248 self._transaction = transaction
2249 self._execution_strategy: ExecutionStrategy = (
2250 PipelineStrategy(self)
2251 if not self._transaction
2252 else TransactionStrategy(self)
2253 )
2254
2255 @property
2256 def nodes_manager(self) -> "NodesManager":
2257 """Get the nodes manager from the cluster client."""
2258 return self.cluster_client.nodes_manager
2259
2260 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
2261 """Set a custom response callback on the cluster client."""
2262 self.cluster_client.set_response_callback(command, callback)
2263
2264 async def initialize(self) -> "ClusterPipeline":
2265 await self._execution_strategy.initialize()
2266 return self
2267
2268 async def __aenter__(self) -> "ClusterPipeline":
2269 return await self.initialize()
2270
2271 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
2272 await self.reset()
2273
2274 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
2275 return self.initialize().__await__()
2276
2277 def __bool__(self) -> bool:
2278 "Pipeline instances should always evaluate to True on Python 3+"
2279 return True
2280
2281 def __len__(self) -> int:
2282 return len(self._execution_strategy)
2283
2284 def execute_command(
2285 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2286 ) -> "ClusterPipeline":
2287 """
2288 Append a raw command to the pipeline.
2289
2290 :param args:
2291 | Raw command args
2292 :param kwargs:
2293
2294 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
2295 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
2296 - Rest of the kwargs are passed to the Redis connection
2297 """
2298 return self._execution_strategy.execute_command(*args, **kwargs)
2299
2300 async def execute(
2301 self, raise_on_error: bool = True, allow_redirections: bool = True
2302 ) -> List[Any]:
2303 """
2304 Execute the pipeline.
2305
2306 It will retry the commands as specified by retries specified in :attr:`retry`
2307 & then raise an exception.
2308
2309 :param raise_on_error:
2310 | Raise the first error if there are any errors
2311 :param allow_redirections:
2312 | Whether to retry each failed command individually in case of redirection
2313 errors
2314
2315 :raises RedisClusterException: if target_nodes is not provided & the command
2316 can't be mapped to a slot
2317 """
2318 try:
2319 return await self._execution_strategy.execute(
2320 raise_on_error, allow_redirections
2321 )
2322 finally:
2323 await self.reset()
2324
2325 def _split_command_across_slots(
2326 self, command: str, *keys: KeyT
2327 ) -> "ClusterPipeline":
2328 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
2329 self.execute_command(command, *slot_keys)
2330
2331 return self
2332
2333 async def reset(self):
2334 """
2335 Reset back to empty pipeline.
2336 """
2337 await self._execution_strategy.reset()
2338
2339 def multi(self):
2340 """
2341 Start a transactional block of the pipeline after WATCH commands
2342 are issued. End the transactional block with `execute`.
2343 """
2344 self._execution_strategy.multi()
2345
2346 async def discard(self):
2347 """ """
2348 await self._execution_strategy.discard()
2349
2350 async def watch(self, *names):
2351 """Watches the values at keys ``names``"""
2352 await self._execution_strategy.watch(*names)
2353
2354 async def unwatch(self):
2355 """Unwatches all previously specified keys"""
2356 await self._execution_strategy.unwatch()
2357
2358 async def unlink(self, *names):
2359 await self._execution_strategy.unlink(*names)
2360
2361 def mset_nonatomic(
2362 self, mapping: Mapping[AnyKeyT, EncodableT]
2363 ) -> "ClusterPipeline":
2364 return self._execution_strategy.mset_nonatomic(mapping)
2365
2366
2367for command in PIPELINE_BLOCKED_COMMANDS:
2368 command = command.replace(" ", "_").lower()
2369 if command == "mset_nonatomic":
2370 continue
2371
2372 setattr(ClusterPipeline, command, block_pipeline_command(command))
2373
2374
2375class PipelineCommand:
2376 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
2377 self.args = args
2378 self.kwargs = kwargs
2379 self.position = position
2380 self.result: Union[Any, Exception] = None
2381 self.command_policies: Optional[CommandPolicies] = None
2382
2383 def __repr__(self) -> str:
2384 return f"[{self.position}] {self.args} ({self.kwargs})"
2385
2386
2387class ExecutionStrategy(ABC):
2388 @abstractmethod
2389 async def initialize(self) -> "ClusterPipeline":
2390 """
2391 Initialize the execution strategy.
2392
2393 See ClusterPipeline.initialize()
2394 """
2395 pass
2396
2397 @abstractmethod
2398 def execute_command(
2399 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2400 ) -> "ClusterPipeline":
2401 """
2402 Append a raw command to the pipeline.
2403
2404 See ClusterPipeline.execute_command()
2405 """
2406 pass
2407
2408 @abstractmethod
2409 async def execute(
2410 self, raise_on_error: bool = True, allow_redirections: bool = True
2411 ) -> List[Any]:
2412 """
2413 Execute the pipeline.
2414
2415 It will retry the commands as specified by retries specified in :attr:`retry`
2416 & then raise an exception.
2417
2418 See ClusterPipeline.execute()
2419 """
2420 pass
2421
2422 @abstractmethod
2423 def mset_nonatomic(
2424 self, mapping: Mapping[AnyKeyT, EncodableT]
2425 ) -> "ClusterPipeline":
2426 """
2427 Executes multiple MSET commands according to the provided slot/pairs mapping.
2428
2429 See ClusterPipeline.mset_nonatomic()
2430 """
2431 pass
2432
2433 @abstractmethod
2434 async def reset(self):
2435 """
2436 Resets current execution strategy.
2437
2438 See: ClusterPipeline.reset()
2439 """
2440 pass
2441
2442 @abstractmethod
2443 def multi(self):
2444 """
2445 Starts transactional context.
2446
2447 See: ClusterPipeline.multi()
2448 """
2449 pass
2450
2451 @abstractmethod
2452 async def watch(self, *names):
2453 """
2454 Watch given keys.
2455
2456 See: ClusterPipeline.watch()
2457 """
2458 pass
2459
2460 @abstractmethod
2461 async def unwatch(self):
2462 """
2463 Unwatches all previously specified keys
2464
2465 See: ClusterPipeline.unwatch()
2466 """
2467 pass
2468
2469 @abstractmethod
2470 async def discard(self):
2471 pass
2472
2473 @abstractmethod
2474 async def unlink(self, *names):
2475 """
2476 "Unlink a key specified by ``names``"
2477
2478 See: ClusterPipeline.unlink()
2479 """
2480 pass
2481
2482 @abstractmethod
2483 def __len__(self) -> int:
2484 pass
2485
2486
2487class AbstractStrategy(ExecutionStrategy):
2488 def __init__(self, pipe: ClusterPipeline) -> None:
2489 self._pipe: ClusterPipeline = pipe
2490 self._command_queue: List["PipelineCommand"] = []
2491
2492 async def initialize(self) -> "ClusterPipeline":
2493 if self._pipe.cluster_client._initialize:
2494 await self._pipe.cluster_client.initialize()
2495 self._command_queue = []
2496 return self._pipe
2497
2498 def execute_command(
2499 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2500 ) -> "ClusterPipeline":
2501 self._command_queue.append(
2502 PipelineCommand(len(self._command_queue), *args, **kwargs)
2503 )
2504 return self._pipe
2505
2506 def _annotate_exception(self, exception, number, command):
2507 """
2508 Provides extra context to the exception prior to it being handled
2509 """
2510 cmd = " ".join(map(safe_str, command))
2511 msg = (
2512 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
2513 f"caused error: {exception.args[0]}"
2514 )
2515 exception.args = (msg,) + exception.args[1:]
2516
2517 @abstractmethod
2518 def mset_nonatomic(
2519 self, mapping: Mapping[AnyKeyT, EncodableT]
2520 ) -> "ClusterPipeline":
2521 pass
2522
2523 @abstractmethod
2524 async def execute(
2525 self, raise_on_error: bool = True, allow_redirections: bool = True
2526 ) -> List[Any]:
2527 pass
2528
2529 @abstractmethod
2530 async def reset(self):
2531 pass
2532
2533 @abstractmethod
2534 def multi(self):
2535 pass
2536
2537 @abstractmethod
2538 async def watch(self, *names):
2539 pass
2540
2541 @abstractmethod
2542 async def unwatch(self):
2543 pass
2544
2545 @abstractmethod
2546 async def discard(self):
2547 pass
2548
2549 @abstractmethod
2550 async def unlink(self, *names):
2551 pass
2552
2553 def __len__(self) -> int:
2554 return len(self._command_queue)
2555
2556
2557class PipelineStrategy(AbstractStrategy):
2558 def __init__(self, pipe: ClusterPipeline) -> None:
2559 super().__init__(pipe)
2560
2561 def mset_nonatomic(
2562 self, mapping: Mapping[AnyKeyT, EncodableT]
2563 ) -> "ClusterPipeline":
2564 encoder = self._pipe.cluster_client.encoder
2565
2566 slots_pairs = {}
2567 for pair in mapping.items():
2568 slot = key_slot(encoder.encode(pair[0]))
2569 slots_pairs.setdefault(slot, []).extend(pair)
2570
2571 for pairs in slots_pairs.values():
2572 self.execute_command("MSET", *pairs)
2573
2574 return self._pipe
2575
2576 async def execute(
2577 self, raise_on_error: bool = True, allow_redirections: bool = True
2578 ) -> List[Any]:
2579 if not self._command_queue:
2580 return []
2581
2582 try:
2583 retry_attempts = self._pipe.cluster_client.retry.get_retries()
2584 while True:
2585 try:
2586 if self._pipe.cluster_client._initialize:
2587 await self._pipe.cluster_client.initialize()
2588 return await self._execute(
2589 self._pipe.cluster_client,
2590 self._command_queue,
2591 raise_on_error=raise_on_error,
2592 allow_redirections=allow_redirections,
2593 )
2594
2595 except RedisCluster.ERRORS_ALLOW_RETRY as e:
2596 if retry_attempts > 0:
2597 # Try again with the new cluster setup. All other errors
2598 # should be raised.
2599 retry_attempts -= 1
2600 await self._pipe.cluster_client.aclose()
2601 await asyncio.sleep(0.25)
2602 else:
2603 # All other errors should be raised.
2604 raise e
2605 finally:
2606 await self.reset()
2607
2608 async def _execute(
2609 self,
2610 client: "RedisCluster",
2611 stack: List["PipelineCommand"],
2612 raise_on_error: bool = True,
2613 allow_redirections: bool = True,
2614 ) -> List[Any]:
2615 todo = [
2616 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
2617 ]
2618
2619 nodes = {}
2620 for cmd in todo:
2621 passed_targets = cmd.kwargs.pop("target_nodes", None)
2622 command_policies = await client._policy_resolver.resolve(
2623 cmd.args[0].lower()
2624 )
2625
2626 if passed_targets and not client._is_node_flag(passed_targets):
2627 target_nodes = client._parse_target_nodes(passed_targets)
2628
2629 if not command_policies:
2630 command_policies = CommandPolicies()
2631 else:
2632 if not command_policies:
2633 command_flag = client.command_flags.get(cmd.args[0])
2634 if not command_flag:
2635 # Fallback to default policy
2636 if not client.get_default_node():
2637 slot = None
2638 else:
2639 slot = await client._determine_slot(*cmd.args)
2640 if slot is None:
2641 command_policies = CommandPolicies()
2642 else:
2643 command_policies = CommandPolicies(
2644 request_policy=RequestPolicy.DEFAULT_KEYED,
2645 response_policy=ResponsePolicy.DEFAULT_KEYED,
2646 )
2647 else:
2648 if command_flag in client._command_flags_mapping:
2649 command_policies = CommandPolicies(
2650 request_policy=client._command_flags_mapping[
2651 command_flag
2652 ]
2653 )
2654 else:
2655 command_policies = CommandPolicies()
2656
2657 target_nodes = await client._determine_nodes(
2658 *cmd.args,
2659 request_policy=command_policies.request_policy,
2660 node_flag=passed_targets,
2661 )
2662 if not target_nodes:
2663 raise RedisClusterException(
2664 f"No targets were found to execute {cmd.args} command on"
2665 )
2666 cmd.command_policies = command_policies
2667 if len(target_nodes) > 1:
2668 raise RedisClusterException(f"Too many targets for command {cmd.args}")
2669 node = target_nodes[0]
2670 if node.name not in nodes:
2671 nodes[node.name] = (node, [])
2672 nodes[node.name][1].append(cmd)
2673
2674 # Start timing for observability
2675 start_time = time.monotonic()
2676
2677 errors = await asyncio.gather(
2678 *(
2679 asyncio.create_task(node[0].execute_pipeline(node[1]))
2680 for node in nodes.values()
2681 )
2682 )
2683
2684 # Record operation duration for each node
2685 for node_name, (node, commands) in nodes.items():
2686 # Find the first error in this node's commands, if any
2687 node_error = None
2688 for cmd in commands:
2689 if isinstance(cmd.result, Exception):
2690 node_error = cmd.result
2691 break
2692
2693 db = node.connection_kwargs.get("db", 0)
2694 await record_operation_duration(
2695 command_name="PIPELINE",
2696 duration_seconds=time.monotonic() - start_time,
2697 server_address=node.host,
2698 server_port=node.port,
2699 db_namespace=str(db) if db is not None else None,
2700 error=node_error,
2701 )
2702
2703 if any(errors):
2704 if allow_redirections:
2705 # send each errored command individually
2706 for cmd in todo:
2707 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2708 try:
2709 cmd.result = client._policies_callback_mapping[
2710 cmd.command_policies.response_policy
2711 ](await client.execute_command(*cmd.args, **cmd.kwargs))
2712 except Exception as e:
2713 cmd.result = e
2714
2715 if raise_on_error:
2716 for cmd in todo:
2717 result = cmd.result
2718 if isinstance(result, Exception):
2719 command = " ".join(map(safe_str, cmd.args))
2720 msg = (
2721 f"Command # {cmd.position + 1} "
2722 f"({truncate_text(command)}) "
2723 f"of pipeline caused error: {result.args}"
2724 )
2725 result.args = (msg,) + result.args[1:]
2726 raise result
2727
2728 default_cluster_node = client.get_default_node()
2729
2730 # Check whether the default node was used. In some cases,
2731 # 'client.get_default_node()' may return None. The check below
2732 # prevents a potential AttributeError.
2733 if default_cluster_node is not None:
2734 default_node = nodes.get(default_cluster_node.name)
2735 if default_node is not None:
2736 # This pipeline execution used the default node, check if we need
2737 # to replace it.
2738 # Note: when the error is raised we'll reset the default node in the
2739 # caller function.
2740 for cmd in default_node[1]:
2741 # Check if it has a command that failed with a relevant
2742 # exception
2743 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2744 client.replace_default_node()
2745 break
2746
2747 return [cmd.result for cmd in stack]
2748
2749 async def reset(self):
2750 """
2751 Reset back to empty pipeline.
2752 """
2753 self._command_queue = []
2754
2755 def multi(self):
2756 raise RedisClusterException(
2757 "method multi() is not supported outside of transactional context"
2758 )
2759
2760 async def watch(self, *names):
2761 raise RedisClusterException(
2762 "method watch() is not supported outside of transactional context"
2763 )
2764
2765 async def unwatch(self):
2766 raise RedisClusterException(
2767 "method unwatch() is not supported outside of transactional context"
2768 )
2769
2770 async def discard(self):
2771 raise RedisClusterException(
2772 "method discard() is not supported outside of transactional context"
2773 )
2774
2775 async def unlink(self, *names):
2776 if len(names) != 1:
2777 raise RedisClusterException(
2778 "unlinking multiple keys is not implemented in pipeline command"
2779 )
2780
2781 return self.execute_command("UNLINK", names[0])
2782
2783
2784class TransactionStrategy(AbstractStrategy):
2785 NO_SLOTS_COMMANDS = {"UNWATCH"}
2786 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2787 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2788 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2789 CONNECTION_ERRORS = (
2790 ConnectionError,
2791 OSError,
2792 ClusterDownError,
2793 SlotNotCoveredError,
2794 )
2795
2796 def __init__(self, pipe: ClusterPipeline) -> None:
2797 super().__init__(pipe)
2798 self._explicit_transaction = False
2799 self._watching = False
2800 self._pipeline_slots: Set[int] = set()
2801 self._transaction_node: Optional[ClusterNode] = None
2802 self._transaction_connection: Optional[Connection] = None
2803 self._executing = False
2804 self._retry = copy(self._pipe.cluster_client.retry)
2805 self._retry.update_supported_errors(
2806 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2807 )
2808
2809 def _get_client_and_connection_for_transaction(
2810 self,
2811 ) -> Tuple[ClusterNode, Connection]:
2812 """
2813 Find a connection for a pipeline transaction.
2814
2815 For running an atomic transaction, watch keys ensure that contents have not been
2816 altered as long as the watch commands for those keys were sent over the same
2817 connection. So once we start watching a key, we fetch a connection to the
2818 node that owns that slot and reuse it.
2819 """
2820 if not self._pipeline_slots:
2821 raise RedisClusterException(
2822 "At least a command with a key is needed to identify a node"
2823 )
2824
2825 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2826 list(self._pipeline_slots)[0], False
2827 )
2828 self._transaction_node = node
2829
2830 if not self._transaction_connection:
2831 connection: Connection = self._transaction_node.acquire_connection()
2832 self._transaction_connection = connection
2833
2834 return self._transaction_node, self._transaction_connection
2835
2836 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2837 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2838 response = None
2839 error = None
2840
2841 def runner():
2842 nonlocal response
2843 nonlocal error
2844 try:
2845 response = asyncio.run(self._execute_command(*args, **kwargs))
2846 except Exception as e:
2847 error = e
2848
2849 thread = threading.Thread(target=runner)
2850 thread.start()
2851 thread.join()
2852
2853 if error:
2854 raise error
2855
2856 return response
2857
2858 async def _execute_command(
2859 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2860 ) -> Any:
2861 if self._pipe.cluster_client._initialize:
2862 await self._pipe.cluster_client.initialize()
2863
2864 slot_number: Optional[int] = None
2865 if args[0] not in self.NO_SLOTS_COMMANDS:
2866 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2867
2868 if (
2869 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2870 ) and not self._explicit_transaction:
2871 if args[0] == "WATCH":
2872 self._validate_watch()
2873
2874 if slot_number is not None:
2875 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2876 raise CrossSlotTransactionError(
2877 "Cannot watch or send commands on different slots"
2878 )
2879
2880 self._pipeline_slots.add(slot_number)
2881 elif args[0] not in self.NO_SLOTS_COMMANDS:
2882 raise RedisClusterException(
2883 f"Cannot identify slot number for command: {args[0]},"
2884 "it cannot be triggered in a transaction"
2885 )
2886
2887 return self._immediate_execute_command(*args, **kwargs)
2888 else:
2889 if slot_number is not None:
2890 self._pipeline_slots.add(slot_number)
2891
2892 return super().execute_command(*args, **kwargs)
2893
2894 def _validate_watch(self):
2895 if self._explicit_transaction:
2896 raise RedisError("Cannot issue a WATCH after a MULTI")
2897
2898 self._watching = True
2899
2900 async def _immediate_execute_command(self, *args, **options):
2901 return await self._retry.call_with_retry(
2902 lambda: self._get_connection_and_send_command(*args, **options),
2903 self._reinitialize_on_error,
2904 with_failure_count=True,
2905 )
2906
2907 async def _get_connection_and_send_command(self, *args, **options):
2908 redis_node, connection = self._get_client_and_connection_for_transaction()
2909 # Only disconnect if not watching - disconnecting would lose WATCH state
2910 if not self._watching:
2911 await redis_node.disconnect_if_needed(connection)
2912
2913 # Start timing for observability
2914 start_time = time.monotonic()
2915
2916 try:
2917 response = await self._send_command_parse_response(
2918 connection, redis_node, args[0], *args, **options
2919 )
2920
2921 await record_operation_duration(
2922 command_name=args[0],
2923 duration_seconds=time.monotonic() - start_time,
2924 server_address=connection.host,
2925 server_port=connection.port,
2926 db_namespace=str(connection.db),
2927 )
2928
2929 return response
2930 except Exception as e:
2931 e.connection = connection
2932 await record_operation_duration(
2933 command_name=args[0],
2934 duration_seconds=time.monotonic() - start_time,
2935 server_address=connection.host,
2936 server_port=connection.port,
2937 db_namespace=str(connection.db),
2938 error=e,
2939 )
2940 raise
2941
2942 async def _send_command_parse_response(
2943 self,
2944 connection: Connection,
2945 redis_node: ClusterNode,
2946 command_name,
2947 *args,
2948 **options,
2949 ):
2950 """
2951 Send a command and parse the response
2952 """
2953
2954 await connection.send_command(*args)
2955 output = await redis_node.parse_response(connection, command_name, **options)
2956
2957 if command_name in self.UNWATCH_COMMANDS:
2958 self._watching = False
2959 return output
2960
2961 async def _reinitialize_on_error(self, error, failure_count):
2962 if hasattr(error, "connection"):
2963 await record_error_count(
2964 server_address=error.connection.host,
2965 server_port=error.connection.port,
2966 network_peer_address=error.connection.host,
2967 network_peer_port=error.connection.port,
2968 error_type=error,
2969 retry_attempts=failure_count,
2970 is_internal=True,
2971 )
2972
2973 if self._watching:
2974 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2975 raise WatchError("Slot rebalancing occurred while watching keys")
2976
2977 if (
2978 type(error) in self.SLOT_REDIRECT_ERRORS
2979 or type(error) in self.CONNECTION_ERRORS
2980 ):
2981 if self._transaction_connection and self._transaction_node:
2982 # Disconnect and release back to pool
2983 await self._transaction_connection.disconnect()
2984 self._transaction_node.release(self._transaction_connection)
2985 self._transaction_connection = None
2986
2987 self._pipe.cluster_client.reinitialize_counter += 1
2988 if (
2989 self._pipe.cluster_client.reinitialize_steps
2990 and self._pipe.cluster_client.reinitialize_counter
2991 % self._pipe.cluster_client.reinitialize_steps
2992 == 0
2993 ):
2994 await self._pipe.cluster_client.nodes_manager.initialize()
2995 self.reinitialize_counter = 0
2996 else:
2997 if isinstance(error, AskError):
2998 await self._pipe.cluster_client.nodes_manager.move_slot(error)
2999
3000 self._executing = False
3001
3002 async def _raise_first_error(self, responses, stack, start_time):
3003 """
3004 Raise the first exception on the stack
3005 """
3006 for r, cmd in zip(responses, stack):
3007 if isinstance(r, Exception):
3008 self._annotate_exception(r, cmd.position + 1, cmd.args)
3009
3010 await record_operation_duration(
3011 command_name="TRANSACTION",
3012 duration_seconds=time.monotonic() - start_time,
3013 server_address=self._transaction_connection.host,
3014 server_port=self._transaction_connection.port,
3015 db_namespace=str(self._transaction_connection.db),
3016 error=r,
3017 )
3018
3019 raise r
3020
3021 def mset_nonatomic(
3022 self, mapping: Mapping[AnyKeyT, EncodableT]
3023 ) -> "ClusterPipeline":
3024 raise NotImplementedError("Method is not supported in transactional context.")
3025
3026 async def execute(
3027 self, raise_on_error: bool = True, allow_redirections: bool = True
3028 ) -> List[Any]:
3029 stack = self._command_queue
3030 if not stack and (not self._watching or not self._pipeline_slots):
3031 return []
3032
3033 return await self._execute_transaction_with_retries(stack, raise_on_error)
3034
3035 async def _execute_transaction_with_retries(
3036 self, stack: List["PipelineCommand"], raise_on_error: bool
3037 ):
3038 return await self._retry.call_with_retry(
3039 lambda: self._execute_transaction(stack, raise_on_error),
3040 lambda error, failure_count: self._reinitialize_on_error(
3041 error, failure_count
3042 ),
3043 with_failure_count=True,
3044 )
3045
3046 async def _execute_transaction(
3047 self, stack: List["PipelineCommand"], raise_on_error: bool
3048 ):
3049 if len(self._pipeline_slots) > 1:
3050 raise CrossSlotTransactionError(
3051 "All keys involved in a cluster transaction must map to the same slot"
3052 )
3053
3054 self._executing = True
3055
3056 redis_node, connection = self._get_client_and_connection_for_transaction()
3057 # Only disconnect if not watching - disconnecting would lose WATCH state
3058 if not self._watching:
3059 await redis_node.disconnect_if_needed(connection)
3060
3061 stack = chain(
3062 [PipelineCommand(0, "MULTI")],
3063 stack,
3064 [PipelineCommand(0, "EXEC")],
3065 )
3066 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
3067 packed_commands = connection.pack_commands(commands)
3068
3069 # Start timing for observability
3070 start_time = time.monotonic()
3071
3072 await connection.send_packed_command(packed_commands)
3073 errors = []
3074
3075 # parse off the response for MULTI
3076 # NOTE: we need to handle ResponseErrors here and continue
3077 # so that we read all the additional command messages from
3078 # the socket
3079 try:
3080 await redis_node.parse_response(connection, "MULTI")
3081 except ResponseError as e:
3082 self._annotate_exception(e, 0, "MULTI")
3083 errors.append(e)
3084 except self.CONNECTION_ERRORS as cluster_error:
3085 self._annotate_exception(cluster_error, 0, "MULTI")
3086 cluster_error.connection = connection
3087 raise
3088
3089 # and all the other commands
3090 for i, command in enumerate(self._command_queue):
3091 if EMPTY_RESPONSE in command.kwargs:
3092 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
3093 else:
3094 try:
3095 _ = await redis_node.parse_response(connection, "_")
3096 except self.SLOT_REDIRECT_ERRORS as slot_error:
3097 self._annotate_exception(slot_error, i + 1, command.args)
3098 errors.append(slot_error)
3099 except self.CONNECTION_ERRORS as cluster_error:
3100 self._annotate_exception(cluster_error, i + 1, command.args)
3101 cluster_error.connection = connection
3102 raise
3103 except ResponseError as e:
3104 self._annotate_exception(e, i + 1, command.args)
3105 errors.append(e)
3106
3107 response = None
3108 # parse the EXEC.
3109 try:
3110 response = await redis_node.parse_response(connection, "EXEC")
3111 except ExecAbortError:
3112 if errors:
3113 raise errors[0]
3114 raise
3115
3116 self._executing = False
3117
3118 # EXEC clears any watched keys
3119 self._watching = False
3120
3121 if response is None:
3122 raise WatchError("Watched variable changed.")
3123
3124 # put any parse errors into the response
3125 for i, e in errors:
3126 response.insert(i, e)
3127
3128 if len(response) != len(self._command_queue):
3129 raise InvalidPipelineStack(
3130 "Unexpected response length for cluster pipeline EXEC."
3131 " Command stack was {} but response had length {}".format(
3132 [c.args[0] for c in self._command_queue], len(response)
3133 )
3134 )
3135
3136 # find any errors in the response and raise if necessary
3137 if raise_on_error or len(errors) > 0:
3138 await self._raise_first_error(
3139 response,
3140 self._command_queue,
3141 start_time,
3142 )
3143
3144 # We have to run response callbacks manually
3145 data = []
3146 for r, cmd in zip(response, self._command_queue):
3147 if not isinstance(r, Exception):
3148 command_name = cmd.args[0]
3149 if command_name in self._pipe.cluster_client.response_callbacks:
3150 r = self._pipe.cluster_client.response_callbacks[command_name](
3151 r, **cmd.kwargs
3152 )
3153 data.append(r)
3154
3155 await record_operation_duration(
3156 command_name="TRANSACTION",
3157 duration_seconds=time.monotonic() - start_time,
3158 server_address=connection.host,
3159 server_port=connection.port,
3160 db_namespace=str(connection.db),
3161 )
3162
3163 return data
3164
3165 async def reset(self):
3166 self._command_queue = []
3167
3168 # make sure to reset the connection state in the event that we were
3169 # watching something
3170 if self._transaction_connection:
3171 try:
3172 if self._watching:
3173 # call this manually since our unwatch or
3174 # immediate_execute_command methods can call reset()
3175 await self._transaction_connection.send_command("UNWATCH")
3176 await self._transaction_connection.read_response()
3177 # we can safely return the connection to the pool here since we're
3178 # sure we're no longer WATCHing anything
3179 await self._transaction_node.disconnect_if_needed(
3180 self._transaction_connection
3181 )
3182 self._transaction_node.release(self._transaction_connection)
3183 self._transaction_connection = None
3184 except self.CONNECTION_ERRORS:
3185 # disconnect will also remove any previous WATCHes
3186 if self._transaction_connection and self._transaction_node:
3187 await self._transaction_connection.disconnect()
3188 self._transaction_node.release(self._transaction_connection)
3189 self._transaction_connection = None
3190
3191 # clean up the other instance attributes
3192 self._transaction_node = None
3193 self._watching = False
3194 self._explicit_transaction = False
3195 self._pipeline_slots = set()
3196 self._executing = False
3197
3198 def multi(self):
3199 if self._explicit_transaction:
3200 raise RedisError("Cannot issue nested calls to MULTI")
3201 if self._command_queue:
3202 raise RedisError(
3203 "Commands without an initial WATCH have already been issued"
3204 )
3205 self._explicit_transaction = True
3206
3207 async def watch(self, *names):
3208 if self._explicit_transaction:
3209 raise RedisError("Cannot issue a WATCH after a MULTI")
3210
3211 return await self.execute_command("WATCH", *names)
3212
3213 async def unwatch(self):
3214 if self._watching:
3215 return await self.execute_command("UNWATCH")
3216
3217 return True
3218
3219 async def discard(self):
3220 await self.reset()
3221
3222 async def unlink(self, *names):
3223 return self.execute_command("UNLINK", *names)
3224
3225
3226class _ClusterNodePoolAdapter(ConnectionPoolInterface):
3227 """Thin adapter exposing the :class:`ConnectionPoolInterface` that
3228 :class:`PubSub` requires, backed by a :class:`ClusterNode`'s own
3229 connection pool.
3230
3231 Connections are acquired from the node via
3232 :meth:`ClusterNode.acquire_connection` and returned via
3233 :meth:`ClusterNode.release`. :meth:`PubSub.aclose` already
3234 disconnects the connection *before* calling :meth:`release`, so the
3235 connection is returned to the node's free-queue in a disconnected
3236 state — guaranteeing that a subscribed socket is never silently
3237 reused for regular commands.
3238
3239 Methods that do not apply to this adapter (the underlying node's
3240 lifecycle is managed by the cluster, not by individual PubSub
3241 instances) are implemented as no-ops so the adapter remains a valid
3242 :class:`ConnectionPoolInterface`.
3243 """
3244
3245 def __init__(self, node: "ClusterNode") -> None:
3246 self._node = node
3247 self.connection_kwargs = node.connection_kwargs
3248
3249 # -- methods used by PubSub ------------------------------------------------
3250
3251 def get_encoder(self) -> Encoder:
3252 return self._node.get_encoder()
3253
3254 async def get_connection(
3255 self, command_name: Optional[str] = None, *keys: Any, **options: Any
3256 ) -> AbstractConnection:
3257 connection = self._node.acquire_connection()
3258 try:
3259 await connection.connect()
3260 except BaseException:
3261 # connect() may fail mid-handshake (e.g. after the TCP socket
3262 # is established but before AUTH/HELLO completes) leaving the
3263 # connection in a partially-connected state. Disconnect before
3264 # returning it to the node's free queue so it is not reused.
3265 await connection.disconnect()
3266 self._node.release(connection)
3267 raise
3268 return connection
3269
3270 async def release(self, connection: AbstractConnection) -> None:
3271 # PubSub.aclose() disconnects the connection before calling
3272 # release(), so it is safe to put it back in the node's free
3273 # queue – it will reconnect lazily on next use.
3274 await self._node.disconnect_if_needed(connection)
3275 self._node.release(connection)
3276
3277 # -- no-op stubs for the rest of ConnectionPoolInterface -------------------
3278 # The node's connections are shared with regular cluster traffic and its
3279 # lifecycle is managed by RedisCluster / NodesManager, so the adapter must
3280 # not reset, disconnect, retry-configure or re-auth them on behalf of a
3281 # single PubSub instance.
3282
3283 def get_protocol(self):
3284 return self.connection_kwargs.get("protocol", None)
3285
3286 def reset(self) -> None:
3287 pass
3288
3289 async def disconnect(self, inuse_connections: bool = True) -> None:
3290 pass
3291
3292 async def aclose(self) -> None:
3293 pass
3294
3295 def set_retry(self, retry: "Retry") -> None:
3296 pass
3297
3298 async def re_auth_callback(self, token: TokenInterface) -> None:
3299 pass
3300
3301 def get_connection_count(self) -> List[Tuple[int, dict]]:
3302 return []
3303
3304
3305def _unregister_slots_cache_listener(
3306 dispatcher_ref: "weakref.ref[EventDispatcher]",
3307 listener: AsyncEventListenerInterface,
3308 event_type: Type[object],
3309) -> None:
3310 # Module-level finalizer callback. Kept free of strong references to the
3311 # owning ClusterPubSub so attaching it via weakref.finalize does not
3312 # extend the pubsub's lifetime.
3313 dispatcher = dispatcher_ref()
3314 if dispatcher is not None:
3315 dispatcher.unregister_listeners({event_type: [listener]})
3316
3317
3318class ClusterPubSubSlotsCacheListener(AsyncEventListenerInterface):
3319 """
3320 Async listener that forwards AsyncAfterSlotsCacheRefreshEvent to a
3321 ClusterPubSub.
3322
3323 Holds a weak reference to the pubsub so it does not keep the instance
3324 alive. Deterministic cleanup of the dispatcher's strong reference to this
3325 listener is performed by a ``weakref.finalize`` attached to the owning
3326 ClusterPubSub in ``ClusterPubSub.__init__``.
3327 """
3328
3329 def __init__(self, pubsub: "ClusterPubSub") -> None:
3330 self._pubsub_ref: "weakref.ref[ClusterPubSub]" = weakref.ref(pubsub)
3331
3332 async def listen(self, event: object) -> None:
3333 pubsub = self._pubsub_ref()
3334 if pubsub is None:
3335 # Race window between pubsub GC and the finalizer running; safe
3336 # no-op, finalizer will remove this listener shortly.
3337 return
3338 try:
3339 await pubsub.on_slots_changed()
3340 except Exception as e:
3341 # Listeners must not break slots-cache refresh; log and continue so
3342 # a single buggy pubsub cannot starve the rest.
3343 logger.exception(
3344 "pubsub %r raised during slots-cache change: %s: %s",
3345 pubsub,
3346 type(e).__name__,
3347 e,
3348 )
3349
3350
3351class ClusterPubSub(PubSub):
3352 """
3353 Async cluster implementation for pub/sub.
3354
3355 IMPORTANT: before using ClusterPubSub, read about the known limitations
3356 with pubsub in Cluster mode and learn how to workaround them:
3357 https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations
3358 """
3359
3360 def __init__(
3361 self,
3362 redis_cluster: "RedisCluster",
3363 node: Optional["ClusterNode"] = None,
3364 host: Optional[str] = None,
3365 port: Optional[int] = None,
3366 push_handler_func: Optional[Callable] = None,
3367 event_dispatcher: Optional[EventDispatcher] = None,
3368 **kwargs: Any,
3369 ) -> None:
3370 """
3371 When a pubsub instance is created without specifying a node, a single
3372 node will be transparently chosen for the pubsub connection on the
3373 first command execution. The node will be determined by:
3374 1. Hashing the channel name in the request to find its keyslot
3375 2. Selecting a node that handles the keyslot: If read_from_replicas is
3376 set to true or load_balancing_strategy is set, a replica can be selected.
3377
3378 :param redis_cluster: RedisCluster instance
3379 :param node: ClusterNode to connect to
3380 :param host: Host of the node to connect to
3381 :param port: Port of the node to connect to
3382 :param push_handler_func: Optional push handler function
3383 :param event_dispatcher: Optional event dispatcher
3384 :param kwargs: Additional keyword arguments
3385 """
3386 self.node = None
3387 self.set_pubsub_node(redis_cluster, node, host, port)
3388
3389 # Borrow the node's own connection pool via an adapter rather than
3390 # creating a second, detached ConnectionPool for pubsub.
3391 if self.node is not None:
3392 connection_pool = _ClusterNodePoolAdapter(self.node)
3393 else:
3394 connection_pool = None
3395
3396 self.cluster = redis_cluster
3397 self.node_pubsub_mapping: Dict[str, PubSub] = {}
3398 # Reverse index: shard channel (normalized) -> owning node.name. Used to
3399 # route sunsubscribe calls and reconcile subscriptions after slot
3400 # migration / failover.
3401 self._shard_channel_to_node: Dict[Any, str] = {}
3402 # Dedicated lock for shard-subscription bookkeeping. Distinct from
3403 # PubSub.self._lock (which serializes wire I/O on the cluster-level
3404 # connection used by aclose / send_command / regular subscribe) so
3405 # that reconciliation cannot starve those unrelated coroutines
3406 # during long per-channel migrations.
3407 self._shard_state_lock: asyncio.Lock = asyncio.Lock()
3408 # Background tasks created by on_slots_changed; kept to prevent GC.
3409 self._reconcile_tasks: Set[asyncio.Task] = set()
3410 self._pubsubs_generator = self._pubsubs_generator()
3411 if event_dispatcher is None:
3412 self._event_dispatcher = EventDispatcher()
3413 else:
3414 self._event_dispatcher = event_dispatcher
3415 super().__init__(
3416 connection_pool=connection_pool,
3417 encoder=redis_cluster.encoder,
3418 push_handler_func=push_handler_func,
3419 event_dispatcher=self._event_dispatcher,
3420 **kwargs,
3421 )
3422 # Subscribe to slots-cache change notifications so shard subscriptions
3423 # can be reconciled automatically after topology refreshes.
3424 nm_dispatcher = redis_cluster.nodes_manager._event_dispatcher
3425 self._slots_cache_listener = ClusterPubSubSlotsCacheListener(self)
3426 nm_dispatcher.register_listeners(
3427 {AsyncAfterSlotsCacheRefreshEvent: [self._slots_cache_listener]}
3428 )
3429 # Deterministic GC-time cleanup so short-lived pubsubs do not leak
3430 # listeners in the dispatcher when no slots-refresh event ever fires.
3431 weakref.finalize(
3432 self,
3433 _unregister_slots_cache_listener,
3434 weakref.ref(nm_dispatcher),
3435 self._slots_cache_listener,
3436 AsyncAfterSlotsCacheRefreshEvent,
3437 )
3438
3439 def set_pubsub_node(
3440 self,
3441 cluster: "RedisCluster",
3442 node: Optional["ClusterNode"] = None,
3443 host: Optional[str] = None,
3444 port: Optional[int] = None,
3445 ) -> None:
3446 """
3447 The pubsub node will be set according to the passed node, host and port
3448 When none of the node, host, or port are specified - the node is set
3449 to None and will be determined by the keyslot of the channel in the
3450 first command to be executed.
3451 RedisClusterException will be thrown if the passed node does not exist
3452 in the cluster.
3453 If host is passed without port, or vice versa, a DataError will be
3454 thrown.
3455 """
3456 if node is not None:
3457 # node is passed by the user
3458 self._raise_on_invalid_node(cluster, node, node.host, node.port)
3459 pubsub_node = node
3460 elif host is not None and port is not None:
3461 # host and port passed by the user
3462 node = cluster.get_node(host=host, port=port)
3463 self._raise_on_invalid_node(cluster, node, host, port)
3464 pubsub_node = node
3465 elif host is not None or port is not None:
3466 # only one of host and port is specified
3467 raise DataError("Specify both host and port")
3468 else:
3469 # nothing specified by the user
3470 pubsub_node = None
3471 self.node = pubsub_node
3472
3473 def get_pubsub_node(self) -> Optional["ClusterNode"]:
3474 """
3475 Get the node that is being used as the pubsub connection.
3476
3477 :return: The ClusterNode being used for pubsub, or None if not yet determined
3478 """
3479 return self.node
3480
3481 async def _resubscribe_shard_channels(self) -> None:
3482 # A single node can own multiple slot ranges, so a batched
3483 # ``SSUBSCRIBE`` covering every tracked channel would be rejected by
3484 # Redis with a ``CROSSSLOT`` error. Group by hash slot and emit one
3485 # ``SSUBSCRIBE`` per slot.
3486 by_slot: defaultdict[int, dict] = defaultdict(dict)
3487 for k, v in self.shard_channels.items():
3488 by_slot[key_slot(self.encoder.encode(k))][k] = v
3489 for subscriptions in by_slot.values():
3490 await self._resubscribe(subscriptions, self.ssubscribe)
3491
3492 def _get_node_pubsub(self, node: "ClusterNode") -> PubSub:
3493 """Get or create a PubSub instance for the given node."""
3494 try:
3495 return self.node_pubsub_mapping[node.name]
3496 except KeyError:
3497 pubsub = PubSub(
3498 connection_pool=_ClusterNodePoolAdapter(node),
3499 encoder=self.cluster.encoder,
3500 push_handler_func=self.push_handler_func,
3501 event_dispatcher=self._event_dispatcher,
3502 )
3503 # Replay shard subscriptions on reconnect with slot-aware grouping
3504 # so that channels spanning multiple slots owned by this node do
3505 # not trigger a CROSSSLOT error.
3506 pubsub._resubscribe_shard_channels = MethodType(
3507 ClusterPubSub._resubscribe_shard_channels, pubsub
3508 )
3509 self.node_pubsub_mapping[node.name] = pubsub
3510 return pubsub
3511
3512 def _find_node_name_for_pubsub(self, pubsub: PubSub) -> Optional[str]:
3513 for name, candidate in self.node_pubsub_mapping.items():
3514 if candidate is pubsub:
3515 return name
3516 return None
3517
3518 async def _sharded_message_generator(
3519 self, timeout: float = 0.0
3520 ) -> Tuple[Optional[PubSub], Optional[Dict[str, Any]]]:
3521 """Generate messages from shard channels across all nodes."""
3522 for _ in range(len(self.node_pubsub_mapping)):
3523 pubsub = next(self._pubsubs_generator)
3524 # Don't pass ignore_subscribe_messages here - let get_sharded_message
3525 # handle the filtering after processing subscription state changes
3526 message = await pubsub.get_message(
3527 ignore_subscribe_messages=False, timeout=timeout
3528 )
3529 if message is not None:
3530 return pubsub, message
3531 return None, None
3532
3533 def _pubsubs_generator(self) -> Generator[PubSub, None, None]:
3534 """Generator that yields PubSub instances in round-robin fashion."""
3535 while True:
3536 current_nodes = list(self.node_pubsub_mapping.values())
3537 if not current_nodes:
3538 return # Avoid infinite loop when no subscriptions exist
3539 yield from current_nodes
3540
3541 async def get_sharded_message(
3542 self,
3543 ignore_subscribe_messages: bool = False,
3544 timeout: float = 0.0,
3545 target_node: Optional["ClusterNode"] = None,
3546 ) -> Optional[Dict[str, Any]]:
3547 """
3548 Get a message from shard channels.
3549
3550 :param ignore_subscribe_messages: Whether to ignore subscribe messages
3551 :param timeout: Timeout for message retrieval
3552 :param target_node: Specific node to get message from
3553 :return: Message dictionary or None
3554 """
3555 pubsub: Optional[PubSub]
3556 if target_node:
3557 pubsub = self.node_pubsub_mapping.get(target_node.name)
3558 if pubsub:
3559 # Don't pass ignore_subscribe_messages here - let get_sharded_message
3560 # handle the filtering after processing subscription state changes
3561 message = await pubsub.get_message(
3562 ignore_subscribe_messages=False, timeout=timeout
3563 )
3564 else:
3565 message = None
3566 else:
3567 pubsub, message = await self._sharded_message_generator(timeout=timeout)
3568
3569 if message is None:
3570 return None
3571 # Only sunsubscribe mutates cluster-level shard state; bypassing the
3572 # lock on the data-message hot path keeps smessage delivery from
3573 # competing with the reconciliation task for _shard_state_lock.
3574 if str_if_bytes(message["type"]) == "sunsubscribe":
3575 # Serialize state mutation against reinitialize_shard_subscriptions
3576 # (background task). The blocking get_message above intentionally
3577 # runs outside the lock so reconciliation is not stalled by long
3578 # polls.
3579 async with self._shard_state_lock:
3580 if message["channel"] in self.pending_unsubscribe_shard_channels:
3581 # User-initiated sunsubscribe: drop from cluster-level tracking.
3582 self.pending_unsubscribe_shard_channels.remove(message["channel"])
3583 self.shard_channels.pop(message["channel"], None)
3584 self._shard_channel_to_node.pop(message["channel"], None)
3585 # Drop the per-node pubsub that delivered the confirmation once
3586 # it no longer holds any shard subscriptions, regardless of
3587 # whether the sunsubscribe was user-initiated or driven by
3588 # slot-migration reconciliation (_migrate_shard_channel, which
3589 # intentionally does not add the channel to
3590 # pending_unsubscribe_shard_channels). This releases the
3591 # dedicated connection that would otherwise linger.
3592 # Identifying the receiving pubsub directly (rather than via
3593 # the cluster's current slot map) is required after slot
3594 # migration, where the channel's owner is no longer the node
3595 # that received our original SSUBSCRIBE.
3596 if pubsub is not None and not pubsub.subscribed:
3597 name = self._find_node_name_for_pubsub(pubsub)
3598 if name is not None:
3599 try:
3600 await pubsub.aclose()
3601 except Exception:
3602 pass
3603 self.node_pubsub_mapping.pop(name, None)
3604
3605 # Only suppress subscribe/unsubscribe messages, not data messages (smessage)
3606 if str_if_bytes(message["type"]) in ("ssubscribe", "sunsubscribe"):
3607 if self.ignore_subscribe_messages or ignore_subscribe_messages:
3608 return None
3609 return message
3610
3611 async def ssubscribe(
3612 self, *args: ChannelT | Subscription, **kwargs: PubSubHandler
3613 ) -> None:
3614 """
3615 Subscribe to shard channels.
3616
3617 :param args: Channel names or ``Subscription`` objects
3618 :param kwargs: Channel names with handlers
3619 """
3620 s_channels = parse_pubsub_subscriptions(args, kwargs)
3621
3622 # Serialize against reinitialize_shard_subscriptions (background
3623 # task) so the reverse index, shard_channels, and node_pubsub_mapping
3624 # are not mutated concurrently. _migrate_shard_channel below does not
3625 # re-acquire this lock (asyncio.Lock is non-reentrant).
3626 async with self._shard_state_lock:
3627 for s_channel, handler in s_channels.items():
3628 node = self.cluster.get_node_from_key(s_channel)
3629 if not node:
3630 continue
3631 # Lazy re-route: if this channel is already tracked against a
3632 # different node (e.g. after a slot migration), migrate it now
3633 # so the caller's intent is applied on the current owner.
3634 normalized_key = next(iter(self._normalize_keys({s_channel: None})))
3635 old_name = self._shard_channel_to_node.get(normalized_key)
3636 if old_name and old_name != node.name:
3637 # Match PubSub.ssubscribe() dict.update() semantics: the
3638 # caller's newly supplied handler (including None) always
3639 # overrides any previously registered handler.
3640 await self._migrate_shard_channel(
3641 normalized_key,
3642 handler,
3643 old_name,
3644 node,
3645 )
3646 continue
3647 pubsub = self._get_node_pubsub(node)
3648 if handler:
3649 await pubsub.ssubscribe(Subscription(s_channel, handler))
3650 else:
3651 await pubsub.ssubscribe(s_channel)
3652 self.shard_channels.update(pubsub.shard_channels)
3653 self._shard_channel_to_node[normalized_key] = node.name
3654 self.pending_unsubscribe_shard_channels.difference_update(
3655 self._normalize_keys({s_channel: None})
3656 )
3657
3658 async def sunsubscribe(self, *args: Any) -> None:
3659 """
3660 Unsubscribe from shard channels.
3661
3662 :param args: Channel names to unsubscribe from. If empty, unsubscribe from all.
3663 """
3664 if args:
3665 args = list_or_args(args[0], args[1:])
3666 else:
3667 args = list(self.shard_channels.keys())
3668
3669 # Serialize against reinitialize_shard_subscriptions: the reverse
3670 # index and node_pubsub_mapping must not change between the lookup
3671 # and the per-node sunsubscribe call below.
3672 async with self._shard_state_lock:
3673 for s_channel in args:
3674 normalized_key = next(iter(self._normalize_keys({s_channel: None})))
3675 # Route via the reverse index so we unsubscribe on the node
3676 # that actually holds the subscription. After a slot migration
3677 # the cluster's current owner may no longer be that node.
3678 name = self._shard_channel_to_node.get(normalized_key)
3679 if name and name in self.node_pubsub_mapping:
3680 pubsub = self.node_pubsub_mapping[name]
3681 else:
3682 node = self.cluster.get_node_from_key(s_channel)
3683 if not node or node.name not in self.node_pubsub_mapping:
3684 continue
3685 pubsub = self.node_pubsub_mapping[node.name]
3686 await pubsub.sunsubscribe(s_channel)
3687 self.pending_unsubscribe_shard_channels.update(
3688 pubsub.pending_unsubscribe_shard_channels
3689 )
3690
3691 async def reinitialize_shard_subscriptions(self) -> None:
3692 """
3693 Reconcile per-node shard subscriptions against the cluster's current
3694 slot ownership map. For each tracked shard channel whose owning node
3695 has changed (e.g. after CLUSTER SETSLOT / failover), sunsubscribe on
3696 the old node's pubsub and ssubscribe on the new owner's pubsub,
3697 preserving any registered handler.
3698 """
3699 uncovered: list = []
3700 made_progress = False
3701 first_migrate_error: Optional[BaseException] = None
3702 async with self._shard_state_lock:
3703 for channel, handler in list(self.shard_channels.items()):
3704 try:
3705 new_node = self.cluster.get_node_from_key(channel)
3706 except SlotNotCoveredError:
3707 # Slot is transiently uncovered (mid-migration / partial
3708 # topology refresh). Defer this channel so coverable
3709 # siblings still reconcile this pass; we surface the
3710 # error below so the caller (and logs) know not every
3711 # channel was reconciled. Retry happens on the next
3712 # slots-cache change notification.
3713 uncovered.append(channel)
3714 continue
3715 old_name = self._shard_channel_to_node.get(channel)
3716 if old_name == new_node.name:
3717 continue
3718 try:
3719 await self._migrate_shard_channel(
3720 channel, handler, old_name, new_node
3721 )
3722 made_progress = True
3723 except (ConnectionError, TimeoutError, OSError) as e:
3724 # Transient connectivity error while subscribing on the
3725 # new owner (or unsubscribing on the old owner if its
3726 # handler chose to re-raise). Do not abort reconciliation
3727 # for sibling channels: _shard_channel_to_node was not
3728 # advanced for this channel, so the next slots-cache
3729 # change notification will retry it.
3730 logger.warning(
3731 "shard channel %r migration deferred: %s: %s",
3732 channel,
3733 type(e).__name__,
3734 e,
3735 )
3736 if first_migrate_error is None:
3737 first_migrate_error = e
3738 continue
3739 # Garbage-collect per-node pubsubs that no longer hold any
3740 # subscription so their connections are released.
3741 for name, pubsub in list(self.node_pubsub_mapping.items()):
3742 if not pubsub.subscribed:
3743 try:
3744 await pubsub.aclose()
3745 except Exception:
3746 pass
3747 self.node_pubsub_mapping.pop(name, None)
3748 if uncovered:
3749 # Surface the uncovered channels so the caller (and observer
3750 # notification path) knows reconciliation was incomplete. All
3751 # coverable siblings have already been migrated above.
3752 raise SlotNotCoveredError(
3753 f"{len(uncovered)} shard channel(s) left unreconciled; "
3754 f"slot(s) not covered by the cluster: {uncovered!r}"
3755 )
3756 if first_migrate_error is not None and not made_progress:
3757 # Every migration attempted in this pass failed transiently and
3758 # nothing else made progress. Re-raise the first caught error
3759 # (typically the root cause; later failures are often downstream
3760 # symptoms of the same unreachable node) so the task's done-
3761 # callback surfaces a single representative failure through the
3762 # same logger channel used for SlotNotCoveredError. Per-channel
3763 # WARNINGs above preserve the full forensic detail.
3764 raise first_migrate_error
3765
3766 async def _migrate_shard_channel(
3767 self,
3768 channel: Any,
3769 handler: Optional[Callable],
3770 old_name: Optional[str],
3771 new_node: "ClusterNode",
3772 ) -> None:
3773 # Detach from the old per-node pubsub, best-effort: the old node may
3774 # already be unreachable during migration / failover.
3775 if old_name and old_name in self.node_pubsub_mapping:
3776 old_pubsub = self.node_pubsub_mapping[old_name]
3777 try:
3778 await old_pubsub.sunsubscribe(channel)
3779 except (ConnectionError, TimeoutError, OSError):
3780 # redis-py's Connection has already called ``disconnect()``
3781 # before raising (see Connection.read_response /
3782 # send_packed_command with ``disconnect_on_error=True``),
3783 # so ``old_pubsub``'s dedicated socket is gone. Two cases:
3784 #
3785 # 1. The old node is no longer in the cluster topology
3786 # (e.g. removed by failover / topology refresh): no
3787 # reconnect target exists, so ``old_pubsub.subscribed``
3788 # would stay True forever and the end-of-pass GC block
3789 # would skip it. Drop it eagerly so the round-robin
3790 # generator does not keep yielding a dead pubsub that
3791 # produces periodic errors from ``get_sharded_message``.
3792 # 2. The old node is still known (transiently slow /
3793 # unreachable): ``PubSub._execute`` auto-reconnects and
3794 # ``on_connect`` re-subscribes to remaining channels,
3795 # so other subscriptions on the same pubsub recover
3796 # naturally. Leave it alone.
3797 if self.cluster.get_node(node_name=old_name) is None:
3798 try:
3799 await old_pubsub.aclose()
3800 except Exception:
3801 pass
3802 self.node_pubsub_mapping.pop(old_name, None)
3803 # Attach to the new per-node pubsub, preserving the handler. Decode to
3804 # a text key only when we must pass it as a kwarg (handler present).
3805 new_pubsub = self._get_node_pubsub(new_node)
3806 if handler:
3807 await new_pubsub.ssubscribe(Subscription(channel, handler))
3808 else:
3809 await new_pubsub.ssubscribe(channel)
3810 self.shard_channels.update(new_pubsub.shard_channels)
3811 normalized_key = next(iter(self._normalize_keys({channel: None})))
3812 self._shard_channel_to_node[normalized_key] = new_node.name
3813 self.pending_unsubscribe_shard_channels.difference_update(
3814 self._normalize_keys({channel: None})
3815 )
3816
3817 async def on_slots_changed(self) -> None:
3818 # Observer hook invoked by NodesManager after a slots-cache refresh.
3819 # Schedule reconciliation as a separate task so the caller's code
3820 # path (typically MovedError handling in _execute_command) is not
3821 # blocked on the network I/O performed by reinitialize_shard_
3822 # subscriptions. No-op when there are no shard subscriptions to
3823 # reconcile.
3824 if not self.shard_channels:
3825 return
3826 task = asyncio.create_task(self.reinitialize_shard_subscriptions())
3827 self._reconcile_tasks.add(task)
3828 task.add_done_callback(self._reconcile_tasks.discard)
3829 # Consume the task's exception (if any) so Python does not emit a
3830 # "Task exception was never retrieved" warning. reinitialize_shard_
3831 # subscriptions surfaces SlotNotCoveredError when a slot is still
3832 # transiently uncovered; route it through the same logger channel
3833 # as sync ClusterPubSubSlotsCacheListener for consistent observability.
3834 task.add_done_callback(self._log_reconcile_task_exception)
3835
3836 @staticmethod
3837 def _log_reconcile_task_exception(task: "asyncio.Task") -> None:
3838 if task.cancelled():
3839 return
3840 exc = task.exception()
3841 if exc is not None:
3842 logger.error(
3843 "shard subscription reconciliation failed: %r", exc, exc_info=exc
3844 )
3845
3846 def get_redis_connection(self) -> Optional["AbstractConnection"]:
3847 """
3848 Get the Redis connection of the pubsub connected node.
3849
3850 Returns the pubsub's dedicated connection (acquired from its own
3851 connection pool), not from the ClusterNode's connection pool.
3852 This avoids the connection pool resource leak that would occur
3853 if we called node.acquire_connection() without releasing.
3854 """
3855 # Return the pubsub's own dedicated connection, which is acquired
3856 # from self.connection_pool when executing pubsub commands.
3857 # This is safe because it's the connection dedicated to this pubsub
3858 # instance, not a shared pool connection from the ClusterNode.
3859 return self.connection
3860
3861 async def aclose(self) -> None:
3862 """
3863 Disconnect the pubsub connection.
3864 """
3865 # Cancel and gather in-flight reconciliation tasks BEFORE acquiring
3866 # _shard_state_lock. The tasks themselves take that lock inside
3867 # reinitialize_shard_subscriptions; since asyncio.Lock is non-
3868 # reentrant, gathering while holding it would deadlock. Awaiting
3869 # each task with suppressed CancelledError also avoids unhandled-
3870 # exception warnings if the task was created but not yet scheduled.
3871 if self._reconcile_tasks:
3872 tasks = list(self._reconcile_tasks)
3873 for task in tasks:
3874 task.cancel()
3875 await asyncio.gather(*tasks, return_exceptions=True)
3876 # Hold _shard_state_lock across the rest of the teardown so it
3877 # observes the same mutual-exclusion discipline as ssubscribe /
3878 # sunsubscribe / get_sharded_message / reinitialize_shard_
3879 # subscriptions, which all mutate shard_channels,
3880 # _shard_channel_to_node, and node_pubsub_mapping under this lock.
3881 # Without it, super().aclose() rebinds shard_channels and
3882 # pending_unsubscribe_shard_channels in parallel with a concurrent
3883 # user-coroutine mutation that resumes during one of the awaits
3884 # below, silently dropping subscription intent.
3885 async with self._shard_state_lock:
3886 self._reconcile_tasks.clear()
3887 # Close all shard pubsub instances first
3888 for pubsub in self.node_pubsub_mapping.values():
3889 await pubsub.aclose()
3890 # Drop the now-dead per-node pubsubs from the mapping so the
3891 # round-robin in _pubsubs_generator / _sharded_message_generator
3892 # cannot yield them between teardown and re-subscription.
3893 self.node_pubsub_mapping.clear()
3894 # _pubsubs_generator captures node_pubsub_mapping.values() into
3895 # a local list inside ``yield from``; clearing the mapping does
3896 # not reach references already held by that captured snapshot,
3897 # so a generator suspended mid-yield-from would still surface
3898 # the now-aclose()'d per-node pubsubs after re-subscription.
3899 # Recreate it to drop the captured list. type(self) bypasses
3900 # the instance-level self-shadow established at __init__
3901 # (self._pubsubs_generator = self._pubsubs_generator()).
3902 self._pubsubs_generator = type(self)._pubsubs_generator( # type: ignore[method-assign]
3903 self
3904 )
3905 # Let parent handle self.connection disconnect under the lock
3906 # (includes disconnect, release to pool, and clearing
3907 # self.connection)
3908 await super().aclose()
3909 # Clear the reverse index so a reused instance doesn't route
3910 # against stale mappings. super().aclose() has already cleared
3911 # shard_channels.
3912 self._shard_channel_to_node.clear()
3913
3914 def _raise_on_invalid_node(
3915 self,
3916 redis_cluster: "RedisCluster",
3917 node: Optional["ClusterNode"],
3918 host: Optional[str],
3919 port: Optional[int],
3920 ) -> None:
3921 """
3922 Raise a RedisClusterException if the node is None or doesn't exist in
3923 the cluster.
3924 """
3925 if node is None or redis_cluster.get_node(node_name=node.name) is None:
3926 raise RedisClusterException(
3927 f"Node {host}:{port} doesn't exist in the cluster"
3928 )
3929
3930 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
3931 """
3932 Execute a command on the appropriate cluster node.
3933
3934 Taken code from redis-py and tweaked to make it work within a cluster.
3935 """
3936 # NOTE: don't parse the response in this function -- it could pull a
3937 # legitimate message off the stack if the connection is already
3938 # subscribed to one or more channels
3939
3940 # For shard commands, route to appropriate node
3941 command = args[0].upper() if args else ""
3942 if command in ("SSUBSCRIBE", "SUNSUBSCRIBE", "SPUBLISH"):
3943 if len(args) > 1:
3944 channel = args[1]
3945 node = self.cluster.get_node_from_key(channel)
3946 if node:
3947 pubsub = self._get_node_pubsub(node)
3948 return await pubsub.execute_command(*args, **kwargs)
3949
3950 # For other commands, use the set node or lazily discover one
3951 if self.connection is None:
3952 if self.connection_pool is None:
3953 if len(args) > 1:
3954 # Hash the first channel and get one of the nodes holding
3955 # this slot
3956 channel = args[1]
3957 slot = self.cluster.keyslot(channel)
3958 node = self.cluster.nodes_manager.get_node_from_slot(
3959 slot,
3960 self.cluster.read_from_replicas,
3961 self.cluster.load_balancing_strategy,
3962 )
3963 else:
3964 # Get a random node
3965 node = self.cluster.get_random_node()
3966 self.node = node
3967 self.connection_pool = _ClusterNodePoolAdapter(node)
3968
3969 # Now we have a connection_pool, use parent's execute_command
3970 return await super().execute_command(*args, **kwargs)