Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/client.py: 22%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import copy
3import inspect
4import re
5import time
6import warnings
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 AsyncIterator,
11 Awaitable,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Literal,
17 Mapping,
18 MutableMapping,
19 Optional,
20 Protocol,
21 Set,
22 Tuple,
23 Type,
24 TypedDict,
25 TypeVar,
26 Union,
27 cast,
28)
30from redis._defaults import (
31 DEFAULT_RETRY_BASE,
32 DEFAULT_RETRY_CAP,
33 DEFAULT_RETRY_COUNT,
34 DEFAULT_SOCKET_CONNECT_TIMEOUT,
35 DEFAULT_SOCKET_READ_SIZE,
36 DEFAULT_SOCKET_TIMEOUT,
37)
38from redis._parsers.helpers import bool_ok, get_response_callbacks
39from redis.asyncio.connection import (
40 Connection,
41 ConnectionPool,
42 SSLConnection,
43 UnixDomainSocketConnection,
44)
45from redis.asyncio.lock import Lock
46from redis.asyncio.observability.recorder import (
47 record_error_count,
48 record_operation_duration,
49 record_pubsub_message,
50)
51from redis.asyncio.retry import Retry
52from redis.backoff import ExponentialWithJitterBackoff
53from redis.client import (
54 EMPTY_RESPONSE,
55 NEVER_DECODE,
56 AbstractRedis,
57 CaseInsensitiveDict,
58)
59from redis.commands import (
60 AsyncCoreCommands,
61 AsyncRedisModuleCommands,
62 AsyncSentinelCommands,
63 list_or_args,
64)
65from redis.commands.helpers import parse_pubsub_subscriptions, pubsub_subscription_args
66from redis.credentials import CredentialProvider
67from redis.driver_info import DriverInfo, resolve_driver_info
68from redis.event import (
69 AfterPooledConnectionsInstantiationEvent,
70 AfterPubSubConnectionInstantiationEvent,
71 AfterSingleConnectionInstantiationEvent,
72 ClientType,
73 EventDispatcher,
74)
75from redis.exceptions import (
76 ConnectionError,
77 ExecAbortError,
78 PubSubError,
79 RedisError,
80 ResponseError,
81 WatchError,
82)
83from redis.observability.attributes import PubSubDirection
84from redis.typing import ChannelT, EncodableT, KeyT, PubSubHandler, Subscription
85from redis.utils import (
86 SENTINEL,
87 SSL_AVAILABLE,
88 _set_info_logger,
89 deprecated_args,
90 deprecated_function,
91 safe_str,
92 str_if_bytes,
93 truncate_text,
94)
96if TYPE_CHECKING and SSL_AVAILABLE:
97 from ssl import TLSVersion, VerifyFlags, VerifyMode
98else:
99 TLSVersion = None
100 VerifyMode = None
101 VerifyFlags = None
103_KeyT = TypeVar("_KeyT", bound=KeyT)
104_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
105_RedisT = TypeVar("_RedisT", bound="Redis")
106_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
107if TYPE_CHECKING:
108 from redis.asyncio.keyspace_notifications import AsyncKeyspaceNotifications
109 from redis.commands.core import Script
112class ResponseCallbackProtocol(Protocol):
113 def __call__(self, response: Any, **kwargs): ...
116class AsyncResponseCallbackProtocol(Protocol):
117 async def __call__(self, response: Any, **kwargs): ...
120ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
123class Redis(
124 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
125):
126 """
127 Implementation of the Redis protocol.
129 This abstract class provides a Python interface to all Redis commands
130 and an implementation of the Redis protocol.
132 Pipelines derive from this, implementing how
133 the commands are sent and received to the Redis server. Based on
134 configuration, an instance will either use a ConnectionPool, or
135 Connection object to talk to redis.
136 """
138 # Type discrimination marker for @overload self-type pattern
139 _is_async_client: Literal[True] = True
141 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
143 @classmethod
144 def from_url(
145 cls: Type["Redis"],
146 url: str,
147 single_connection_client: bool = False,
148 auto_close_connection_pool: Optional[bool] = None,
149 **kwargs,
150 ) -> "Redis":
151 """
152 Return a Redis client object configured from the given URL
154 For example::
156 redis://[[username]:[password]]@localhost:6379/0
157 rediss://[[username]:[password]]@localhost:6379/0
158 unix://[username@]/path/to/socket.sock?db=0[&password=password]
160 Three URL schemes are supported:
162 - `redis://` creates a TCP socket connection. See more at:
163 <https://www.iana.org/assignments/uri-schemes/prov/redis>
164 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
165 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
166 - ``unix://``: creates a Unix Domain Socket connection.
168 The username, password, hostname, path and all querystring values
169 are passed through urllib.parse.unquote in order to replace any
170 percent-encoded values with their corresponding characters.
172 There are several ways to specify a database number. The first value
173 found will be used:
175 1. A ``db`` querystring option, e.g. redis://localhost?db=0
177 2. If using the redis:// or rediss:// schemes, the path argument
178 of the url, e.g. redis://localhost/0
180 3. A ``db`` keyword argument to this function.
182 If none of these options are specified, the default db=0 is used.
184 All querystring options are cast to their appropriate Python types.
185 Boolean arguments can be specified with string values "True"/"False"
186 or "Yes"/"No". Values that cannot be properly cast cause a
187 ``ValueError`` to be raised. Once parsed, the querystring arguments
188 and keyword arguments are passed to the ``ConnectionPool``'s
189 class initializer. In the case of conflicting arguments, querystring
190 arguments always win.
192 """
193 connection_pool = ConnectionPool.from_url(url, **kwargs)
194 client = cls(
195 connection_pool=connection_pool,
196 single_connection_client=single_connection_client,
197 )
198 if auto_close_connection_pool is not None:
199 warnings.warn(
200 DeprecationWarning(
201 '"auto_close_connection_pool" is deprecated '
202 "since version 5.0.1. "
203 "Please create a ConnectionPool explicitly and "
204 "provide to the Redis() constructor instead."
205 )
206 )
207 else:
208 auto_close_connection_pool = True
209 client.auto_close_connection_pool = auto_close_connection_pool
210 return client
212 @classmethod
213 def from_pool(
214 cls: Type["Redis"],
215 connection_pool: ConnectionPool,
216 ) -> "Redis":
217 """
218 Return a Redis client from the given connection pool.
219 The Redis client will take ownership of the connection pool and
220 close it when the Redis client is closed.
221 """
222 client = cls(
223 connection_pool=connection_pool,
224 )
225 client.auto_close_connection_pool = True
226 return client
228 @deprecated_args(
229 args_to_warn=["retry_on_timeout"],
230 reason="TimeoutError is included by default.",
231 version="6.0.0",
232 )
233 @deprecated_args(
234 args_to_warn=["lib_name", "lib_version"],
235 reason="Use 'driver_info' parameter instead. "
236 "lib_name and lib_version will be removed in a future version.",
237 )
238 def __init__(
239 self,
240 *,
241 host: str = "localhost",
242 port: int = 6379,
243 db: str | int = 0,
244 password: str | None = None,
245 socket_timeout: float | None = DEFAULT_SOCKET_TIMEOUT,
246 socket_connect_timeout: float | None = DEFAULT_SOCKET_CONNECT_TIMEOUT,
247 socket_read_size: int = DEFAULT_SOCKET_READ_SIZE,
248 socket_keepalive: bool | None = True,
249 socket_keepalive_options: Mapping[int, int | bytes] | object | None = SENTINEL,
250 connection_pool: ConnectionPool | None = None,
251 unix_socket_path: str | None = None,
252 encoding: str = "utf-8",
253 encoding_errors: str = "strict",
254 decode_responses: bool = False,
255 retry_on_timeout: bool = False,
256 retry: Retry = Retry(
257 backoff=ExponentialWithJitterBackoff(
258 base=DEFAULT_RETRY_BASE, cap=DEFAULT_RETRY_CAP
259 ),
260 retries=DEFAULT_RETRY_COUNT,
261 ),
262 retry_on_error: list | None = None,
263 ssl: bool = False,
264 ssl_keyfile: str | None = None,
265 ssl_certfile: str | None = None,
266 ssl_cert_reqs: "str | VerifyMode" = "required",
267 ssl_include_verify_flags: List["VerifyFlags"] | None = None,
268 ssl_exclude_verify_flags: List["VerifyFlags"] | None = None,
269 ssl_ca_certs: str | None = None,
270 ssl_ca_data: str | None = None,
271 ssl_ca_path: str | None = None,
272 ssl_check_hostname: bool = True,
273 ssl_min_version: "TLSVersion | None" = None,
274 ssl_ciphers: str | None = None,
275 ssl_password: str | None = None,
276 max_connections: int | None = None,
277 single_connection_client: bool = False,
278 health_check_interval: int = 0,
279 client_name: str | None = None,
280 lib_name: str | object | None = SENTINEL,
281 lib_version: str | object | None = SENTINEL,
282 driver_info: DriverInfo | object | None = SENTINEL,
283 username: str | None = None,
284 auto_close_connection_pool: bool | None = None,
285 redis_connect_func=None,
286 credential_provider: CredentialProvider | None = None,
287 protocol: int | None = None,
288 legacy_responses: bool = True,
289 event_dispatcher: EventDispatcher | None = None,
290 ):
291 """
292 Initialize a new Redis client.
294 To specify a retry policy for specific errors, you have two options:
296 1. Set the `retry_on_error` to a list of the error/s to retry on, and
297 you can also set `retry` to a valid `Retry` object(in case the default
298 one is not appropriate) - with this approach the retries will be triggered
299 on the default errors specified in the Retry object enriched with the
300 errors specified in `retry_on_error`.
302 2. Define a `Retry` object with configured 'supported_errors' and set
303 it to the `retry` parameter - with this approach you completely redefine
304 the errors on which retries will happen.
306 `retry_on_timeout` is deprecated - please include the TimeoutError
307 either in the Retry object or in the `retry_on_error` list.
309 When 'connection_pool' is provided - the retry configuration of the
310 provided pool will be used.
312 Args:
314 socket_keepalive:
315 if `True`, TCP keepalive is enabled for TCP socket connections.
316 Argument is ignored when connection_pool is provided.
317 socket_keepalive_options:
318 mapping of TCP keepalive socket option constants to values, for
319 example `{socket.TCP_KEEPIDLE: 30}`. If left unspecified, redis-py
320 uses TCP keepalive defaults when `socket_keepalive` is enabled:
321 idle 30 seconds, interval 5 seconds, and 3 probes. Platform-specific
322 options that are not available are skipped. Pass `None` or `{}` to
323 avoid setting additional TCP keepalive options. Argument is ignored
324 when connection_pool is provided.
325 """
326 kwargs: Dict[str, Any]
327 if event_dispatcher is None:
328 self._event_dispatcher = EventDispatcher()
329 else:
330 self._event_dispatcher = event_dispatcher
331 # auto_close_connection_pool only has an effect if connection_pool is
332 # None. It is assumed that if connection_pool is not None, the user
333 # wants to manage the connection pool themselves.
334 if auto_close_connection_pool is not None:
335 warnings.warn(
336 DeprecationWarning(
337 '"auto_close_connection_pool" is deprecated '
338 "since version 5.0.1. "
339 "Please create a ConnectionPool explicitly and "
340 "provide to the Redis() constructor instead."
341 )
342 )
343 else:
344 auto_close_connection_pool = True
346 if not connection_pool:
347 # Create internal connection pool, expected to be closed by Redis instance
348 if not retry_on_error:
349 retry_on_error = []
351 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version.
352 computed_driver_info = resolve_driver_info(
353 driver_info, lib_name, lib_version
354 )
356 kwargs = {
357 "db": db,
358 "username": username,
359 "password": password,
360 "credential_provider": credential_provider,
361 "socket_timeout": socket_timeout,
362 "socket_read_size": socket_read_size,
363 "encoding": encoding,
364 "encoding_errors": encoding_errors,
365 "decode_responses": decode_responses,
366 "retry_on_error": retry_on_error,
367 "retry": copy.deepcopy(retry),
368 "max_connections": max_connections,
369 "health_check_interval": health_check_interval,
370 "client_name": client_name,
371 "driver_info": computed_driver_info,
372 "redis_connect_func": redis_connect_func,
373 "protocol": protocol,
374 "legacy_responses": legacy_responses,
375 }
376 # based on input, setup appropriate connection args
377 if unix_socket_path is not None:
378 kwargs.update(
379 {
380 "path": unix_socket_path,
381 "connection_class": UnixDomainSocketConnection,
382 }
383 )
384 else:
385 # TCP specific options
386 kwargs.update(
387 {
388 "host": host,
389 "port": port,
390 "socket_connect_timeout": socket_connect_timeout,
391 "socket_keepalive": socket_keepalive,
392 "socket_keepalive_options": socket_keepalive_options,
393 }
394 )
396 if ssl:
397 kwargs.update(
398 {
399 "connection_class": SSLConnection,
400 "ssl_keyfile": ssl_keyfile,
401 "ssl_certfile": ssl_certfile,
402 "ssl_cert_reqs": ssl_cert_reqs,
403 "ssl_include_verify_flags": ssl_include_verify_flags,
404 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
405 "ssl_ca_certs": ssl_ca_certs,
406 "ssl_ca_data": ssl_ca_data,
407 "ssl_ca_path": ssl_ca_path,
408 "ssl_check_hostname": ssl_check_hostname,
409 "ssl_min_version": ssl_min_version,
410 "ssl_ciphers": ssl_ciphers,
411 "ssl_password": ssl_password,
412 }
413 )
414 # This arg only used if no pool is passed in
415 self.auto_close_connection_pool = auto_close_connection_pool
416 connection_pool = ConnectionPool(**kwargs)
417 self._event_dispatcher.dispatch(
418 AfterPooledConnectionsInstantiationEvent(
419 [connection_pool], ClientType.ASYNC, credential_provider
420 )
421 )
422 else:
423 # If a pool is passed in, do not close it
424 self.auto_close_connection_pool = False
425 self._event_dispatcher.dispatch(
426 AfterPooledConnectionsInstantiationEvent(
427 [connection_pool], ClientType.ASYNC, credential_provider
428 )
429 )
431 self.connection_pool = connection_pool
432 self.single_connection_client = single_connection_client
433 self.connection: Optional[Connection] = None
435 connection_kwargs = self.connection_pool.connection_kwargs
436 self.response_callbacks = CaseInsensitiveDict(
437 get_response_callbacks(
438 user_protocol=connection_kwargs.get("protocol"),
439 legacy_responses=connection_kwargs.get("legacy_responses", True),
440 )
441 )
443 # If using a single connection client, we need to lock creation-of and use-of
444 # the client in order to avoid race conditions such as using asyncio.gather
445 # on a set of redis commands
446 self._single_conn_lock = asyncio.Lock()
448 # When used as an async context manager, we need to increment and decrement
449 # a usage counter so that we can close the connection pool when no one is
450 # using the client.
451 self._usage_counter = 0
452 self._usage_lock = asyncio.Lock()
454 def __repr__(self):
455 return (
456 f"<{self.__class__.__module__}.{self.__class__.__name__}"
457 f"({self.connection_pool!r})>"
458 )
460 def __await__(self):
461 return self.initialize().__await__()
463 async def initialize(self: _RedisT) -> _RedisT:
464 if self.single_connection_client:
465 async with self._single_conn_lock:
466 if self.connection is None:
467 self.connection = await self.connection_pool.get_connection()
469 self._event_dispatcher.dispatch(
470 AfterSingleConnectionInstantiationEvent(
471 self.connection, ClientType.ASYNC, self._single_conn_lock
472 )
473 )
474 return self
476 def set_response_callback(self, command: str, callback: ResponseCallbackT):
477 """Set a custom Response Callback"""
478 self.response_callbacks[command] = callback
480 def get_encoder(self):
481 """Get the connection pool's encoder"""
482 return self.connection_pool.get_encoder()
484 def get_connection_kwargs(self):
485 """Get the connection's key-word arguments"""
486 return self.connection_pool.connection_kwargs
488 def get_retry(self) -> Optional[Retry]:
489 return self.get_connection_kwargs().get("retry")
491 def set_retry(self, retry: Retry) -> None:
492 self.get_connection_kwargs().update({"retry": retry})
493 self.connection_pool.set_retry(retry)
495 def load_external_module(self, funcname, func):
496 """
497 This function can be used to add externally defined redis modules,
498 and their namespaces to the redis client.
500 funcname - A string containing the name of the function to create
501 func - The function, being added to this class.
503 ex: Assume that one has a custom redis module named foomod that
504 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
505 To load function functions into this namespace:
507 from redis import Redis
508 from foomodule import F
509 r = Redis()
510 r.load_external_module("foo", F)
511 r.foo().dothing('your', 'arguments')
513 For a concrete example see the reimport of the redisjson module in
514 tests/test_connection.py::test_loading_external_modules
515 """
516 setattr(self, funcname, func)
518 def pipeline(
519 self, transaction: bool = True, shard_hint: Optional[str] = None
520 ) -> "Pipeline":
521 """
522 Return a new pipeline object that can queue multiple commands for
523 later execution. ``transaction`` indicates whether all commands
524 should be executed atomically. Apart from making a group of operations
525 atomic, pipelines are useful for reducing the back-and-forth overhead
526 between the client and server.
527 """
528 return Pipeline(
529 self.connection_pool, self.response_callbacks, transaction, shard_hint
530 )
532 async def transaction(
533 self,
534 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
535 *watches: KeyT,
536 shard_hint: Optional[str] = None,
537 value_from_callable: bool = False,
538 watch_delay: Optional[float] = None,
539 ):
540 """
541 Convenience method for executing the callable `func` as a transaction
542 while watching all keys specified in `watches`. The 'func' callable
543 should expect a single argument which is a Pipeline object.
544 """
545 pipe: Pipeline
546 async with self.pipeline(True, shard_hint) as pipe:
547 while True:
548 try:
549 if watches:
550 await pipe.watch(*watches)
551 func_value = func(pipe)
552 if inspect.isawaitable(func_value):
553 func_value = await func_value
554 exec_value = await pipe.execute()
555 return func_value if value_from_callable else exec_value
556 except WatchError:
557 if watch_delay is not None and watch_delay > 0:
558 await asyncio.sleep(watch_delay)
559 continue
561 def lock(
562 self,
563 name: KeyT,
564 timeout: Optional[float] = None,
565 sleep: float = 0.1,
566 blocking: bool = True,
567 blocking_timeout: Optional[float] = None,
568 lock_class: Optional[Type[Lock]] = None,
569 thread_local: bool = True,
570 raise_on_release_error: bool = True,
571 ) -> Lock:
572 """
573 Return a new Lock object using key ``name`` that mimics
574 the behavior of threading.Lock.
576 If specified, ``timeout`` indicates a maximum life for the lock.
577 By default, it will remain locked until release() is called.
579 ``sleep`` indicates the amount of time to sleep per loop iteration
580 when the lock is in blocking mode and another client is currently
581 holding the lock.
583 ``blocking`` indicates whether calling ``acquire`` should block until
584 the lock has been acquired or to fail immediately, causing ``acquire``
585 to return False and the lock not being acquired. Defaults to True.
586 Note this value can be overridden by passing a ``blocking``
587 argument to ``acquire``.
589 ``blocking_timeout`` indicates the maximum amount of time in seconds to
590 spend trying to acquire the lock. A value of ``None`` indicates
591 continue trying forever. ``blocking_timeout`` can be specified as a
592 float or integer, both representing the number of seconds to wait.
594 ``lock_class`` forces the specified lock implementation. Note that as
595 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
596 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
597 you have created your own custom lock class.
599 ``thread_local`` indicates whether the lock token is placed in
600 thread-local storage. By default, the token is placed in thread local
601 storage so that a thread only sees its token, not a token set by
602 another thread. Consider the following timeline:
604 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
605 thread-1 sets the token to "abc"
606 time: 1, thread-2 blocks trying to acquire `my-lock` using the
607 Lock instance.
608 time: 5, thread-1 has not yet completed. redis expires the lock
609 key.
610 time: 5, thread-2 acquired `my-lock` now that it's available.
611 thread-2 sets the token to "xyz"
612 time: 6, thread-1 finishes its work and calls release(). if the
613 token is *not* stored in thread local storage, then
614 thread-1 would see the token value as "xyz" and would be
615 able to successfully release the thread-2's lock.
617 ``raise_on_release_error`` indicates whether to raise an exception when
618 the lock is no longer owned when exiting the context manager. By default,
619 this is True, meaning an exception will be raised. If False, the warning
620 will be logged and the exception will be suppressed.
622 In some use cases it's necessary to disable thread local storage. For
623 example, if you have code where one thread acquires a lock and passes
624 that lock instance to a worker thread to release later. If thread
625 local storage isn't disabled in this case, the worker thread won't see
626 the token set by the thread that acquired the lock. Our assumption
627 is that these cases aren't common and as such default to using
628 thread local storage."""
629 if lock_class is None:
630 lock_class = Lock
631 return lock_class(
632 self,
633 name,
634 timeout=timeout,
635 sleep=sleep,
636 blocking=blocking,
637 blocking_timeout=blocking_timeout,
638 thread_local=thread_local,
639 raise_on_release_error=raise_on_release_error,
640 )
642 def pubsub(self, **kwargs) -> "PubSub":
643 """
644 Return a Publish/Subscribe object. With this object, you can
645 subscribe to channels and listen for messages that get published to
646 them.
647 """
648 return PubSub(
649 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
650 )
652 def keyspace_notifications(
653 self,
654 key_prefix: Union[str, bytes, None] = None,
655 ignore_subscribe_messages: bool = True,
656 ) -> "AsyncKeyspaceNotifications":
657 """
658 Return an :class:`~redis.asyncio.keyspace_notifications.AsyncKeyspaceNotifications`
659 object for subscribing to keyspace and keyevent notifications.
661 Note: Keyspace notifications must be enabled on the Redis server via
662 the ``notify-keyspace-events`` configuration option.
664 Args:
665 key_prefix: Optional prefix to filter and strip from keys in
666 notifications.
667 ignore_subscribe_messages: If True, subscribe/unsubscribe
668 confirmations are not returned by
669 get_message/listen.
670 """
671 from redis.asyncio.keyspace_notifications import AsyncKeyspaceNotifications
673 return AsyncKeyspaceNotifications(
674 self,
675 key_prefix=key_prefix,
676 ignore_subscribe_messages=ignore_subscribe_messages,
677 )
679 def monitor(self) -> "Monitor":
680 return Monitor(self.connection_pool)
682 def client(self) -> "Redis":
683 return self.__class__(
684 connection_pool=self.connection_pool, single_connection_client=True
685 )
687 async def __aenter__(self: _RedisT) -> _RedisT:
688 """
689 Async context manager entry. Increments a usage counter so that the
690 connection pool is only closed (via aclose()) when no context is using
691 the client.
692 """
693 await self._increment_usage()
694 try:
695 # Initialize the client (i.e. establish connection, etc.)
696 return await self.initialize()
697 except Exception:
698 # If initialization fails, decrement the counter to keep it in sync
699 await self._decrement_usage()
700 raise
702 async def _increment_usage(self) -> int:
703 """
704 Helper coroutine to increment the usage counter while holding the lock.
705 Returns the new value of the usage counter.
706 """
707 async with self._usage_lock:
708 self._usage_counter += 1
709 return self._usage_counter
711 async def _decrement_usage(self) -> int:
712 """
713 Helper coroutine to decrement the usage counter while holding the lock.
714 Returns the new value of the usage counter.
715 """
716 async with self._usage_lock:
717 self._usage_counter -= 1
718 return self._usage_counter
720 async def __aexit__(self, exc_type, exc_value, traceback):
721 """
722 Async context manager exit. Decrements a usage counter. If this is the
723 last exit (counter becomes zero), the client closes its connection pool.
724 """
725 current_usage = await asyncio.shield(self._decrement_usage())
726 if current_usage == 0:
727 # This was the last active context, so disconnect the pool.
728 await asyncio.shield(self.aclose())
730 _DEL_MESSAGE = "Unclosed Redis client"
732 # passing _warnings and _grl as argument default since they may be gone
733 # by the time __del__ is called at shutdown
734 def __del__(
735 self,
736 _warn: Any = warnings.warn,
737 _grl: Any = asyncio.get_running_loop,
738 ) -> None:
739 if hasattr(self, "connection") and (self.connection is not None):
740 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
741 try:
742 context = {"client": self, "message": self._DEL_MESSAGE}
743 _grl().call_exception_handler(context)
744 except RuntimeError:
745 pass
746 self.connection._close()
748 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
749 """
750 Closes Redis client connection
752 Args:
753 close_connection_pool:
754 decides whether to close the connection pool used by this Redis client,
755 overriding Redis.auto_close_connection_pool.
756 By default, let Redis.auto_close_connection_pool decide
757 whether to close the connection pool.
758 """
759 conn = self.connection
760 if conn:
761 self.connection = None
762 await self.connection_pool.release(conn)
763 if close_connection_pool or (
764 close_connection_pool is None and self.auto_close_connection_pool
765 ):
766 await self.connection_pool.disconnect()
768 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
769 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
770 """
771 Alias for aclose(), for backwards compatibility
772 """
773 await self.aclose(close_connection_pool)
775 async def _send_command_parse_response(self, conn, command_name, *args, **options):
776 """
777 Send a command and parse the response
778 """
779 await conn.send_command(*args)
780 return await self.parse_response(conn, command_name, **options)
782 async def _close_connection(
783 self,
784 conn: Connection,
785 error: Optional[BaseException] = None,
786 failure_count: Optional[int] = None,
787 start_time: Optional[float] = None,
788 command_name: Optional[str] = None,
789 ):
790 """
791 Close the connection before retrying.
793 The supported exceptions are already checked in the
794 retry object so we don't need to do it here.
796 After we disconnect the connection, it will try to reconnect and
797 do a health check as part of the send_command logic(on connection level).
798 """
799 if (
800 error
801 and failure_count is not None
802 and failure_count <= conn.retry.get_retries()
803 ):
804 await record_operation_duration(
805 command_name=command_name,
806 duration_seconds=time.monotonic() - start_time,
807 server_address=getattr(conn, "host", None),
808 server_port=getattr(conn, "port", None),
809 db_namespace=str(conn.db),
810 error=error,
811 retry_attempts=failure_count,
812 )
814 await conn.disconnect(error=error, failure_count=failure_count)
816 # COMMAND EXECUTION AND PROTOCOL PARSING
817 async def execute_command(self, *args, **options):
818 """Execute a command and return a parsed response"""
819 await self.initialize()
820 pool = self.connection_pool
821 command_name = args[0]
822 conn = self.connection or await pool.get_connection()
824 # Start timing for observability
825 start_time = time.monotonic()
826 # Track actual retry attempts for error reporting
827 actual_retry_attempts = 0
829 def failure_callback(error, failure_count):
830 nonlocal actual_retry_attempts
831 actual_retry_attempts = failure_count
832 return self._close_connection(
833 conn, error, failure_count, start_time, command_name
834 )
836 if self.single_connection_client:
837 await self._single_conn_lock.acquire()
838 try:
839 result = await conn.retry.call_with_retry(
840 lambda: self._send_command_parse_response(
841 conn, command_name, *args, **options
842 ),
843 failure_callback,
844 with_failure_count=True,
845 )
847 await record_operation_duration(
848 command_name=command_name,
849 duration_seconds=time.monotonic() - start_time,
850 server_address=getattr(conn, "host", None),
851 server_port=getattr(conn, "port", None),
852 db_namespace=str(conn.db),
853 )
854 return result
855 except Exception as e:
856 await record_error_count(
857 server_address=getattr(conn, "host", None),
858 server_port=getattr(conn, "port", None),
859 network_peer_address=getattr(conn, "host", None),
860 network_peer_port=getattr(conn, "port", None),
861 error_type=e,
862 retry_attempts=actual_retry_attempts,
863 is_internal=False,
864 )
865 raise
866 finally:
867 if self.single_connection_client:
868 self._single_conn_lock.release()
869 if not self.connection:
870 await pool.release(conn)
872 async def parse_response(
873 self, connection: Connection, command_name: Union[str, bytes], **options
874 ):
875 """Parses a response from the Redis server"""
876 try:
877 if NEVER_DECODE in options:
878 response = await connection.read_response(disable_decoding=True)
879 options.pop(NEVER_DECODE)
880 else:
881 response = await connection.read_response()
882 except ResponseError:
883 if EMPTY_RESPONSE in options:
884 return options[EMPTY_RESPONSE]
885 raise
887 if EMPTY_RESPONSE in options:
888 options.pop(EMPTY_RESPONSE)
890 # Remove keys entry, it needs only for cache.
891 options.pop("keys", None)
893 if command_name in self.response_callbacks:
894 # Mypy bug: https://github.com/python/mypy/issues/10977
895 command_name = cast(str, command_name)
896 retval = self.response_callbacks[command_name](response, **options)
897 return await retval if inspect.isawaitable(retval) else retval
898 return response
901StrictRedis = Redis
904class MonitorCommandInfo(TypedDict):
905 time: float
906 db: int
907 client_address: str
908 client_port: str
909 client_type: str
910 command: str
913class Monitor:
914 """
915 Monitor is useful for handling the MONITOR command to the redis server.
916 next_command() method returns one command from monitor
917 listen() method yields commands from monitor.
918 """
920 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
921 command_re = re.compile(r'"(.*?)(?<!\\)"')
923 def __init__(self, connection_pool: ConnectionPool):
924 self.connection_pool = connection_pool
925 self.connection: Optional[Connection] = None
927 async def connect(self):
928 if self.connection is None:
929 self.connection = await self.connection_pool.get_connection()
931 async def __aenter__(self):
932 await self.connect()
933 await self.connection.send_command("MONITOR")
934 # check that monitor returns 'OK', but don't return it to user
935 response = await self.connection.read_response()
936 if not bool_ok(response):
937 raise RedisError(f"MONITOR failed: {response}")
938 return self
940 async def __aexit__(self, *args):
941 await self.connection.disconnect()
942 await self.connection_pool.release(self.connection)
944 async def next_command(self) -> MonitorCommandInfo:
945 """Parse the response from a monitor command"""
946 await self.connect()
947 response = await self.connection.read_response()
948 if isinstance(response, bytes):
949 response = self.connection.encoder.decode(response, force=True)
950 command_time, command_data = response.split(" ", 1)
951 m = self.monitor_re.match(command_data)
952 db_id, client_info, command = m.groups()
953 command = " ".join(self.command_re.findall(command))
954 # Redis escapes double quotes because each piece of the command
955 # string is surrounded by double quotes. We don't have that
956 # requirement so remove the escaping and leave the quote.
957 command = command.replace('\\"', '"')
959 if client_info == "lua":
960 client_address = "lua"
961 client_port = ""
962 client_type = "lua"
963 elif client_info.startswith("unix"):
964 client_address = "unix"
965 client_port = client_info[5:]
966 client_type = "unix"
967 else:
968 # use rsplit as ipv6 addresses contain colons
969 client_address, client_port = client_info.rsplit(":", 1)
970 client_type = "tcp"
971 return {
972 "time": float(command_time),
973 "db": int(db_id),
974 "client_address": client_address,
975 "client_port": client_port,
976 "client_type": client_type,
977 "command": command,
978 }
980 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
981 """Listen for commands coming to the server."""
982 while True:
983 yield await self.next_command()
986class PubSub:
987 """
988 PubSub provides publish, subscribe and listen support to Redis channels.
990 After subscribing to one or more channels, the listen() method will block
991 until a message arrives on one of the subscribed channels. That message
992 will be returned and it's safe to start listening again.
993 """
995 PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
996 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
997 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
999 def __init__(
1000 self,
1001 connection_pool: ConnectionPool,
1002 shard_hint: Optional[str] = None,
1003 ignore_subscribe_messages: bool = False,
1004 encoder=None,
1005 push_handler_func: Optional[Callable] = None,
1006 event_dispatcher: Optional["EventDispatcher"] = None,
1007 ):
1008 if event_dispatcher is None:
1009 self._event_dispatcher = EventDispatcher()
1010 else:
1011 self._event_dispatcher = event_dispatcher
1012 self.connection_pool = connection_pool
1013 self.shard_hint = shard_hint
1014 self.ignore_subscribe_messages = ignore_subscribe_messages
1015 self.connection = None
1016 # we need to know the encoding options for this connection in order
1017 # to lookup channel and pattern names for callback handlers.
1018 self.encoder = encoder
1019 self.push_handler_func = push_handler_func
1020 if self.encoder is None:
1021 self.encoder = self.connection_pool.get_encoder()
1022 if self.encoder.decode_responses:
1023 self.health_check_response = [
1024 ["pong", self.HEALTH_CHECK_MESSAGE],
1025 self.HEALTH_CHECK_MESSAGE,
1026 ]
1027 else:
1028 self.health_check_response = [
1029 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
1030 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
1031 ]
1032 if self.push_handler_func is None:
1033 _set_info_logger()
1034 self.channels = {}
1035 self.pending_unsubscribe_channels = set()
1036 self.patterns = {}
1037 self.pending_unsubscribe_patterns = set()
1038 self.shard_channels = {}
1039 self.pending_unsubscribe_shard_channels = set()
1040 self._lock = asyncio.Lock()
1042 async def __aenter__(self):
1043 return self
1045 async def __aexit__(self, exc_type, exc_value, traceback):
1046 await self.aclose()
1048 def __del__(self):
1049 if self.connection:
1050 self.connection.deregister_connect_callback(self.on_connect)
1052 async def aclose(self):
1053 # In case a connection property does not yet exist
1054 # (due to a crash earlier in the Redis() constructor), return
1055 # immediately as there is nothing to clean-up.
1056 if not hasattr(self, "connection"):
1057 return
1058 async with self._lock:
1059 if self.connection:
1060 # Use nowait=True to avoid awaiting StreamWriter.wait_closed(),
1061 # which can deadlock when a concurrent reader task (e.g. one
1062 # running pubsub.run() or get_message(block=True)) still holds
1063 # the transport. See https://github.com/redis/redis-py/issues/3941
1064 await self.connection.disconnect(nowait=True)
1065 self.connection.deregister_connect_callback(self.on_connect)
1066 await self.connection_pool.release(self.connection)
1067 self.connection = None
1068 self.channels = {}
1069 self.pending_unsubscribe_channels = set()
1070 self.patterns = {}
1071 self.pending_unsubscribe_patterns = set()
1072 self.shard_channels = {}
1073 self.pending_unsubscribe_shard_channels = set()
1075 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
1076 async def close(self) -> None:
1077 """Alias for aclose(), for backwards compatibility"""
1078 await self.aclose()
1080 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
1081 async def reset(self) -> None:
1082 """Alias for aclose(), for backwards compatibility"""
1083 await self.aclose()
1085 async def _resubscribe(self, subscribed, subscribe_fn) -> None:
1086 # Replay handler-backed subscriptions as positional Subscription objects
1087 # so binary names never need to be decoded into keyword argument keys.
1088 subscriptions = pubsub_subscription_args(subscribed)
1089 if subscriptions:
1090 await subscribe_fn(*subscriptions)
1092 async def _resubscribe_shard_channels(self) -> None:
1093 await self._resubscribe(self.shard_channels, self.ssubscribe)
1095 async def on_connect(self, connection: Connection):
1096 """Re-subscribe to any channels and patterns previously subscribed to"""
1097 self.pending_unsubscribe_channels.clear()
1098 self.pending_unsubscribe_patterns.clear()
1099 self.pending_unsubscribe_shard_channels.clear()
1100 if self.channels:
1101 await self._resubscribe(self.channels, self.subscribe)
1102 if self.patterns:
1103 await self._resubscribe(self.patterns, self.psubscribe)
1104 if self.shard_channels:
1105 await self._resubscribe_shard_channels()
1107 @property
1108 def subscribed(self):
1109 """Indicates if there are subscriptions to any channels or patterns"""
1110 return bool(self.channels or self.patterns or self.shard_channels)
1112 async def execute_command(self, *args: EncodableT):
1113 """Execute a publish/subscribe command"""
1115 # NOTE: don't parse the response in this function -- it could pull a
1116 # legitimate message off the stack if the connection is already
1117 # subscribed to one or more channels
1119 await self.connect()
1120 connection = self.connection
1121 kwargs = {"check_health": not self.subscribed}
1122 await self._execute(connection, connection.send_command, *args, **kwargs)
1124 async def connect(self):
1125 """
1126 Ensure that the PubSub is connected
1127 """
1128 if self.connection is None:
1129 self.connection = await self.connection_pool.get_connection()
1130 # register a callback that re-subscribes to any channels we
1131 # were listening to when we were disconnected
1132 self.connection.register_connect_callback(self.on_connect)
1133 else:
1134 await self.connection.connect()
1135 if self.push_handler_func is not None:
1136 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
1138 self._event_dispatcher.dispatch(
1139 AfterPubSubConnectionInstantiationEvent(
1140 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
1141 )
1142 )
1144 async def _reconnect(
1145 self,
1146 conn,
1147 error: Optional[BaseException] = None,
1148 failure_count: Optional[int] = None,
1149 start_time: Optional[float] = None,
1150 command_name: Optional[str] = None,
1151 ):
1152 """
1153 The supported exceptions are already checked in the
1154 retry object so we don't need to do it here.
1156 In this error handler we are trying to reconnect to the server.
1157 """
1158 if (
1159 error
1160 and failure_count is not None
1161 and failure_count <= conn.retry.get_retries()
1162 ):
1163 if command_name:
1164 await record_operation_duration(
1165 command_name=command_name,
1166 duration_seconds=time.monotonic() - start_time,
1167 server_address=getattr(conn, "host", None),
1168 server_port=getattr(conn, "port", None),
1169 db_namespace=str(conn.db),
1170 error=error,
1171 retry_attempts=failure_count,
1172 )
1173 await conn.disconnect(error=error, failure_count=failure_count)
1174 await conn.connect()
1176 async def _execute(self, conn, command, *args, **kwargs):
1177 """
1178 Connect manually upon disconnection. If the Redis server is down,
1179 this will fail and raise a ConnectionError as desired.
1180 After reconnection, the ``on_connect`` callback should have been
1181 called by the # connection to resubscribe us to any channels and
1182 patterns we were previously listening to
1183 """
1184 if not len(args) == 0:
1185 command_name = args[0]
1186 else:
1187 command_name = None
1189 # Start timing for observability
1190 start_time = time.monotonic()
1191 # Track actual retry attempts for error reporting
1192 actual_retry_attempts = 0
1194 def failure_callback(error, failure_count):
1195 nonlocal actual_retry_attempts
1196 actual_retry_attempts = failure_count
1197 return self._reconnect(conn, error, failure_count, start_time, command_name)
1199 try:
1200 response = await conn.retry.call_with_retry(
1201 lambda: command(*args, **kwargs),
1202 failure_callback,
1203 with_failure_count=True,
1204 )
1206 if command_name:
1207 await record_operation_duration(
1208 command_name=command_name,
1209 duration_seconds=time.monotonic() - start_time,
1210 server_address=getattr(conn, "host", None),
1211 server_port=getattr(conn, "port", None),
1212 db_namespace=str(conn.db),
1213 )
1215 return response
1216 except Exception as e:
1217 await record_error_count(
1218 server_address=getattr(conn, "host", None),
1219 server_port=getattr(conn, "port", None),
1220 network_peer_address=getattr(conn, "host", None),
1221 network_peer_port=getattr(conn, "port", None),
1222 error_type=e,
1223 retry_attempts=actual_retry_attempts,
1224 is_internal=False,
1225 )
1226 raise
1228 async def parse_response(self, block: bool = True, timeout: float = 0):
1229 """
1230 Parse the response from a publish/subscribe command.
1232 Args:
1233 block: If True, block indefinitely until a message is available.
1234 If False, return immediately if no message is available.
1235 Default: True
1236 timeout: The timeout in seconds for reading a response when block=False.
1237 This parameter is ignored when block=True.
1238 Default: 0 (return immediately if no data available)
1240 Returns:
1241 The parsed response from the server, or None if no message is available
1242 within the timeout period (when block=False).
1244 Important:
1245 The block and timeout parameters work together:
1246 - When block=True: timeout is IGNORED, method blocks indefinitely
1247 - When block=False: timeout is USED, method returns after timeout expires
1249 Typically, you should use get_message(timeout=X) instead of calling
1250 parse_response() directly. The get_message() method automatically sets
1251 block=False when a timeout is provided, and block=True when timeout=None.
1253 Example:
1254 # Block indefinitely (timeout is ignored)
1255 response = await pubsub.parse_response(block=True, timeout=0.1)
1257 # Non-blocking with 0.1 second timeout
1258 response = await pubsub.parse_response(block=False, timeout=0.1)
1260 # Non-blocking, return immediately
1261 response = await pubsub.parse_response(block=False, timeout=0)
1263 # Recommended: use get_message() instead
1264 msg = await pubsub.get_message(timeout=0.1) # automatically sets block=False
1265 msg = await pubsub.get_message(timeout=None) # automatically sets block=True
1266 """
1267 conn = self.connection
1268 if conn is None:
1269 raise RuntimeError(
1270 "pubsub connection not set: "
1271 "did you forget to call subscribe() or psubscribe()?"
1272 )
1274 await self.check_health()
1276 if not conn.is_connected:
1277 await conn.connect()
1279 read_timeout = None if block else timeout
1280 response = await self._execute(
1281 conn,
1282 conn.read_response,
1283 timeout=read_timeout,
1284 disconnect_on_error=False,
1285 push_request=True,
1286 )
1288 if conn.health_check_interval and response in self.health_check_response:
1289 # ignore the health check message as user might not expect it
1290 return None
1291 return response
1293 async def check_health(self):
1294 conn = self.connection
1295 if conn is None:
1296 raise RuntimeError(
1297 "pubsub connection not set: "
1298 "did you forget to call subscribe() or psubscribe()?"
1299 )
1301 if (
1302 conn.health_check_interval
1303 and asyncio.get_running_loop().time() > conn.next_health_check
1304 ):
1305 await conn.send_command(
1306 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1307 )
1309 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1310 """
1311 normalize channel/pattern names to be either bytes or strings
1312 based on whether responses are automatically decoded. this saves us
1313 from coercing the value for each message coming in.
1314 """
1315 encode = self.encoder.encode
1316 decode = self.encoder.decode
1317 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1319 async def psubscribe(
1320 self, *args: ChannelT | Subscription, **kwargs: PubSubHandler
1321 ) -> None:
1322 """
1323 Subscribe to channel patterns.
1324 Patterns supplied as keyword arguments expect a pattern name as the
1325 key and a callable as the value.
1326 ``Subscription`` objects can also be supplied positionally with an
1327 optional handler.
1328 A pattern's callable will be invoked automatically
1329 when a message is received on that pattern rather than producing a
1330 message via ``listen()``.
1331 """
1332 new_patterns = parse_pubsub_subscriptions(args, kwargs)
1333 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1334 # update the patterns dict AFTER we send the command. we don't want to
1335 # subscribe twice to these patterns, once for the command and again
1336 # for the reconnection.
1337 new_patterns = self._normalize_keys(new_patterns)
1338 self.patterns.update(new_patterns)
1339 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1340 return ret_val
1342 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1343 """
1344 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1345 all patterns.
1346 """
1347 patterns: Iterable[ChannelT]
1348 if args:
1349 parsed_args = list_or_args((args[0],), args[1:])
1350 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1351 else:
1352 parsed_args = []
1353 patterns = self.patterns
1354 self.pending_unsubscribe_patterns.update(patterns)
1355 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1357 async def subscribe(
1358 self, *args: ChannelT | Subscription, **kwargs: PubSubHandler
1359 ) -> None:
1360 """
1361 Subscribe to channels.
1362 Channels supplied as keyword arguments expect
1363 a channel name as the key and a callable as the value.
1364 ``Subscription`` objects can also be supplied positionally with an
1365 optional handler.
1366 A channel's callable will be invoked automatically
1367 when a message is received on that channel rather than producing a
1368 message via ``listen()`` or ``get_message()``.
1369 """
1370 new_channels = parse_pubsub_subscriptions(args, kwargs)
1371 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1372 # update the channels dict AFTER we send the command. we don't want to
1373 # subscribe twice to these channels, once for the command and again
1374 # for the reconnection.
1375 new_channels = self._normalize_keys(new_channels)
1376 self.channels.update(new_channels)
1377 self.pending_unsubscribe_channels.difference_update(new_channels)
1378 return ret_val
1380 def unsubscribe(self, *args) -> Awaitable:
1381 """
1382 Unsubscribe from the supplied channels. If empty, unsubscribe from
1383 all channels
1384 """
1385 if args:
1386 parsed_args = list_or_args(args[0], args[1:])
1387 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1388 else:
1389 parsed_args = []
1390 channels = self.channels
1391 self.pending_unsubscribe_channels.update(channels)
1392 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1394 async def ssubscribe(
1395 self,
1396 *args: ChannelT | Subscription,
1397 target_node: Any = None,
1398 **kwargs: PubSubHandler,
1399 ) -> None:
1400 """
1401 Subscribes the client to the specified shard channels.
1402 Channels supplied as keyword arguments expect a channel name as the key
1403 and a callable as the value.
1404 ``Subscription`` objects can also be supplied positionally
1405 with an optional handler.
1406 A channel's callable will be invoked automatically when a message
1407 is received on that channel rather than producing a message
1408 via ``listen()`` or ``get_sharded_message()``.
1409 """
1410 new_s_channels = parse_pubsub_subscriptions(args, kwargs)
1411 ret_val = await self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1412 # update the s_channels dict AFTER we send the command. we don't want to
1413 # subscribe twice to these channels, once for the command and again
1414 # for the reconnection.
1415 new_s_channels = self._normalize_keys(new_s_channels)
1416 self.shard_channels.update(new_s_channels)
1417 self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1418 return ret_val
1420 def sunsubscribe(self, *args, target_node=None) -> Awaitable:
1421 """
1422 Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1423 all shard_channels
1424 """
1425 if args:
1426 args = list_or_args(args[0], args[1:])
1427 s_channels = self._normalize_keys(dict.fromkeys(args))
1428 else:
1429 s_channels = self.shard_channels
1430 self.pending_unsubscribe_shard_channels.update(s_channels)
1431 return self.execute_command("SUNSUBSCRIBE", *args)
1433 async def listen(self) -> AsyncIterator:
1434 """Listen for messages on channels this client has been subscribed to"""
1435 while self.subscribed:
1436 response = await self.handle_message(await self.parse_response(block=True))
1437 if response is not None:
1438 yield response
1440 async def get_message(
1441 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1442 ):
1443 """
1444 Get the next message if one is available, otherwise None.
1446 If timeout is specified, the system will wait for `timeout` seconds
1447 before returning. Timeout should be specified as a floating point
1448 number or None to wait indefinitely.
1449 """
1450 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1451 if response:
1452 return await self.handle_message(response, ignore_subscribe_messages)
1453 return None
1455 def ping(self, message=None) -> Awaitable[bool]:
1456 """
1457 Ping the Redis server to test connectivity.
1459 Sends a PING command to the Redis server and returns True if the server
1460 responds with "PONG".
1461 """
1462 args = ["PING", message] if message is not None else ["PING"]
1463 return self.execute_command(*args)
1465 async def handle_message(self, response, ignore_subscribe_messages=False):
1466 """
1467 Parses a pub/sub message. If the channel or pattern was subscribed to
1468 with a message handler, the handler is invoked instead of a parsed
1469 message being returned.
1470 """
1471 if response is None:
1472 return None
1473 if isinstance(response, bytes):
1474 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1475 message_type = str_if_bytes(response[0])
1476 if message_type == "pmessage":
1477 message = {
1478 "type": message_type,
1479 "pattern": response[1],
1480 "channel": response[2],
1481 "data": response[3],
1482 }
1483 elif message_type == "pong":
1484 message = {
1485 "type": message_type,
1486 "pattern": None,
1487 "channel": None,
1488 "data": response[1],
1489 }
1490 else:
1491 message = {
1492 "type": message_type,
1493 "pattern": None,
1494 "channel": response[1],
1495 "data": response[2],
1496 }
1498 if message_type in ["message", "pmessage"]:
1499 channel = str_if_bytes(message["channel"])
1500 await record_pubsub_message(
1501 direction=PubSubDirection.RECEIVE,
1502 channel=channel,
1503 )
1504 elif message_type == "smessage":
1505 channel = str_if_bytes(message["channel"])
1506 await record_pubsub_message(
1507 direction=PubSubDirection.RECEIVE,
1508 channel=channel,
1509 sharded=True,
1510 )
1512 # if this is an unsubscribe message, remove it from memory
1513 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1514 if message_type == "punsubscribe":
1515 pattern = response[1]
1516 if pattern in self.pending_unsubscribe_patterns:
1517 self.pending_unsubscribe_patterns.remove(pattern)
1518 self.patterns.pop(pattern, None)
1519 elif message_type == "sunsubscribe":
1520 s_channel = response[1]
1521 if s_channel in self.pending_unsubscribe_shard_channels:
1522 self.pending_unsubscribe_shard_channels.remove(s_channel)
1523 self.shard_channels.pop(s_channel, None)
1524 else:
1525 channel = response[1]
1526 if channel in self.pending_unsubscribe_channels:
1527 self.pending_unsubscribe_channels.remove(channel)
1528 self.channels.pop(channel, None)
1530 if message_type in self.PUBLISH_MESSAGE_TYPES:
1531 # if there's a message handler, invoke it
1532 if message_type == "pmessage":
1533 handler = self.patterns.get(message["pattern"], None)
1534 elif message_type == "smessage":
1535 handler = self.shard_channels.get(message["channel"], None)
1536 else:
1537 handler = self.channels.get(message["channel"], None)
1538 if handler:
1539 if inspect.iscoroutinefunction(handler):
1540 await handler(message)
1541 else:
1542 handler(message)
1543 return None
1544 elif message_type != "pong":
1545 # this is a subscribe/unsubscribe message. ignore if we don't
1546 # want them
1547 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1548 return None
1550 return message
1552 async def run(
1553 self,
1554 *,
1555 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1556 poll_timeout: float = 1.0,
1557 pubsub=None,
1558 ) -> None:
1559 """Process pub/sub messages using registered callbacks.
1561 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1562 redis-py, but it is a coroutine. To launch it as a separate task, use
1563 ``asyncio.create_task``:
1565 >>> task = asyncio.create_task(pubsub.run())
1567 To shut it down, use asyncio cancellation:
1569 >>> task.cancel()
1570 >>> await task
1571 """
1572 for channel, handler in self.channels.items():
1573 if handler is None:
1574 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1575 for pattern, handler in self.patterns.items():
1576 if handler is None:
1577 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1579 await self.connect()
1580 while True:
1581 try:
1582 if pubsub is None:
1583 await self.get_message(
1584 ignore_subscribe_messages=True, timeout=poll_timeout
1585 )
1586 else:
1587 await pubsub.get_message(
1588 ignore_subscribe_messages=True, timeout=poll_timeout
1589 )
1590 except asyncio.CancelledError:
1591 raise
1592 except BaseException as e:
1593 if exception_handler is None:
1594 raise
1595 res = exception_handler(e, self)
1596 if inspect.isawaitable(res):
1597 await res
1598 # Ensure that other tasks on the event loop get a chance to run
1599 # if we didn't have to block for I/O anywhere.
1600 await asyncio.sleep(0)
1603class PubsubWorkerExceptionHandler(Protocol):
1604 def __call__(self, e: BaseException, pubsub: PubSub): ...
1607class AsyncPubsubWorkerExceptionHandler(Protocol):
1608 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1611PSWorkerThreadExcHandlerT = Union[
1612 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1613]
1616CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1617CommandStackT = List[CommandT]
1620class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1621 """
1622 Pipelines provide a way to transmit multiple commands to the Redis server
1623 in one transmission. This is convenient for batch processing, such as
1624 saving all the values in a list to Redis.
1626 All commands executed within a pipeline(when running in transactional mode,
1627 which is the default behavior) are wrapped with MULTI and EXEC
1628 calls. This guarantees all commands executed in the pipeline will be
1629 executed atomically.
1631 Any command raising an exception does *not* halt the execution of
1632 subsequent commands in the pipeline. Instead, the exception is caught
1633 and its instance is placed into the response list returned by execute().
1634 Code iterating over the response list should be able to deal with an
1635 instance of an exception as a potential value. In general, these will be
1636 ResponseError exceptions, such as those raised when issuing a command
1637 on a key of a different datatype.
1638 """
1640 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1642 def __init__(
1643 self,
1644 connection_pool: ConnectionPool,
1645 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1646 transaction: bool,
1647 shard_hint: Optional[str],
1648 ):
1649 self.connection_pool = connection_pool
1650 self.connection = None
1651 self.response_callbacks = response_callbacks
1652 self.is_transaction = transaction
1653 self.shard_hint = shard_hint
1654 self.watching = False
1655 self.command_stack: CommandStackT = []
1656 self.scripts: Set[Script] = set()
1657 self.explicit_transaction = False
1659 async def __aenter__(self: _RedisT) -> _RedisT:
1660 return self
1662 async def __aexit__(self, exc_type, exc_value, traceback):
1663 await self.reset()
1665 def __await__(self):
1666 return self._async_self().__await__()
1668 _DEL_MESSAGE = "Unclosed Pipeline client"
1670 def __len__(self):
1671 return len(self.command_stack)
1673 def __bool__(self):
1674 """Pipeline instances should always evaluate to True"""
1675 return True
1677 async def _async_self(self):
1678 return self
1680 async def reset(self):
1681 self.command_stack = []
1682 self.scripts = set()
1683 # make sure to reset the connection state in the event that we were
1684 # watching something
1685 if self.watching and self.connection:
1686 try:
1687 # call this manually since our unwatch or
1688 # immediate_execute_command methods can call reset()
1689 await self.connection.send_command("UNWATCH")
1690 await self.connection.read_response()
1691 except ConnectionError:
1692 # disconnect will also remove any previous WATCHes
1693 if self.connection:
1694 await self.connection.disconnect()
1695 # clean up the other instance attributes
1696 self.watching = False
1697 self.explicit_transaction = False
1698 # we can safely return the connection to the pool here since we're
1699 # sure we're no longer WATCHing anything
1700 if self.connection:
1701 await self.connection_pool.release(self.connection)
1702 self.connection = None
1704 async def aclose(self) -> None:
1705 """Alias for reset(), a standard method name for cleanup"""
1706 await self.reset()
1708 def multi(self):
1709 """
1710 Start a transactional block of the pipeline after WATCH commands
1711 are issued. End the transactional block with `execute`.
1712 """
1713 if self.explicit_transaction:
1714 raise RedisError("Cannot issue nested calls to MULTI")
1715 if self.command_stack:
1716 raise RedisError(
1717 "Commands without an initial WATCH have already been issued"
1718 )
1719 self.explicit_transaction = True
1721 def execute_command(
1722 self, *args, **kwargs
1723 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1724 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1725 return self.immediate_execute_command(*args, **kwargs)
1726 return self.pipeline_execute_command(*args, **kwargs)
1728 async def _disconnect_reset_raise_on_watching(
1729 self,
1730 conn: Connection,
1731 error: Exception,
1732 failure_count: Optional[int] = None,
1733 start_time: Optional[float] = None,
1734 command_name: Optional[str] = None,
1735 ) -> None:
1736 """
1737 Close the connection reset watching state and
1738 raise an exception if we were watching.
1740 The supported exceptions are already checked in the
1741 retry object so we don't need to do it here.
1743 After we disconnect the connection, it will try to reconnect and
1744 do a health check as part of the send_command logic(on connection level).
1745 """
1746 if (
1747 error
1748 and failure_count is not None
1749 and failure_count <= conn.retry.get_retries()
1750 ):
1751 await record_operation_duration(
1752 command_name=command_name,
1753 duration_seconds=time.monotonic() - start_time,
1754 server_address=getattr(conn, "host", None),
1755 server_port=getattr(conn, "port", None),
1756 db_namespace=str(conn.db),
1757 error=error,
1758 retry_attempts=failure_count,
1759 )
1760 await conn.disconnect(error=error, failure_count=failure_count)
1761 # if we were already watching a variable, the watch is no longer
1762 # valid since this connection has died. raise a WatchError, which
1763 # indicates the user should retry this transaction.
1764 if self.watching:
1765 await self.reset()
1766 raise WatchError(
1767 f"A {type(error).__name__} occurred while watching one or more keys"
1768 )
1770 async def immediate_execute_command(self, *args, **options):
1771 """
1772 Execute a command immediately, but don't auto-retry on the supported
1773 errors for retry if we're already WATCHing a variable.
1774 Used when issuing WATCH or subsequent commands retrieving their values but before
1775 MULTI is called.
1776 """
1777 command_name = args[0]
1778 conn = self.connection
1779 # if this is the first call, we need a connection
1780 if not conn:
1781 conn = await self.connection_pool.get_connection()
1782 self.connection = conn
1784 # Start timing for observability
1785 start_time = time.monotonic()
1786 # Track actual retry attempts for error reporting
1787 actual_retry_attempts = 0
1789 def failure_callback(error, failure_count):
1790 nonlocal actual_retry_attempts
1791 actual_retry_attempts = failure_count
1792 return self._disconnect_reset_raise_on_watching(
1793 conn, error, failure_count, start_time, command_name
1794 )
1796 try:
1797 response = await conn.retry.call_with_retry(
1798 lambda: self._send_command_parse_response(
1799 conn, command_name, *args, **options
1800 ),
1801 failure_callback,
1802 with_failure_count=True,
1803 )
1805 await record_operation_duration(
1806 command_name=command_name,
1807 duration_seconds=time.monotonic() - start_time,
1808 server_address=getattr(conn, "host", None),
1809 server_port=getattr(conn, "port", None),
1810 db_namespace=str(conn.db),
1811 )
1813 return response
1814 except Exception as e:
1815 await record_error_count(
1816 server_address=getattr(conn, "host", None),
1817 server_port=getattr(conn, "port", None),
1818 network_peer_address=getattr(conn, "host", None),
1819 network_peer_port=getattr(conn, "port", None),
1820 error_type=e,
1821 retry_attempts=actual_retry_attempts,
1822 is_internal=False,
1823 )
1824 raise
1826 def pipeline_execute_command(self, *args, **options):
1827 """
1828 Stage a command to be executed when execute() is next called
1830 Returns the current Pipeline object back so commands can be
1831 chained together, such as:
1833 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1835 At some other point, you can then run: pipe.execute(),
1836 which will execute all commands queued in the pipe.
1837 """
1838 self.command_stack.append((args, options))
1839 return self
1841 async def _execute_transaction( # noqa: C901
1842 self, connection: Connection, commands: CommandStackT, raise_on_error
1843 ):
1844 pre: CommandT = (("MULTI",), {})
1845 post: CommandT = (("EXEC",), {})
1846 cmds = (pre, *commands, post)
1847 all_cmds = connection.pack_commands(
1848 args for args, options in cmds if EMPTY_RESPONSE not in options
1849 )
1850 await connection.send_packed_command(all_cmds)
1851 errors = []
1853 # parse off the response for MULTI
1854 # NOTE: we need to handle ResponseErrors here and continue
1855 # so that we read all the additional command messages from
1856 # the socket
1857 try:
1858 await self.parse_response(connection, "_")
1859 except ResponseError as err:
1860 errors.append((0, err))
1862 # and all the other commands
1863 for i, command in enumerate(commands):
1864 if EMPTY_RESPONSE in command[1]:
1865 errors.append((i, command[1][EMPTY_RESPONSE]))
1866 else:
1867 try:
1868 await self.parse_response(connection, "_")
1869 except ResponseError as err:
1870 self.annotate_exception(err, i + 1, command[0])
1871 errors.append((i, err))
1873 # parse the EXEC.
1874 try:
1875 response = await self.parse_response(connection, "_")
1876 except ExecAbortError as err:
1877 if errors:
1878 raise errors[0][1] from err
1879 raise
1881 # EXEC clears any watched keys
1882 self.watching = False
1884 if response is None:
1885 raise WatchError("Watched variable changed.") from None
1887 # put any parse errors into the response
1888 for i, e in errors:
1889 response.insert(i, e)
1891 if len(response) != len(commands):
1892 if self.connection:
1893 await self.connection.disconnect()
1894 raise ResponseError(
1895 "Wrong number of response items from pipeline execution"
1896 ) from None
1898 # find any errors in the response and raise if necessary
1899 if raise_on_error:
1900 self.raise_first_error(commands, response)
1902 # We have to run response callbacks manually
1903 data = []
1904 for r, cmd in zip(response, commands):
1905 if not isinstance(r, Exception):
1906 args, options = cmd
1907 command_name = args[0]
1909 # Remove keys entry, it needs only for cache.
1910 options.pop("keys", None)
1912 if command_name in self.response_callbacks:
1913 r = self.response_callbacks[command_name](r, **options)
1914 if inspect.isawaitable(r):
1915 r = await r
1916 data.append(r)
1917 return data
1919 async def _execute_pipeline(
1920 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1921 ):
1922 # build up all commands into a single request to increase network perf
1923 all_cmds = connection.pack_commands([args for args, _ in commands])
1924 await connection.send_packed_command(all_cmds)
1926 response = []
1927 for args, options in commands:
1928 try:
1929 response.append(
1930 await self.parse_response(connection, args[0], **options)
1931 )
1932 except ResponseError as e:
1933 response.append(e)
1935 if raise_on_error:
1936 self.raise_first_error(commands, response)
1937 return response
1939 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1940 for i, r in enumerate(response):
1941 if isinstance(r, ResponseError):
1942 self.annotate_exception(r, i + 1, commands[i][0])
1943 raise r
1945 def annotate_exception(
1946 self, exception: Exception, number: int, command: Iterable[object]
1947 ) -> None:
1948 cmd = " ".join(map(safe_str, command))
1949 msg = (
1950 f"Command # {number} ({truncate_text(cmd)}) "
1951 f"of pipeline caused error: {exception.args}"
1952 )
1953 exception.args = (msg,) + exception.args[1:]
1955 async def parse_response(
1956 self, connection: Connection, command_name: Union[str, bytes], **options
1957 ):
1958 result = await super().parse_response(connection, command_name, **options)
1959 if command_name in self.UNWATCH_COMMANDS:
1960 self.watching = False
1961 elif command_name == "WATCH":
1962 self.watching = True
1963 return result
1965 async def load_scripts(self):
1966 # make sure all scripts that are about to be run on this pipeline exist
1967 scripts = list(self.scripts)
1968 immediate = self.immediate_execute_command
1969 shas = [s.sha for s in scripts]
1970 # we can't use the normal script_* methods because they would just
1971 # get buffered in the pipeline.
1972 exists = await immediate("SCRIPT EXISTS", *shas)
1973 if not all(exists):
1974 for s, exist in zip(scripts, exists):
1975 if not exist:
1976 s.sha = await immediate("SCRIPT LOAD", s.script)
1978 async def _disconnect_raise_on_watching(
1979 self,
1980 conn: Connection,
1981 error: Exception,
1982 failure_count: Optional[int] = None,
1983 start_time: Optional[float] = None,
1984 command_name: Optional[str] = None,
1985 ):
1986 """
1987 Close the connection, raise an exception if we were watching.
1989 The supported exceptions are already checked in the
1990 retry object so we don't need to do it here.
1992 After we disconnect the connection, it will try to reconnect and
1993 do a health check as part of the send_command logic(on connection level).
1994 """
1995 if (
1996 error
1997 and failure_count is not None
1998 and failure_count <= conn.retry.get_retries()
1999 ):
2000 await record_operation_duration(
2001 command_name=command_name,
2002 duration_seconds=time.monotonic() - start_time,
2003 server_address=getattr(conn, "host", None),
2004 server_port=getattr(conn, "port", None),
2005 db_namespace=str(conn.db),
2006 error=error,
2007 retry_attempts=failure_count,
2008 )
2009 await conn.disconnect(error=error, failure_count=failure_count)
2010 # if we were watching a variable, the watch is no longer valid
2011 # since this connection has died. raise a WatchError, which
2012 # indicates the user should retry this transaction.
2013 if self.watching:
2014 raise WatchError(
2015 f"A {type(error).__name__} occurred while watching one or more keys"
2016 )
2018 async def execute(self, raise_on_error: bool = True) -> List[Any]:
2019 """Execute all the commands in the current pipeline"""
2020 stack = self.command_stack
2021 if not stack and not self.watching:
2022 return []
2023 if self.scripts:
2024 await self.load_scripts()
2025 if self.is_transaction or self.explicit_transaction:
2026 execute = self._execute_transaction
2027 operation_name = "MULTI"
2028 else:
2029 execute = self._execute_pipeline
2030 operation_name = "PIPELINE"
2032 conn = self.connection
2033 if not conn:
2034 conn = await self.connection_pool.get_connection()
2035 # assign to self.connection so reset() releases the connection
2036 # back to the pool after we're done
2037 self.connection = conn
2038 conn = cast(Connection, conn)
2040 # Start timing for observability
2041 start_time = time.monotonic()
2042 # Track actual retry attempts for error reporting
2043 actual_retry_attempts = 0
2045 def failure_callback(error, failure_count):
2046 nonlocal actual_retry_attempts
2047 actual_retry_attempts = failure_count
2048 return self._disconnect_raise_on_watching(
2049 conn, error, failure_count, start_time, operation_name
2050 )
2052 try:
2053 response = await conn.retry.call_with_retry(
2054 lambda: execute(conn, stack, raise_on_error),
2055 failure_callback,
2056 with_failure_count=True,
2057 )
2059 await record_operation_duration(
2060 command_name=operation_name,
2061 duration_seconds=time.monotonic() - start_time,
2062 server_address=getattr(conn, "host", None),
2063 server_port=getattr(conn, "port", None),
2064 db_namespace=str(conn.db),
2065 )
2066 return response
2067 except Exception as e:
2068 await record_error_count(
2069 server_address=getattr(conn, "host", None),
2070 server_port=getattr(conn, "port", None),
2071 network_peer_address=getattr(conn, "host", None),
2072 network_peer_port=getattr(conn, "port", None),
2073 error_type=e,
2074 retry_attempts=actual_retry_attempts,
2075 is_internal=False,
2076 )
2077 raise
2078 finally:
2079 await self.reset()
2081 async def discard(self):
2082 """Flushes all previously queued commands
2083 See: https://redis.io/commands/DISCARD
2084 """
2085 await self.execute_command("DISCARD")
2087 async def watch(self, *names: KeyT):
2088 """Watches the values at keys ``names``"""
2089 if self.explicit_transaction:
2090 raise RedisError("Cannot issue a WATCH after a MULTI")
2091 return await self.execute_command("WATCH", *names)
2093 async def unwatch(self):
2094 """Unwatches all previously specified keys"""
2095 return self.watching and await self.execute_command("UNWATCH") or True