1import asyncio
2import collections
3import logging
4import random
5import socket
6import threading
7import time
8import warnings
9import weakref
10from abc import ABC, abstractmethod
11from collections import defaultdict
12from copy import copy
13from itertools import chain
14from types import MethodType
15from typing import (
16 TYPE_CHECKING,
17 Any,
18 Callable,
19 Coroutine,
20 Deque,
21 Dict,
22 Generator,
23 List,
24 Literal,
25 Mapping,
26 Optional,
27 Set,
28 Tuple,
29 Type,
30 TypeVar,
31 Union,
32)
33
34if TYPE_CHECKING:
35 from redis.asyncio.keyspace_notifications import (
36 AsyncClusterKeyspaceNotifications,
37 )
38
39from redis._parsers import AsyncCommandsParser, Encoder
40from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
41from redis._parsers.helpers import (
42 _RedisCallbacks,
43 _RedisCallbacksRESP2,
44 _RedisCallbacksRESP3,
45)
46from redis.asyncio.client import PubSub, ResponseCallbackT
47from redis.asyncio.connection import (
48 AbstractConnection,
49 Connection,
50 ConnectionPoolInterface,
51 SSLConnection,
52 parse_url,
53)
54from redis.asyncio.lock import Lock
55from redis.asyncio.observability.recorder import (
56 record_error_count,
57 record_operation_duration,
58)
59from redis.asyncio.retry import Retry
60from redis.auth.token import TokenInterface
61from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
62from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
63from redis.cluster import (
64 PIPELINE_BLOCKED_COMMANDS,
65 PRIMARY,
66 REPLICA,
67 SLOT_ID,
68 AbstractRedisCluster,
69 LoadBalancer,
70 LoadBalancingStrategy,
71 block_pipeline_command,
72 get_node_name,
73 parse_cluster_slots,
74)
75from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
76from redis.commands.helpers import list_or_args
77from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
78from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
79from redis.credentials import CredentialProvider
80from redis.driver_info import DriverInfo, resolve_driver_info
81from redis.event import (
82 AfterAsyncClusterInstantiationEvent,
83 AsyncAfterSlotsCacheRefreshEvent,
84 AsyncEventListenerInterface,
85 EventDispatcher,
86)
87from redis.exceptions import (
88 AskError,
89 BusyLoadingError,
90 ClusterDownError,
91 ClusterError,
92 ConnectionError,
93 CrossSlotTransactionError,
94 DataError,
95 ExecAbortError,
96 InvalidPipelineStack,
97 MaxConnectionsError,
98 MovedError,
99 RedisClusterException,
100 RedisError,
101 ResponseError,
102 SlotNotCoveredError,
103 TimeoutError,
104 TryAgainError,
105 WatchError,
106)
107from redis.typing import AnyKeyT, EncodableT, KeyT
108from redis.utils import (
109 DEFAULT_RESP_VERSION,
110 SSL_AVAILABLE,
111 check_protocol_version,
112 deprecated_args,
113 deprecated_function,
114 safe_str,
115 str_if_bytes,
116 truncate_text,
117)
118
119if SSL_AVAILABLE:
120 from ssl import TLSVersion, VerifyFlags, VerifyMode
121else:
122 TLSVersion = None
123 VerifyMode = None
124 VerifyFlags = None
125
126logger = logging.getLogger(__name__)
127
128TargetNodesT = TypeVar(
129 "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
130)
131
132
133class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
134 """
135 Create a new RedisCluster client.
136
137 Pass one of parameters:
138
139 - `host` & `port`
140 - `startup_nodes`
141
142 | Use ``await`` :meth:`initialize` to find cluster nodes & create connections.
143 | Use ``await`` :meth:`close` to disconnect connections & close client.
144
145 Many commands support the target_nodes kwarg. It can be one of the
146 :attr:`NODE_FLAGS`:
147
148 - :attr:`PRIMARIES`
149 - :attr:`REPLICAS`
150 - :attr:`ALL_NODES`
151 - :attr:`RANDOM`
152 - :attr:`DEFAULT_NODE`
153
154 Note: This client is not thread/process/fork safe.
155
156 :param host:
157 | Can be used to point to a startup node
158 :param port:
159 | Port used if **host** is provided
160 :param startup_nodes:
161 | :class:`~.ClusterNode` to used as a startup node
162 :param require_full_coverage:
163 | When set to ``False``: the client will not require a full coverage of
164 the slots. However, if not all slots are covered, and at least one node
165 has ``cluster-require-full-coverage`` set to ``yes``, the server will throw
166 a :class:`~.ClusterDownError` for some key-based commands.
167 | When set to ``True``: all slots must be covered to construct the cluster
168 client. If not all slots are covered, :class:`~.RedisClusterException` will be
169 thrown.
170 | See:
171 https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
172 :param read_from_replicas:
173 | @deprecated - please use load_balancing_strategy instead
174 | Enable read from replicas in READONLY mode.
175 When set to true, read commands will be assigned between the primary and
176 its replications in a Round-Robin manner.
177 The data read from replicas is eventually consistent with the data in primary nodes.
178 :param load_balancing_strategy:
179 | Enable read from replicas in READONLY mode and defines the load balancing
180 strategy that will be used for cluster node selection.
181 The data read from replicas is eventually consistent with the data in primary nodes.
182 :param dynamic_startup_nodes:
183 | Set the RedisCluster's startup nodes to all the discovered nodes.
184 If true (default value), the cluster's discovered nodes will be used to
185 determine the cluster nodes-slots mapping in the next topology refresh.
186 It will remove the initial passed startup nodes if their endpoints aren't
187 listed in the CLUSTER SLOTS output.
188 If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
189 specific IP addresses, it is best to set it to false.
190 :param reinitialize_steps:
191 | Specifies the number of MOVED errors that need to occur before reinitializing
192 the whole cluster topology. If a MOVED error occurs and the cluster does not
193 need to be reinitialized on this current error handling, only the MOVED slot
194 will be patched with the redirected node.
195 To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
196 To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
197 0.
198 :param cluster_error_retry_attempts:
199 | @deprecated - Please configure the 'retry' object instead
200 In case 'retry' object is set - this argument is ignored!
201
202 Number of times to retry before raising an error when :class:`~.TimeoutError`,
203 :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError`
204 or :class:`~.ClusterDownError` are encountered
205 :param retry:
206 | A retry object that defines the retry strategy and the number of
207 retries for the cluster client.
208 In current implementation for the cluster client (starting form redis-py version 6.0.0)
209 the retry object is not yet fully utilized, instead it is used just to determine
210 the number of retries for the cluster client.
211 In the future releases the retry object will be used to handle the cluster client retries!
212 :param max_connections:
213 | Maximum number of connections per node. If there are no free connections & the
214 maximum number of connections are already created, a
215 :class:`~.MaxConnectionsError` is raised.
216 :param address_remap:
217 | An optional callable which, when provided with an internal network
218 address of a node, e.g. a `(host, port)` tuple, will return the address
219 where the node is reachable. This can be used to map the addresses at
220 which the nodes _think_ they are, to addresses at which a client may
221 reach them, such as when they sit behind a proxy.
222
223 | Rest of the arguments will be passed to the
224 :class:`~redis.asyncio.connection.Connection` instances when created
225
226 :raises RedisClusterException:
227 if any arguments are invalid or unknown. Eg:
228
229 - `db` != 0 or None
230 - `path` argument for unix socket connection
231 - none of the `host`/`port` & `startup_nodes` were provided
232
233 """
234
235 @classmethod
236 def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
237 """
238 Return a Redis client object configured from the given URL.
239
240 For example::
241
242 redis://[[username]:[password]]@localhost:6379/0
243 rediss://[[username]:[password]]@localhost:6379/0
244
245 Three URL schemes are supported:
246
247 - `redis://` creates a TCP socket connection. See more at:
248 <https://www.iana.org/assignments/uri-schemes/prov/redis>
249 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
250 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
251
252 The username, password, hostname, path and all querystring values are passed
253 through ``urllib.parse.unquote`` in order to replace any percent-encoded values
254 with their corresponding characters.
255
256 All querystring options are cast to their appropriate Python types. Boolean
257 arguments can be specified with string values "True"/"False" or "Yes"/"No".
258 Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
259 parsed, the querystring arguments and keyword arguments are passed to
260 :class:`~redis.asyncio.connection.Connection` when created.
261 In the case of conflicting arguments, querystring arguments are used.
262 """
263 kwargs.update(parse_url(url))
264 if kwargs.pop("connection_class", None) is SSLConnection:
265 kwargs["ssl"] = True
266 return cls(**kwargs)
267
268 # Type discrimination marker for @overload self-type pattern
269 _is_async_client: Literal[True] = True
270
271 __slots__ = (
272 "_initialize",
273 "_lock",
274 "retry",
275 "command_flags",
276 "commands_parser",
277 "connection_kwargs",
278 "encoder",
279 "node_flags",
280 "nodes_manager",
281 "read_from_replicas",
282 "reinitialize_counter",
283 "reinitialize_steps",
284 "response_callbacks",
285 "result_callbacks",
286 )
287
288 @deprecated_args(
289 args_to_warn=["read_from_replicas"],
290 reason="Please configure the 'load_balancing_strategy' instead",
291 version="5.3.0",
292 )
293 @deprecated_args(
294 args_to_warn=[
295 "cluster_error_retry_attempts",
296 ],
297 reason="Please configure the 'retry' object instead",
298 version="6.0.0",
299 )
300 @deprecated_args(
301 args_to_warn=["lib_name", "lib_version"],
302 reason="Use 'driver_info' parameter instead. "
303 "lib_name and lib_version will be removed in a future version.",
304 )
305 def __init__(
306 self,
307 host: Optional[str] = None,
308 port: Union[str, int] = 6379,
309 # Cluster related kwargs
310 startup_nodes: Optional[List["ClusterNode"]] = None,
311 require_full_coverage: bool = True,
312 read_from_replicas: bool = False,
313 load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
314 dynamic_startup_nodes: bool = True,
315 reinitialize_steps: int = 5,
316 cluster_error_retry_attempts: int = 3,
317 max_connections: int = 2**31,
318 retry: Optional["Retry"] = None,
319 retry_on_error: Optional[List[Type[Exception]]] = None,
320 # Client related kwargs
321 db: Union[str, int] = 0,
322 path: Optional[str] = None,
323 credential_provider: Optional[CredentialProvider] = None,
324 username: Optional[str] = None,
325 password: Optional[str] = None,
326 client_name: Optional[str] = None,
327 lib_name: Optional[str] = None,
328 lib_version: Optional[str] = None,
329 driver_info: Optional["DriverInfo"] = None,
330 # Encoding related kwargs
331 encoding: str = "utf-8",
332 encoding_errors: str = "strict",
333 decode_responses: bool = False,
334 # Connection related kwargs
335 health_check_interval: float = 0,
336 socket_connect_timeout: Optional[float] = None,
337 socket_keepalive: bool = False,
338 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
339 socket_timeout: Optional[float] = None,
340 # SSL related kwargs
341 ssl: bool = False,
342 ssl_ca_certs: Optional[str] = None,
343 ssl_ca_data: Optional[str] = None,
344 ssl_cert_reqs: Union[str, VerifyMode] = "required",
345 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
346 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
347 ssl_certfile: Optional[str] = None,
348 ssl_check_hostname: bool = True,
349 ssl_keyfile: Optional[str] = None,
350 ssl_min_version: Optional[TLSVersion] = None,
351 ssl_ciphers: Optional[str] = None,
352 protocol: Optional[int] = 3,
353 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
354 event_dispatcher: Optional[EventDispatcher] = None,
355 policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
356 ) -> None:
357 if db:
358 raise RedisClusterException(
359 "Argument 'db' must be 0 or None in cluster mode"
360 )
361
362 if path:
363 raise RedisClusterException(
364 "Unix domain socket is not supported in cluster mode"
365 )
366
367 if (not host or not port) and not startup_nodes:
368 raise RedisClusterException(
369 "RedisCluster requires at least one node to discover the cluster.\n"
370 "Please provide one of the following or use RedisCluster.from_url:\n"
371 ' - host and port: RedisCluster(host="localhost", port=6379)\n'
372 " - startup_nodes: RedisCluster(startup_nodes=["
373 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
374 )
375
376 computed_driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
377
378 kwargs: Dict[str, Any] = {
379 "max_connections": max_connections,
380 "connection_class": Connection,
381 # Client related kwargs
382 "credential_provider": credential_provider,
383 "username": username,
384 "password": password,
385 "client_name": client_name,
386 "driver_info": computed_driver_info,
387 # Encoding related kwargs
388 "encoding": encoding,
389 "encoding_errors": encoding_errors,
390 "decode_responses": decode_responses,
391 # Connection related kwargs
392 "health_check_interval": health_check_interval,
393 "socket_connect_timeout": socket_connect_timeout,
394 "socket_keepalive": socket_keepalive,
395 "socket_keepalive_options": socket_keepalive_options,
396 "socket_timeout": socket_timeout,
397 "protocol": protocol,
398 }
399
400 if ssl:
401 # SSL related kwargs
402 kwargs.update(
403 {
404 "connection_class": SSLConnection,
405 "ssl_ca_certs": ssl_ca_certs,
406 "ssl_ca_data": ssl_ca_data,
407 "ssl_cert_reqs": ssl_cert_reqs,
408 "ssl_include_verify_flags": ssl_include_verify_flags,
409 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
410 "ssl_certfile": ssl_certfile,
411 "ssl_check_hostname": ssl_check_hostname,
412 "ssl_keyfile": ssl_keyfile,
413 "ssl_min_version": ssl_min_version,
414 "ssl_ciphers": ssl_ciphers,
415 }
416 )
417
418 if read_from_replicas or load_balancing_strategy:
419 # Call our on_connect function to configure READONLY mode
420 kwargs["redis_connect_func"] = self.on_connect
421
422 if retry:
423 self.retry = retry
424 else:
425 self.retry = Retry(
426 backoff=ExponentialWithJitterBackoff(base=1, cap=10),
427 retries=cluster_error_retry_attempts,
428 )
429 if retry_on_error:
430 self.retry.update_supported_errors(retry_on_error)
431
432 kwargs["response_callbacks"] = _RedisCallbacks.copy()
433 if check_protocol_version(kwargs.get("protocol", DEFAULT_RESP_VERSION), 3):
434 kwargs["response_callbacks"].update(_RedisCallbacksRESP3)
435 else:
436 kwargs["response_callbacks"].update(_RedisCallbacksRESP2)
437 self.connection_kwargs = kwargs
438
439 if startup_nodes:
440 passed_nodes = []
441 for node in startup_nodes:
442 passed_nodes.append(
443 ClusterNode(node.host, node.port, **self.connection_kwargs)
444 )
445 startup_nodes = passed_nodes
446 else:
447 startup_nodes = []
448 if host and port:
449 startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
450
451 if event_dispatcher is None:
452 self._event_dispatcher = EventDispatcher()
453 else:
454 self._event_dispatcher = event_dispatcher
455
456 self.startup_nodes = startup_nodes
457 self.nodes_manager = NodesManager(
458 startup_nodes,
459 require_full_coverage,
460 kwargs,
461 dynamic_startup_nodes=dynamic_startup_nodes,
462 address_remap=address_remap,
463 event_dispatcher=self._event_dispatcher,
464 )
465 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
466 self.read_from_replicas = read_from_replicas
467 self.load_balancing_strategy = load_balancing_strategy
468 self.reinitialize_steps = reinitialize_steps
469 self.reinitialize_counter = 0
470
471 # For backward compatibility, mapping from existing policies to new one
472 self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
473 self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
474 self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
475 self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
476 self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
477 self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
478 SLOT_ID: RequestPolicy.DEFAULT_KEYED,
479 }
480
481 self._policies_callback_mapping: dict[
482 Union[RequestPolicy, ResponsePolicy], Callable
483 ] = {
484 RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
485 self.get_random_primary_or_all_nodes(command_name)
486 ],
487 RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
488 RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
489 RequestPolicy.ALL_SHARDS: self.get_primaries,
490 RequestPolicy.ALL_NODES: self.get_nodes,
491 RequestPolicy.ALL_REPLICAS: self.get_replicas,
492 RequestPolicy.SPECIAL: self.get_special_nodes,
493 ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
494 ResponsePolicy.DEFAULT_KEYED: lambda res: res,
495 }
496
497 self._policy_resolver = policy_resolver
498 self.commands_parser = AsyncCommandsParser()
499 self._aggregate_nodes = None
500 self.node_flags = self.__class__.NODE_FLAGS.copy()
501 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
502 self.response_callbacks = kwargs["response_callbacks"]
503 self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
504 self.result_callbacks["CLUSTER SLOTS"] = (
505 lambda cmd, res, **kwargs: parse_cluster_slots(
506 list(res.values())[0], **kwargs
507 )
508 )
509
510 self._initialize = True
511 self._lock: Optional[asyncio.Lock] = None
512
513 # When used as an async context manager, we need to increment and decrement
514 # a usage counter so that we can close the connection pool when no one is
515 # using the client.
516 self._usage_counter = 0
517 self._usage_lock = asyncio.Lock()
518
519 async def initialize(self) -> "RedisCluster":
520 """Get all nodes from startup nodes & creates connections if not initialized."""
521 if self._initialize:
522 if not self._lock:
523 self._lock = asyncio.Lock()
524 async with self._lock:
525 if self._initialize:
526 try:
527 await self.nodes_manager.initialize()
528 await self.commands_parser.initialize(
529 self.nodes_manager.default_node
530 )
531 self._initialize = False
532 except BaseException:
533 await self.nodes_manager.aclose()
534 await self.nodes_manager.aclose("startup_nodes")
535 raise
536 return self
537
538 async def aclose(self) -> None:
539 """Close all connections & client if initialized."""
540 if not self._initialize:
541 if not self._lock:
542 self._lock = asyncio.Lock()
543 async with self._lock:
544 if not self._initialize:
545 self._initialize = True
546 await self.nodes_manager.aclose()
547 await self.nodes_manager.aclose("startup_nodes")
548
549 @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
550 async def close(self) -> None:
551 """alias for aclose() for backwards compatibility"""
552 await self.aclose()
553
554 async def __aenter__(self) -> "RedisCluster":
555 """
556 Async context manager entry. Increments a usage counter so that the
557 connection pool is only closed (via aclose()) when no context is using
558 the client.
559 """
560 await self._increment_usage()
561 try:
562 # Initialize the client (i.e. establish connection, etc.)
563 return await self.initialize()
564 except Exception:
565 # If initialization fails, decrement the counter to keep it in sync
566 await self._decrement_usage()
567 raise
568
569 async def _increment_usage(self) -> int:
570 """
571 Helper coroutine to increment the usage counter while holding the lock.
572 Returns the new value of the usage counter.
573 """
574 async with self._usage_lock:
575 self._usage_counter += 1
576 return self._usage_counter
577
578 async def _decrement_usage(self) -> int:
579 """
580 Helper coroutine to decrement the usage counter while holding the lock.
581 Returns the new value of the usage counter.
582 """
583 async with self._usage_lock:
584 self._usage_counter -= 1
585 return self._usage_counter
586
587 async def __aexit__(self, exc_type, exc_value, traceback):
588 """
589 Async context manager exit. Decrements a usage counter. If this is the
590 last exit (counter becomes zero), the client closes its connection pool.
591 """
592 current_usage = await asyncio.shield(self._decrement_usage())
593 if current_usage == 0:
594 # This was the last active context, so disconnect the pool.
595 await asyncio.shield(self.aclose())
596
597 def __await__(self) -> Generator[Any, None, "RedisCluster"]:
598 return self.initialize().__await__()
599
600 _DEL_MESSAGE = "Unclosed RedisCluster client"
601
602 def __del__(
603 self,
604 _warn: Any = warnings.warn,
605 _grl: Any = asyncio.get_running_loop,
606 ) -> None:
607 if hasattr(self, "_initialize") and not self._initialize:
608 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
609 try:
610 context = {"client": self, "message": self._DEL_MESSAGE}
611 _grl().call_exception_handler(context)
612 except RuntimeError:
613 pass
614
615 async def on_connect(self, connection: Connection) -> None:
616 await connection.on_connect()
617
618 # Sending READONLY command to server to configure connection as
619 # readonly. Since each cluster node may change its server type due
620 # to a failover, we should establish a READONLY connection
621 # regardless of the server type. If this is a primary connection,
622 # READONLY would not affect executing write commands.
623 await connection.send_command("READONLY")
624 if str_if_bytes(await connection.read_response()) != "OK":
625 raise ConnectionError("READONLY command failed")
626
627 def get_nodes(self) -> List["ClusterNode"]:
628 """Get all nodes of the cluster."""
629 return list(self.nodes_manager.nodes_cache.values())
630
631 def get_primaries(self) -> List["ClusterNode"]:
632 """Get the primary nodes of the cluster."""
633 return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
634
635 def get_replicas(self) -> List["ClusterNode"]:
636 """Get the replica nodes of the cluster."""
637 return self.nodes_manager.get_nodes_by_server_type(REPLICA)
638
639 def get_random_node(self) -> "ClusterNode":
640 """Get a random node of the cluster."""
641 return random.choice(list(self.nodes_manager.nodes_cache.values()))
642
643 def get_default_node(self) -> "ClusterNode":
644 """Get the default node of the client."""
645 return self.nodes_manager.default_node
646
647 def set_default_node(self, node: "ClusterNode") -> None:
648 """
649 Set the default node of the client.
650
651 :raises DataError: if None is passed or node does not exist in cluster.
652 """
653 if not node or not self.get_node(node_name=node.name):
654 raise DataError("The requested node does not exist in the cluster.")
655
656 self.nodes_manager.default_node = node
657
658 def get_node(
659 self,
660 host: Optional[str] = None,
661 port: Optional[int] = None,
662 node_name: Optional[str] = None,
663 ) -> Optional["ClusterNode"]:
664 """Get node by (host, port) or node_name."""
665 return self.nodes_manager.get_node(host, port, node_name)
666
667 def get_node_from_key(
668 self, key: str, replica: bool = False
669 ) -> Optional["ClusterNode"]:
670 """
671 Get the cluster node corresponding to the provided key.
672
673 :param key:
674 :param replica:
675 | Indicates if a replica should be returned
676 |
677 None will returned if no replica holds this key
678
679 :raises SlotNotCoveredError: if the key is not covered by any slot.
680 """
681 slot = self.keyslot(key)
682 slot_cache = self.nodes_manager.slots_cache.get(slot)
683 if not slot_cache:
684 raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
685
686 if replica:
687 if len(self.nodes_manager.slots_cache[slot]) < 2:
688 return None
689 node_idx = 1
690 else:
691 node_idx = 0
692
693 return slot_cache[node_idx]
694
695 def get_random_primary_or_all_nodes(self, command_name):
696 """
697 Returns random primary or all nodes depends on READONLY mode.
698 """
699 if self.read_from_replicas and command_name in READ_COMMANDS:
700 return self.get_random_node()
701
702 return self.get_random_primary_node()
703
704 def get_random_primary_node(self) -> "ClusterNode":
705 """
706 Returns a random primary node
707 """
708 return random.choice(self.get_primaries())
709
710 async def get_nodes_from_slot(self, command: str, *args):
711 """
712 Returns a list of nodes that hold the specified keys' slots.
713 """
714 # get the node that holds the key's slot
715 return [
716 self.nodes_manager.get_node_from_slot(
717 await self._determine_slot(command, *args),
718 self.read_from_replicas and command in READ_COMMANDS,
719 self.load_balancing_strategy if command in READ_COMMANDS else None,
720 )
721 ]
722
723 def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
724 """
725 Returns a list of nodes for commands with a special policy.
726 """
727 if not self._aggregate_nodes:
728 raise RedisClusterException(
729 "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
730 )
731
732 return self._aggregate_nodes
733
734 def keyslot(self, key: EncodableT) -> int:
735 """
736 Find the keyslot for a given key.
737
738 See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding
739 """
740 return key_slot(self.encoder.encode(key))
741
742 def get_encoder(self) -> Encoder:
743 """Get the encoder object of the client."""
744 return self.encoder
745
746 def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
747 """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`."""
748 return self.connection_kwargs
749
750 def set_retry(self, retry: Retry) -> None:
751 self.retry = retry
752
753 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
754 """Set a custom response callback."""
755 self.response_callbacks[command] = callback
756
757 async def _determine_nodes(
758 self,
759 command: str,
760 *args: Any,
761 request_policy: RequestPolicy,
762 node_flag: Optional[str] = None,
763 ) -> List["ClusterNode"]:
764 # Determine which nodes should be executed the command on.
765 # Returns a list of target nodes.
766 if not node_flag:
767 # get the nodes group for this command if it was predefined
768 node_flag = self.command_flags.get(command)
769
770 if node_flag in self._command_flags_mapping:
771 request_policy = self._command_flags_mapping[node_flag]
772
773 policy_callback = self._policies_callback_mapping[request_policy]
774
775 if request_policy == RequestPolicy.DEFAULT_KEYED:
776 nodes = await policy_callback(command, *args)
777 elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
778 nodes = policy_callback(command)
779 else:
780 nodes = policy_callback()
781
782 if command.lower() == "ft.aggregate":
783 self._aggregate_nodes = nodes
784
785 return nodes
786
787 async def _determine_slot(self, command: str, *args: Any) -> int:
788 if self.command_flags.get(command) == SLOT_ID:
789 # The command contains the slot ID
790 return int(args[0])
791
792 # Get the keys in the command
793
794 # EVAL and EVALSHA are common enough that it's wasteful to go to the
795 # redis server to parse the keys. Besides, there is a bug in redis<7.0
796 # where `self._get_command_keys()` fails anyway. So, we special case
797 # EVAL/EVALSHA.
798 # - issue: https://github.com/redis/redis/issues/9493
799 # - fix: https://github.com/redis/redis/pull/9733
800 if command.upper() in ("EVAL", "EVALSHA"):
801 # command syntax: EVAL "script body" num_keys ...
802 if len(args) < 2:
803 raise RedisClusterException(
804 f"Invalid args in command: {command, *args}"
805 )
806 keys = args[2 : 2 + int(args[1])]
807 # if there are 0 keys, that means the script can be run on any node
808 # so we can just return a random slot
809 if not keys:
810 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
811 else:
812 keys = await self.commands_parser.get_keys(command, *args)
813 if not keys:
814 # FCALL can call a function with 0 keys, that means the function
815 # can be run on any node so we can just return a random slot
816 if command.upper() in ("FCALL", "FCALL_RO"):
817 return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
818 raise RedisClusterException(
819 "No way to dispatch this command to Redis Cluster. "
820 "Missing key.\nYou can execute the command by specifying "
821 f"target nodes.\nCommand: {args}"
822 )
823
824 # single key command
825 if len(keys) == 1:
826 return self.keyslot(keys[0])
827
828 # multi-key command; we need to make sure all keys are mapped to
829 # the same slot
830 slots = {self.keyslot(key) for key in keys}
831 if len(slots) != 1:
832 raise RedisClusterException(
833 f"{command} - all keys must map to the same key slot"
834 )
835
836 return slots.pop()
837
838 def _is_node_flag(self, target_nodes: Any) -> bool:
839 return isinstance(target_nodes, str) and target_nodes in self.node_flags
840
841 def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
842 if isinstance(target_nodes, list):
843 nodes = target_nodes
844 elif isinstance(target_nodes, ClusterNode):
845 # Supports passing a single ClusterNode as a variable
846 nodes = [target_nodes]
847 elif isinstance(target_nodes, dict):
848 # Supports dictionaries of the format {node_name: node}.
849 # It enables to execute commands with multi nodes as follows:
850 # rc.cluster_save_config(rc.get_primaries())
851 nodes = list(target_nodes.values())
852 else:
853 raise TypeError(
854 "target_nodes type can be one of the following: "
855 "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES),"
856 "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. "
857 f"The passed type is {type(target_nodes)}"
858 )
859 return nodes
860
861 async def _record_error_metric(
862 self,
863 error: Exception,
864 connection: Union[Connection, "ClusterNode"],
865 is_internal: bool = True,
866 retry_attempts: Optional[int] = None,
867 ):
868 """
869 Records error count metric directly.
870 Accepts either a Connection or ClusterNode object.
871 """
872 await record_error_count(
873 server_address=connection.host,
874 server_port=connection.port,
875 network_peer_address=connection.host,
876 network_peer_port=connection.port,
877 error_type=error,
878 retry_attempts=retry_attempts if retry_attempts is not None else 0,
879 is_internal=is_internal,
880 )
881
882 async def _record_command_metric(
883 self,
884 command_name: str,
885 duration_seconds: float,
886 connection: Union[Connection, "ClusterNode"],
887 error: Optional[Exception] = None,
888 ):
889 """
890 Records operation duration metric directly.
891 Accepts either a Connection or ClusterNode object.
892 """
893 # Connection has db attribute, ClusterNode has connection_kwargs
894 if hasattr(connection, "db"):
895 db = connection.db
896 else:
897 db = connection.connection_kwargs.get("db", 0)
898 await record_operation_duration(
899 command_name=command_name,
900 duration_seconds=duration_seconds,
901 server_address=connection.host,
902 server_port=connection.port,
903 db_namespace=str(db) if db is not None else None,
904 error=error,
905 )
906
907 async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
908 """
909 Execute a raw command on the appropriate cluster node or target_nodes.
910
911 It will retry the command as specified by the retries property of
912 the :attr:`retry` & then raise an exception.
913
914 :param args:
915 | Raw command args
916 :param kwargs:
917
918 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
919 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
920 - Rest of the kwargs are passed to the Redis connection
921
922 :raises RedisClusterException: if target_nodes is not provided & the command
923 can't be mapped to a slot
924 """
925 command = args[0]
926 target_nodes = []
927 target_nodes_specified = False
928 retry_attempts = self.retry.get_retries()
929
930 passed_targets = kwargs.pop("target_nodes", None)
931 if passed_targets and not self._is_node_flag(passed_targets):
932 target_nodes = self._parse_target_nodes(passed_targets)
933 target_nodes_specified = True
934 retry_attempts = 0
935
936 command_policies = await self._policy_resolver.resolve(args[0].lower())
937
938 if not command_policies and not target_nodes_specified:
939 command_flag = self.command_flags.get(command)
940 if not command_flag:
941 # Fallback to default policy
942 if not self.get_default_node():
943 slot = None
944 else:
945 slot = await self._determine_slot(*args)
946 if slot is None:
947 command_policies = CommandPolicies()
948 else:
949 command_policies = CommandPolicies(
950 request_policy=RequestPolicy.DEFAULT_KEYED,
951 response_policy=ResponsePolicy.DEFAULT_KEYED,
952 )
953 else:
954 if command_flag in self._command_flags_mapping:
955 command_policies = CommandPolicies(
956 request_policy=self._command_flags_mapping[command_flag]
957 )
958 else:
959 command_policies = CommandPolicies()
960 elif not command_policies and target_nodes_specified:
961 command_policies = CommandPolicies()
962
963 # Add one for the first execution
964 execute_attempts = 1 + retry_attempts
965 failure_count = 0
966
967 # Start timing for observability
968 start_time = time.monotonic()
969
970 for _ in range(execute_attempts):
971 if self._initialize:
972 await self.initialize()
973 if (
974 len(target_nodes) == 1
975 and target_nodes[0] == self.get_default_node()
976 ):
977 # Replace the default cluster node
978 self.replace_default_node()
979 try:
980 if not target_nodes_specified:
981 # Determine the nodes to execute the command on
982 target_nodes = await self._determine_nodes(
983 *args,
984 request_policy=command_policies.request_policy,
985 node_flag=passed_targets,
986 )
987 if not target_nodes:
988 raise RedisClusterException(
989 f"No targets were found to execute {args} command on"
990 )
991
992 if len(target_nodes) == 1:
993 # Return the processed result
994 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
995 if command in self.result_callbacks:
996 ret = self.result_callbacks[command](
997 command, {target_nodes[0].name: ret}, **kwargs
998 )
999 return self._policies_callback_mapping[
1000 command_policies.response_policy
1001 ](ret)
1002 else:
1003 keys = [node.name for node in target_nodes]
1004 values = await asyncio.gather(
1005 *(
1006 asyncio.create_task(
1007 self._execute_command(node, *args, **kwargs)
1008 )
1009 for node in target_nodes
1010 )
1011 )
1012 if command in self.result_callbacks:
1013 return self.result_callbacks[command](
1014 command, dict(zip(keys, values)), **kwargs
1015 )
1016 return self._policies_callback_mapping[
1017 command_policies.response_policy
1018 ](dict(zip(keys, values)))
1019 except Exception as e:
1020 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
1021 # The nodes and slots cache were should be reinitialized.
1022 # Try again with the new cluster setup.
1023 retry_attempts -= 1
1024 failure_count += 1
1025
1026 if hasattr(e, "connection"):
1027 await self._record_command_metric(
1028 command_name=command,
1029 duration_seconds=time.monotonic() - start_time,
1030 connection=e.connection,
1031 error=e,
1032 )
1033 await self._record_error_metric(
1034 error=e,
1035 connection=e.connection,
1036 retry_attempts=failure_count,
1037 )
1038 continue
1039 else:
1040 # raise the exception
1041 if hasattr(e, "connection"):
1042 await self._record_error_metric(
1043 error=e,
1044 connection=e.connection,
1045 retry_attempts=failure_count,
1046 is_internal=False,
1047 )
1048 raise e
1049
1050 async def _execute_command(
1051 self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
1052 ) -> Any:
1053 asking = moved = False
1054 redirect_addr = None
1055 ttl = self.RedisClusterRequestTTL
1056 command = args[0]
1057 start_time = time.monotonic()
1058
1059 while ttl > 0:
1060 ttl -= 1
1061 try:
1062 if asking:
1063 target_node = self.get_node(node_name=redirect_addr)
1064 await target_node.execute_command("ASKING")
1065 asking = False
1066 elif moved:
1067 # MOVED occurred and the slots cache was updated,
1068 # refresh the target node
1069 slot = await self._determine_slot(*args)
1070 target_node = self.nodes_manager.get_node_from_slot(
1071 slot,
1072 self.read_from_replicas and args[0] in READ_COMMANDS,
1073 self.load_balancing_strategy
1074 if args[0] in READ_COMMANDS
1075 else None,
1076 )
1077 moved = False
1078
1079 response = await target_node.execute_command(*args, **kwargs)
1080 await self._record_command_metric(
1081 command_name=command,
1082 duration_seconds=time.monotonic() - start_time,
1083 connection=target_node,
1084 )
1085 return response
1086 except BusyLoadingError as e:
1087 e.connection = target_node
1088 await self._record_command_metric(
1089 command_name=command,
1090 duration_seconds=time.monotonic() - start_time,
1091 connection=target_node,
1092 error=e,
1093 )
1094 raise
1095 except MaxConnectionsError as e:
1096 # MaxConnectionsError indicates client-side resource exhaustion
1097 # (too many connections in the pool), not a node failure.
1098 # Don't treat this as a node failure - just re-raise the error
1099 # without reinitializing the cluster.
1100 e.connection = target_node
1101 await self._record_command_metric(
1102 command_name=command,
1103 duration_seconds=time.monotonic() - start_time,
1104 connection=target_node,
1105 error=e,
1106 )
1107 raise
1108 except (ConnectionError, TimeoutError) as e:
1109 # Connection retries are being handled in the node's
1110 # Retry object.
1111 # Mark active connections for reconnect and disconnect free ones
1112 # This handles connection state (like READONLY) that may be stale
1113 target_node.update_active_connections_for_reconnect()
1114 await target_node.disconnect_free_connections()
1115
1116 # Move the failed node to the end of the cached nodes list
1117 # so it's tried last during reinitialization
1118 self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)
1119
1120 # Signal that reinitialization is needed
1121 # The retry loop will handle initialize() AND replace_default_node()
1122 self._initialize = True
1123 e.connection = target_node
1124 await self._record_command_metric(
1125 command_name=command,
1126 duration_seconds=time.monotonic() - start_time,
1127 connection=target_node,
1128 error=e,
1129 )
1130 raise
1131 except (ClusterDownError, SlotNotCoveredError) as e:
1132 # ClusterDownError can occur during a failover and to get
1133 # self-healed, we will try to reinitialize the cluster layout
1134 # and retry executing the command
1135
1136 # SlotNotCoveredError can occur when the cluster is not fully
1137 # initialized or can be temporary issue.
1138 # We will try to reinitialize the cluster topology
1139 # and retry executing the command
1140
1141 await self.aclose()
1142 await asyncio.sleep(0.25)
1143 e.connection = target_node
1144 await self._record_command_metric(
1145 command_name=command,
1146 duration_seconds=time.monotonic() - start_time,
1147 connection=target_node,
1148 error=e,
1149 )
1150 raise
1151 except MovedError as e:
1152 # First, we will try to patch the slots/nodes cache with the
1153 # redirected node output and try again. If MovedError exceeds
1154 # 'reinitialize_steps' number of times, we will force
1155 # reinitializing the tables, and then try again.
1156 # 'reinitialize_steps' counter will increase faster when
1157 # the same client object is shared between multiple threads. To
1158 # reduce the frequency you can set this variable in the
1159 # RedisCluster constructor.
1160 self.reinitialize_counter += 1
1161 if (
1162 self.reinitialize_steps
1163 and self.reinitialize_counter % self.reinitialize_steps == 0
1164 ):
1165 await self.aclose()
1166 # Reset the counter
1167 self.reinitialize_counter = 0
1168 else:
1169 await self.nodes_manager.move_slot(e)
1170 moved = True
1171 await self._record_command_metric(
1172 command_name=command,
1173 duration_seconds=time.monotonic() - start_time,
1174 connection=target_node,
1175 error=e,
1176 )
1177 await self._record_error_metric(
1178 error=e,
1179 connection=target_node,
1180 )
1181 except AskError as e:
1182 redirect_addr = get_node_name(host=e.host, port=e.port)
1183 asking = True
1184 await self._record_command_metric(
1185 command_name=command,
1186 duration_seconds=time.monotonic() - start_time,
1187 connection=target_node,
1188 error=e,
1189 )
1190 await self._record_error_metric(
1191 error=e,
1192 connection=target_node,
1193 )
1194 except TryAgainError as e:
1195 if ttl < self.RedisClusterRequestTTL / 2:
1196 await asyncio.sleep(0.05)
1197 await self._record_command_metric(
1198 command_name=command,
1199 duration_seconds=time.monotonic() - start_time,
1200 connection=target_node,
1201 error=e,
1202 )
1203 await self._record_error_metric(
1204 error=e,
1205 connection=target_node,
1206 )
1207 except ResponseError as e:
1208 e.connection = target_node
1209 await self._record_command_metric(
1210 command_name=command,
1211 duration_seconds=time.monotonic() - start_time,
1212 connection=target_node,
1213 error=e,
1214 )
1215 raise
1216 except Exception as e:
1217 e.connection = target_node
1218 await self._record_command_metric(
1219 command_name=command,
1220 duration_seconds=time.monotonic() - start_time,
1221 connection=target_node,
1222 error=e,
1223 )
1224 raise
1225
1226 e = ClusterError("TTL exhausted.")
1227 e.connection = target_node
1228 await self._record_command_metric(
1229 command_name=command,
1230 duration_seconds=time.monotonic() - start_time,
1231 connection=target_node,
1232 error=e,
1233 )
1234 raise e
1235
1236 def pipeline(
1237 self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
1238 ) -> "ClusterPipeline":
1239 """
1240 Create & return a new :class:`~.ClusterPipeline` object.
1241
1242 Cluster implementation of pipeline does not support transaction or shard_hint.
1243
1244 :raises RedisClusterException: if transaction or shard_hint are truthy values
1245 """
1246 if shard_hint:
1247 raise RedisClusterException("shard_hint is deprecated in cluster mode")
1248
1249 return ClusterPipeline(self, transaction)
1250
1251 def pubsub(
1252 self,
1253 node: Optional["ClusterNode"] = None,
1254 host: Optional[str] = None,
1255 port: Optional[int] = None,
1256 **kwargs: Any,
1257 ) -> "ClusterPubSub":
1258 """
1259 Create and return a ClusterPubSub instance.
1260
1261 Allows passing a ClusterNode, or host&port, to get a pubsub instance
1262 connected to the specified node
1263
1264 :param node: ClusterNode to connect to
1265 :param host: Host of the node to connect to
1266 :param port: Port of the node to connect to
1267 :param kwargs: Additional keyword arguments
1268 :return: ClusterPubSub instance
1269 """
1270 return ClusterPubSub(self, node=node, host=host, port=port, **kwargs)
1271
1272 def keyspace_notifications(
1273 self,
1274 key_prefix: Union[str, bytes, None] = None,
1275 ignore_subscribe_messages: bool = True,
1276 ) -> "AsyncClusterKeyspaceNotifications":
1277 """
1278 Return an
1279 :class:`~redis.asyncio.keyspace_notifications.AsyncClusterKeyspaceNotifications`
1280 object for subscribing to keyspace and keyevent notifications across
1281 all primary nodes in the cluster.
1282
1283 Note: Keyspace notifications must be enabled on all Redis cluster nodes
1284 via the ``notify-keyspace-events`` configuration option.
1285
1286 Args:
1287 key_prefix: Optional prefix to filter and strip from keys in
1288 notifications.
1289 ignore_subscribe_messages: If True, subscribe/unsubscribe
1290 confirmations are not returned by
1291 get_message/listen.
1292 """
1293 from redis.asyncio.keyspace_notifications import (
1294 AsyncClusterKeyspaceNotifications,
1295 )
1296
1297 return AsyncClusterKeyspaceNotifications(
1298 self,
1299 key_prefix=key_prefix,
1300 ignore_subscribe_messages=ignore_subscribe_messages,
1301 )
1302
1303 def lock(
1304 self,
1305 name: KeyT,
1306 timeout: Optional[float] = None,
1307 sleep: float = 0.1,
1308 blocking: bool = True,
1309 blocking_timeout: Optional[float] = None,
1310 lock_class: Optional[Type[Lock]] = None,
1311 thread_local: bool = True,
1312 raise_on_release_error: bool = True,
1313 ) -> Lock:
1314 """
1315 Return a new Lock object using key ``name`` that mimics
1316 the behavior of threading.Lock.
1317
1318 If specified, ``timeout`` indicates a maximum life for the lock.
1319 By default, it will remain locked until release() is called.
1320
1321 ``sleep`` indicates the amount of time to sleep per loop iteration
1322 when the lock is in blocking mode and another client is currently
1323 holding the lock.
1324
1325 ``blocking`` indicates whether calling ``acquire`` should block until
1326 the lock has been acquired or to fail immediately, causing ``acquire``
1327 to return False and the lock not being acquired. Defaults to True.
1328 Note this value can be overridden by passing a ``blocking``
1329 argument to ``acquire``.
1330
1331 ``blocking_timeout`` indicates the maximum amount of time in seconds to
1332 spend trying to acquire the lock. A value of ``None`` indicates
1333 continue trying forever. ``blocking_timeout`` can be specified as a
1334 float or integer, both representing the number of seconds to wait.
1335
1336 ``lock_class`` forces the specified lock implementation. Note that as
1337 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
1338 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
1339 you have created your own custom lock class.
1340
1341 ``thread_local`` indicates whether the lock token is placed in
1342 thread-local storage. By default, the token is placed in thread local
1343 storage so that a thread only sees its token, not a token set by
1344 another thread. Consider the following timeline:
1345
1346 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
1347 thread-1 sets the token to "abc"
1348 time: 1, thread-2 blocks trying to acquire `my-lock` using the
1349 Lock instance.
1350 time: 5, thread-1 has not yet completed. redis expires the lock
1351 key.
1352 time: 5, thread-2 acquired `my-lock` now that it's available.
1353 thread-2 sets the token to "xyz"
1354 time: 6, thread-1 finishes its work and calls release(). if the
1355 token is *not* stored in thread local storage, then
1356 thread-1 would see the token value as "xyz" and would be
1357 able to successfully release the thread-2's lock.
1358
1359 ``raise_on_release_error`` indicates whether to raise an exception when
1360 the lock is no longer owned when exiting the context manager. By default,
1361 this is True, meaning an exception will be raised. If False, the warning
1362 will be logged and the exception will be suppressed.
1363
1364 In some use cases it's necessary to disable thread local storage. For
1365 example, if you have code where one thread acquires a lock and passes
1366 that lock instance to a worker thread to release later. If thread
1367 local storage isn't disabled in this case, the worker thread won't see
1368 the token set by the thread that acquired the lock. Our assumption
1369 is that these cases aren't common and as such default to using
1370 thread local storage."""
1371 if lock_class is None:
1372 lock_class = Lock
1373 return lock_class(
1374 self,
1375 name,
1376 timeout=timeout,
1377 sleep=sleep,
1378 blocking=blocking,
1379 blocking_timeout=blocking_timeout,
1380 thread_local=thread_local,
1381 raise_on_release_error=raise_on_release_error,
1382 )
1383
1384 async def transaction(
1385 self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
1386 ):
1387 """
1388 Convenience method for executing the callable `func` as a transaction
1389 while watching all keys specified in `watches`. The 'func' callable
1390 should expect a single argument which is a Pipeline object.
1391 """
1392 shard_hint = kwargs.pop("shard_hint", None)
1393 value_from_callable = kwargs.pop("value_from_callable", False)
1394 watch_delay = kwargs.pop("watch_delay", None)
1395 async with self.pipeline(True, shard_hint) as pipe:
1396 while True:
1397 try:
1398 if watches:
1399 await pipe.watch(*watches)
1400 func_value = await func(pipe)
1401 exec_value = await pipe.execute()
1402 return func_value if value_from_callable else exec_value
1403 except WatchError:
1404 if watch_delay is not None and watch_delay > 0:
1405 time.sleep(watch_delay)
1406 continue
1407
1408
1409class ClusterNode:
1410 """
1411 Create a new ClusterNode.
1412
1413 Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection`
1414 objects for the (host, port).
1415 """
1416
1417 __slots__ = (
1418 "_connections",
1419 "_free",
1420 "_lock",
1421 "_event_dispatcher",
1422 "connection_class",
1423 "connection_kwargs",
1424 "host",
1425 "max_connections",
1426 "name",
1427 "port",
1428 "response_callbacks",
1429 "server_type",
1430 )
1431
1432 def __init__(
1433 self,
1434 host: str,
1435 port: Union[str, int],
1436 server_type: Optional[str] = None,
1437 *,
1438 max_connections: int = 2**31,
1439 connection_class: Type[Connection] = Connection,
1440 **connection_kwargs: Any,
1441 ) -> None:
1442 if host == "localhost":
1443 host = socket.gethostbyname(host)
1444
1445 connection_kwargs["host"] = host
1446 connection_kwargs["port"] = port
1447 self.host = host
1448 self.port = port
1449 self.name = get_node_name(host, port)
1450 self.server_type = server_type
1451
1452 self.max_connections = max_connections
1453 self.connection_class = connection_class
1454 self.connection_kwargs = connection_kwargs
1455 self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
1456
1457 self._connections: List[Connection] = []
1458 self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
1459 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1460 if self._event_dispatcher is None:
1461 self._event_dispatcher = EventDispatcher()
1462
1463 def __repr__(self) -> str:
1464 return (
1465 f"[host={self.host}, port={self.port}, "
1466 f"name={self.name}, server_type={self.server_type}]"
1467 )
1468
1469 def __eq__(self, obj: Any) -> bool:
1470 return isinstance(obj, ClusterNode) and obj.name == self.name
1471
1472 def __hash__(self) -> int:
1473 return hash(self.name)
1474
1475 _DEL_MESSAGE = "Unclosed ClusterNode object"
1476
1477 def __del__(
1478 self,
1479 _warn: Any = warnings.warn,
1480 _grl: Any = asyncio.get_running_loop,
1481 ) -> None:
1482 for connection in self._connections:
1483 if connection.is_connected:
1484 _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
1485
1486 try:
1487 context = {"client": self, "message": self._DEL_MESSAGE}
1488 _grl().call_exception_handler(context)
1489 except RuntimeError:
1490 pass
1491 break
1492
1493 async def disconnect(self) -> None:
1494 ret = await asyncio.gather(
1495 *(
1496 asyncio.create_task(connection.disconnect())
1497 for connection in self._connections
1498 ),
1499 return_exceptions=True,
1500 )
1501 exc = next((res for res in ret if isinstance(res, Exception)), None)
1502 if exc:
1503 raise exc
1504
1505 def acquire_connection(self) -> Connection:
1506 try:
1507 return self._free.popleft()
1508 except IndexError:
1509 if len(self._connections) < self.max_connections:
1510 # We are configuring the connection pool not to retry
1511 # connections on lower level clients to avoid retrying
1512 # connections to nodes that are not reachable
1513 # and to avoid blocking the connection pool.
1514 # The only error that will have some handling in the lower
1515 # level clients is ConnectionError which will trigger disconnection
1516 # of the socket.
1517 # The retries will be handled on cluster client level
1518 # where we will have proper handling of the cluster topology
1519 retry = Retry(
1520 backoff=NoBackoff(),
1521 retries=0,
1522 supported_errors=(ConnectionError,),
1523 )
1524 connection_kwargs = self.connection_kwargs.copy()
1525 connection_kwargs["retry"] = retry
1526 connection = self.connection_class(**connection_kwargs)
1527 self._connections.append(connection)
1528 return connection
1529
1530 raise MaxConnectionsError()
1531
1532 async def disconnect_if_needed(self, connection: Connection) -> None:
1533 """
1534 Disconnect a connection if it's marked for reconnect.
1535 This implements lazy disconnection to avoid race conditions.
1536 The connection will auto-reconnect on next use.
1537 """
1538 if connection.should_reconnect():
1539 await connection.disconnect()
1540
1541 def release(self, connection: Connection) -> None:
1542 """
1543 Release connection back to free queue.
1544 If the connection is marked for reconnect, it will be disconnected
1545 lazily when next acquired via disconnect_if_needed().
1546 """
1547 self._free.append(connection)
1548
1549 def get_encoder(self) -> Encoder:
1550 """Return an :class:`Encoder` derived from this node's connection kwargs."""
1551 kwargs = self.connection_kwargs
1552 encoder_class = kwargs.get("encoder_class", Encoder)
1553 return encoder_class(
1554 encoding=kwargs.get("encoding", "utf-8"),
1555 encoding_errors=kwargs.get("encoding_errors", "strict"),
1556 decode_responses=kwargs.get("decode_responses", False),
1557 )
1558
1559 def update_active_connections_for_reconnect(self) -> None:
1560 """
1561 Mark all in-use (active) connections for reconnect.
1562 In-use connections are those in _connections but not currently in _free.
1563 They will be disconnected when released back to the pool.
1564 """
1565 free_set = set(self._free)
1566 for connection in self._connections:
1567 if connection not in free_set:
1568 connection.mark_for_reconnect()
1569
1570 async def disconnect_free_connections(self) -> None:
1571 """
1572 Disconnect all free/idle connections in the pool.
1573 This is useful after topology changes (e.g., failover) to clear
1574 stale connection state like READONLY mode.
1575 The connections remain in the pool and will reconnect on next use.
1576 """
1577 if self._free:
1578 # Take a snapshot to avoid issues if _free changes during await
1579 await asyncio.gather(
1580 *(connection.disconnect() for connection in tuple(self._free)),
1581 return_exceptions=True,
1582 )
1583
1584 async def parse_response(
1585 self, connection: Connection, command: str, **kwargs: Any
1586 ) -> Any:
1587 try:
1588 if NEVER_DECODE in kwargs:
1589 response = await connection.read_response(disable_decoding=True)
1590 kwargs.pop(NEVER_DECODE)
1591 else:
1592 response = await connection.read_response()
1593 except ResponseError:
1594 if EMPTY_RESPONSE in kwargs:
1595 return kwargs[EMPTY_RESPONSE]
1596 raise
1597
1598 if EMPTY_RESPONSE in kwargs:
1599 kwargs.pop(EMPTY_RESPONSE)
1600
1601 # Remove keys entry, it needs only for cache.
1602 kwargs.pop("keys", None)
1603
1604 # Return response
1605 if command in self.response_callbacks:
1606 return self.response_callbacks[command](response, **kwargs)
1607
1608 return response
1609
1610 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
1611 # Acquire connection
1612 connection = self.acquire_connection()
1613 # Handle lazy disconnect for connections marked for reconnect
1614 await self.disconnect_if_needed(connection)
1615
1616 # Execute command
1617 await connection.send_packed_command(connection.pack_command(*args), False)
1618
1619 # Read response
1620 try:
1621 return await self.parse_response(connection, args[0], **kwargs)
1622 finally:
1623 await self.disconnect_if_needed(connection)
1624 # Release connection
1625 self._free.append(connection)
1626
1627 async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1628 # Acquire connection
1629 connection = self.acquire_connection()
1630 # Handle lazy disconnect for connections marked for reconnect
1631 await self.disconnect_if_needed(connection)
1632
1633 # Execute command
1634 await connection.send_packed_command(
1635 connection.pack_commands(cmd.args for cmd in commands), False
1636 )
1637
1638 # Read responses
1639 ret = False
1640 for cmd in commands:
1641 try:
1642 cmd.result = await self.parse_response(
1643 connection, cmd.args[0], **cmd.kwargs
1644 )
1645 except Exception as e:
1646 cmd.result = e
1647 ret = True
1648
1649 # Release connection
1650 await self.disconnect_if_needed(connection)
1651 self._free.append(connection)
1652
1653 return ret
1654
1655 async def re_auth_callback(self, token: TokenInterface):
1656 tmp_queue = collections.deque()
1657 while self._free:
1658 conn = self._free.popleft()
1659 await conn.retry.call_with_retry(
1660 lambda: conn.send_command(
1661 "AUTH", token.try_get("oid"), token.get_value()
1662 ),
1663 lambda error: self._mock(error),
1664 )
1665 await conn.retry.call_with_retry(
1666 lambda: conn.read_response(), lambda error: self._mock(error)
1667 )
1668 tmp_queue.append(conn)
1669
1670 while tmp_queue:
1671 conn = tmp_queue.popleft()
1672 self._free.append(conn)
1673
1674 async def _mock(self, error: RedisError):
1675 """
1676 Dummy functions, needs to be passed as error callback to retry object.
1677 :param error:
1678 :return:
1679 """
1680 pass
1681
1682
1683class NodesManager:
1684 __slots__ = (
1685 "_dynamic_startup_nodes",
1686 "_event_dispatcher",
1687 "_background_tasks",
1688 "connection_kwargs",
1689 "default_node",
1690 "nodes_cache",
1691 "_epoch",
1692 "read_load_balancer",
1693 "_initialize_lock",
1694 "require_full_coverage",
1695 "slots_cache",
1696 "startup_nodes",
1697 "address_remap",
1698 )
1699
1700 def __init__(
1701 self,
1702 startup_nodes: List["ClusterNode"],
1703 require_full_coverage: bool,
1704 connection_kwargs: Dict[str, Any],
1705 dynamic_startup_nodes: bool = True,
1706 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1707 event_dispatcher: Optional[EventDispatcher] = None,
1708 ) -> None:
1709 self.startup_nodes = {node.name: node for node in startup_nodes}
1710 self.require_full_coverage = require_full_coverage
1711 self.connection_kwargs = connection_kwargs
1712 self.address_remap = address_remap
1713
1714 self.default_node: "ClusterNode" = None
1715 self.nodes_cache: Dict[str, "ClusterNode"] = {}
1716 self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1717 self._epoch: int = 0
1718 self.read_load_balancer = LoadBalancer()
1719 self._initialize_lock: asyncio.Lock = asyncio.Lock()
1720
1721 self._background_tasks: Set[asyncio.Task] = set()
1722 self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1723 if event_dispatcher is None:
1724 self._event_dispatcher = EventDispatcher()
1725 else:
1726 self._event_dispatcher = event_dispatcher
1727
1728 def get_node(
1729 self,
1730 host: Optional[str] = None,
1731 port: Optional[int] = None,
1732 node_name: Optional[str] = None,
1733 ) -> Optional["ClusterNode"]:
1734 if host and port:
1735 # the user passed host and port
1736 if host == "localhost":
1737 host = socket.gethostbyname(host)
1738 return self.nodes_cache.get(get_node_name(host=host, port=port))
1739 elif node_name:
1740 return self.nodes_cache.get(node_name)
1741 else:
1742 raise DataError(
1743 "get_node requires one of the following: 1. node name 2. host and port"
1744 )
1745
1746 def set_nodes(
1747 self,
1748 old: Dict[str, "ClusterNode"],
1749 new: Dict[str, "ClusterNode"],
1750 remove_old: bool = False,
1751 ) -> None:
1752 if remove_old:
1753 for name in list(old.keys()):
1754 if name not in new:
1755 # Node is removed from cache before disconnect starts,
1756 # so it won't be found in lookups during disconnect
1757 # Mark active connections for reconnect so they get disconnected after current command completes
1758 # and disconnect free connections immediately
1759 # the node is removed from the cache before the connections changes so it won't be used and should be safe
1760 # not to wait for the disconnects
1761 removed_node = old.pop(name)
1762 removed_node.update_active_connections_for_reconnect()
1763 task = asyncio.create_task(
1764 removed_node.disconnect_free_connections()
1765 )
1766 self._background_tasks.add(task)
1767 task.add_done_callback(self._background_tasks.discard)
1768
1769 for name, node in new.items():
1770 if name in old:
1771 # Preserve the existing node but mark connections for reconnect.
1772 # This method is sync so we can't call disconnect_free_connections()
1773 # which is async. Instead, we mark free connections for reconnect
1774 # and they will be lazily disconnected when acquired via
1775 # disconnect_if_needed() to avoid race conditions.
1776 # TODO: Make this method async in the next major release to allow
1777 # immediate disconnection of free connections.
1778 existing_node = old[name]
1779 existing_node.update_active_connections_for_reconnect()
1780 for conn in existing_node._free:
1781 conn.mark_for_reconnect()
1782 continue
1783 # New node is detected and should be added to the pool
1784 old[name] = node
1785
1786 def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
1787 """
1788 Move a failing node to the end of startup_nodes and nodes_cache so it's
1789 tried last during reinitialization and when selecting the default node.
1790 If the node is not in the respective list, nothing is done.
1791 """
1792 # Move in startup_nodes
1793 if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
1794 node = self.startup_nodes.pop(node_name)
1795 self.startup_nodes[node_name] = node # Re-insert at end
1796
1797 # Move in nodes_cache - this affects get_nodes_by_server_type ordering
1798 # which is used to select the default_node during initialize()
1799 if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
1800 node = self.nodes_cache.pop(node_name)
1801 self.nodes_cache[node_name] = node # Re-insert at end
1802
1803 async def move_slot(self, e: AskError | MovedError):
1804 node_changed = False
1805 redirected_node = self.get_node(host=e.host, port=e.port)
1806 if redirected_node:
1807 # The node already exists
1808 if redirected_node.server_type != PRIMARY:
1809 # Update the node's server type
1810 redirected_node.server_type = PRIMARY
1811 else:
1812 # This is a new node, we will add it to the nodes cache
1813 redirected_node = ClusterNode(
1814 e.host, e.port, PRIMARY, **self.connection_kwargs
1815 )
1816 self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node})
1817 slot_nodes = self.slots_cache[e.slot_id]
1818 if redirected_node not in slot_nodes:
1819 # The new slot owner is a new server, or a server from a different
1820 # shard. We need to remove all current nodes from the slot's list
1821 # (including replications) and add just the new node.
1822 self.slots_cache[e.slot_id] = [redirected_node]
1823 node_changed = True
1824 elif redirected_node is not slot_nodes[0]:
1825 # The MOVED error resulted from a failover, and the new slot owner
1826 # had previously been a replica.
1827 old_primary = slot_nodes[0]
1828 # Update the old primary to be a replica and add it to the end of
1829 # the slot's node list
1830 old_primary.server_type = REPLICA
1831 slot_nodes.append(old_primary)
1832 # Remove the old replica, which is now a primary, from the slot's
1833 # node list
1834 slot_nodes.remove(redirected_node)
1835 # Override the old primary with the new one
1836 slot_nodes[0] = redirected_node
1837 if self.default_node == old_primary:
1838 # Update the default node with the new primary
1839 self.default_node = redirected_node
1840 node_changed = True
1841 # else: circular MOVED to current primary -> no-op
1842 # Dispatch so listeners can run shard-pubsub reconciliation; skipped on
1843 # the no-op branch to avoid needless walks under MOVED storms. A
1844 # listener must not break slots-cache refresh; log and continue so a
1845 # single buggy listener cannot starve the rest.
1846 if node_changed:
1847 try:
1848 await self._event_dispatcher.dispatch_async(
1849 AsyncAfterSlotsCacheRefreshEvent()
1850 )
1851 except Exception as exc:
1852 # Don't shadow the method parameter ``e``: ``except as`` binds
1853 # the listener exception in the function scope and ``del``s
1854 # the name on block exit (PEP 3134), which would also wipe
1855 # out the original AskError/MovedError parameter.
1856 logger.exception(
1857 "listener raised during slots-cache refresh: %s: %s",
1858 type(exc).__name__,
1859 exc,
1860 )
1861
1862 def get_node_from_slot(
1863 self,
1864 slot: int,
1865 read_from_replicas: bool = False,
1866 load_balancing_strategy=None,
1867 ) -> "ClusterNode":
1868 if read_from_replicas is True and load_balancing_strategy is None:
1869 load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN
1870
1871 try:
1872 if len(self.slots_cache[slot]) > 1 and load_balancing_strategy:
1873 # get the server index using the strategy defined in load_balancing_strategy
1874 primary_name = self.slots_cache[slot][0].name
1875 node_idx = self.read_load_balancer.get_server_index(
1876 primary_name, len(self.slots_cache[slot]), load_balancing_strategy
1877 )
1878 return self.slots_cache[slot][node_idx]
1879 return self.slots_cache[slot][0]
1880 except (IndexError, TypeError):
1881 raise SlotNotCoveredError(
1882 f'Slot "{slot}" not covered by the cluster. '
1883 f'"require_full_coverage={self.require_full_coverage}"'
1884 )
1885
1886 def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
1887 return [
1888 node
1889 for node in self.nodes_cache.values()
1890 if node.server_type == server_type
1891 ]
1892
1893 async def initialize(self) -> None:
1894 self.read_load_balancer.reset()
1895 tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
1896 tmp_slots: Dict[int, List["ClusterNode"]] = {}
1897 disagreements = []
1898 startup_nodes_reachable = False
1899 fully_covered = False
1900 exception = None
1901 epoch = self._epoch
1902
1903 async with self._initialize_lock:
1904 if self._epoch != epoch:
1905 # another initialize call has already reinitialized the
1906 # nodes since we started waiting for the lock;
1907 # we don't need to do it again.
1908 return
1909
1910 # Convert to tuple to prevent RuntimeError if self.startup_nodes
1911 # is modified during iteration
1912 for startup_node in tuple(self.startup_nodes.values()):
1913 try:
1914 # Make sure cluster mode is enabled on this node
1915 try:
1916 self._event_dispatcher.dispatch(
1917 AfterAsyncClusterInstantiationEvent(
1918 self.nodes_cache,
1919 self.connection_kwargs.get("credential_provider", None),
1920 )
1921 )
1922 cluster_slots = await startup_node.execute_command(
1923 "CLUSTER SLOTS"
1924 )
1925 except ResponseError:
1926 raise RedisClusterException(
1927 "Cluster mode is not enabled on this node"
1928 )
1929 startup_nodes_reachable = True
1930 except Exception as e:
1931 # Try the next startup node.
1932 # The exception is saved and raised only if we have no more nodes.
1933 exception = e
1934 continue
1935
1936 # CLUSTER SLOTS command results in the following output:
1937 # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]]
1938 # where each node contains the following list: [IP, port, node_id]
1939 # Therefore, cluster_slots[0][2][0] will be the IP address of the
1940 # primary node of the first slot section.
1941 # If there's only one server in the cluster, its ``host`` is ''
1942 # Fix it to the host in startup_nodes
1943 if (
1944 len(cluster_slots) == 1
1945 and not cluster_slots[0][2][0]
1946 and len(self.startup_nodes) == 1
1947 ):
1948 cluster_slots[0][2][0] = startup_node.host
1949
1950 for slot in cluster_slots:
1951 for i in range(2, len(slot)):
1952 slot[i] = [str_if_bytes(val) for val in slot[i]]
1953 primary_node = slot[2]
1954 host = primary_node[0]
1955 if host == "":
1956 host = startup_node.host
1957 port = int(primary_node[1])
1958 host, port = self.remap_host_port(host, port)
1959
1960 nodes_for_slot = []
1961
1962 target_node = tmp_nodes_cache.get(get_node_name(host, port))
1963 if not target_node:
1964 target_node = ClusterNode(
1965 host, port, PRIMARY, **self.connection_kwargs
1966 )
1967 # add this node to the nodes cache
1968 tmp_nodes_cache[target_node.name] = target_node
1969 nodes_for_slot.append(target_node)
1970
1971 replica_nodes = slot[3:]
1972 for replica_node in replica_nodes:
1973 host = replica_node[0]
1974 port = replica_node[1]
1975 host, port = self.remap_host_port(host, port)
1976
1977 target_replica_node = tmp_nodes_cache.get(
1978 get_node_name(host, port)
1979 )
1980 if not target_replica_node:
1981 target_replica_node = ClusterNode(
1982 host, port, REPLICA, **self.connection_kwargs
1983 )
1984 # add this node to the nodes cache
1985 tmp_nodes_cache[target_replica_node.name] = target_replica_node
1986 nodes_for_slot.append(target_replica_node)
1987
1988 for i in range(int(slot[0]), int(slot[1]) + 1):
1989 if i not in tmp_slots:
1990 tmp_slots[i] = nodes_for_slot
1991 else:
1992 # Validate that 2 nodes want to use the same slot cache
1993 # setup
1994 tmp_slot = tmp_slots[i][0]
1995 if tmp_slot.name != target_node.name:
1996 disagreements.append(
1997 f"{tmp_slot.name} vs {target_node.name} on slot: {i}"
1998 )
1999
2000 if len(disagreements) > 5:
2001 raise RedisClusterException(
2002 f"startup_nodes could not agree on a valid "
2003 f"slots cache: {', '.join(disagreements)}"
2004 )
2005
2006 # Validate if all slots are covered or if we should try next startup node
2007 fully_covered = True
2008 for i in range(REDIS_CLUSTER_HASH_SLOTS):
2009 if i not in tmp_slots:
2010 fully_covered = False
2011 break
2012 if fully_covered:
2013 break
2014
2015 if not startup_nodes_reachable:
2016 raise RedisClusterException(
2017 f"Redis Cluster cannot be connected. Please provide at least "
2018 f"one reachable node: {str(exception)}"
2019 ) from exception
2020
2021 # Check if the slots are not fully covered
2022 if not fully_covered and self.require_full_coverage:
2023 # Despite the requirement that the slots be covered, there
2024 # isn't a full coverage
2025 raise RedisClusterException(
2026 f"All slots are not covered after query all startup_nodes. "
2027 f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} "
2028 f"covered..."
2029 )
2030
2031 # Set the tmp variables to the real variables
2032 self.slots_cache = tmp_slots
2033 self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
2034
2035 if self._dynamic_startup_nodes:
2036 # Populate the startup nodes with all discovered nodes
2037 self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
2038
2039 # Set the default node
2040 self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
2041 self._epoch += 1
2042 # Dispatch so listeners (e.g. ClusterPubSub) can reconcile per-node
2043 # state after slot ownership may have changed. A listener must not
2044 # break slots-cache refresh; log and continue so a single buggy
2045 # listener cannot starve the rest.
2046 try:
2047 await self._event_dispatcher.dispatch_async(
2048 AsyncAfterSlotsCacheRefreshEvent()
2049 )
2050 except Exception as e:
2051 logger.exception(
2052 "listener raised during slots-cache refresh: %s: %s",
2053 type(e).__name__,
2054 e,
2055 )
2056
2057 async def aclose(self, attr: str = "nodes_cache") -> None:
2058 self.default_node = None
2059 await asyncio.gather(
2060 *(
2061 asyncio.create_task(node.disconnect())
2062 for node in getattr(self, attr).values()
2063 )
2064 )
2065
2066 def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
2067 """
2068 Remap the host and port returned from the cluster to a different
2069 internal value. Useful if the client is not connecting directly
2070 to the cluster.
2071 """
2072 if self.address_remap:
2073 return self.address_remap((host, port))
2074 return host, port
2075
2076
2077class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
2078 """
2079 Create a new ClusterPipeline object.
2080
2081 Usage::
2082
2083 result = await (
2084 rc.pipeline()
2085 .set("A", 1)
2086 .get("A")
2087 .hset("K", "F", "V")
2088 .hgetall("K")
2089 .mset_nonatomic({"A": 2, "B": 3})
2090 .get("A")
2091 .get("B")
2092 .delete("A", "B", "K")
2093 .execute()
2094 )
2095 # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
2096
2097 Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
2098 are split across multiple nodes, you'll get multiple results for them in the array.
2099
2100 Retryable errors:
2101 - :class:`~.ClusterDownError`
2102 - :class:`~.ConnectionError`
2103 - :class:`~.TimeoutError`
2104
2105 Redirection errors:
2106 - :class:`~.TryAgainError`
2107 - :class:`~.MovedError`
2108 - :class:`~.AskError`
2109
2110 :param client:
2111 | Existing :class:`~.RedisCluster` client
2112 """
2113
2114 __slots__ = (
2115 "cluster_client",
2116 "_transaction",
2117 "_execution_strategy",
2118 )
2119
2120 # Type discrimination marker for @overload self-type pattern
2121 _is_async_client: Literal[True] = True
2122
2123 def __init__(
2124 self, client: RedisCluster, transaction: Optional[bool] = None
2125 ) -> None:
2126 self.cluster_client = client
2127 self._transaction = transaction
2128 self._execution_strategy: ExecutionStrategy = (
2129 PipelineStrategy(self)
2130 if not self._transaction
2131 else TransactionStrategy(self)
2132 )
2133
2134 @property
2135 def nodes_manager(self) -> "NodesManager":
2136 """Get the nodes manager from the cluster client."""
2137 return self.cluster_client.nodes_manager
2138
2139 def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
2140 """Set a custom response callback on the cluster client."""
2141 self.cluster_client.set_response_callback(command, callback)
2142
2143 async def initialize(self) -> "ClusterPipeline":
2144 await self._execution_strategy.initialize()
2145 return self
2146
2147 async def __aenter__(self) -> "ClusterPipeline":
2148 return await self.initialize()
2149
2150 async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
2151 await self.reset()
2152
2153 def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
2154 return self.initialize().__await__()
2155
2156 def __bool__(self) -> bool:
2157 "Pipeline instances should always evaluate to True on Python 3+"
2158 return True
2159
2160 def __len__(self) -> int:
2161 return len(self._execution_strategy)
2162
2163 def execute_command(
2164 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2165 ) -> "ClusterPipeline":
2166 """
2167 Append a raw command to the pipeline.
2168
2169 :param args:
2170 | Raw command args
2171 :param kwargs:
2172
2173 - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
2174 or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
2175 - Rest of the kwargs are passed to the Redis connection
2176 """
2177 return self._execution_strategy.execute_command(*args, **kwargs)
2178
2179 async def execute(
2180 self, raise_on_error: bool = True, allow_redirections: bool = True
2181 ) -> List[Any]:
2182 """
2183 Execute the pipeline.
2184
2185 It will retry the commands as specified by retries specified in :attr:`retry`
2186 & then raise an exception.
2187
2188 :param raise_on_error:
2189 | Raise the first error if there are any errors
2190 :param allow_redirections:
2191 | Whether to retry each failed command individually in case of redirection
2192 errors
2193
2194 :raises RedisClusterException: if target_nodes is not provided & the command
2195 can't be mapped to a slot
2196 """
2197 try:
2198 return await self._execution_strategy.execute(
2199 raise_on_error, allow_redirections
2200 )
2201 finally:
2202 await self.reset()
2203
2204 def _split_command_across_slots(
2205 self, command: str, *keys: KeyT
2206 ) -> "ClusterPipeline":
2207 for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
2208 self.execute_command(command, *slot_keys)
2209
2210 return self
2211
2212 async def reset(self):
2213 """
2214 Reset back to empty pipeline.
2215 """
2216 await self._execution_strategy.reset()
2217
2218 def multi(self):
2219 """
2220 Start a transactional block of the pipeline after WATCH commands
2221 are issued. End the transactional block with `execute`.
2222 """
2223 self._execution_strategy.multi()
2224
2225 async def discard(self):
2226 """ """
2227 await self._execution_strategy.discard()
2228
2229 async def watch(self, *names):
2230 """Watches the values at keys ``names``"""
2231 await self._execution_strategy.watch(*names)
2232
2233 async def unwatch(self):
2234 """Unwatches all previously specified keys"""
2235 await self._execution_strategy.unwatch()
2236
2237 async def unlink(self, *names):
2238 await self._execution_strategy.unlink(*names)
2239
2240 def mset_nonatomic(
2241 self, mapping: Mapping[AnyKeyT, EncodableT]
2242 ) -> "ClusterPipeline":
2243 return self._execution_strategy.mset_nonatomic(mapping)
2244
2245
2246for command in PIPELINE_BLOCKED_COMMANDS:
2247 command = command.replace(" ", "_").lower()
2248 if command == "mset_nonatomic":
2249 continue
2250
2251 setattr(ClusterPipeline, command, block_pipeline_command(command))
2252
2253
2254class PipelineCommand:
2255 def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
2256 self.args = args
2257 self.kwargs = kwargs
2258 self.position = position
2259 self.result: Union[Any, Exception] = None
2260 self.command_policies: Optional[CommandPolicies] = None
2261
2262 def __repr__(self) -> str:
2263 return f"[{self.position}] {self.args} ({self.kwargs})"
2264
2265
2266class ExecutionStrategy(ABC):
2267 @abstractmethod
2268 async def initialize(self) -> "ClusterPipeline":
2269 """
2270 Initialize the execution strategy.
2271
2272 See ClusterPipeline.initialize()
2273 """
2274 pass
2275
2276 @abstractmethod
2277 def execute_command(
2278 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2279 ) -> "ClusterPipeline":
2280 """
2281 Append a raw command to the pipeline.
2282
2283 See ClusterPipeline.execute_command()
2284 """
2285 pass
2286
2287 @abstractmethod
2288 async def execute(
2289 self, raise_on_error: bool = True, allow_redirections: bool = True
2290 ) -> List[Any]:
2291 """
2292 Execute the pipeline.
2293
2294 It will retry the commands as specified by retries specified in :attr:`retry`
2295 & then raise an exception.
2296
2297 See ClusterPipeline.execute()
2298 """
2299 pass
2300
2301 @abstractmethod
2302 def mset_nonatomic(
2303 self, mapping: Mapping[AnyKeyT, EncodableT]
2304 ) -> "ClusterPipeline":
2305 """
2306 Executes multiple MSET commands according to the provided slot/pairs mapping.
2307
2308 See ClusterPipeline.mset_nonatomic()
2309 """
2310 pass
2311
2312 @abstractmethod
2313 async def reset(self):
2314 """
2315 Resets current execution strategy.
2316
2317 See: ClusterPipeline.reset()
2318 """
2319 pass
2320
2321 @abstractmethod
2322 def multi(self):
2323 """
2324 Starts transactional context.
2325
2326 See: ClusterPipeline.multi()
2327 """
2328 pass
2329
2330 @abstractmethod
2331 async def watch(self, *names):
2332 """
2333 Watch given keys.
2334
2335 See: ClusterPipeline.watch()
2336 """
2337 pass
2338
2339 @abstractmethod
2340 async def unwatch(self):
2341 """
2342 Unwatches all previously specified keys
2343
2344 See: ClusterPipeline.unwatch()
2345 """
2346 pass
2347
2348 @abstractmethod
2349 async def discard(self):
2350 pass
2351
2352 @abstractmethod
2353 async def unlink(self, *names):
2354 """
2355 "Unlink a key specified by ``names``"
2356
2357 See: ClusterPipeline.unlink()
2358 """
2359 pass
2360
2361 @abstractmethod
2362 def __len__(self) -> int:
2363 pass
2364
2365
2366class AbstractStrategy(ExecutionStrategy):
2367 def __init__(self, pipe: ClusterPipeline) -> None:
2368 self._pipe: ClusterPipeline = pipe
2369 self._command_queue: List["PipelineCommand"] = []
2370
2371 async def initialize(self) -> "ClusterPipeline":
2372 if self._pipe.cluster_client._initialize:
2373 await self._pipe.cluster_client.initialize()
2374 self._command_queue = []
2375 return self._pipe
2376
2377 def execute_command(
2378 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2379 ) -> "ClusterPipeline":
2380 self._command_queue.append(
2381 PipelineCommand(len(self._command_queue), *args, **kwargs)
2382 )
2383 return self._pipe
2384
2385 def _annotate_exception(self, exception, number, command):
2386 """
2387 Provides extra context to the exception prior to it being handled
2388 """
2389 cmd = " ".join(map(safe_str, command))
2390 msg = (
2391 f"Command # {number} ({truncate_text(cmd)}) of pipeline "
2392 f"caused error: {exception.args[0]}"
2393 )
2394 exception.args = (msg,) + exception.args[1:]
2395
2396 @abstractmethod
2397 def mset_nonatomic(
2398 self, mapping: Mapping[AnyKeyT, EncodableT]
2399 ) -> "ClusterPipeline":
2400 pass
2401
2402 @abstractmethod
2403 async def execute(
2404 self, raise_on_error: bool = True, allow_redirections: bool = True
2405 ) -> List[Any]:
2406 pass
2407
2408 @abstractmethod
2409 async def reset(self):
2410 pass
2411
2412 @abstractmethod
2413 def multi(self):
2414 pass
2415
2416 @abstractmethod
2417 async def watch(self, *names):
2418 pass
2419
2420 @abstractmethod
2421 async def unwatch(self):
2422 pass
2423
2424 @abstractmethod
2425 async def discard(self):
2426 pass
2427
2428 @abstractmethod
2429 async def unlink(self, *names):
2430 pass
2431
2432 def __len__(self) -> int:
2433 return len(self._command_queue)
2434
2435
2436class PipelineStrategy(AbstractStrategy):
2437 def __init__(self, pipe: ClusterPipeline) -> None:
2438 super().__init__(pipe)
2439
2440 def mset_nonatomic(
2441 self, mapping: Mapping[AnyKeyT, EncodableT]
2442 ) -> "ClusterPipeline":
2443 encoder = self._pipe.cluster_client.encoder
2444
2445 slots_pairs = {}
2446 for pair in mapping.items():
2447 slot = key_slot(encoder.encode(pair[0]))
2448 slots_pairs.setdefault(slot, []).extend(pair)
2449
2450 for pairs in slots_pairs.values():
2451 self.execute_command("MSET", *pairs)
2452
2453 return self._pipe
2454
2455 async def execute(
2456 self, raise_on_error: bool = True, allow_redirections: bool = True
2457 ) -> List[Any]:
2458 if not self._command_queue:
2459 return []
2460
2461 try:
2462 retry_attempts = self._pipe.cluster_client.retry.get_retries()
2463 while True:
2464 try:
2465 if self._pipe.cluster_client._initialize:
2466 await self._pipe.cluster_client.initialize()
2467 return await self._execute(
2468 self._pipe.cluster_client,
2469 self._command_queue,
2470 raise_on_error=raise_on_error,
2471 allow_redirections=allow_redirections,
2472 )
2473
2474 except RedisCluster.ERRORS_ALLOW_RETRY as e:
2475 if retry_attempts > 0:
2476 # Try again with the new cluster setup. All other errors
2477 # should be raised.
2478 retry_attempts -= 1
2479 await self._pipe.cluster_client.aclose()
2480 await asyncio.sleep(0.25)
2481 else:
2482 # All other errors should be raised.
2483 raise e
2484 finally:
2485 await self.reset()
2486
2487 async def _execute(
2488 self,
2489 client: "RedisCluster",
2490 stack: List["PipelineCommand"],
2491 raise_on_error: bool = True,
2492 allow_redirections: bool = True,
2493 ) -> List[Any]:
2494 todo = [
2495 cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
2496 ]
2497
2498 nodes = {}
2499 for cmd in todo:
2500 passed_targets = cmd.kwargs.pop("target_nodes", None)
2501 command_policies = await client._policy_resolver.resolve(
2502 cmd.args[0].lower()
2503 )
2504
2505 if passed_targets and not client._is_node_flag(passed_targets):
2506 target_nodes = client._parse_target_nodes(passed_targets)
2507
2508 if not command_policies:
2509 command_policies = CommandPolicies()
2510 else:
2511 if not command_policies:
2512 command_flag = client.command_flags.get(cmd.args[0])
2513 if not command_flag:
2514 # Fallback to default policy
2515 if not client.get_default_node():
2516 slot = None
2517 else:
2518 slot = await client._determine_slot(*cmd.args)
2519 if slot is None:
2520 command_policies = CommandPolicies()
2521 else:
2522 command_policies = CommandPolicies(
2523 request_policy=RequestPolicy.DEFAULT_KEYED,
2524 response_policy=ResponsePolicy.DEFAULT_KEYED,
2525 )
2526 else:
2527 if command_flag in client._command_flags_mapping:
2528 command_policies = CommandPolicies(
2529 request_policy=client._command_flags_mapping[
2530 command_flag
2531 ]
2532 )
2533 else:
2534 command_policies = CommandPolicies()
2535
2536 target_nodes = await client._determine_nodes(
2537 *cmd.args,
2538 request_policy=command_policies.request_policy,
2539 node_flag=passed_targets,
2540 )
2541 if not target_nodes:
2542 raise RedisClusterException(
2543 f"No targets were found to execute {cmd.args} command on"
2544 )
2545 cmd.command_policies = command_policies
2546 if len(target_nodes) > 1:
2547 raise RedisClusterException(f"Too many targets for command {cmd.args}")
2548 node = target_nodes[0]
2549 if node.name not in nodes:
2550 nodes[node.name] = (node, [])
2551 nodes[node.name][1].append(cmd)
2552
2553 # Start timing for observability
2554 start_time = time.monotonic()
2555
2556 errors = await asyncio.gather(
2557 *(
2558 asyncio.create_task(node[0].execute_pipeline(node[1]))
2559 for node in nodes.values()
2560 )
2561 )
2562
2563 # Record operation duration for each node
2564 for node_name, (node, commands) in nodes.items():
2565 # Find the first error in this node's commands, if any
2566 node_error = None
2567 for cmd in commands:
2568 if isinstance(cmd.result, Exception):
2569 node_error = cmd.result
2570 break
2571
2572 db = node.connection_kwargs.get("db", 0)
2573 await record_operation_duration(
2574 command_name="PIPELINE",
2575 duration_seconds=time.monotonic() - start_time,
2576 server_address=node.host,
2577 server_port=node.port,
2578 db_namespace=str(db) if db is not None else None,
2579 error=node_error,
2580 )
2581
2582 if any(errors):
2583 if allow_redirections:
2584 # send each errored command individually
2585 for cmd in todo:
2586 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
2587 try:
2588 cmd.result = client._policies_callback_mapping[
2589 cmd.command_policies.response_policy
2590 ](await client.execute_command(*cmd.args, **cmd.kwargs))
2591 except Exception as e:
2592 cmd.result = e
2593
2594 if raise_on_error:
2595 for cmd in todo:
2596 result = cmd.result
2597 if isinstance(result, Exception):
2598 command = " ".join(map(safe_str, cmd.args))
2599 msg = (
2600 f"Command # {cmd.position + 1} "
2601 f"({truncate_text(command)}) "
2602 f"of pipeline caused error: {result.args}"
2603 )
2604 result.args = (msg,) + result.args[1:]
2605 raise result
2606
2607 default_cluster_node = client.get_default_node()
2608
2609 # Check whether the default node was used. In some cases,
2610 # 'client.get_default_node()' may return None. The check below
2611 # prevents a potential AttributeError.
2612 if default_cluster_node is not None:
2613 default_node = nodes.get(default_cluster_node.name)
2614 if default_node is not None:
2615 # This pipeline execution used the default node, check if we need
2616 # to replace it.
2617 # Note: when the error is raised we'll reset the default node in the
2618 # caller function.
2619 for cmd in default_node[1]:
2620 # Check if it has a command that failed with a relevant
2621 # exception
2622 if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
2623 client.replace_default_node()
2624 break
2625
2626 return [cmd.result for cmd in stack]
2627
2628 async def reset(self):
2629 """
2630 Reset back to empty pipeline.
2631 """
2632 self._command_queue = []
2633
2634 def multi(self):
2635 raise RedisClusterException(
2636 "method multi() is not supported outside of transactional context"
2637 )
2638
2639 async def watch(self, *names):
2640 raise RedisClusterException(
2641 "method watch() is not supported outside of transactional context"
2642 )
2643
2644 async def unwatch(self):
2645 raise RedisClusterException(
2646 "method unwatch() is not supported outside of transactional context"
2647 )
2648
2649 async def discard(self):
2650 raise RedisClusterException(
2651 "method discard() is not supported outside of transactional context"
2652 )
2653
2654 async def unlink(self, *names):
2655 if len(names) != 1:
2656 raise RedisClusterException(
2657 "unlinking multiple keys is not implemented in pipeline command"
2658 )
2659
2660 return self.execute_command("UNLINK", names[0])
2661
2662
2663class TransactionStrategy(AbstractStrategy):
2664 NO_SLOTS_COMMANDS = {"UNWATCH"}
2665 IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2666 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2667 SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2668 CONNECTION_ERRORS = (
2669 ConnectionError,
2670 OSError,
2671 ClusterDownError,
2672 SlotNotCoveredError,
2673 )
2674
2675 def __init__(self, pipe: ClusterPipeline) -> None:
2676 super().__init__(pipe)
2677 self._explicit_transaction = False
2678 self._watching = False
2679 self._pipeline_slots: Set[int] = set()
2680 self._transaction_node: Optional[ClusterNode] = None
2681 self._transaction_connection: Optional[Connection] = None
2682 self._executing = False
2683 self._retry = copy(self._pipe.cluster_client.retry)
2684 self._retry.update_supported_errors(
2685 RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2686 )
2687
2688 def _get_client_and_connection_for_transaction(
2689 self,
2690 ) -> Tuple[ClusterNode, Connection]:
2691 """
2692 Find a connection for a pipeline transaction.
2693
2694 For running an atomic transaction, watch keys ensure that contents have not been
2695 altered as long as the watch commands for those keys were sent over the same
2696 connection. So once we start watching a key, we fetch a connection to the
2697 node that owns that slot and reuse it.
2698 """
2699 if not self._pipeline_slots:
2700 raise RedisClusterException(
2701 "At least a command with a key is needed to identify a node"
2702 )
2703
2704 node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2705 list(self._pipeline_slots)[0], False
2706 )
2707 self._transaction_node = node
2708
2709 if not self._transaction_connection:
2710 connection: Connection = self._transaction_node.acquire_connection()
2711 self._transaction_connection = connection
2712
2713 return self._transaction_node, self._transaction_connection
2714
2715 def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2716 # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2717 response = None
2718 error = None
2719
2720 def runner():
2721 nonlocal response
2722 nonlocal error
2723 try:
2724 response = asyncio.run(self._execute_command(*args, **kwargs))
2725 except Exception as e:
2726 error = e
2727
2728 thread = threading.Thread(target=runner)
2729 thread.start()
2730 thread.join()
2731
2732 if error:
2733 raise error
2734
2735 return response
2736
2737 async def _execute_command(
2738 self, *args: Union[KeyT, EncodableT], **kwargs: Any
2739 ) -> Any:
2740 if self._pipe.cluster_client._initialize:
2741 await self._pipe.cluster_client.initialize()
2742
2743 slot_number: Optional[int] = None
2744 if args[0] not in self.NO_SLOTS_COMMANDS:
2745 slot_number = await self._pipe.cluster_client._determine_slot(*args)
2746
2747 if (
2748 self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2749 ) and not self._explicit_transaction:
2750 if args[0] == "WATCH":
2751 self._validate_watch()
2752
2753 if slot_number is not None:
2754 if self._pipeline_slots and slot_number not in self._pipeline_slots:
2755 raise CrossSlotTransactionError(
2756 "Cannot watch or send commands on different slots"
2757 )
2758
2759 self._pipeline_slots.add(slot_number)
2760 elif args[0] not in self.NO_SLOTS_COMMANDS:
2761 raise RedisClusterException(
2762 f"Cannot identify slot number for command: {args[0]},"
2763 "it cannot be triggered in a transaction"
2764 )
2765
2766 return self._immediate_execute_command(*args, **kwargs)
2767 else:
2768 if slot_number is not None:
2769 self._pipeline_slots.add(slot_number)
2770
2771 return super().execute_command(*args, **kwargs)
2772
2773 def _validate_watch(self):
2774 if self._explicit_transaction:
2775 raise RedisError("Cannot issue a WATCH after a MULTI")
2776
2777 self._watching = True
2778
2779 async def _immediate_execute_command(self, *args, **options):
2780 return await self._retry.call_with_retry(
2781 lambda: self._get_connection_and_send_command(*args, **options),
2782 self._reinitialize_on_error,
2783 with_failure_count=True,
2784 )
2785
2786 async def _get_connection_and_send_command(self, *args, **options):
2787 redis_node, connection = self._get_client_and_connection_for_transaction()
2788 # Only disconnect if not watching - disconnecting would lose WATCH state
2789 if not self._watching:
2790 await redis_node.disconnect_if_needed(connection)
2791
2792 # Start timing for observability
2793 start_time = time.monotonic()
2794
2795 try:
2796 response = await self._send_command_parse_response(
2797 connection, redis_node, args[0], *args, **options
2798 )
2799
2800 await record_operation_duration(
2801 command_name=args[0],
2802 duration_seconds=time.monotonic() - start_time,
2803 server_address=connection.host,
2804 server_port=connection.port,
2805 db_namespace=str(connection.db),
2806 )
2807
2808 return response
2809 except Exception as e:
2810 e.connection = connection
2811 await record_operation_duration(
2812 command_name=args[0],
2813 duration_seconds=time.monotonic() - start_time,
2814 server_address=connection.host,
2815 server_port=connection.port,
2816 db_namespace=str(connection.db),
2817 error=e,
2818 )
2819 raise
2820
2821 async def _send_command_parse_response(
2822 self,
2823 connection: Connection,
2824 redis_node: ClusterNode,
2825 command_name,
2826 *args,
2827 **options,
2828 ):
2829 """
2830 Send a command and parse the response
2831 """
2832
2833 await connection.send_command(*args)
2834 output = await redis_node.parse_response(connection, command_name, **options)
2835
2836 if command_name in self.UNWATCH_COMMANDS:
2837 self._watching = False
2838 return output
2839
2840 async def _reinitialize_on_error(self, error, failure_count):
2841 if hasattr(error, "connection"):
2842 await record_error_count(
2843 server_address=error.connection.host,
2844 server_port=error.connection.port,
2845 network_peer_address=error.connection.host,
2846 network_peer_port=error.connection.port,
2847 error_type=error,
2848 retry_attempts=failure_count,
2849 is_internal=True,
2850 )
2851
2852 if self._watching:
2853 if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2854 raise WatchError("Slot rebalancing occurred while watching keys")
2855
2856 if (
2857 type(error) in self.SLOT_REDIRECT_ERRORS
2858 or type(error) in self.CONNECTION_ERRORS
2859 ):
2860 if self._transaction_connection and self._transaction_node:
2861 # Disconnect and release back to pool
2862 await self._transaction_connection.disconnect()
2863 self._transaction_node.release(self._transaction_connection)
2864 self._transaction_connection = None
2865
2866 self._pipe.cluster_client.reinitialize_counter += 1
2867 if (
2868 self._pipe.cluster_client.reinitialize_steps
2869 and self._pipe.cluster_client.reinitialize_counter
2870 % self._pipe.cluster_client.reinitialize_steps
2871 == 0
2872 ):
2873 await self._pipe.cluster_client.nodes_manager.initialize()
2874 self.reinitialize_counter = 0
2875 else:
2876 if isinstance(error, AskError):
2877 await self._pipe.cluster_client.nodes_manager.move_slot(error)
2878
2879 self._executing = False
2880
2881 async def _raise_first_error(self, responses, stack, start_time):
2882 """
2883 Raise the first exception on the stack
2884 """
2885 for r, cmd in zip(responses, stack):
2886 if isinstance(r, Exception):
2887 self._annotate_exception(r, cmd.position + 1, cmd.args)
2888
2889 await record_operation_duration(
2890 command_name="TRANSACTION",
2891 duration_seconds=time.monotonic() - start_time,
2892 server_address=self._transaction_connection.host,
2893 server_port=self._transaction_connection.port,
2894 db_namespace=str(self._transaction_connection.db),
2895 error=r,
2896 )
2897
2898 raise r
2899
2900 def mset_nonatomic(
2901 self, mapping: Mapping[AnyKeyT, EncodableT]
2902 ) -> "ClusterPipeline":
2903 raise NotImplementedError("Method is not supported in transactional context.")
2904
2905 async def execute(
2906 self, raise_on_error: bool = True, allow_redirections: bool = True
2907 ) -> List[Any]:
2908 stack = self._command_queue
2909 if not stack and (not self._watching or not self._pipeline_slots):
2910 return []
2911
2912 return await self._execute_transaction_with_retries(stack, raise_on_error)
2913
2914 async def _execute_transaction_with_retries(
2915 self, stack: List["PipelineCommand"], raise_on_error: bool
2916 ):
2917 return await self._retry.call_with_retry(
2918 lambda: self._execute_transaction(stack, raise_on_error),
2919 lambda error, failure_count: self._reinitialize_on_error(
2920 error, failure_count
2921 ),
2922 with_failure_count=True,
2923 )
2924
2925 async def _execute_transaction(
2926 self, stack: List["PipelineCommand"], raise_on_error: bool
2927 ):
2928 if len(self._pipeline_slots) > 1:
2929 raise CrossSlotTransactionError(
2930 "All keys involved in a cluster transaction must map to the same slot"
2931 )
2932
2933 self._executing = True
2934
2935 redis_node, connection = self._get_client_and_connection_for_transaction()
2936 # Only disconnect if not watching - disconnecting would lose WATCH state
2937 if not self._watching:
2938 await redis_node.disconnect_if_needed(connection)
2939
2940 stack = chain(
2941 [PipelineCommand(0, "MULTI")],
2942 stack,
2943 [PipelineCommand(0, "EXEC")],
2944 )
2945 commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2946 packed_commands = connection.pack_commands(commands)
2947
2948 # Start timing for observability
2949 start_time = time.monotonic()
2950
2951 await connection.send_packed_command(packed_commands)
2952 errors = []
2953
2954 # parse off the response for MULTI
2955 # NOTE: we need to handle ResponseErrors here and continue
2956 # so that we read all the additional command messages from
2957 # the socket
2958 try:
2959 await redis_node.parse_response(connection, "MULTI")
2960 except ResponseError as e:
2961 self._annotate_exception(e, 0, "MULTI")
2962 errors.append(e)
2963 except self.CONNECTION_ERRORS as cluster_error:
2964 self._annotate_exception(cluster_error, 0, "MULTI")
2965 cluster_error.connection = connection
2966 raise
2967
2968 # and all the other commands
2969 for i, command in enumerate(self._command_queue):
2970 if EMPTY_RESPONSE in command.kwargs:
2971 errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2972 else:
2973 try:
2974 _ = await redis_node.parse_response(connection, "_")
2975 except self.SLOT_REDIRECT_ERRORS as slot_error:
2976 self._annotate_exception(slot_error, i + 1, command.args)
2977 errors.append(slot_error)
2978 except self.CONNECTION_ERRORS as cluster_error:
2979 self._annotate_exception(cluster_error, i + 1, command.args)
2980 cluster_error.connection = connection
2981 raise
2982 except ResponseError as e:
2983 self._annotate_exception(e, i + 1, command.args)
2984 errors.append(e)
2985
2986 response = None
2987 # parse the EXEC.
2988 try:
2989 response = await redis_node.parse_response(connection, "EXEC")
2990 except ExecAbortError:
2991 if errors:
2992 raise errors[0]
2993 raise
2994
2995 self._executing = False
2996
2997 # EXEC clears any watched keys
2998 self._watching = False
2999
3000 if response is None:
3001 raise WatchError("Watched variable changed.")
3002
3003 # put any parse errors into the response
3004 for i, e in errors:
3005 response.insert(i, e)
3006
3007 if len(response) != len(self._command_queue):
3008 raise InvalidPipelineStack(
3009 "Unexpected response length for cluster pipeline EXEC."
3010 " Command stack was {} but response had length {}".format(
3011 [c.args[0] for c in self._command_queue], len(response)
3012 )
3013 )
3014
3015 # find any errors in the response and raise if necessary
3016 if raise_on_error or len(errors) > 0:
3017 await self._raise_first_error(
3018 response,
3019 self._command_queue,
3020 start_time,
3021 )
3022
3023 # We have to run response callbacks manually
3024 data = []
3025 for r, cmd in zip(response, self._command_queue):
3026 if not isinstance(r, Exception):
3027 command_name = cmd.args[0]
3028 if command_name in self._pipe.cluster_client.response_callbacks:
3029 r = self._pipe.cluster_client.response_callbacks[command_name](
3030 r, **cmd.kwargs
3031 )
3032 data.append(r)
3033
3034 await record_operation_duration(
3035 command_name="TRANSACTION",
3036 duration_seconds=time.monotonic() - start_time,
3037 server_address=connection.host,
3038 server_port=connection.port,
3039 db_namespace=str(connection.db),
3040 )
3041
3042 return data
3043
3044 async def reset(self):
3045 self._command_queue = []
3046
3047 # make sure to reset the connection state in the event that we were
3048 # watching something
3049 if self._transaction_connection:
3050 try:
3051 if self._watching:
3052 # call this manually since our unwatch or
3053 # immediate_execute_command methods can call reset()
3054 await self._transaction_connection.send_command("UNWATCH")
3055 await self._transaction_connection.read_response()
3056 # we can safely return the connection to the pool here since we're
3057 # sure we're no longer WATCHing anything
3058 self._transaction_node.release(self._transaction_connection)
3059 self._transaction_connection = None
3060 except self.CONNECTION_ERRORS:
3061 # disconnect will also remove any previous WATCHes
3062 if self._transaction_connection and self._transaction_node:
3063 await self._transaction_connection.disconnect()
3064 self._transaction_node.release(self._transaction_connection)
3065 self._transaction_connection = None
3066
3067 # clean up the other instance attributes
3068 self._transaction_node = None
3069 self._watching = False
3070 self._explicit_transaction = False
3071 self._pipeline_slots = set()
3072 self._executing = False
3073
3074 def multi(self):
3075 if self._explicit_transaction:
3076 raise RedisError("Cannot issue nested calls to MULTI")
3077 if self._command_queue:
3078 raise RedisError(
3079 "Commands without an initial WATCH have already been issued"
3080 )
3081 self._explicit_transaction = True
3082
3083 async def watch(self, *names):
3084 if self._explicit_transaction:
3085 raise RedisError("Cannot issue a WATCH after a MULTI")
3086
3087 return await self.execute_command("WATCH", *names)
3088
3089 async def unwatch(self):
3090 if self._watching:
3091 return await self.execute_command("UNWATCH")
3092
3093 return True
3094
3095 async def discard(self):
3096 await self.reset()
3097
3098 async def unlink(self, *names):
3099 return self.execute_command("UNLINK", *names)
3100
3101
3102class _ClusterNodePoolAdapter(ConnectionPoolInterface):
3103 """Thin adapter exposing the :class:`ConnectionPoolInterface` that
3104 :class:`PubSub` requires, backed by a :class:`ClusterNode`'s own
3105 connection pool.
3106
3107 Connections are acquired from the node via
3108 :meth:`ClusterNode.acquire_connection` and returned via
3109 :meth:`ClusterNode.release`. :meth:`PubSub.aclose` already
3110 disconnects the connection *before* calling :meth:`release`, so the
3111 connection is returned to the node's free-queue in a disconnected
3112 state — guaranteeing that a subscribed socket is never silently
3113 reused for regular commands.
3114
3115 Methods that do not apply to this adapter (the underlying node's
3116 lifecycle is managed by the cluster, not by individual PubSub
3117 instances) are implemented as no-ops so the adapter remains a valid
3118 :class:`ConnectionPoolInterface`.
3119 """
3120
3121 def __init__(self, node: "ClusterNode") -> None:
3122 self._node = node
3123 self.connection_kwargs = node.connection_kwargs
3124
3125 # -- methods used by PubSub ------------------------------------------------
3126
3127 def get_encoder(self) -> Encoder:
3128 return self._node.get_encoder()
3129
3130 async def get_connection(
3131 self, command_name: Optional[str] = None, *keys: Any, **options: Any
3132 ) -> AbstractConnection:
3133 connection = self._node.acquire_connection()
3134 try:
3135 await connection.connect()
3136 except BaseException:
3137 # connect() may fail mid-handshake (e.g. after the TCP socket
3138 # is established but before AUTH/HELLO completes) leaving the
3139 # connection in a partially-connected state. Disconnect before
3140 # returning it to the node's free queue so it is not reused.
3141 await connection.disconnect()
3142 self._node.release(connection)
3143 raise
3144 return connection
3145
3146 async def release(self, connection: AbstractConnection) -> None:
3147 # PubSub.aclose() disconnects the connection before calling
3148 # release(), so it is safe to put it back in the node's free
3149 # queue – it will reconnect lazily on next use.
3150 self._node.release(connection)
3151
3152 # -- no-op stubs for the rest of ConnectionPoolInterface -------------------
3153 # The node's connections are shared with regular cluster traffic and its
3154 # lifecycle is managed by RedisCluster / NodesManager, so the adapter must
3155 # not reset, disconnect, retry-configure or re-auth them on behalf of a
3156 # single PubSub instance.
3157
3158 def get_protocol(self):
3159 return self.connection_kwargs.get("protocol", None)
3160
3161 def reset(self) -> None:
3162 pass
3163
3164 async def disconnect(self, inuse_connections: bool = True) -> None:
3165 pass
3166
3167 async def aclose(self) -> None:
3168 pass
3169
3170 def set_retry(self, retry: "Retry") -> None:
3171 pass
3172
3173 async def re_auth_callback(self, token: TokenInterface) -> None:
3174 pass
3175
3176 def get_connection_count(self) -> List[Tuple[int, dict]]:
3177 return []
3178
3179
3180def _unregister_slots_cache_listener(
3181 dispatcher_ref: "weakref.ref[EventDispatcher]",
3182 listener: AsyncEventListenerInterface,
3183 event_type: Type[object],
3184) -> None:
3185 # Module-level finalizer callback. Kept free of strong references to the
3186 # owning ClusterPubSub so attaching it via weakref.finalize does not
3187 # extend the pubsub's lifetime.
3188 dispatcher = dispatcher_ref()
3189 if dispatcher is not None:
3190 dispatcher.unregister_listeners({event_type: [listener]})
3191
3192
3193class ClusterPubSubSlotsCacheListener(AsyncEventListenerInterface):
3194 """
3195 Async listener that forwards AsyncAfterSlotsCacheRefreshEvent to a
3196 ClusterPubSub.
3197
3198 Holds a weak reference to the pubsub so it does not keep the instance
3199 alive. Deterministic cleanup of the dispatcher's strong reference to this
3200 listener is performed by a ``weakref.finalize`` attached to the owning
3201 ClusterPubSub in ``ClusterPubSub.__init__``.
3202 """
3203
3204 def __init__(self, pubsub: "ClusterPubSub") -> None:
3205 self._pubsub_ref: "weakref.ref[ClusterPubSub]" = weakref.ref(pubsub)
3206
3207 async def listen(self, event: object) -> None:
3208 pubsub = self._pubsub_ref()
3209 if pubsub is None:
3210 # Race window between pubsub GC and the finalizer running; safe
3211 # no-op, finalizer will remove this listener shortly.
3212 return
3213 try:
3214 await pubsub.on_slots_changed()
3215 except Exception as e:
3216 # Listeners must not break slots-cache refresh; log and continue so
3217 # a single buggy pubsub cannot starve the rest.
3218 logger.exception(
3219 "pubsub %r raised during slots-cache change: %s: %s",
3220 pubsub,
3221 type(e).__name__,
3222 e,
3223 )
3224
3225
3226class ClusterPubSub(PubSub):
3227 """
3228 Async cluster implementation for pub/sub.
3229
3230 IMPORTANT: before using ClusterPubSub, read about the known limitations
3231 with pubsub in Cluster mode and learn how to workaround them:
3232 https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations
3233 """
3234
3235 def __init__(
3236 self,
3237 redis_cluster: "RedisCluster",
3238 node: Optional["ClusterNode"] = None,
3239 host: Optional[str] = None,
3240 port: Optional[int] = None,
3241 push_handler_func: Optional[Callable] = None,
3242 event_dispatcher: Optional[EventDispatcher] = None,
3243 **kwargs: Any,
3244 ) -> None:
3245 """
3246 When a pubsub instance is created without specifying a node, a single
3247 node will be transparently chosen for the pubsub connection on the
3248 first command execution. The node will be determined by:
3249 1. Hashing the channel name in the request to find its keyslot
3250 2. Selecting a node that handles the keyslot: If read_from_replicas is
3251 set to true or load_balancing_strategy is set, a replica can be selected.
3252
3253 :param redis_cluster: RedisCluster instance
3254 :param node: ClusterNode to connect to
3255 :param host: Host of the node to connect to
3256 :param port: Port of the node to connect to
3257 :param push_handler_func: Optional push handler function
3258 :param event_dispatcher: Optional event dispatcher
3259 :param kwargs: Additional keyword arguments
3260 """
3261 self.node = None
3262 self.set_pubsub_node(redis_cluster, node, host, port)
3263
3264 # Borrow the node's own connection pool via an adapter rather than
3265 # creating a second, detached ConnectionPool for pubsub.
3266 if self.node is not None:
3267 connection_pool = _ClusterNodePoolAdapter(self.node)
3268 else:
3269 connection_pool = None
3270
3271 self.cluster = redis_cluster
3272 self.node_pubsub_mapping: Dict[str, PubSub] = {}
3273 # Reverse index: shard channel (normalized) -> owning node.name. Used to
3274 # route sunsubscribe calls and reconcile subscriptions after slot
3275 # migration / failover.
3276 self._shard_channel_to_node: Dict[Any, str] = {}
3277 # Dedicated lock for shard-subscription bookkeeping. Distinct from
3278 # PubSub.self._lock (which serializes wire I/O on the cluster-level
3279 # connection used by aclose / send_command / regular subscribe) so
3280 # that reconciliation cannot starve those unrelated coroutines
3281 # during long per-channel migrations.
3282 self._shard_state_lock: asyncio.Lock = asyncio.Lock()
3283 # Background tasks created by on_slots_changed; kept to prevent GC.
3284 self._reconcile_tasks: Set[asyncio.Task] = set()
3285 self._pubsubs_generator = self._pubsubs_generator()
3286 if event_dispatcher is None:
3287 self._event_dispatcher = EventDispatcher()
3288 else:
3289 self._event_dispatcher = event_dispatcher
3290 super().__init__(
3291 connection_pool=connection_pool,
3292 encoder=redis_cluster.encoder,
3293 push_handler_func=push_handler_func,
3294 event_dispatcher=self._event_dispatcher,
3295 **kwargs,
3296 )
3297 # Subscribe to slots-cache change notifications so shard subscriptions
3298 # can be reconciled automatically after topology refreshes.
3299 nm_dispatcher = redis_cluster.nodes_manager._event_dispatcher
3300 self._slots_cache_listener = ClusterPubSubSlotsCacheListener(self)
3301 nm_dispatcher.register_listeners(
3302 {AsyncAfterSlotsCacheRefreshEvent: [self._slots_cache_listener]}
3303 )
3304 # Deterministic GC-time cleanup so short-lived pubsubs do not leak
3305 # listeners in the dispatcher when no slots-refresh event ever fires.
3306 weakref.finalize(
3307 self,
3308 _unregister_slots_cache_listener,
3309 weakref.ref(nm_dispatcher),
3310 self._slots_cache_listener,
3311 AsyncAfterSlotsCacheRefreshEvent,
3312 )
3313
3314 def set_pubsub_node(
3315 self,
3316 cluster: "RedisCluster",
3317 node: Optional["ClusterNode"] = None,
3318 host: Optional[str] = None,
3319 port: Optional[int] = None,
3320 ) -> None:
3321 """
3322 The pubsub node will be set according to the passed node, host and port
3323 When none of the node, host, or port are specified - the node is set
3324 to None and will be determined by the keyslot of the channel in the
3325 first command to be executed.
3326 RedisClusterException will be thrown if the passed node does not exist
3327 in the cluster.
3328 If host is passed without port, or vice versa, a DataError will be
3329 thrown.
3330 """
3331 if node is not None:
3332 # node is passed by the user
3333 self._raise_on_invalid_node(cluster, node, node.host, node.port)
3334 pubsub_node = node
3335 elif host is not None and port is not None:
3336 # host and port passed by the user
3337 node = cluster.get_node(host=host, port=port)
3338 self._raise_on_invalid_node(cluster, node, host, port)
3339 pubsub_node = node
3340 elif host is not None or port is not None:
3341 # only one of host and port is specified
3342 raise DataError("Specify both host and port")
3343 else:
3344 # nothing specified by the user
3345 pubsub_node = None
3346 self.node = pubsub_node
3347
3348 def get_pubsub_node(self) -> Optional["ClusterNode"]:
3349 """
3350 Get the node that is being used as the pubsub connection.
3351
3352 :return: The ClusterNode being used for pubsub, or None if not yet determined
3353 """
3354 return self.node
3355
3356 async def _resubscribe_shard_channels(self) -> None:
3357 # A single node can own multiple slot ranges, so a batched
3358 # ``SSUBSCRIBE`` covering every tracked channel would be rejected by
3359 # Redis with a ``CROSSSLOT`` error. Group by hash slot and emit one
3360 # ``SSUBSCRIBE`` per slot.
3361 by_slot: defaultdict[int, dict] = defaultdict(dict)
3362 for k, v in self.shard_channels.items():
3363 by_slot[key_slot(self.encoder.encode(k))][k] = v
3364 for subscriptions in by_slot.values():
3365 await self._resubscribe(subscriptions, self.ssubscribe)
3366
3367 def _get_node_pubsub(self, node: "ClusterNode") -> PubSub:
3368 """Get or create a PubSub instance for the given node."""
3369 try:
3370 return self.node_pubsub_mapping[node.name]
3371 except KeyError:
3372 pubsub = PubSub(
3373 connection_pool=_ClusterNodePoolAdapter(node),
3374 encoder=self.cluster.encoder,
3375 push_handler_func=self.push_handler_func,
3376 event_dispatcher=self._event_dispatcher,
3377 )
3378 # Replay shard subscriptions on reconnect with slot-aware grouping
3379 # so that channels spanning multiple slots owned by this node do
3380 # not trigger a CROSSSLOT error.
3381 pubsub._resubscribe_shard_channels = MethodType(
3382 ClusterPubSub._resubscribe_shard_channels, pubsub
3383 )
3384 self.node_pubsub_mapping[node.name] = pubsub
3385 return pubsub
3386
3387 def _find_node_name_for_pubsub(self, pubsub: PubSub) -> Optional[str]:
3388 for name, candidate in self.node_pubsub_mapping.items():
3389 if candidate is pubsub:
3390 return name
3391 return None
3392
3393 async def _sharded_message_generator(
3394 self, timeout: float = 0.0
3395 ) -> Tuple[Optional[PubSub], Optional[Dict[str, Any]]]:
3396 """Generate messages from shard channels across all nodes."""
3397 for _ in range(len(self.node_pubsub_mapping)):
3398 pubsub = next(self._pubsubs_generator)
3399 # Don't pass ignore_subscribe_messages here - let get_sharded_message
3400 # handle the filtering after processing subscription state changes
3401 message = await pubsub.get_message(
3402 ignore_subscribe_messages=False, timeout=timeout
3403 )
3404 if message is not None:
3405 return pubsub, message
3406 return None, None
3407
3408 def _pubsubs_generator(self) -> Generator[PubSub, None, None]:
3409 """Generator that yields PubSub instances in round-robin fashion."""
3410 while True:
3411 current_nodes = list(self.node_pubsub_mapping.values())
3412 if not current_nodes:
3413 return # Avoid infinite loop when no subscriptions exist
3414 yield from current_nodes
3415
3416 async def get_sharded_message(
3417 self,
3418 ignore_subscribe_messages: bool = False,
3419 timeout: float = 0.0,
3420 target_node: Optional["ClusterNode"] = None,
3421 ) -> Optional[Dict[str, Any]]:
3422 """
3423 Get a message from shard channels.
3424
3425 :param ignore_subscribe_messages: Whether to ignore subscribe messages
3426 :param timeout: Timeout for message retrieval
3427 :param target_node: Specific node to get message from
3428 :return: Message dictionary or None
3429 """
3430 pubsub: Optional[PubSub]
3431 if target_node:
3432 pubsub = self.node_pubsub_mapping.get(target_node.name)
3433 if pubsub:
3434 # Don't pass ignore_subscribe_messages here - let get_sharded_message
3435 # handle the filtering after processing subscription state changes
3436 message = await pubsub.get_message(
3437 ignore_subscribe_messages=False, timeout=timeout
3438 )
3439 else:
3440 message = None
3441 else:
3442 pubsub, message = await self._sharded_message_generator(timeout=timeout)
3443
3444 if message is None:
3445 return None
3446 # Only sunsubscribe mutates cluster-level shard state; bypassing the
3447 # lock on the data-message hot path keeps smessage delivery from
3448 # competing with the reconciliation task for _shard_state_lock.
3449 if str_if_bytes(message["type"]) == "sunsubscribe":
3450 # Serialize state mutation against reinitialize_shard_subscriptions
3451 # (background task). The blocking get_message above intentionally
3452 # runs outside the lock so reconciliation is not stalled by long
3453 # polls.
3454 async with self._shard_state_lock:
3455 if message["channel"] in self.pending_unsubscribe_shard_channels:
3456 # User-initiated sunsubscribe: drop from cluster-level tracking.
3457 self.pending_unsubscribe_shard_channels.remove(message["channel"])
3458 self.shard_channels.pop(message["channel"], None)
3459 self._shard_channel_to_node.pop(message["channel"], None)
3460 # Drop the per-node pubsub that delivered the confirmation once
3461 # it no longer holds any shard subscriptions, regardless of
3462 # whether the sunsubscribe was user-initiated or driven by
3463 # slot-migration reconciliation (_migrate_shard_channel, which
3464 # intentionally does not add the channel to
3465 # pending_unsubscribe_shard_channels). This releases the
3466 # dedicated connection that would otherwise linger.
3467 # Identifying the receiving pubsub directly (rather than via
3468 # the cluster's current slot map) is required after slot
3469 # migration, where the channel's owner is no longer the node
3470 # that received our original SSUBSCRIBE.
3471 if pubsub is not None and not pubsub.subscribed:
3472 name = self._find_node_name_for_pubsub(pubsub)
3473 if name is not None:
3474 try:
3475 await pubsub.aclose()
3476 except Exception:
3477 pass
3478 self.node_pubsub_mapping.pop(name, None)
3479
3480 # Only suppress subscribe/unsubscribe messages, not data messages (smessage)
3481 if str_if_bytes(message["type"]) in ("ssubscribe", "sunsubscribe"):
3482 if self.ignore_subscribe_messages or ignore_subscribe_messages:
3483 return None
3484 return message
3485
3486 async def ssubscribe(self, *args: Any, **kwargs: Any) -> None:
3487 """
3488 Subscribe to shard channels.
3489
3490 :param args: Channel names
3491 :param kwargs: Channel names with handlers
3492 """
3493 if args:
3494 args = list_or_args(args[0], args[1:])
3495 s_channels = dict.fromkeys(args)
3496 s_channels.update(kwargs)
3497
3498 # Serialize against reinitialize_shard_subscriptions (background
3499 # task) so the reverse index, shard_channels, and node_pubsub_mapping
3500 # are not mutated concurrently. _migrate_shard_channel below does not
3501 # re-acquire this lock (asyncio.Lock is non-reentrant).
3502 async with self._shard_state_lock:
3503 for s_channel, handler in s_channels.items():
3504 node = self.cluster.get_node_from_key(s_channel)
3505 if not node:
3506 continue
3507 # Lazy re-route: if this channel is already tracked against a
3508 # different node (e.g. after a slot migration), migrate it now
3509 # so the caller's intent is applied on the current owner.
3510 normalized_key = next(iter(self._normalize_keys({s_channel: None})))
3511 old_name = self._shard_channel_to_node.get(normalized_key)
3512 if old_name and old_name != node.name:
3513 # Match PubSub.ssubscribe() dict.update() semantics: the
3514 # caller's newly supplied handler (including None) always
3515 # overrides any previously registered handler.
3516 await self._migrate_shard_channel(
3517 normalized_key,
3518 handler,
3519 old_name,
3520 node,
3521 )
3522 continue
3523 pubsub = self._get_node_pubsub(node)
3524 if handler:
3525 await pubsub.ssubscribe(**{s_channel: handler})
3526 else:
3527 await pubsub.ssubscribe(s_channel)
3528 self.shard_channels.update(pubsub.shard_channels)
3529 self._shard_channel_to_node[normalized_key] = node.name
3530 self.pending_unsubscribe_shard_channels.difference_update(
3531 self._normalize_keys({s_channel: None})
3532 )
3533
3534 async def sunsubscribe(self, *args: Any) -> None:
3535 """
3536 Unsubscribe from shard channels.
3537
3538 :param args: Channel names to unsubscribe from. If empty, unsubscribe from all.
3539 """
3540 if args:
3541 args = list_or_args(args[0], args[1:])
3542 else:
3543 args = list(self.shard_channels.keys())
3544
3545 # Serialize against reinitialize_shard_subscriptions: the reverse
3546 # index and node_pubsub_mapping must not change between the lookup
3547 # and the per-node sunsubscribe call below.
3548 async with self._shard_state_lock:
3549 for s_channel in args:
3550 normalized_key = next(iter(self._normalize_keys({s_channel: None})))
3551 # Route via the reverse index so we unsubscribe on the node
3552 # that actually holds the subscription. After a slot migration
3553 # the cluster's current owner may no longer be that node.
3554 name = self._shard_channel_to_node.get(normalized_key)
3555 if name and name in self.node_pubsub_mapping:
3556 pubsub = self.node_pubsub_mapping[name]
3557 else:
3558 node = self.cluster.get_node_from_key(s_channel)
3559 if not node or node.name not in self.node_pubsub_mapping:
3560 continue
3561 pubsub = self.node_pubsub_mapping[node.name]
3562 await pubsub.sunsubscribe(s_channel)
3563 self.pending_unsubscribe_shard_channels.update(
3564 pubsub.pending_unsubscribe_shard_channels
3565 )
3566
3567 async def reinitialize_shard_subscriptions(self) -> None:
3568 """
3569 Reconcile per-node shard subscriptions against the cluster's current
3570 slot ownership map. For each tracked shard channel whose owning node
3571 has changed (e.g. after CLUSTER SETSLOT / failover), sunsubscribe on
3572 the old node's pubsub and ssubscribe on the new owner's pubsub,
3573 preserving any registered handler.
3574 """
3575 uncovered: list = []
3576 made_progress = False
3577 first_migrate_error: Optional[BaseException] = None
3578 async with self._shard_state_lock:
3579 for channel, handler in list(self.shard_channels.items()):
3580 try:
3581 new_node = self.cluster.get_node_from_key(channel)
3582 except SlotNotCoveredError:
3583 # Slot is transiently uncovered (mid-migration / partial
3584 # topology refresh). Defer this channel so coverable
3585 # siblings still reconcile this pass; we surface the
3586 # error below so the caller (and logs) know not every
3587 # channel was reconciled. Retry happens on the next
3588 # slots-cache change notification.
3589 uncovered.append(channel)
3590 continue
3591 old_name = self._shard_channel_to_node.get(channel)
3592 if old_name == new_node.name:
3593 continue
3594 try:
3595 await self._migrate_shard_channel(
3596 channel, handler, old_name, new_node
3597 )
3598 made_progress = True
3599 except (ConnectionError, TimeoutError, OSError) as e:
3600 # Transient connectivity error while subscribing on the
3601 # new owner (or unsubscribing on the old owner if its
3602 # handler chose to re-raise). Do not abort reconciliation
3603 # for sibling channels: _shard_channel_to_node was not
3604 # advanced for this channel, so the next slots-cache
3605 # change notification will retry it.
3606 logger.warning(
3607 "shard channel %r migration deferred: %s: %s",
3608 channel,
3609 type(e).__name__,
3610 e,
3611 )
3612 if first_migrate_error is None:
3613 first_migrate_error = e
3614 continue
3615 # Garbage-collect per-node pubsubs that no longer hold any
3616 # subscription so their connections are released.
3617 for name, pubsub in list(self.node_pubsub_mapping.items()):
3618 if not pubsub.subscribed:
3619 try:
3620 await pubsub.aclose()
3621 except Exception:
3622 pass
3623 self.node_pubsub_mapping.pop(name, None)
3624 if uncovered:
3625 # Surface the uncovered channels so the caller (and observer
3626 # notification path) knows reconciliation was incomplete. All
3627 # coverable siblings have already been migrated above.
3628 raise SlotNotCoveredError(
3629 f"{len(uncovered)} shard channel(s) left unreconciled; "
3630 f"slot(s) not covered by the cluster: {uncovered!r}"
3631 )
3632 if first_migrate_error is not None and not made_progress:
3633 # Every migration attempted in this pass failed transiently and
3634 # nothing else made progress. Re-raise the first caught error
3635 # (typically the root cause; later failures are often downstream
3636 # symptoms of the same unreachable node) so the task's done-
3637 # callback surfaces a single representative failure through the
3638 # same logger channel used for SlotNotCoveredError. Per-channel
3639 # WARNINGs above preserve the full forensic detail.
3640 raise first_migrate_error
3641
3642 async def _migrate_shard_channel(
3643 self,
3644 channel: Any,
3645 handler: Optional[Callable],
3646 old_name: Optional[str],
3647 new_node: "ClusterNode",
3648 ) -> None:
3649 # Detach from the old per-node pubsub, best-effort: the old node may
3650 # already be unreachable during migration / failover.
3651 if old_name and old_name in self.node_pubsub_mapping:
3652 old_pubsub = self.node_pubsub_mapping[old_name]
3653 try:
3654 await old_pubsub.sunsubscribe(channel)
3655 except (ConnectionError, TimeoutError, OSError):
3656 # redis-py's Connection has already called ``disconnect()``
3657 # before raising (see Connection.read_response /
3658 # send_packed_command with ``disconnect_on_error=True``),
3659 # so ``old_pubsub``'s dedicated socket is gone. Two cases:
3660 #
3661 # 1. The old node is no longer in the cluster topology
3662 # (e.g. removed by failover / topology refresh): no
3663 # reconnect target exists, so ``old_pubsub.subscribed``
3664 # would stay True forever and the end-of-pass GC block
3665 # would skip it. Drop it eagerly so the round-robin
3666 # generator does not keep yielding a dead pubsub that
3667 # produces periodic errors from ``get_sharded_message``.
3668 # 2. The old node is still known (transiently slow /
3669 # unreachable): ``PubSub._execute`` auto-reconnects and
3670 # ``on_connect`` re-subscribes to remaining channels,
3671 # so other subscriptions on the same pubsub recover
3672 # naturally. Leave it alone.
3673 if self.cluster.get_node(node_name=old_name) is None:
3674 try:
3675 await old_pubsub.aclose()
3676 except Exception:
3677 pass
3678 self.node_pubsub_mapping.pop(old_name, None)
3679 # Attach to the new per-node pubsub, preserving the handler. Decode to
3680 # a text key only when we must pass it as a kwarg (handler present).
3681 new_pubsub = self._get_node_pubsub(new_node)
3682 if handler:
3683 decoded = (
3684 self.encoder.decode(channel, force=True)
3685 if isinstance(channel, (bytes, bytearray))
3686 else channel
3687 )
3688 await new_pubsub.ssubscribe(**{decoded: handler})
3689 else:
3690 await new_pubsub.ssubscribe(channel)
3691 self.shard_channels.update(new_pubsub.shard_channels)
3692 normalized_key = next(iter(self._normalize_keys({channel: None})))
3693 self._shard_channel_to_node[normalized_key] = new_node.name
3694 self.pending_unsubscribe_shard_channels.difference_update(
3695 self._normalize_keys({channel: None})
3696 )
3697
3698 async def on_slots_changed(self) -> None:
3699 # Observer hook invoked by NodesManager after a slots-cache refresh.
3700 # Schedule reconciliation as a separate task so the caller's code
3701 # path (typically MovedError handling in _execute_command) is not
3702 # blocked on the network I/O performed by reinitialize_shard_
3703 # subscriptions. No-op when there are no shard subscriptions to
3704 # reconcile.
3705 if not self.shard_channels:
3706 return
3707 task = asyncio.create_task(self.reinitialize_shard_subscriptions())
3708 self._reconcile_tasks.add(task)
3709 task.add_done_callback(self._reconcile_tasks.discard)
3710 # Consume the task's exception (if any) so Python does not emit a
3711 # "Task exception was never retrieved" warning. reinitialize_shard_
3712 # subscriptions surfaces SlotNotCoveredError when a slot is still
3713 # transiently uncovered; route it through the same logger channel
3714 # as sync ClusterPubSubSlotsCacheListener for consistent observability.
3715 task.add_done_callback(self._log_reconcile_task_exception)
3716
3717 @staticmethod
3718 def _log_reconcile_task_exception(task: "asyncio.Task") -> None:
3719 if task.cancelled():
3720 return
3721 exc = task.exception()
3722 if exc is not None:
3723 logger.error(
3724 "shard subscription reconciliation failed: %r", exc, exc_info=exc
3725 )
3726
3727 def get_redis_connection(self) -> Optional["AbstractConnection"]:
3728 """
3729 Get the Redis connection of the pubsub connected node.
3730
3731 Returns the pubsub's dedicated connection (acquired from its own
3732 connection pool), not from the ClusterNode's connection pool.
3733 This avoids the connection pool resource leak that would occur
3734 if we called node.acquire_connection() without releasing.
3735 """
3736 # Return the pubsub's own dedicated connection, which is acquired
3737 # from self.connection_pool when executing pubsub commands.
3738 # This is safe because it's the connection dedicated to this pubsub
3739 # instance, not a shared pool connection from the ClusterNode.
3740 return self.connection
3741
3742 async def aclose(self) -> None:
3743 """
3744 Disconnect the pubsub connection.
3745 """
3746 # Cancel and gather in-flight reconciliation tasks BEFORE acquiring
3747 # _shard_state_lock. The tasks themselves take that lock inside
3748 # reinitialize_shard_subscriptions; since asyncio.Lock is non-
3749 # reentrant, gathering while holding it would deadlock. Awaiting
3750 # each task with suppressed CancelledError also avoids unhandled-
3751 # exception warnings if the task was created but not yet scheduled.
3752 if self._reconcile_tasks:
3753 tasks = list(self._reconcile_tasks)
3754 for task in tasks:
3755 task.cancel()
3756 await asyncio.gather(*tasks, return_exceptions=True)
3757 # Hold _shard_state_lock across the rest of the teardown so it
3758 # observes the same mutual-exclusion discipline as ssubscribe /
3759 # sunsubscribe / get_sharded_message / reinitialize_shard_
3760 # subscriptions, which all mutate shard_channels,
3761 # _shard_channel_to_node, and node_pubsub_mapping under this lock.
3762 # Without it, super().aclose() rebinds shard_channels and
3763 # pending_unsubscribe_shard_channels in parallel with a concurrent
3764 # user-coroutine mutation that resumes during one of the awaits
3765 # below, silently dropping subscription intent.
3766 async with self._shard_state_lock:
3767 self._reconcile_tasks.clear()
3768 # Close all shard pubsub instances first
3769 for pubsub in self.node_pubsub_mapping.values():
3770 await pubsub.aclose()
3771 # Drop the now-dead per-node pubsubs from the mapping so the
3772 # round-robin in _pubsubs_generator / _sharded_message_generator
3773 # cannot yield them between teardown and re-subscription.
3774 self.node_pubsub_mapping.clear()
3775 # _pubsubs_generator captures node_pubsub_mapping.values() into
3776 # a local list inside ``yield from``; clearing the mapping does
3777 # not reach references already held by that captured snapshot,
3778 # so a generator suspended mid-yield-from would still surface
3779 # the now-aclose()'d per-node pubsubs after re-subscription.
3780 # Recreate it to drop the captured list. type(self) bypasses
3781 # the instance-level self-shadow established at __init__
3782 # (self._pubsubs_generator = self._pubsubs_generator()).
3783 self._pubsubs_generator = type(self)._pubsubs_generator( # type: ignore[method-assign]
3784 self
3785 )
3786 # Let parent handle self.connection disconnect under the lock
3787 # (includes disconnect, release to pool, and clearing
3788 # self.connection)
3789 await super().aclose()
3790 # Clear the reverse index so a reused instance doesn't route
3791 # against stale mappings. super().aclose() has already cleared
3792 # shard_channels.
3793 self._shard_channel_to_node.clear()
3794
3795 def _raise_on_invalid_node(
3796 self,
3797 redis_cluster: "RedisCluster",
3798 node: Optional["ClusterNode"],
3799 host: Optional[str],
3800 port: Optional[int],
3801 ) -> None:
3802 """
3803 Raise a RedisClusterException if the node is None or doesn't exist in
3804 the cluster.
3805 """
3806 if node is None or redis_cluster.get_node(node_name=node.name) is None:
3807 raise RedisClusterException(
3808 f"Node {host}:{port} doesn't exist in the cluster"
3809 )
3810
3811 async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
3812 """
3813 Execute a command on the appropriate cluster node.
3814
3815 Taken code from redis-py and tweaked to make it work within a cluster.
3816 """
3817 # NOTE: don't parse the response in this function -- it could pull a
3818 # legitimate message off the stack if the connection is already
3819 # subscribed to one or more channels
3820
3821 # For shard commands, route to appropriate node
3822 command = args[0].upper() if args else ""
3823 if command in ("SSUBSCRIBE", "SUNSUBSCRIBE", "SPUBLISH"):
3824 if len(args) > 1:
3825 channel = args[1]
3826 node = self.cluster.get_node_from_key(channel)
3827 if node:
3828 pubsub = self._get_node_pubsub(node)
3829 return await pubsub.execute_command(*args, **kwargs)
3830
3831 # For other commands, use the set node or lazily discover one
3832 if self.connection is None:
3833 if self.connection_pool is None:
3834 if len(args) > 1:
3835 # Hash the first channel and get one of the nodes holding
3836 # this slot
3837 channel = args[1]
3838 slot = self.cluster.keyslot(channel)
3839 node = self.cluster.nodes_manager.get_node_from_slot(
3840 slot,
3841 self.cluster.read_from_replicas,
3842 self.cluster.load_balancing_strategy,
3843 )
3844 else:
3845 # Get a random node
3846 node = self.cluster.get_random_node()
3847 self.node = node
3848 self.connection_pool = _ClusterNodePoolAdapter(node)
3849
3850 # Now we have a connection_pool, use parent's execute_command
3851 return await super().execute_command(*args, **kwargs)