Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/client.py: 22%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import copy
3import inspect
4import re
5import time
6import warnings
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 AsyncIterator,
11 Awaitable,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Literal,
17 Mapping,
18 MutableMapping,
19 Optional,
20 Protocol,
21 Set,
22 Tuple,
23 Type,
24 TypedDict,
25 TypeVar,
26 Union,
27 cast,
28)
30from redis._parsers.helpers import (
31 _RedisCallbacks,
32 _RedisCallbacksRESP2,
33 _RedisCallbacksRESP3,
34 bool_ok,
35)
36from redis.asyncio.connection import (
37 Connection,
38 ConnectionPool,
39 SSLConnection,
40 UnixDomainSocketConnection,
41)
42from redis.asyncio.lock import Lock
43from redis.asyncio.observability.recorder import (
44 record_error_count,
45 record_operation_duration,
46 record_pubsub_message,
47)
48from redis.asyncio.retry import Retry
49from redis.backoff import ExponentialWithJitterBackoff
50from redis.client import (
51 EMPTY_RESPONSE,
52 NEVER_DECODE,
53 AbstractRedis,
54 CaseInsensitiveDict,
55)
56from redis.commands import (
57 AsyncCoreCommands,
58 AsyncRedisModuleCommands,
59 AsyncSentinelCommands,
60 list_or_args,
61)
62from redis.commands.helpers import partition_pubsub_subscriptions_by_handler
63from redis.credentials import CredentialProvider
64from redis.driver_info import DriverInfo, resolve_driver_info
65from redis.event import (
66 AfterPooledConnectionsInstantiationEvent,
67 AfterPubSubConnectionInstantiationEvent,
68 AfterSingleConnectionInstantiationEvent,
69 ClientType,
70 EventDispatcher,
71)
72from redis.exceptions import (
73 ConnectionError,
74 ExecAbortError,
75 PubSubError,
76 RedisError,
77 ResponseError,
78 WatchError,
79)
80from redis.observability.attributes import PubSubDirection
81from redis.typing import ChannelT, EncodableT, KeyT
82from redis.utils import (
83 DEFAULT_RESP_VERSION,
84 SSL_AVAILABLE,
85 _set_info_logger,
86 check_protocol_version,
87 deprecated_args,
88 deprecated_function,
89 safe_str,
90 str_if_bytes,
91 truncate_text,
92)
94if TYPE_CHECKING and SSL_AVAILABLE:
95 from ssl import TLSVersion, VerifyFlags, VerifyMode
96else:
97 TLSVersion = None
98 VerifyMode = None
99 VerifyFlags = None
101PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
102_KeyT = TypeVar("_KeyT", bound=KeyT)
103_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
104_RedisT = TypeVar("_RedisT", bound="Redis")
105_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
106if TYPE_CHECKING:
107 from redis.asyncio.keyspace_notifications import AsyncKeyspaceNotifications
108 from redis.commands.core import Script
111class ResponseCallbackProtocol(Protocol):
112 def __call__(self, response: Any, **kwargs): ...
115class AsyncResponseCallbackProtocol(Protocol):
116 async def __call__(self, response: Any, **kwargs): ...
119ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
122class Redis(
123 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
124):
125 """
126 Implementation of the Redis protocol.
128 This abstract class provides a Python interface to all Redis commands
129 and an implementation of the Redis protocol.
131 Pipelines derive from this, implementing how
132 the commands are sent and received to the Redis server. Based on
133 configuration, an instance will either use a ConnectionPool, or
134 Connection object to talk to redis.
135 """
137 # Type discrimination marker for @overload self-type pattern
138 _is_async_client: Literal[True] = True
140 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
142 @classmethod
143 def from_url(
144 cls: Type["Redis"],
145 url: str,
146 single_connection_client: bool = False,
147 auto_close_connection_pool: Optional[bool] = None,
148 **kwargs,
149 ) -> "Redis":
150 """
151 Return a Redis client object configured from the given URL
153 For example::
155 redis://[[username]:[password]]@localhost:6379/0
156 rediss://[[username]:[password]]@localhost:6379/0
157 unix://[username@]/path/to/socket.sock?db=0[&password=password]
159 Three URL schemes are supported:
161 - `redis://` creates a TCP socket connection. See more at:
162 <https://www.iana.org/assignments/uri-schemes/prov/redis>
163 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
164 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
165 - ``unix://``: creates a Unix Domain Socket connection.
167 The username, password, hostname, path and all querystring values
168 are passed through urllib.parse.unquote in order to replace any
169 percent-encoded values with their corresponding characters.
171 There are several ways to specify a database number. The first value
172 found will be used:
174 1. A ``db`` querystring option, e.g. redis://localhost?db=0
176 2. If using the redis:// or rediss:// schemes, the path argument
177 of the url, e.g. redis://localhost/0
179 3. A ``db`` keyword argument to this function.
181 If none of these options are specified, the default db=0 is used.
183 All querystring options are cast to their appropriate Python types.
184 Boolean arguments can be specified with string values "True"/"False"
185 or "Yes"/"No". Values that cannot be properly cast cause a
186 ``ValueError`` to be raised. Once parsed, the querystring arguments
187 and keyword arguments are passed to the ``ConnectionPool``'s
188 class initializer. In the case of conflicting arguments, querystring
189 arguments always win.
191 """
192 connection_pool = ConnectionPool.from_url(url, **kwargs)
193 client = cls(
194 connection_pool=connection_pool,
195 single_connection_client=single_connection_client,
196 )
197 if auto_close_connection_pool is not None:
198 warnings.warn(
199 DeprecationWarning(
200 '"auto_close_connection_pool" is deprecated '
201 "since version 5.0.1. "
202 "Please create a ConnectionPool explicitly and "
203 "provide to the Redis() constructor instead."
204 )
205 )
206 else:
207 auto_close_connection_pool = True
208 client.auto_close_connection_pool = auto_close_connection_pool
209 return client
211 @classmethod
212 def from_pool(
213 cls: Type["Redis"],
214 connection_pool: ConnectionPool,
215 ) -> "Redis":
216 """
217 Return a Redis client from the given connection pool.
218 The Redis client will take ownership of the connection pool and
219 close it when the Redis client is closed.
220 """
221 client = cls(
222 connection_pool=connection_pool,
223 )
224 client.auto_close_connection_pool = True
225 return client
227 @deprecated_args(
228 args_to_warn=["retry_on_timeout"],
229 reason="TimeoutError is included by default.",
230 version="6.0.0",
231 )
232 @deprecated_args(
233 args_to_warn=["lib_name", "lib_version"],
234 reason="Use 'driver_info' parameter instead. "
235 "lib_name and lib_version will be removed in a future version.",
236 )
237 def __init__(
238 self,
239 *,
240 host: str = "localhost",
241 port: int = 6379,
242 db: Union[str, int] = 0,
243 password: Optional[str] = None,
244 socket_timeout: Optional[float] = None,
245 socket_connect_timeout: Optional[float] = None,
246 socket_keepalive: Optional[bool] = None,
247 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
248 connection_pool: Optional[ConnectionPool] = None,
249 unix_socket_path: Optional[str] = None,
250 encoding: str = "utf-8",
251 encoding_errors: str = "strict",
252 decode_responses: bool = False,
253 retry_on_timeout: bool = False,
254 retry: Retry = Retry(
255 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
256 ),
257 retry_on_error: Optional[list] = None,
258 ssl: bool = False,
259 ssl_keyfile: Optional[str] = None,
260 ssl_certfile: Optional[str] = None,
261 ssl_cert_reqs: Union[str, VerifyMode] = "required",
262 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
263 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
264 ssl_ca_certs: Optional[str] = None,
265 ssl_ca_data: Optional[str] = None,
266 ssl_ca_path: Optional[str] = None,
267 ssl_check_hostname: bool = True,
268 ssl_min_version: Optional[TLSVersion] = None,
269 ssl_ciphers: Optional[str] = None,
270 ssl_password: Optional[str] = None,
271 max_connections: Optional[int] = None,
272 single_connection_client: bool = False,
273 health_check_interval: int = 0,
274 client_name: Optional[str] = None,
275 lib_name: Optional[str] = None,
276 lib_version: Optional[str] = None,
277 driver_info: Optional["DriverInfo"] = None,
278 username: Optional[str] = None,
279 auto_close_connection_pool: Optional[bool] = None,
280 redis_connect_func=None,
281 credential_provider: Optional[CredentialProvider] = None,
282 protocol: Optional[int] = 3,
283 event_dispatcher: Optional[EventDispatcher] = None,
284 ):
285 """
286 Initialize a new Redis client.
288 To specify a retry policy for specific errors, you have two options:
290 1. Set the `retry_on_error` to a list of the error/s to retry on, and
291 you can also set `retry` to a valid `Retry` object(in case the default
292 one is not appropriate) - with this approach the retries will be triggered
293 on the default errors specified in the Retry object enriched with the
294 errors specified in `retry_on_error`.
296 2. Define a `Retry` object with configured 'supported_errors' and set
297 it to the `retry` parameter - with this approach you completely redefine
298 the errors on which retries will happen.
300 `retry_on_timeout` is deprecated - please include the TimeoutError
301 either in the Retry object or in the `retry_on_error` list.
303 When 'connection_pool' is provided - the retry configuration of the
304 provided pool will be used.
305 """
306 kwargs: Dict[str, Any]
307 if event_dispatcher is None:
308 self._event_dispatcher = EventDispatcher()
309 else:
310 self._event_dispatcher = event_dispatcher
311 # auto_close_connection_pool only has an effect if connection_pool is
312 # None. It is assumed that if connection_pool is not None, the user
313 # wants to manage the connection pool themselves.
314 if auto_close_connection_pool is not None:
315 warnings.warn(
316 DeprecationWarning(
317 '"auto_close_connection_pool" is deprecated '
318 "since version 5.0.1. "
319 "Please create a ConnectionPool explicitly and "
320 "provide to the Redis() constructor instead."
321 )
322 )
323 else:
324 auto_close_connection_pool = True
326 if not connection_pool:
327 # Create internal connection pool, expected to be closed by Redis instance
328 if not retry_on_error:
329 retry_on_error = []
331 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
332 computed_driver_info = resolve_driver_info(
333 driver_info, lib_name, lib_version
334 )
336 kwargs = {
337 "db": db,
338 "username": username,
339 "password": password,
340 "credential_provider": credential_provider,
341 "socket_timeout": socket_timeout,
342 "encoding": encoding,
343 "encoding_errors": encoding_errors,
344 "decode_responses": decode_responses,
345 "retry_on_error": retry_on_error,
346 "retry": copy.deepcopy(retry),
347 "max_connections": max_connections,
348 "health_check_interval": health_check_interval,
349 "client_name": client_name,
350 "driver_info": computed_driver_info,
351 "redis_connect_func": redis_connect_func,
352 "protocol": protocol,
353 }
354 # based on input, setup appropriate connection args
355 if unix_socket_path is not None:
356 kwargs.update(
357 {
358 "path": unix_socket_path,
359 "connection_class": UnixDomainSocketConnection,
360 }
361 )
362 else:
363 # TCP specific options
364 kwargs.update(
365 {
366 "host": host,
367 "port": port,
368 "socket_connect_timeout": socket_connect_timeout,
369 "socket_keepalive": socket_keepalive,
370 "socket_keepalive_options": socket_keepalive_options,
371 }
372 )
374 if ssl:
375 kwargs.update(
376 {
377 "connection_class": SSLConnection,
378 "ssl_keyfile": ssl_keyfile,
379 "ssl_certfile": ssl_certfile,
380 "ssl_cert_reqs": ssl_cert_reqs,
381 "ssl_include_verify_flags": ssl_include_verify_flags,
382 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
383 "ssl_ca_certs": ssl_ca_certs,
384 "ssl_ca_data": ssl_ca_data,
385 "ssl_ca_path": ssl_ca_path,
386 "ssl_check_hostname": ssl_check_hostname,
387 "ssl_min_version": ssl_min_version,
388 "ssl_ciphers": ssl_ciphers,
389 "ssl_password": ssl_password,
390 }
391 )
392 # This arg only used if no pool is passed in
393 self.auto_close_connection_pool = auto_close_connection_pool
394 connection_pool = ConnectionPool(**kwargs)
395 self._event_dispatcher.dispatch(
396 AfterPooledConnectionsInstantiationEvent(
397 [connection_pool], ClientType.ASYNC, credential_provider
398 )
399 )
400 else:
401 # If a pool is passed in, do not close it
402 self.auto_close_connection_pool = False
403 self._event_dispatcher.dispatch(
404 AfterPooledConnectionsInstantiationEvent(
405 [connection_pool], ClientType.ASYNC, credential_provider
406 )
407 )
409 self.connection_pool = connection_pool
410 self.single_connection_client = single_connection_client
411 self.connection: Optional[Connection] = None
413 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
415 if check_protocol_version(
416 self.connection_pool.connection_kwargs.get(
417 "protocol", DEFAULT_RESP_VERSION
418 ),
419 3,
420 ):
421 self.response_callbacks.update(_RedisCallbacksRESP3)
422 else:
423 self.response_callbacks.update(_RedisCallbacksRESP2)
425 # If using a single connection client, we need to lock creation-of and use-of
426 # the client in order to avoid race conditions such as using asyncio.gather
427 # on a set of redis commands
428 self._single_conn_lock = asyncio.Lock()
430 # When used as an async context manager, we need to increment and decrement
431 # a usage counter so that we can close the connection pool when no one is
432 # using the client.
433 self._usage_counter = 0
434 self._usage_lock = asyncio.Lock()
436 def __repr__(self):
437 return (
438 f"<{self.__class__.__module__}.{self.__class__.__name__}"
439 f"({self.connection_pool!r})>"
440 )
442 def __await__(self):
443 return self.initialize().__await__()
445 async def initialize(self: _RedisT) -> _RedisT:
446 if self.single_connection_client:
447 async with self._single_conn_lock:
448 if self.connection is None:
449 self.connection = await self.connection_pool.get_connection()
451 self._event_dispatcher.dispatch(
452 AfterSingleConnectionInstantiationEvent(
453 self.connection, ClientType.ASYNC, self._single_conn_lock
454 )
455 )
456 return self
458 def set_response_callback(self, command: str, callback: ResponseCallbackT):
459 """Set a custom Response Callback"""
460 self.response_callbacks[command] = callback
462 def get_encoder(self):
463 """Get the connection pool's encoder"""
464 return self.connection_pool.get_encoder()
466 def get_connection_kwargs(self):
467 """Get the connection's key-word arguments"""
468 return self.connection_pool.connection_kwargs
470 def get_retry(self) -> Optional[Retry]:
471 return self.get_connection_kwargs().get("retry")
473 def set_retry(self, retry: Retry) -> None:
474 self.get_connection_kwargs().update({"retry": retry})
475 self.connection_pool.set_retry(retry)
477 def load_external_module(self, funcname, func):
478 """
479 This function can be used to add externally defined redis modules,
480 and their namespaces to the redis client.
482 funcname - A string containing the name of the function to create
483 func - The function, being added to this class.
485 ex: Assume that one has a custom redis module named foomod that
486 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
487 To load function functions into this namespace:
489 from redis import Redis
490 from foomodule import F
491 r = Redis()
492 r.load_external_module("foo", F)
493 r.foo().dothing('your', 'arguments')
495 For a concrete example see the reimport of the redisjson module in
496 tests/test_connection.py::test_loading_external_modules
497 """
498 setattr(self, funcname, func)
500 def pipeline(
501 self, transaction: bool = True, shard_hint: Optional[str] = None
502 ) -> "Pipeline":
503 """
504 Return a new pipeline object that can queue multiple commands for
505 later execution. ``transaction`` indicates whether all commands
506 should be executed atomically. Apart from making a group of operations
507 atomic, pipelines are useful for reducing the back-and-forth overhead
508 between the client and server.
509 """
510 return Pipeline(
511 self.connection_pool, self.response_callbacks, transaction, shard_hint
512 )
514 async def transaction(
515 self,
516 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
517 *watches: KeyT,
518 shard_hint: Optional[str] = None,
519 value_from_callable: bool = False,
520 watch_delay: Optional[float] = None,
521 ):
522 """
523 Convenience method for executing the callable `func` as a transaction
524 while watching all keys specified in `watches`. The 'func' callable
525 should expect a single argument which is a Pipeline object.
526 """
527 pipe: Pipeline
528 async with self.pipeline(True, shard_hint) as pipe:
529 while True:
530 try:
531 if watches:
532 await pipe.watch(*watches)
533 func_value = func(pipe)
534 if inspect.isawaitable(func_value):
535 func_value = await func_value
536 exec_value = await pipe.execute()
537 return func_value if value_from_callable else exec_value
538 except WatchError:
539 if watch_delay is not None and watch_delay > 0:
540 await asyncio.sleep(watch_delay)
541 continue
543 def lock(
544 self,
545 name: KeyT,
546 timeout: Optional[float] = None,
547 sleep: float = 0.1,
548 blocking: bool = True,
549 blocking_timeout: Optional[float] = None,
550 lock_class: Optional[Type[Lock]] = None,
551 thread_local: bool = True,
552 raise_on_release_error: bool = True,
553 ) -> Lock:
554 """
555 Return a new Lock object using key ``name`` that mimics
556 the behavior of threading.Lock.
558 If specified, ``timeout`` indicates a maximum life for the lock.
559 By default, it will remain locked until release() is called.
561 ``sleep`` indicates the amount of time to sleep per loop iteration
562 when the lock is in blocking mode and another client is currently
563 holding the lock.
565 ``blocking`` indicates whether calling ``acquire`` should block until
566 the lock has been acquired or to fail immediately, causing ``acquire``
567 to return False and the lock not being acquired. Defaults to True.
568 Note this value can be overridden by passing a ``blocking``
569 argument to ``acquire``.
571 ``blocking_timeout`` indicates the maximum amount of time in seconds to
572 spend trying to acquire the lock. A value of ``None`` indicates
573 continue trying forever. ``blocking_timeout`` can be specified as a
574 float or integer, both representing the number of seconds to wait.
576 ``lock_class`` forces the specified lock implementation. Note that as
577 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
578 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
579 you have created your own custom lock class.
581 ``thread_local`` indicates whether the lock token is placed in
582 thread-local storage. By default, the token is placed in thread local
583 storage so that a thread only sees its token, not a token set by
584 another thread. Consider the following timeline:
586 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
587 thread-1 sets the token to "abc"
588 time: 1, thread-2 blocks trying to acquire `my-lock` using the
589 Lock instance.
590 time: 5, thread-1 has not yet completed. redis expires the lock
591 key.
592 time: 5, thread-2 acquired `my-lock` now that it's available.
593 thread-2 sets the token to "xyz"
594 time: 6, thread-1 finishes its work and calls release(). if the
595 token is *not* stored in thread local storage, then
596 thread-1 would see the token value as "xyz" and would be
597 able to successfully release the thread-2's lock.
599 ``raise_on_release_error`` indicates whether to raise an exception when
600 the lock is no longer owned when exiting the context manager. By default,
601 this is True, meaning an exception will be raised. If False, the warning
602 will be logged and the exception will be suppressed.
604 In some use cases it's necessary to disable thread local storage. For
605 example, if you have code where one thread acquires a lock and passes
606 that lock instance to a worker thread to release later. If thread
607 local storage isn't disabled in this case, the worker thread won't see
608 the token set by the thread that acquired the lock. Our assumption
609 is that these cases aren't common and as such default to using
610 thread local storage."""
611 if lock_class is None:
612 lock_class = Lock
613 return lock_class(
614 self,
615 name,
616 timeout=timeout,
617 sleep=sleep,
618 blocking=blocking,
619 blocking_timeout=blocking_timeout,
620 thread_local=thread_local,
621 raise_on_release_error=raise_on_release_error,
622 )
624 def pubsub(self, **kwargs) -> "PubSub":
625 """
626 Return a Publish/Subscribe object. With this object, you can
627 subscribe to channels and listen for messages that get published to
628 them.
629 """
630 return PubSub(
631 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
632 )
634 def keyspace_notifications(
635 self,
636 key_prefix: Union[str, bytes, None] = None,
637 ignore_subscribe_messages: bool = True,
638 ) -> "AsyncKeyspaceNotifications":
639 """
640 Return an :class:`~redis.asyncio.keyspace_notifications.AsyncKeyspaceNotifications`
641 object for subscribing to keyspace and keyevent notifications.
643 Note: Keyspace notifications must be enabled on the Redis server via
644 the ``notify-keyspace-events`` configuration option.
646 Args:
647 key_prefix: Optional prefix to filter and strip from keys in
648 notifications.
649 ignore_subscribe_messages: If True, subscribe/unsubscribe
650 confirmations are not returned by
651 get_message/listen.
652 """
653 from redis.asyncio.keyspace_notifications import AsyncKeyspaceNotifications
655 return AsyncKeyspaceNotifications(
656 self,
657 key_prefix=key_prefix,
658 ignore_subscribe_messages=ignore_subscribe_messages,
659 )
661 def monitor(self) -> "Monitor":
662 return Monitor(self.connection_pool)
664 def client(self) -> "Redis":
665 return self.__class__(
666 connection_pool=self.connection_pool, single_connection_client=True
667 )
669 async def __aenter__(self: _RedisT) -> _RedisT:
670 """
671 Async context manager entry. Increments a usage counter so that the
672 connection pool is only closed (via aclose()) when no context is using
673 the client.
674 """
675 await self._increment_usage()
676 try:
677 # Initialize the client (i.e. establish connection, etc.)
678 return await self.initialize()
679 except Exception:
680 # If initialization fails, decrement the counter to keep it in sync
681 await self._decrement_usage()
682 raise
684 async def _increment_usage(self) -> int:
685 """
686 Helper coroutine to increment the usage counter while holding the lock.
687 Returns the new value of the usage counter.
688 """
689 async with self._usage_lock:
690 self._usage_counter += 1
691 return self._usage_counter
693 async def _decrement_usage(self) -> int:
694 """
695 Helper coroutine to decrement the usage counter while holding the lock.
696 Returns the new value of the usage counter.
697 """
698 async with self._usage_lock:
699 self._usage_counter -= 1
700 return self._usage_counter
702 async def __aexit__(self, exc_type, exc_value, traceback):
703 """
704 Async context manager exit. Decrements a usage counter. If this is the
705 last exit (counter becomes zero), the client closes its connection pool.
706 """
707 current_usage = await asyncio.shield(self._decrement_usage())
708 if current_usage == 0:
709 # This was the last active context, so disconnect the pool.
710 await asyncio.shield(self.aclose())
712 _DEL_MESSAGE = "Unclosed Redis client"
714 # passing _warnings and _grl as argument default since they may be gone
715 # by the time __del__ is called at shutdown
716 def __del__(
717 self,
718 _warn: Any = warnings.warn,
719 _grl: Any = asyncio.get_running_loop,
720 ) -> None:
721 if hasattr(self, "connection") and (self.connection is not None):
722 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
723 try:
724 context = {"client": self, "message": self._DEL_MESSAGE}
725 _grl().call_exception_handler(context)
726 except RuntimeError:
727 pass
728 self.connection._close()
730 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
731 """
732 Closes Redis client connection
734 Args:
735 close_connection_pool:
736 decides whether to close the connection pool used by this Redis client,
737 overriding Redis.auto_close_connection_pool.
738 By default, let Redis.auto_close_connection_pool decide
739 whether to close the connection pool.
740 """
741 conn = self.connection
742 if conn:
743 self.connection = None
744 await self.connection_pool.release(conn)
745 if close_connection_pool or (
746 close_connection_pool is None and self.auto_close_connection_pool
747 ):
748 await self.connection_pool.disconnect()
750 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
751 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
752 """
753 Alias for aclose(), for backwards compatibility
754 """
755 await self.aclose(close_connection_pool)
757 async def _send_command_parse_response(self, conn, command_name, *args, **options):
758 """
759 Send a command and parse the response
760 """
761 await conn.send_command(*args)
762 return await self.parse_response(conn, command_name, **options)
764 async def _close_connection(
765 self,
766 conn: Connection,
767 error: Optional[BaseException] = None,
768 failure_count: Optional[int] = None,
769 start_time: Optional[float] = None,
770 command_name: Optional[str] = None,
771 ):
772 """
773 Close the connection before retrying.
775 The supported exceptions are already checked in the
776 retry object so we don't need to do it here.
778 After we disconnect the connection, it will try to reconnect and
779 do a health check as part of the send_command logic(on connection level).
780 """
781 if (
782 error
783 and failure_count is not None
784 and failure_count <= conn.retry.get_retries()
785 ):
786 await record_operation_duration(
787 command_name=command_name,
788 duration_seconds=time.monotonic() - start_time,
789 server_address=getattr(conn, "host", None),
790 server_port=getattr(conn, "port", None),
791 db_namespace=str(conn.db),
792 error=error,
793 retry_attempts=failure_count,
794 )
796 await conn.disconnect(error=error, failure_count=failure_count)
798 # COMMAND EXECUTION AND PROTOCOL PARSING
799 async def execute_command(self, *args, **options):
800 """Execute a command and return a parsed response"""
801 await self.initialize()
802 pool = self.connection_pool
803 command_name = args[0]
804 conn = self.connection or await pool.get_connection()
806 # Start timing for observability
807 start_time = time.monotonic()
808 # Track actual retry attempts for error reporting
809 actual_retry_attempts = 0
811 def failure_callback(error, failure_count):
812 nonlocal actual_retry_attempts
813 actual_retry_attempts = failure_count
814 return self._close_connection(
815 conn, error, failure_count, start_time, command_name
816 )
818 if self.single_connection_client:
819 await self._single_conn_lock.acquire()
820 try:
821 result = await conn.retry.call_with_retry(
822 lambda: self._send_command_parse_response(
823 conn, command_name, *args, **options
824 ),
825 failure_callback,
826 with_failure_count=True,
827 )
829 await record_operation_duration(
830 command_name=command_name,
831 duration_seconds=time.monotonic() - start_time,
832 server_address=getattr(conn, "host", None),
833 server_port=getattr(conn, "port", None),
834 db_namespace=str(conn.db),
835 )
836 return result
837 except Exception as e:
838 await record_error_count(
839 server_address=getattr(conn, "host", None),
840 server_port=getattr(conn, "port", None),
841 network_peer_address=getattr(conn, "host", None),
842 network_peer_port=getattr(conn, "port", None),
843 error_type=e,
844 retry_attempts=actual_retry_attempts,
845 is_internal=False,
846 )
847 raise
848 finally:
849 if self.single_connection_client:
850 self._single_conn_lock.release()
851 if not self.connection:
852 await pool.release(conn)
854 async def parse_response(
855 self, connection: Connection, command_name: Union[str, bytes], **options
856 ):
857 """Parses a response from the Redis server"""
858 try:
859 if NEVER_DECODE in options:
860 response = await connection.read_response(disable_decoding=True)
861 options.pop(NEVER_DECODE)
862 else:
863 response = await connection.read_response()
864 except ResponseError:
865 if EMPTY_RESPONSE in options:
866 return options[EMPTY_RESPONSE]
867 raise
869 if EMPTY_RESPONSE in options:
870 options.pop(EMPTY_RESPONSE)
872 # Remove keys entry, it needs only for cache.
873 options.pop("keys", None)
875 if command_name in self.response_callbacks:
876 # Mypy bug: https://github.com/python/mypy/issues/10977
877 command_name = cast(str, command_name)
878 retval = self.response_callbacks[command_name](response, **options)
879 return await retval if inspect.isawaitable(retval) else retval
880 return response
883StrictRedis = Redis
886class MonitorCommandInfo(TypedDict):
887 time: float
888 db: int
889 client_address: str
890 client_port: str
891 client_type: str
892 command: str
895class Monitor:
896 """
897 Monitor is useful for handling the MONITOR command to the redis server.
898 next_command() method returns one command from monitor
899 listen() method yields commands from monitor.
900 """
902 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
903 command_re = re.compile(r'"(.*?)(?<!\\)"')
905 def __init__(self, connection_pool: ConnectionPool):
906 self.connection_pool = connection_pool
907 self.connection: Optional[Connection] = None
909 async def connect(self):
910 if self.connection is None:
911 self.connection = await self.connection_pool.get_connection()
913 async def __aenter__(self):
914 await self.connect()
915 await self.connection.send_command("MONITOR")
916 # check that monitor returns 'OK', but don't return it to user
917 response = await self.connection.read_response()
918 if not bool_ok(response):
919 raise RedisError(f"MONITOR failed: {response}")
920 return self
922 async def __aexit__(self, *args):
923 await self.connection.disconnect()
924 await self.connection_pool.release(self.connection)
926 async def next_command(self) -> MonitorCommandInfo:
927 """Parse the response from a monitor command"""
928 await self.connect()
929 response = await self.connection.read_response()
930 if isinstance(response, bytes):
931 response = self.connection.encoder.decode(response, force=True)
932 command_time, command_data = response.split(" ", 1)
933 m = self.monitor_re.match(command_data)
934 db_id, client_info, command = m.groups()
935 command = " ".join(self.command_re.findall(command))
936 # Redis escapes double quotes because each piece of the command
937 # string is surrounded by double quotes. We don't have that
938 # requirement so remove the escaping and leave the quote.
939 command = command.replace('\\"', '"')
941 if client_info == "lua":
942 client_address = "lua"
943 client_port = ""
944 client_type = "lua"
945 elif client_info.startswith("unix"):
946 client_address = "unix"
947 client_port = client_info[5:]
948 client_type = "unix"
949 else:
950 # use rsplit as ipv6 addresses contain colons
951 client_address, client_port = client_info.rsplit(":", 1)
952 client_type = "tcp"
953 return {
954 "time": float(command_time),
955 "db": int(db_id),
956 "client_address": client_address,
957 "client_port": client_port,
958 "client_type": client_type,
959 "command": command,
960 }
962 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
963 """Listen for commands coming to the server."""
964 while True:
965 yield await self.next_command()
968class PubSub:
969 """
970 PubSub provides publish, subscribe and listen support to Redis channels.
972 After subscribing to one or more channels, the listen() method will block
973 until a message arrives on one of the subscribed channels. That message
974 will be returned and it's safe to start listening again.
975 """
977 PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
978 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
979 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
981 def __init__(
982 self,
983 connection_pool: ConnectionPool,
984 shard_hint: Optional[str] = None,
985 ignore_subscribe_messages: bool = False,
986 encoder=None,
987 push_handler_func: Optional[Callable] = None,
988 event_dispatcher: Optional["EventDispatcher"] = None,
989 ):
990 if event_dispatcher is None:
991 self._event_dispatcher = EventDispatcher()
992 else:
993 self._event_dispatcher = event_dispatcher
994 self.connection_pool = connection_pool
995 self.shard_hint = shard_hint
996 self.ignore_subscribe_messages = ignore_subscribe_messages
997 self.connection = None
998 # we need to know the encoding options for this connection in order
999 # to lookup channel and pattern names for callback handlers.
1000 self.encoder = encoder
1001 self.push_handler_func = push_handler_func
1002 if self.encoder is None:
1003 self.encoder = self.connection_pool.get_encoder()
1004 if self.encoder.decode_responses:
1005 self.health_check_response = [
1006 ["pong", self.HEALTH_CHECK_MESSAGE],
1007 self.HEALTH_CHECK_MESSAGE,
1008 ]
1009 else:
1010 self.health_check_response = [
1011 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
1012 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
1013 ]
1014 if self.push_handler_func is None:
1015 _set_info_logger()
1016 self.channels = {}
1017 self.pending_unsubscribe_channels = set()
1018 self.patterns = {}
1019 self.pending_unsubscribe_patterns = set()
1020 self.shard_channels = {}
1021 self.pending_unsubscribe_shard_channels = set()
1022 self._lock = asyncio.Lock()
1024 async def __aenter__(self):
1025 return self
1027 async def __aexit__(self, exc_type, exc_value, traceback):
1028 await self.aclose()
1030 def __del__(self):
1031 if self.connection:
1032 self.connection.deregister_connect_callback(self.on_connect)
1034 async def aclose(self):
1035 # In case a connection property does not yet exist
1036 # (due to a crash earlier in the Redis() constructor), return
1037 # immediately as there is nothing to clean-up.
1038 if not hasattr(self, "connection"):
1039 return
1040 async with self._lock:
1041 if self.connection:
1042 # Use nowait=True to avoid awaiting StreamWriter.wait_closed(),
1043 # which can deadlock when a concurrent reader task (e.g. one
1044 # running pubsub.run() or get_message(block=True)) still holds
1045 # the transport. See https://github.com/redis/redis-py/issues/3941
1046 await self.connection.disconnect(nowait=True)
1047 self.connection.deregister_connect_callback(self.on_connect)
1048 await self.connection_pool.release(self.connection)
1049 self.connection = None
1050 self.channels = {}
1051 self.pending_unsubscribe_channels = set()
1052 self.patterns = {}
1053 self.pending_unsubscribe_patterns = set()
1054 self.shard_channels = {}
1055 self.pending_unsubscribe_shard_channels = set()
1057 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
1058 async def close(self) -> None:
1059 """Alias for aclose(), for backwards compatibility"""
1060 await self.aclose()
1062 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
1063 async def reset(self) -> None:
1064 """Alias for aclose(), for backwards compatibility"""
1065 await self.aclose()
1067 async def _resubscribe(self, subscribed, subscribe_fn) -> None:
1068 subscriptions_without_handlers, subscriptions_with_handlers = (
1069 partition_pubsub_subscriptions_by_handler(subscribed, self.encoder)
1070 )
1071 if subscriptions_without_handlers or subscriptions_with_handlers:
1072 await subscribe_fn(
1073 *subscriptions_without_handlers, **subscriptions_with_handlers
1074 )
1076 async def _resubscribe_shard_channels(self) -> None:
1077 await self._resubscribe(self.shard_channels, self.ssubscribe)
1079 async def on_connect(self, connection: Connection):
1080 """Re-subscribe to any channels and patterns previously subscribed to"""
1081 self.pending_unsubscribe_channels.clear()
1082 self.pending_unsubscribe_patterns.clear()
1083 self.pending_unsubscribe_shard_channels.clear()
1084 if self.channels:
1085 await self._resubscribe(self.channels, self.subscribe)
1086 if self.patterns:
1087 await self._resubscribe(self.patterns, self.psubscribe)
1088 if self.shard_channels:
1089 await self._resubscribe_shard_channels()
1091 @property
1092 def subscribed(self):
1093 """Indicates if there are subscriptions to any channels or patterns"""
1094 return bool(self.channels or self.patterns or self.shard_channels)
1096 async def execute_command(self, *args: EncodableT):
1097 """Execute a publish/subscribe command"""
1099 # NOTE: don't parse the response in this function -- it could pull a
1100 # legitimate message off the stack if the connection is already
1101 # subscribed to one or more channels
1103 await self.connect()
1104 connection = self.connection
1105 kwargs = {"check_health": not self.subscribed}
1106 await self._execute(connection, connection.send_command, *args, **kwargs)
1108 async def connect(self):
1109 """
1110 Ensure that the PubSub is connected
1111 """
1112 if self.connection is None:
1113 self.connection = await self.connection_pool.get_connection()
1114 # register a callback that re-subscribes to any channels we
1115 # were listening to when we were disconnected
1116 self.connection.register_connect_callback(self.on_connect)
1117 else:
1118 await self.connection.connect()
1119 if self.push_handler_func is not None:
1120 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
1122 self._event_dispatcher.dispatch(
1123 AfterPubSubConnectionInstantiationEvent(
1124 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
1125 )
1126 )
1128 async def _reconnect(
1129 self,
1130 conn,
1131 error: Optional[BaseException] = None,
1132 failure_count: Optional[int] = None,
1133 start_time: Optional[float] = None,
1134 command_name: Optional[str] = None,
1135 ):
1136 """
1137 The supported exceptions are already checked in the
1138 retry object so we don't need to do it here.
1140 In this error handler we are trying to reconnect to the server.
1141 """
1142 if (
1143 error
1144 and failure_count is not None
1145 and failure_count <= conn.retry.get_retries()
1146 ):
1147 if command_name:
1148 await record_operation_duration(
1149 command_name=command_name,
1150 duration_seconds=time.monotonic() - start_time,
1151 server_address=getattr(conn, "host", None),
1152 server_port=getattr(conn, "port", None),
1153 db_namespace=str(conn.db),
1154 error=error,
1155 retry_attempts=failure_count,
1156 )
1157 await conn.disconnect(error=error, failure_count=failure_count)
1158 await conn.connect()
1160 async def _execute(self, conn, command, *args, **kwargs):
1161 """
1162 Connect manually upon disconnection. If the Redis server is down,
1163 this will fail and raise a ConnectionError as desired.
1164 After reconnection, the ``on_connect`` callback should have been
1165 called by the # connection to resubscribe us to any channels and
1166 patterns we were previously listening to
1167 """
1168 if not len(args) == 0:
1169 command_name = args[0]
1170 else:
1171 command_name = None
1173 # Start timing for observability
1174 start_time = time.monotonic()
1175 # Track actual retry attempts for error reporting
1176 actual_retry_attempts = 0
1178 def failure_callback(error, failure_count):
1179 nonlocal actual_retry_attempts
1180 actual_retry_attempts = failure_count
1181 return self._reconnect(conn, error, failure_count, start_time, command_name)
1183 try:
1184 response = await conn.retry.call_with_retry(
1185 lambda: command(*args, **kwargs),
1186 failure_callback,
1187 with_failure_count=True,
1188 )
1190 if command_name:
1191 await record_operation_duration(
1192 command_name=command_name,
1193 duration_seconds=time.monotonic() - start_time,
1194 server_address=getattr(conn, "host", None),
1195 server_port=getattr(conn, "port", None),
1196 db_namespace=str(conn.db),
1197 )
1199 return response
1200 except Exception as e:
1201 await record_error_count(
1202 server_address=getattr(conn, "host", None),
1203 server_port=getattr(conn, "port", None),
1204 network_peer_address=getattr(conn, "host", None),
1205 network_peer_port=getattr(conn, "port", None),
1206 error_type=e,
1207 retry_attempts=actual_retry_attempts,
1208 is_internal=False,
1209 )
1210 raise
1212 async def parse_response(self, block: bool = True, timeout: float = 0):
1213 """
1214 Parse the response from a publish/subscribe command.
1216 Args:
1217 block: If True, block indefinitely until a message is available.
1218 If False, return immediately if no message is available.
1219 Default: True
1220 timeout: The timeout in seconds for reading a response when block=False.
1221 This parameter is ignored when block=True.
1222 Default: 0 (return immediately if no data available)
1224 Returns:
1225 The parsed response from the server, or None if no message is available
1226 within the timeout period (when block=False).
1228 Important:
1229 The block and timeout parameters work together:
1230 - When block=True: timeout is IGNORED, method blocks indefinitely
1231 - When block=False: timeout is USED, method returns after timeout expires
1233 Typically, you should use get_message(timeout=X) instead of calling
1234 parse_response() directly. The get_message() method automatically sets
1235 block=False when a timeout is provided, and block=True when timeout=None.
1237 Example:
1238 # Block indefinitely (timeout is ignored)
1239 response = await pubsub.parse_response(block=True, timeout=0.1)
1241 # Non-blocking with 0.1 second timeout
1242 response = await pubsub.parse_response(block=False, timeout=0.1)
1244 # Non-blocking, return immediately
1245 response = await pubsub.parse_response(block=False, timeout=0)
1247 # Recommended: use get_message() instead
1248 msg = await pubsub.get_message(timeout=0.1) # automatically sets block=False
1249 msg = await pubsub.get_message(timeout=None) # automatically sets block=True
1250 """
1251 conn = self.connection
1252 if conn is None:
1253 raise RuntimeError(
1254 "pubsub connection not set: "
1255 "did you forget to call subscribe() or psubscribe()?"
1256 )
1258 await self.check_health()
1260 if not conn.is_connected:
1261 await conn.connect()
1263 read_timeout = None if block else timeout
1264 response = await self._execute(
1265 conn,
1266 conn.read_response,
1267 timeout=read_timeout,
1268 disconnect_on_error=False,
1269 push_request=True,
1270 )
1272 if conn.health_check_interval and response in self.health_check_response:
1273 # ignore the health check message as user might not expect it
1274 return None
1275 return response
1277 async def check_health(self):
1278 conn = self.connection
1279 if conn is None:
1280 raise RuntimeError(
1281 "pubsub connection not set: "
1282 "did you forget to call subscribe() or psubscribe()?"
1283 )
1285 if (
1286 conn.health_check_interval
1287 and asyncio.get_running_loop().time() > conn.next_health_check
1288 ):
1289 await conn.send_command(
1290 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1291 )
1293 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1294 """
1295 normalize channel/pattern names to be either bytes or strings
1296 based on whether responses are automatically decoded. this saves us
1297 from coercing the value for each message coming in.
1298 """
1299 encode = self.encoder.encode
1300 decode = self.encoder.decode
1301 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1303 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1304 """
1305 Subscribe to channel patterns. Patterns supplied as keyword arguments
1306 expect a pattern name as the key and a callable as the value. A
1307 pattern's callable will be invoked automatically when a message is
1308 received on that pattern rather than producing a message via
1309 ``listen()``.
1310 """
1311 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1312 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1313 # Mypy bug: https://github.com/python/mypy/issues/10970
1314 new_patterns.update(kwargs) # type: ignore[arg-type]
1315 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1316 # update the patterns dict AFTER we send the command. we don't want to
1317 # subscribe twice to these patterns, once for the command and again
1318 # for the reconnection.
1319 new_patterns = self._normalize_keys(new_patterns)
1320 self.patterns.update(new_patterns)
1321 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1322 return ret_val
1324 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1325 """
1326 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1327 all patterns.
1328 """
1329 patterns: Iterable[ChannelT]
1330 if args:
1331 parsed_args = list_or_args((args[0],), args[1:])
1332 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1333 else:
1334 parsed_args = []
1335 patterns = self.patterns
1336 self.pending_unsubscribe_patterns.update(patterns)
1337 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1339 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1340 """
1341 Subscribe to channels. Channels supplied as keyword arguments expect
1342 a channel name as the key and a callable as the value. A channel's
1343 callable will be invoked automatically when a message is received on
1344 that channel rather than producing a message via ``listen()`` or
1345 ``get_message()``.
1346 """
1347 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1348 new_channels = dict.fromkeys(parsed_args)
1349 # Mypy bug: https://github.com/python/mypy/issues/10970
1350 new_channels.update(kwargs) # type: ignore[arg-type]
1351 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1352 # update the channels dict AFTER we send the command. we don't want to
1353 # subscribe twice to these channels, once for the command and again
1354 # for the reconnection.
1355 new_channels = self._normalize_keys(new_channels)
1356 self.channels.update(new_channels)
1357 self.pending_unsubscribe_channels.difference_update(new_channels)
1358 return ret_val
1360 def unsubscribe(self, *args) -> Awaitable:
1361 """
1362 Unsubscribe from the supplied channels. If empty, unsubscribe from
1363 all channels
1364 """
1365 if args:
1366 parsed_args = list_or_args(args[0], args[1:])
1367 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1368 else:
1369 parsed_args = []
1370 channels = self.channels
1371 self.pending_unsubscribe_channels.update(channels)
1372 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1374 async def ssubscribe(self, *args, target_node=None, **kwargs):
1375 """
1376 Subscribes the client to the specified shard channels.
1377 Channels supplied as keyword arguments expect a channel name as the key
1378 and a callable as the value. A channel's callable will be invoked automatically
1379 when a message is received on that channel rather than producing a message via
1380 ``listen()`` or ``get_sharded_message()``.
1381 """
1382 if args:
1383 args = list_or_args(args[0], args[1:])
1384 new_s_channels = dict.fromkeys(args)
1385 new_s_channels.update(kwargs)
1386 ret_val = await self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1387 # update the s_channels dict AFTER we send the command. we don't want to
1388 # subscribe twice to these channels, once for the command and again
1389 # for the reconnection.
1390 new_s_channels = self._normalize_keys(new_s_channels)
1391 self.shard_channels.update(new_s_channels)
1392 self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1393 return ret_val
1395 def sunsubscribe(self, *args, target_node=None) -> Awaitable:
1396 """
1397 Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1398 all shard_channels
1399 """
1400 if args:
1401 args = list_or_args(args[0], args[1:])
1402 s_channels = self._normalize_keys(dict.fromkeys(args))
1403 else:
1404 s_channels = self.shard_channels
1405 self.pending_unsubscribe_shard_channels.update(s_channels)
1406 return self.execute_command("SUNSUBSCRIBE", *args)
1408 async def listen(self) -> AsyncIterator:
1409 """Listen for messages on channels this client has been subscribed to"""
1410 while self.subscribed:
1411 response = await self.handle_message(await self.parse_response(block=True))
1412 if response is not None:
1413 yield response
1415 async def get_message(
1416 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1417 ):
1418 """
1419 Get the next message if one is available, otherwise None.
1421 If timeout is specified, the system will wait for `timeout` seconds
1422 before returning. Timeout should be specified as a floating point
1423 number or None to wait indefinitely.
1424 """
1425 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1426 if response:
1427 return await self.handle_message(response, ignore_subscribe_messages)
1428 return None
1430 def ping(self, message=None) -> Awaitable[bool]:
1431 """
1432 Ping the Redis server to test connectivity.
1434 Sends a PING command to the Redis server and returns True if the server
1435 responds with "PONG".
1436 """
1437 args = ["PING", message] if message is not None else ["PING"]
1438 return self.execute_command(*args)
1440 async def handle_message(self, response, ignore_subscribe_messages=False):
1441 """
1442 Parses a pub/sub message. If the channel or pattern was subscribed to
1443 with a message handler, the handler is invoked instead of a parsed
1444 message being returned.
1445 """
1446 if response is None:
1447 return None
1448 if isinstance(response, bytes):
1449 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1450 message_type = str_if_bytes(response[0])
1451 if message_type == "pmessage":
1452 message = {
1453 "type": message_type,
1454 "pattern": response[1],
1455 "channel": response[2],
1456 "data": response[3],
1457 }
1458 elif message_type == "pong":
1459 message = {
1460 "type": message_type,
1461 "pattern": None,
1462 "channel": None,
1463 "data": response[1],
1464 }
1465 else:
1466 message = {
1467 "type": message_type,
1468 "pattern": None,
1469 "channel": response[1],
1470 "data": response[2],
1471 }
1473 if message_type in ["message", "pmessage"]:
1474 channel = str_if_bytes(message["channel"])
1475 await record_pubsub_message(
1476 direction=PubSubDirection.RECEIVE,
1477 channel=channel,
1478 )
1479 elif message_type == "smessage":
1480 channel = str_if_bytes(message["channel"])
1481 await record_pubsub_message(
1482 direction=PubSubDirection.RECEIVE,
1483 channel=channel,
1484 sharded=True,
1485 )
1487 # if this is an unsubscribe message, remove it from memory
1488 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1489 if message_type == "punsubscribe":
1490 pattern = response[1]
1491 if pattern in self.pending_unsubscribe_patterns:
1492 self.pending_unsubscribe_patterns.remove(pattern)
1493 self.patterns.pop(pattern, None)
1494 elif message_type == "sunsubscribe":
1495 s_channel = response[1]
1496 if s_channel in self.pending_unsubscribe_shard_channels:
1497 self.pending_unsubscribe_shard_channels.remove(s_channel)
1498 self.shard_channels.pop(s_channel, None)
1499 else:
1500 channel = response[1]
1501 if channel in self.pending_unsubscribe_channels:
1502 self.pending_unsubscribe_channels.remove(channel)
1503 self.channels.pop(channel, None)
1505 if message_type in self.PUBLISH_MESSAGE_TYPES:
1506 # if there's a message handler, invoke it
1507 if message_type == "pmessage":
1508 handler = self.patterns.get(message["pattern"], None)
1509 elif message_type == "smessage":
1510 handler = self.shard_channels.get(message["channel"], None)
1511 else:
1512 handler = self.channels.get(message["channel"], None)
1513 if handler:
1514 if inspect.iscoroutinefunction(handler):
1515 await handler(message)
1516 else:
1517 handler(message)
1518 return None
1519 elif message_type != "pong":
1520 # this is a subscribe/unsubscribe message. ignore if we don't
1521 # want them
1522 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1523 return None
1525 return message
1527 async def run(
1528 self,
1529 *,
1530 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1531 poll_timeout: float = 1.0,
1532 pubsub=None,
1533 ) -> None:
1534 """Process pub/sub messages using registered callbacks.
1536 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1537 redis-py, but it is a coroutine. To launch it as a separate task, use
1538 ``asyncio.create_task``:
1540 >>> task = asyncio.create_task(pubsub.run())
1542 To shut it down, use asyncio cancellation:
1544 >>> task.cancel()
1545 >>> await task
1546 """
1547 for channel, handler in self.channels.items():
1548 if handler is None:
1549 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1550 for pattern, handler in self.patterns.items():
1551 if handler is None:
1552 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1554 await self.connect()
1555 while True:
1556 try:
1557 if pubsub is None:
1558 await self.get_message(
1559 ignore_subscribe_messages=True, timeout=poll_timeout
1560 )
1561 else:
1562 await pubsub.get_message(
1563 ignore_subscribe_messages=True, timeout=poll_timeout
1564 )
1565 except asyncio.CancelledError:
1566 raise
1567 except BaseException as e:
1568 if exception_handler is None:
1569 raise
1570 res = exception_handler(e, self)
1571 if inspect.isawaitable(res):
1572 await res
1573 # Ensure that other tasks on the event loop get a chance to run
1574 # if we didn't have to block for I/O anywhere.
1575 await asyncio.sleep(0)
1578class PubsubWorkerExceptionHandler(Protocol):
1579 def __call__(self, e: BaseException, pubsub: PubSub): ...
1582class AsyncPubsubWorkerExceptionHandler(Protocol):
1583 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1586PSWorkerThreadExcHandlerT = Union[
1587 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1588]
1591CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1592CommandStackT = List[CommandT]
1595class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1596 """
1597 Pipelines provide a way to transmit multiple commands to the Redis server
1598 in one transmission. This is convenient for batch processing, such as
1599 saving all the values in a list to Redis.
1601 All commands executed within a pipeline(when running in transactional mode,
1602 which is the default behavior) are wrapped with MULTI and EXEC
1603 calls. This guarantees all commands executed in the pipeline will be
1604 executed atomically.
1606 Any command raising an exception does *not* halt the execution of
1607 subsequent commands in the pipeline. Instead, the exception is caught
1608 and its instance is placed into the response list returned by execute().
1609 Code iterating over the response list should be able to deal with an
1610 instance of an exception as a potential value. In general, these will be
1611 ResponseError exceptions, such as those raised when issuing a command
1612 on a key of a different datatype.
1613 """
1615 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1617 def __init__(
1618 self,
1619 connection_pool: ConnectionPool,
1620 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1621 transaction: bool,
1622 shard_hint: Optional[str],
1623 ):
1624 self.connection_pool = connection_pool
1625 self.connection = None
1626 self.response_callbacks = response_callbacks
1627 self.is_transaction = transaction
1628 self.shard_hint = shard_hint
1629 self.watching = False
1630 self.command_stack: CommandStackT = []
1631 self.scripts: Set[Script] = set()
1632 self.explicit_transaction = False
1634 async def __aenter__(self: _RedisT) -> _RedisT:
1635 return self
1637 async def __aexit__(self, exc_type, exc_value, traceback):
1638 await self.reset()
1640 def __await__(self):
1641 return self._async_self().__await__()
1643 _DEL_MESSAGE = "Unclosed Pipeline client"
1645 def __len__(self):
1646 return len(self.command_stack)
1648 def __bool__(self):
1649 """Pipeline instances should always evaluate to True"""
1650 return True
1652 async def _async_self(self):
1653 return self
1655 async def reset(self):
1656 self.command_stack = []
1657 self.scripts = set()
1658 # make sure to reset the connection state in the event that we were
1659 # watching something
1660 if self.watching and self.connection:
1661 try:
1662 # call this manually since our unwatch or
1663 # immediate_execute_command methods can call reset()
1664 await self.connection.send_command("UNWATCH")
1665 await self.connection.read_response()
1666 except ConnectionError:
1667 # disconnect will also remove any previous WATCHes
1668 if self.connection:
1669 await self.connection.disconnect()
1670 # clean up the other instance attributes
1671 self.watching = False
1672 self.explicit_transaction = False
1673 # we can safely return the connection to the pool here since we're
1674 # sure we're no longer WATCHing anything
1675 if self.connection:
1676 await self.connection_pool.release(self.connection)
1677 self.connection = None
1679 async def aclose(self) -> None:
1680 """Alias for reset(), a standard method name for cleanup"""
1681 await self.reset()
1683 def multi(self):
1684 """
1685 Start a transactional block of the pipeline after WATCH commands
1686 are issued. End the transactional block with `execute`.
1687 """
1688 if self.explicit_transaction:
1689 raise RedisError("Cannot issue nested calls to MULTI")
1690 if self.command_stack:
1691 raise RedisError(
1692 "Commands without an initial WATCH have already been issued"
1693 )
1694 self.explicit_transaction = True
1696 def execute_command(
1697 self, *args, **kwargs
1698 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1699 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1700 return self.immediate_execute_command(*args, **kwargs)
1701 return self.pipeline_execute_command(*args, **kwargs)
1703 async def _disconnect_reset_raise_on_watching(
1704 self,
1705 conn: Connection,
1706 error: Exception,
1707 failure_count: Optional[int] = None,
1708 start_time: Optional[float] = None,
1709 command_name: Optional[str] = None,
1710 ) -> None:
1711 """
1712 Close the connection reset watching state and
1713 raise an exception if we were watching.
1715 The supported exceptions are already checked in the
1716 retry object so we don't need to do it here.
1718 After we disconnect the connection, it will try to reconnect and
1719 do a health check as part of the send_command logic(on connection level).
1720 """
1721 if (
1722 error
1723 and failure_count is not None
1724 and failure_count <= conn.retry.get_retries()
1725 ):
1726 await record_operation_duration(
1727 command_name=command_name,
1728 duration_seconds=time.monotonic() - start_time,
1729 server_address=getattr(conn, "host", None),
1730 server_port=getattr(conn, "port", None),
1731 db_namespace=str(conn.db),
1732 error=error,
1733 retry_attempts=failure_count,
1734 )
1735 await conn.disconnect(error=error, failure_count=failure_count)
1736 # if we were already watching a variable, the watch is no longer
1737 # valid since this connection has died. raise a WatchError, which
1738 # indicates the user should retry this transaction.
1739 if self.watching:
1740 await self.reset()
1741 raise WatchError(
1742 f"A {type(error).__name__} occurred while watching one or more keys"
1743 )
1745 async def immediate_execute_command(self, *args, **options):
1746 """
1747 Execute a command immediately, but don't auto-retry on the supported
1748 errors for retry if we're already WATCHing a variable.
1749 Used when issuing WATCH or subsequent commands retrieving their values but before
1750 MULTI is called.
1751 """
1752 command_name = args[0]
1753 conn = self.connection
1754 # if this is the first call, we need a connection
1755 if not conn:
1756 conn = await self.connection_pool.get_connection()
1757 self.connection = conn
1759 # Start timing for observability
1760 start_time = time.monotonic()
1761 # Track actual retry attempts for error reporting
1762 actual_retry_attempts = 0
1764 def failure_callback(error, failure_count):
1765 nonlocal actual_retry_attempts
1766 actual_retry_attempts = failure_count
1767 return self._disconnect_reset_raise_on_watching(
1768 conn, error, failure_count, start_time, command_name
1769 )
1771 try:
1772 response = await conn.retry.call_with_retry(
1773 lambda: self._send_command_parse_response(
1774 conn, command_name, *args, **options
1775 ),
1776 failure_callback,
1777 with_failure_count=True,
1778 )
1780 await record_operation_duration(
1781 command_name=command_name,
1782 duration_seconds=time.monotonic() - start_time,
1783 server_address=getattr(conn, "host", None),
1784 server_port=getattr(conn, "port", None),
1785 db_namespace=str(conn.db),
1786 )
1788 return response
1789 except Exception as e:
1790 await record_error_count(
1791 server_address=getattr(conn, "host", None),
1792 server_port=getattr(conn, "port", None),
1793 network_peer_address=getattr(conn, "host", None),
1794 network_peer_port=getattr(conn, "port", None),
1795 error_type=e,
1796 retry_attempts=actual_retry_attempts,
1797 is_internal=False,
1798 )
1799 raise
1801 def pipeline_execute_command(self, *args, **options):
1802 """
1803 Stage a command to be executed when execute() is next called
1805 Returns the current Pipeline object back so commands can be
1806 chained together, such as:
1808 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1810 At some other point, you can then run: pipe.execute(),
1811 which will execute all commands queued in the pipe.
1812 """
1813 self.command_stack.append((args, options))
1814 return self
1816 async def _execute_transaction( # noqa: C901
1817 self, connection: Connection, commands: CommandStackT, raise_on_error
1818 ):
1819 pre: CommandT = (("MULTI",), {})
1820 post: CommandT = (("EXEC",), {})
1821 cmds = (pre, *commands, post)
1822 all_cmds = connection.pack_commands(
1823 args for args, options in cmds if EMPTY_RESPONSE not in options
1824 )
1825 await connection.send_packed_command(all_cmds)
1826 errors = []
1828 # parse off the response for MULTI
1829 # NOTE: we need to handle ResponseErrors here and continue
1830 # so that we read all the additional command messages from
1831 # the socket
1832 try:
1833 await self.parse_response(connection, "_")
1834 except ResponseError as err:
1835 errors.append((0, err))
1837 # and all the other commands
1838 for i, command in enumerate(commands):
1839 if EMPTY_RESPONSE in command[1]:
1840 errors.append((i, command[1][EMPTY_RESPONSE]))
1841 else:
1842 try:
1843 await self.parse_response(connection, "_")
1844 except ResponseError as err:
1845 self.annotate_exception(err, i + 1, command[0])
1846 errors.append((i, err))
1848 # parse the EXEC.
1849 try:
1850 response = await self.parse_response(connection, "_")
1851 except ExecAbortError as err:
1852 if errors:
1853 raise errors[0][1] from err
1854 raise
1856 # EXEC clears any watched keys
1857 self.watching = False
1859 if response is None:
1860 raise WatchError("Watched variable changed.") from None
1862 # put any parse errors into the response
1863 for i, e in errors:
1864 response.insert(i, e)
1866 if len(response) != len(commands):
1867 if self.connection:
1868 await self.connection.disconnect()
1869 raise ResponseError(
1870 "Wrong number of response items from pipeline execution"
1871 ) from None
1873 # find any errors in the response and raise if necessary
1874 if raise_on_error:
1875 self.raise_first_error(commands, response)
1877 # We have to run response callbacks manually
1878 data = []
1879 for r, cmd in zip(response, commands):
1880 if not isinstance(r, Exception):
1881 args, options = cmd
1882 command_name = args[0]
1884 # Remove keys entry, it needs only for cache.
1885 options.pop("keys", None)
1887 if command_name in self.response_callbacks:
1888 r = self.response_callbacks[command_name](r, **options)
1889 if inspect.isawaitable(r):
1890 r = await r
1891 data.append(r)
1892 return data
1894 async def _execute_pipeline(
1895 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1896 ):
1897 # build up all commands into a single request to increase network perf
1898 all_cmds = connection.pack_commands([args for args, _ in commands])
1899 await connection.send_packed_command(all_cmds)
1901 response = []
1902 for args, options in commands:
1903 try:
1904 response.append(
1905 await self.parse_response(connection, args[0], **options)
1906 )
1907 except ResponseError as e:
1908 response.append(e)
1910 if raise_on_error:
1911 self.raise_first_error(commands, response)
1912 return response
1914 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1915 for i, r in enumerate(response):
1916 if isinstance(r, ResponseError):
1917 self.annotate_exception(r, i + 1, commands[i][0])
1918 raise r
1920 def annotate_exception(
1921 self, exception: Exception, number: int, command: Iterable[object]
1922 ) -> None:
1923 cmd = " ".join(map(safe_str, command))
1924 msg = (
1925 f"Command # {number} ({truncate_text(cmd)}) "
1926 f"of pipeline caused error: {exception.args}"
1927 )
1928 exception.args = (msg,) + exception.args[1:]
1930 async def parse_response(
1931 self, connection: Connection, command_name: Union[str, bytes], **options
1932 ):
1933 result = await super().parse_response(connection, command_name, **options)
1934 if command_name in self.UNWATCH_COMMANDS:
1935 self.watching = False
1936 elif command_name == "WATCH":
1937 self.watching = True
1938 return result
1940 async def load_scripts(self):
1941 # make sure all scripts that are about to be run on this pipeline exist
1942 scripts = list(self.scripts)
1943 immediate = self.immediate_execute_command
1944 shas = [s.sha for s in scripts]
1945 # we can't use the normal script_* methods because they would just
1946 # get buffered in the pipeline.
1947 exists = await immediate("SCRIPT EXISTS", *shas)
1948 if not all(exists):
1949 for s, exist in zip(scripts, exists):
1950 if not exist:
1951 s.sha = await immediate("SCRIPT LOAD", s.script)
1953 async def _disconnect_raise_on_watching(
1954 self,
1955 conn: Connection,
1956 error: Exception,
1957 failure_count: Optional[int] = None,
1958 start_time: Optional[float] = None,
1959 command_name: Optional[str] = None,
1960 ):
1961 """
1962 Close the connection, raise an exception if we were watching.
1964 The supported exceptions are already checked in the
1965 retry object so we don't need to do it here.
1967 After we disconnect the connection, it will try to reconnect and
1968 do a health check as part of the send_command logic(on connection level).
1969 """
1970 if (
1971 error
1972 and failure_count is not None
1973 and failure_count <= conn.retry.get_retries()
1974 ):
1975 await record_operation_duration(
1976 command_name=command_name,
1977 duration_seconds=time.monotonic() - start_time,
1978 server_address=getattr(conn, "host", None),
1979 server_port=getattr(conn, "port", None),
1980 db_namespace=str(conn.db),
1981 error=error,
1982 retry_attempts=failure_count,
1983 )
1984 await conn.disconnect(error=error, failure_count=failure_count)
1985 # if we were watching a variable, the watch is no longer valid
1986 # since this connection has died. raise a WatchError, which
1987 # indicates the user should retry this transaction.
1988 if self.watching:
1989 raise WatchError(
1990 f"A {type(error).__name__} occurred while watching one or more keys"
1991 )
1993 async def execute(self, raise_on_error: bool = True) -> List[Any]:
1994 """Execute all the commands in the current pipeline"""
1995 stack = self.command_stack
1996 if not stack and not self.watching:
1997 return []
1998 if self.scripts:
1999 await self.load_scripts()
2000 if self.is_transaction or self.explicit_transaction:
2001 execute = self._execute_transaction
2002 operation_name = "MULTI"
2003 else:
2004 execute = self._execute_pipeline
2005 operation_name = "PIPELINE"
2007 conn = self.connection
2008 if not conn:
2009 conn = await self.connection_pool.get_connection()
2010 # assign to self.connection so reset() releases the connection
2011 # back to the pool after we're done
2012 self.connection = conn
2013 conn = cast(Connection, conn)
2015 # Start timing for observability
2016 start_time = time.monotonic()
2017 # Track actual retry attempts for error reporting
2018 actual_retry_attempts = 0
2020 def failure_callback(error, failure_count):
2021 nonlocal actual_retry_attempts
2022 actual_retry_attempts = failure_count
2023 return self._disconnect_raise_on_watching(
2024 conn, error, failure_count, start_time, operation_name
2025 )
2027 try:
2028 response = await conn.retry.call_with_retry(
2029 lambda: execute(conn, stack, raise_on_error),
2030 failure_callback,
2031 with_failure_count=True,
2032 )
2034 await record_operation_duration(
2035 command_name=operation_name,
2036 duration_seconds=time.monotonic() - start_time,
2037 server_address=getattr(conn, "host", None),
2038 server_port=getattr(conn, "port", None),
2039 db_namespace=str(conn.db),
2040 )
2041 return response
2042 except Exception as e:
2043 await record_error_count(
2044 server_address=getattr(conn, "host", None),
2045 server_port=getattr(conn, "port", None),
2046 network_peer_address=getattr(conn, "host", None),
2047 network_peer_port=getattr(conn, "port", None),
2048 error_type=e,
2049 retry_attempts=actual_retry_attempts,
2050 is_internal=False,
2051 )
2052 raise
2053 finally:
2054 await self.reset()
2056 async def discard(self):
2057 """Flushes all previously queued commands
2058 See: https://redis.io/commands/DISCARD
2059 """
2060 await self.execute_command("DISCARD")
2062 async def watch(self, *names: KeyT):
2063 """Watches the values at keys ``names``"""
2064 if self.explicit_transaction:
2065 raise RedisError("Cannot issue a WATCH after a MULTI")
2066 return await self.execute_command("WATCH", *names)
2068 async def unwatch(self):
2069 """Unwatches all previously specified keys"""
2070 return self.watching and await self.execute_command("UNWATCH") or True