Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/client.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import copy
3import inspect
4import re
5import time
6import warnings
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 AsyncIterator,
11 Awaitable,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Literal,
17 Mapping,
18 MutableMapping,
19 Optional,
20 Protocol,
21 Set,
22 Tuple,
23 Type,
24 TypedDict,
25 TypeVar,
26 Union,
27 cast,
28)
30from redis._parsers.helpers import (
31 _RedisCallbacks,
32 _RedisCallbacksRESP2,
33 _RedisCallbacksRESP3,
34 bool_ok,
35)
36from redis.asyncio.connection import (
37 Connection,
38 ConnectionPool,
39 SSLConnection,
40 UnixDomainSocketConnection,
41)
42from redis.asyncio.lock import Lock
43from redis.asyncio.observability.recorder import (
44 record_error_count,
45 record_operation_duration,
46 record_pubsub_message,
47)
48from redis.asyncio.retry import Retry
49from redis.backoff import ExponentialWithJitterBackoff
50from redis.client import (
51 EMPTY_RESPONSE,
52 NEVER_DECODE,
53 AbstractRedis,
54 CaseInsensitiveDict,
55)
56from redis.commands import (
57 AsyncCoreCommands,
58 AsyncRedisModuleCommands,
59 AsyncSentinelCommands,
60 list_or_args,
61)
62from redis.credentials import CredentialProvider
63from redis.driver_info import DriverInfo, resolve_driver_info
64from redis.event import (
65 AfterPooledConnectionsInstantiationEvent,
66 AfterPubSubConnectionInstantiationEvent,
67 AfterSingleConnectionInstantiationEvent,
68 ClientType,
69 EventDispatcher,
70)
71from redis.exceptions import (
72 ConnectionError,
73 ExecAbortError,
74 PubSubError,
75 RedisError,
76 ResponseError,
77 WatchError,
78)
79from redis.observability.attributes import PubSubDirection
80from redis.typing import ChannelT, EncodableT, KeyT
81from redis.utils import (
82 SSL_AVAILABLE,
83 _set_info_logger,
84 deprecated_args,
85 deprecated_function,
86 safe_str,
87 str_if_bytes,
88 truncate_text,
89)
91if TYPE_CHECKING and SSL_AVAILABLE:
92 from ssl import TLSVersion, VerifyFlags, VerifyMode
93else:
94 TLSVersion = None
95 VerifyMode = None
96 VerifyFlags = None
98PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
99_KeyT = TypeVar("_KeyT", bound=KeyT)
100_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
101_RedisT = TypeVar("_RedisT", bound="Redis")
102_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
103if TYPE_CHECKING:
104 from redis.asyncio.keyspace_notifications import AsyncKeyspaceNotifications
105 from redis.commands.core import Script
108class ResponseCallbackProtocol(Protocol):
109 def __call__(self, response: Any, **kwargs): ...
112class AsyncResponseCallbackProtocol(Protocol):
113 async def __call__(self, response: Any, **kwargs): ...
116ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
119class Redis(
120 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
121):
122 """
123 Implementation of the Redis protocol.
125 This abstract class provides a Python interface to all Redis commands
126 and an implementation of the Redis protocol.
128 Pipelines derive from this, implementing how
129 the commands are sent and received to the Redis server. Based on
130 configuration, an instance will either use a ConnectionPool, or
131 Connection object to talk to redis.
132 """
134 # Type discrimination marker for @overload self-type pattern
135 _is_async_client: Literal[True] = True
137 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
139 @classmethod
140 def from_url(
141 cls: Type["Redis"],
142 url: str,
143 single_connection_client: bool = False,
144 auto_close_connection_pool: Optional[bool] = None,
145 **kwargs,
146 ) -> "Redis":
147 """
148 Return a Redis client object configured from the given URL
150 For example::
152 redis://[[username]:[password]]@localhost:6379/0
153 rediss://[[username]:[password]]@localhost:6379/0
154 unix://[username@]/path/to/socket.sock?db=0[&password=password]
156 Three URL schemes are supported:
158 - `redis://` creates a TCP socket connection. See more at:
159 <https://www.iana.org/assignments/uri-schemes/prov/redis>
160 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
161 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
162 - ``unix://``: creates a Unix Domain Socket connection.
164 The username, password, hostname, path and all querystring values
165 are passed through urllib.parse.unquote in order to replace any
166 percent-encoded values with their corresponding characters.
168 There are several ways to specify a database number. The first value
169 found will be used:
171 1. A ``db`` querystring option, e.g. redis://localhost?db=0
173 2. If using the redis:// or rediss:// schemes, the path argument
174 of the url, e.g. redis://localhost/0
176 3. A ``db`` keyword argument to this function.
178 If none of these options are specified, the default db=0 is used.
180 All querystring options are cast to their appropriate Python types.
181 Boolean arguments can be specified with string values "True"/"False"
182 or "Yes"/"No". Values that cannot be properly cast cause a
183 ``ValueError`` to be raised. Once parsed, the querystring arguments
184 and keyword arguments are passed to the ``ConnectionPool``'s
185 class initializer. In the case of conflicting arguments, querystring
186 arguments always win.
188 """
189 connection_pool = ConnectionPool.from_url(url, **kwargs)
190 client = cls(
191 connection_pool=connection_pool,
192 single_connection_client=single_connection_client,
193 )
194 if auto_close_connection_pool is not None:
195 warnings.warn(
196 DeprecationWarning(
197 '"auto_close_connection_pool" is deprecated '
198 "since version 5.0.1. "
199 "Please create a ConnectionPool explicitly and "
200 "provide to the Redis() constructor instead."
201 )
202 )
203 else:
204 auto_close_connection_pool = True
205 client.auto_close_connection_pool = auto_close_connection_pool
206 return client
208 @classmethod
209 def from_pool(
210 cls: Type["Redis"],
211 connection_pool: ConnectionPool,
212 ) -> "Redis":
213 """
214 Return a Redis client from the given connection pool.
215 The Redis client will take ownership of the connection pool and
216 close it when the Redis client is closed.
217 """
218 client = cls(
219 connection_pool=connection_pool,
220 )
221 client.auto_close_connection_pool = True
222 return client
224 @deprecated_args(
225 args_to_warn=["retry_on_timeout"],
226 reason="TimeoutError is included by default.",
227 version="6.0.0",
228 )
229 @deprecated_args(
230 args_to_warn=["lib_name", "lib_version"],
231 reason="Use 'driver_info' parameter instead. "
232 "lib_name and lib_version will be removed in a future version.",
233 )
234 def __init__(
235 self,
236 *,
237 host: str = "localhost",
238 port: int = 6379,
239 db: Union[str, int] = 0,
240 password: Optional[str] = None,
241 socket_timeout: Optional[float] = None,
242 socket_connect_timeout: Optional[float] = None,
243 socket_keepalive: Optional[bool] = None,
244 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
245 connection_pool: Optional[ConnectionPool] = None,
246 unix_socket_path: Optional[str] = None,
247 encoding: str = "utf-8",
248 encoding_errors: str = "strict",
249 decode_responses: bool = False,
250 retry_on_timeout: bool = False,
251 retry: Retry = Retry(
252 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
253 ),
254 retry_on_error: Optional[list] = None,
255 ssl: bool = False,
256 ssl_keyfile: Optional[str] = None,
257 ssl_certfile: Optional[str] = None,
258 ssl_cert_reqs: Union[str, VerifyMode] = "required",
259 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
260 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
261 ssl_ca_certs: Optional[str] = None,
262 ssl_ca_data: Optional[str] = None,
263 ssl_ca_path: Optional[str] = None,
264 ssl_check_hostname: bool = True,
265 ssl_min_version: Optional[TLSVersion] = None,
266 ssl_ciphers: Optional[str] = None,
267 ssl_password: Optional[str] = None,
268 max_connections: Optional[int] = None,
269 single_connection_client: bool = False,
270 health_check_interval: int = 0,
271 client_name: Optional[str] = None,
272 lib_name: Optional[str] = None,
273 lib_version: Optional[str] = None,
274 driver_info: Optional["DriverInfo"] = None,
275 username: Optional[str] = None,
276 auto_close_connection_pool: Optional[bool] = None,
277 redis_connect_func=None,
278 credential_provider: Optional[CredentialProvider] = None,
279 protocol: Optional[int] = 2,
280 event_dispatcher: Optional[EventDispatcher] = None,
281 ):
282 """
283 Initialize a new Redis client.
285 To specify a retry policy for specific errors, you have two options:
287 1. Set the `retry_on_error` to a list of the error/s to retry on, and
288 you can also set `retry` to a valid `Retry` object(in case the default
289 one is not appropriate) - with this approach the retries will be triggered
290 on the default errors specified in the Retry object enriched with the
291 errors specified in `retry_on_error`.
293 2. Define a `Retry` object with configured 'supported_errors' and set
294 it to the `retry` parameter - with this approach you completely redefine
295 the errors on which retries will happen.
297 `retry_on_timeout` is deprecated - please include the TimeoutError
298 either in the Retry object or in the `retry_on_error` list.
300 When 'connection_pool' is provided - the retry configuration of the
301 provided pool will be used.
302 """
303 kwargs: Dict[str, Any]
304 if event_dispatcher is None:
305 self._event_dispatcher = EventDispatcher()
306 else:
307 self._event_dispatcher = event_dispatcher
308 # auto_close_connection_pool only has an effect if connection_pool is
309 # None. It is assumed that if connection_pool is not None, the user
310 # wants to manage the connection pool themselves.
311 if auto_close_connection_pool is not None:
312 warnings.warn(
313 DeprecationWarning(
314 '"auto_close_connection_pool" is deprecated '
315 "since version 5.0.1. "
316 "Please create a ConnectionPool explicitly and "
317 "provide to the Redis() constructor instead."
318 )
319 )
320 else:
321 auto_close_connection_pool = True
323 if not connection_pool:
324 # Create internal connection pool, expected to be closed by Redis instance
325 if not retry_on_error:
326 retry_on_error = []
328 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
329 computed_driver_info = resolve_driver_info(
330 driver_info, lib_name, lib_version
331 )
333 kwargs = {
334 "db": db,
335 "username": username,
336 "password": password,
337 "credential_provider": credential_provider,
338 "socket_timeout": socket_timeout,
339 "encoding": encoding,
340 "encoding_errors": encoding_errors,
341 "decode_responses": decode_responses,
342 "retry_on_error": retry_on_error,
343 "retry": copy.deepcopy(retry),
344 "max_connections": max_connections,
345 "health_check_interval": health_check_interval,
346 "client_name": client_name,
347 "driver_info": computed_driver_info,
348 "redis_connect_func": redis_connect_func,
349 "protocol": protocol,
350 }
351 # based on input, setup appropriate connection args
352 if unix_socket_path is not None:
353 kwargs.update(
354 {
355 "path": unix_socket_path,
356 "connection_class": UnixDomainSocketConnection,
357 }
358 )
359 else:
360 # TCP specific options
361 kwargs.update(
362 {
363 "host": host,
364 "port": port,
365 "socket_connect_timeout": socket_connect_timeout,
366 "socket_keepalive": socket_keepalive,
367 "socket_keepalive_options": socket_keepalive_options,
368 }
369 )
371 if ssl:
372 kwargs.update(
373 {
374 "connection_class": SSLConnection,
375 "ssl_keyfile": ssl_keyfile,
376 "ssl_certfile": ssl_certfile,
377 "ssl_cert_reqs": ssl_cert_reqs,
378 "ssl_include_verify_flags": ssl_include_verify_flags,
379 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
380 "ssl_ca_certs": ssl_ca_certs,
381 "ssl_ca_data": ssl_ca_data,
382 "ssl_ca_path": ssl_ca_path,
383 "ssl_check_hostname": ssl_check_hostname,
384 "ssl_min_version": ssl_min_version,
385 "ssl_ciphers": ssl_ciphers,
386 "ssl_password": ssl_password,
387 }
388 )
389 # This arg only used if no pool is passed in
390 self.auto_close_connection_pool = auto_close_connection_pool
391 connection_pool = ConnectionPool(**kwargs)
392 self._event_dispatcher.dispatch(
393 AfterPooledConnectionsInstantiationEvent(
394 [connection_pool], ClientType.ASYNC, credential_provider
395 )
396 )
397 else:
398 # If a pool is passed in, do not close it
399 self.auto_close_connection_pool = False
400 self._event_dispatcher.dispatch(
401 AfterPooledConnectionsInstantiationEvent(
402 [connection_pool], ClientType.ASYNC, credential_provider
403 )
404 )
406 self.connection_pool = connection_pool
407 self.single_connection_client = single_connection_client
408 self.connection: Optional[Connection] = None
410 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
412 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
413 self.response_callbacks.update(_RedisCallbacksRESP3)
414 else:
415 self.response_callbacks.update(_RedisCallbacksRESP2)
417 # If using a single connection client, we need to lock creation-of and use-of
418 # the client in order to avoid race conditions such as using asyncio.gather
419 # on a set of redis commands
420 self._single_conn_lock = asyncio.Lock()
422 # When used as an async context manager, we need to increment and decrement
423 # a usage counter so that we can close the connection pool when no one is
424 # using the client.
425 self._usage_counter = 0
426 self._usage_lock = asyncio.Lock()
428 def __repr__(self):
429 return (
430 f"<{self.__class__.__module__}.{self.__class__.__name__}"
431 f"({self.connection_pool!r})>"
432 )
434 def __await__(self):
435 return self.initialize().__await__()
437 async def initialize(self: _RedisT) -> _RedisT:
438 if self.single_connection_client:
439 async with self._single_conn_lock:
440 if self.connection is None:
441 self.connection = await self.connection_pool.get_connection()
443 self._event_dispatcher.dispatch(
444 AfterSingleConnectionInstantiationEvent(
445 self.connection, ClientType.ASYNC, self._single_conn_lock
446 )
447 )
448 return self
450 def set_response_callback(self, command: str, callback: ResponseCallbackT):
451 """Set a custom Response Callback"""
452 self.response_callbacks[command] = callback
454 def get_encoder(self):
455 """Get the connection pool's encoder"""
456 return self.connection_pool.get_encoder()
458 def get_connection_kwargs(self):
459 """Get the connection's key-word arguments"""
460 return self.connection_pool.connection_kwargs
462 def get_retry(self) -> Optional[Retry]:
463 return self.get_connection_kwargs().get("retry")
465 def set_retry(self, retry: Retry) -> None:
466 self.get_connection_kwargs().update({"retry": retry})
467 self.connection_pool.set_retry(retry)
469 def load_external_module(self, funcname, func):
470 """
471 This function can be used to add externally defined redis modules,
472 and their namespaces to the redis client.
474 funcname - A string containing the name of the function to create
475 func - The function, being added to this class.
477 ex: Assume that one has a custom redis module named foomod that
478 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
479 To load function functions into this namespace:
481 from redis import Redis
482 from foomodule import F
483 r = Redis()
484 r.load_external_module("foo", F)
485 r.foo().dothing('your', 'arguments')
487 For a concrete example see the reimport of the redisjson module in
488 tests/test_connection.py::test_loading_external_modules
489 """
490 setattr(self, funcname, func)
492 def pipeline(
493 self, transaction: bool = True, shard_hint: Optional[str] = None
494 ) -> "Pipeline":
495 """
496 Return a new pipeline object that can queue multiple commands for
497 later execution. ``transaction`` indicates whether all commands
498 should be executed atomically. Apart from making a group of operations
499 atomic, pipelines are useful for reducing the back-and-forth overhead
500 between the client and server.
501 """
502 return Pipeline(
503 self.connection_pool, self.response_callbacks, transaction, shard_hint
504 )
506 async def transaction(
507 self,
508 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
509 *watches: KeyT,
510 shard_hint: Optional[str] = None,
511 value_from_callable: bool = False,
512 watch_delay: Optional[float] = None,
513 ):
514 """
515 Convenience method for executing the callable `func` as a transaction
516 while watching all keys specified in `watches`. The 'func' callable
517 should expect a single argument which is a Pipeline object.
518 """
519 pipe: Pipeline
520 async with self.pipeline(True, shard_hint) as pipe:
521 while True:
522 try:
523 if watches:
524 await pipe.watch(*watches)
525 func_value = func(pipe)
526 if inspect.isawaitable(func_value):
527 func_value = await func_value
528 exec_value = await pipe.execute()
529 return func_value if value_from_callable else exec_value
530 except WatchError:
531 if watch_delay is not None and watch_delay > 0:
532 await asyncio.sleep(watch_delay)
533 continue
535 def lock(
536 self,
537 name: KeyT,
538 timeout: Optional[float] = None,
539 sleep: float = 0.1,
540 blocking: bool = True,
541 blocking_timeout: Optional[float] = None,
542 lock_class: Optional[Type[Lock]] = None,
543 thread_local: bool = True,
544 raise_on_release_error: bool = True,
545 ) -> Lock:
546 """
547 Return a new Lock object using key ``name`` that mimics
548 the behavior of threading.Lock.
550 If specified, ``timeout`` indicates a maximum life for the lock.
551 By default, it will remain locked until release() is called.
553 ``sleep`` indicates the amount of time to sleep per loop iteration
554 when the lock is in blocking mode and another client is currently
555 holding the lock.
557 ``blocking`` indicates whether calling ``acquire`` should block until
558 the lock has been acquired or to fail immediately, causing ``acquire``
559 to return False and the lock not being acquired. Defaults to True.
560 Note this value can be overridden by passing a ``blocking``
561 argument to ``acquire``.
563 ``blocking_timeout`` indicates the maximum amount of time in seconds to
564 spend trying to acquire the lock. A value of ``None`` indicates
565 continue trying forever. ``blocking_timeout`` can be specified as a
566 float or integer, both representing the number of seconds to wait.
568 ``lock_class`` forces the specified lock implementation. Note that as
569 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
570 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
571 you have created your own custom lock class.
573 ``thread_local`` indicates whether the lock token is placed in
574 thread-local storage. By default, the token is placed in thread local
575 storage so that a thread only sees its token, not a token set by
576 another thread. Consider the following timeline:
578 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
579 thread-1 sets the token to "abc"
580 time: 1, thread-2 blocks trying to acquire `my-lock` using the
581 Lock instance.
582 time: 5, thread-1 has not yet completed. redis expires the lock
583 key.
584 time: 5, thread-2 acquired `my-lock` now that it's available.
585 thread-2 sets the token to "xyz"
586 time: 6, thread-1 finishes its work and calls release(). if the
587 token is *not* stored in thread local storage, then
588 thread-1 would see the token value as "xyz" and would be
589 able to successfully release the thread-2's lock.
591 ``raise_on_release_error`` indicates whether to raise an exception when
592 the lock is no longer owned when exiting the context manager. By default,
593 this is True, meaning an exception will be raised. If False, the warning
594 will be logged and the exception will be suppressed.
596 In some use cases it's necessary to disable thread local storage. For
597 example, if you have code where one thread acquires a lock and passes
598 that lock instance to a worker thread to release later. If thread
599 local storage isn't disabled in this case, the worker thread won't see
600 the token set by the thread that acquired the lock. Our assumption
601 is that these cases aren't common and as such default to using
602 thread local storage."""
603 if lock_class is None:
604 lock_class = Lock
605 return lock_class(
606 self,
607 name,
608 timeout=timeout,
609 sleep=sleep,
610 blocking=blocking,
611 blocking_timeout=blocking_timeout,
612 thread_local=thread_local,
613 raise_on_release_error=raise_on_release_error,
614 )
616 def pubsub(self, **kwargs) -> "PubSub":
617 """
618 Return a Publish/Subscribe object. With this object, you can
619 subscribe to channels and listen for messages that get published to
620 them.
621 """
622 return PubSub(
623 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
624 )
626 def keyspace_notifications(
627 self,
628 key_prefix: Union[str, bytes, None] = None,
629 ignore_subscribe_messages: bool = True,
630 ) -> "AsyncKeyspaceNotifications":
631 """
632 Return an :class:`~redis.asyncio.keyspace_notifications.AsyncKeyspaceNotifications`
633 object for subscribing to keyspace and keyevent notifications.
635 Note: Keyspace notifications must be enabled on the Redis server via
636 the ``notify-keyspace-events`` configuration option.
638 Args:
639 key_prefix: Optional prefix to filter and strip from keys in
640 notifications.
641 ignore_subscribe_messages: If True, subscribe/unsubscribe
642 confirmations are not returned by
643 get_message/listen.
644 """
645 from redis.asyncio.keyspace_notifications import AsyncKeyspaceNotifications
647 return AsyncKeyspaceNotifications(
648 self,
649 key_prefix=key_prefix,
650 ignore_subscribe_messages=ignore_subscribe_messages,
651 )
653 def monitor(self) -> "Monitor":
654 return Monitor(self.connection_pool)
656 def client(self) -> "Redis":
657 return self.__class__(
658 connection_pool=self.connection_pool, single_connection_client=True
659 )
661 async def __aenter__(self: _RedisT) -> _RedisT:
662 """
663 Async context manager entry. Increments a usage counter so that the
664 connection pool is only closed (via aclose()) when no context is using
665 the client.
666 """
667 await self._increment_usage()
668 try:
669 # Initialize the client (i.e. establish connection, etc.)
670 return await self.initialize()
671 except Exception:
672 # If initialization fails, decrement the counter to keep it in sync
673 await self._decrement_usage()
674 raise
676 async def _increment_usage(self) -> int:
677 """
678 Helper coroutine to increment the usage counter while holding the lock.
679 Returns the new value of the usage counter.
680 """
681 async with self._usage_lock:
682 self._usage_counter += 1
683 return self._usage_counter
685 async def _decrement_usage(self) -> int:
686 """
687 Helper coroutine to decrement the usage counter while holding the lock.
688 Returns the new value of the usage counter.
689 """
690 async with self._usage_lock:
691 self._usage_counter -= 1
692 return self._usage_counter
694 async def __aexit__(self, exc_type, exc_value, traceback):
695 """
696 Async context manager exit. Decrements a usage counter. If this is the
697 last exit (counter becomes zero), the client closes its connection pool.
698 """
699 current_usage = await asyncio.shield(self._decrement_usage())
700 if current_usage == 0:
701 # This was the last active context, so disconnect the pool.
702 await asyncio.shield(self.aclose())
704 _DEL_MESSAGE = "Unclosed Redis client"
706 # passing _warnings and _grl as argument default since they may be gone
707 # by the time __del__ is called at shutdown
708 def __del__(
709 self,
710 _warn: Any = warnings.warn,
711 _grl: Any = asyncio.get_running_loop,
712 ) -> None:
713 if hasattr(self, "connection") and (self.connection is not None):
714 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
715 try:
716 context = {"client": self, "message": self._DEL_MESSAGE}
717 _grl().call_exception_handler(context)
718 except RuntimeError:
719 pass
720 self.connection._close()
722 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
723 """
724 Closes Redis client connection
726 Args:
727 close_connection_pool:
728 decides whether to close the connection pool used by this Redis client,
729 overriding Redis.auto_close_connection_pool.
730 By default, let Redis.auto_close_connection_pool decide
731 whether to close the connection pool.
732 """
733 conn = self.connection
734 if conn:
735 self.connection = None
736 await self.connection_pool.release(conn)
737 if close_connection_pool or (
738 close_connection_pool is None and self.auto_close_connection_pool
739 ):
740 await self.connection_pool.disconnect()
742 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
743 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
744 """
745 Alias for aclose(), for backwards compatibility
746 """
747 await self.aclose(close_connection_pool)
749 async def _send_command_parse_response(self, conn, command_name, *args, **options):
750 """
751 Send a command and parse the response
752 """
753 await conn.send_command(*args)
754 return await self.parse_response(conn, command_name, **options)
756 async def _close_connection(
757 self,
758 conn: Connection,
759 error: Optional[BaseException] = None,
760 failure_count: Optional[int] = None,
761 start_time: Optional[float] = None,
762 command_name: Optional[str] = None,
763 ):
764 """
765 Close the connection before retrying.
767 The supported exceptions are already checked in the
768 retry object so we don't need to do it here.
770 After we disconnect the connection, it will try to reconnect and
771 do a health check as part of the send_command logic(on connection level).
772 """
773 if (
774 error
775 and failure_count is not None
776 and failure_count <= conn.retry.get_retries()
777 ):
778 await record_operation_duration(
779 command_name=command_name,
780 duration_seconds=time.monotonic() - start_time,
781 server_address=getattr(conn, "host", None),
782 server_port=getattr(conn, "port", None),
783 db_namespace=str(conn.db),
784 error=error,
785 retry_attempts=failure_count,
786 )
788 await conn.disconnect(error=error, failure_count=failure_count)
790 # COMMAND EXECUTION AND PROTOCOL PARSING
791 async def execute_command(self, *args, **options):
792 """Execute a command and return a parsed response"""
793 await self.initialize()
794 pool = self.connection_pool
795 command_name = args[0]
796 conn = self.connection or await pool.get_connection()
798 # Start timing for observability
799 start_time = time.monotonic()
800 # Track actual retry attempts for error reporting
801 actual_retry_attempts = 0
803 def failure_callback(error, failure_count):
804 nonlocal actual_retry_attempts
805 actual_retry_attempts = failure_count
806 return self._close_connection(
807 conn, error, failure_count, start_time, command_name
808 )
810 if self.single_connection_client:
811 await self._single_conn_lock.acquire()
812 try:
813 result = await conn.retry.call_with_retry(
814 lambda: self._send_command_parse_response(
815 conn, command_name, *args, **options
816 ),
817 failure_callback,
818 with_failure_count=True,
819 )
821 await record_operation_duration(
822 command_name=command_name,
823 duration_seconds=time.monotonic() - start_time,
824 server_address=getattr(conn, "host", None),
825 server_port=getattr(conn, "port", None),
826 db_namespace=str(conn.db),
827 )
828 return result
829 except Exception as e:
830 await record_error_count(
831 server_address=getattr(conn, "host", None),
832 server_port=getattr(conn, "port", None),
833 network_peer_address=getattr(conn, "host", None),
834 network_peer_port=getattr(conn, "port", None),
835 error_type=e,
836 retry_attempts=actual_retry_attempts,
837 is_internal=False,
838 )
839 raise
840 finally:
841 if self.single_connection_client:
842 self._single_conn_lock.release()
843 if not self.connection:
844 await pool.release(conn)
846 async def parse_response(
847 self, connection: Connection, command_name: Union[str, bytes], **options
848 ):
849 """Parses a response from the Redis server"""
850 try:
851 if NEVER_DECODE in options:
852 response = await connection.read_response(disable_decoding=True)
853 options.pop(NEVER_DECODE)
854 else:
855 response = await connection.read_response()
856 except ResponseError:
857 if EMPTY_RESPONSE in options:
858 return options[EMPTY_RESPONSE]
859 raise
861 if EMPTY_RESPONSE in options:
862 options.pop(EMPTY_RESPONSE)
864 # Remove keys entry, it needs only for cache.
865 options.pop("keys", None)
867 if command_name in self.response_callbacks:
868 # Mypy bug: https://github.com/python/mypy/issues/10977
869 command_name = cast(str, command_name)
870 retval = self.response_callbacks[command_name](response, **options)
871 return await retval if inspect.isawaitable(retval) else retval
872 return response
875StrictRedis = Redis
878class MonitorCommandInfo(TypedDict):
879 time: float
880 db: int
881 client_address: str
882 client_port: str
883 client_type: str
884 command: str
887class Monitor:
888 """
889 Monitor is useful for handling the MONITOR command to the redis server.
890 next_command() method returns one command from monitor
891 listen() method yields commands from monitor.
892 """
894 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
895 command_re = re.compile(r'"(.*?)(?<!\\)"')
897 def __init__(self, connection_pool: ConnectionPool):
898 self.connection_pool = connection_pool
899 self.connection: Optional[Connection] = None
901 async def connect(self):
902 if self.connection is None:
903 self.connection = await self.connection_pool.get_connection()
905 async def __aenter__(self):
906 await self.connect()
907 await self.connection.send_command("MONITOR")
908 # check that monitor returns 'OK', but don't return it to user
909 response = await self.connection.read_response()
910 if not bool_ok(response):
911 raise RedisError(f"MONITOR failed: {response}")
912 return self
914 async def __aexit__(self, *args):
915 await self.connection.disconnect()
916 await self.connection_pool.release(self.connection)
918 async def next_command(self) -> MonitorCommandInfo:
919 """Parse the response from a monitor command"""
920 await self.connect()
921 response = await self.connection.read_response()
922 if isinstance(response, bytes):
923 response = self.connection.encoder.decode(response, force=True)
924 command_time, command_data = response.split(" ", 1)
925 m = self.monitor_re.match(command_data)
926 db_id, client_info, command = m.groups()
927 command = " ".join(self.command_re.findall(command))
928 # Redis escapes double quotes because each piece of the command
929 # string is surrounded by double quotes. We don't have that
930 # requirement so remove the escaping and leave the quote.
931 command = command.replace('\\"', '"')
933 if client_info == "lua":
934 client_address = "lua"
935 client_port = ""
936 client_type = "lua"
937 elif client_info.startswith("unix"):
938 client_address = "unix"
939 client_port = client_info[5:]
940 client_type = "unix"
941 else:
942 # use rsplit as ipv6 addresses contain colons
943 client_address, client_port = client_info.rsplit(":", 1)
944 client_type = "tcp"
945 return {
946 "time": float(command_time),
947 "db": int(db_id),
948 "client_address": client_address,
949 "client_port": client_port,
950 "client_type": client_type,
951 "command": command,
952 }
954 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
955 """Listen for commands coming to the server."""
956 while True:
957 yield await self.next_command()
960class PubSub:
961 """
962 PubSub provides publish, subscribe and listen support to Redis channels.
964 After subscribing to one or more channels, the listen() method will block
965 until a message arrives on one of the subscribed channels. That message
966 will be returned and it's safe to start listening again.
967 """
969 PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
970 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
971 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
973 def __init__(
974 self,
975 connection_pool: ConnectionPool,
976 shard_hint: Optional[str] = None,
977 ignore_subscribe_messages: bool = False,
978 encoder=None,
979 push_handler_func: Optional[Callable] = None,
980 event_dispatcher: Optional["EventDispatcher"] = None,
981 ):
982 if event_dispatcher is None:
983 self._event_dispatcher = EventDispatcher()
984 else:
985 self._event_dispatcher = event_dispatcher
986 self.connection_pool = connection_pool
987 self.shard_hint = shard_hint
988 self.ignore_subscribe_messages = ignore_subscribe_messages
989 self.connection = None
990 # we need to know the encoding options for this connection in order
991 # to lookup channel and pattern names for callback handlers.
992 self.encoder = encoder
993 self.push_handler_func = push_handler_func
994 if self.encoder is None:
995 self.encoder = self.connection_pool.get_encoder()
996 if self.encoder.decode_responses:
997 self.health_check_response = [
998 ["pong", self.HEALTH_CHECK_MESSAGE],
999 self.HEALTH_CHECK_MESSAGE,
1000 ]
1001 else:
1002 self.health_check_response = [
1003 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
1004 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
1005 ]
1006 if self.push_handler_func is None:
1007 _set_info_logger()
1008 self.channels = {}
1009 self.pending_unsubscribe_channels = set()
1010 self.patterns = {}
1011 self.pending_unsubscribe_patterns = set()
1012 self.shard_channels = {}
1013 self.pending_unsubscribe_shard_channels = set()
1014 self._lock = asyncio.Lock()
1016 async def __aenter__(self):
1017 return self
1019 async def __aexit__(self, exc_type, exc_value, traceback):
1020 await self.aclose()
1022 def __del__(self):
1023 if self.connection:
1024 self.connection.deregister_connect_callback(self.on_connect)
1026 async def aclose(self):
1027 # In case a connection property does not yet exist
1028 # (due to a crash earlier in the Redis() constructor), return
1029 # immediately as there is nothing to clean-up.
1030 if not hasattr(self, "connection"):
1031 return
1032 async with self._lock:
1033 if self.connection:
1034 await self.connection.disconnect()
1035 self.connection.deregister_connect_callback(self.on_connect)
1036 await self.connection_pool.release(self.connection)
1037 self.connection = None
1038 self.channels = {}
1039 self.pending_unsubscribe_channels = set()
1040 self.patterns = {}
1041 self.pending_unsubscribe_patterns = set()
1042 self.shard_channels = {}
1043 self.pending_unsubscribe_shard_channels = set()
1045 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
1046 async def close(self) -> None:
1047 """Alias for aclose(), for backwards compatibility"""
1048 await self.aclose()
1050 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
1051 async def reset(self) -> None:
1052 """Alias for aclose(), for backwards compatibility"""
1053 await self.aclose()
1055 async def on_connect(self, connection: Connection):
1056 """Re-subscribe to any channels and patterns previously subscribed to"""
1057 # NOTE: for python3, we can't pass bytestrings as keyword arguments
1058 # so we need to decode channel/pattern names back to unicode strings
1059 # before passing them to [p]subscribe.
1060 #
1061 # However, channels subscribed without a callback (positional args) may
1062 # have binary names that are not valid in the current encoding (e.g.
1063 # arbitrary bytes that are not valid UTF-8). These channels are stored
1064 # with a ``None`` handler. We re-subscribe them as positional args so
1065 # that no decoding is required.
1066 self.pending_unsubscribe_channels.clear()
1067 self.pending_unsubscribe_patterns.clear()
1068 self.pending_unsubscribe_shard_channels.clear()
1069 if self.channels:
1070 channels_with_handlers = {}
1071 channels_without_handlers = []
1072 for k, v in self.channels.items():
1073 if v is not None:
1074 channels_with_handlers[self.encoder.decode(k, force=True)] = v
1075 else:
1076 channels_without_handlers.append(k)
1077 if channels_with_handlers or channels_without_handlers:
1078 await self.subscribe(
1079 *channels_without_handlers, **channels_with_handlers
1080 )
1081 if self.patterns:
1082 patterns_with_handlers = {}
1083 patterns_without_handlers = []
1084 for k, v in self.patterns.items():
1085 if v is not None:
1086 patterns_with_handlers[self.encoder.decode(k, force=True)] = v
1087 else:
1088 patterns_without_handlers.append(k)
1089 if patterns_with_handlers or patterns_without_handlers:
1090 await self.psubscribe(
1091 *patterns_without_handlers, **patterns_with_handlers
1092 )
1093 if self.shard_channels:
1094 shard_with_handlers = {}
1095 shard_without_handlers = []
1096 for k, v in self.shard_channels.items():
1097 if v is not None:
1098 shard_with_handlers[self.encoder.decode(k, force=True)] = v
1099 else:
1100 shard_without_handlers.append(k)
1101 if shard_with_handlers or shard_without_handlers:
1102 await self.ssubscribe(*shard_without_handlers, **shard_with_handlers)
1104 @property
1105 def subscribed(self):
1106 """Indicates if there are subscriptions to any channels or patterns"""
1107 return bool(self.channels or self.patterns or self.shard_channels)
1109 async def execute_command(self, *args: EncodableT):
1110 """Execute a publish/subscribe command"""
1112 # NOTE: don't parse the response in this function -- it could pull a
1113 # legitimate message off the stack if the connection is already
1114 # subscribed to one or more channels
1116 await self.connect()
1117 connection = self.connection
1118 kwargs = {"check_health": not self.subscribed}
1119 await self._execute(connection, connection.send_command, *args, **kwargs)
1121 async def connect(self):
1122 """
1123 Ensure that the PubSub is connected
1124 """
1125 if self.connection is None:
1126 self.connection = await self.connection_pool.get_connection()
1127 # register a callback that re-subscribes to any channels we
1128 # were listening to when we were disconnected
1129 self.connection.register_connect_callback(self.on_connect)
1130 else:
1131 await self.connection.connect()
1132 if self.push_handler_func is not None:
1133 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
1135 self._event_dispatcher.dispatch(
1136 AfterPubSubConnectionInstantiationEvent(
1137 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
1138 )
1139 )
1141 async def _reconnect(
1142 self,
1143 conn,
1144 error: Optional[BaseException] = None,
1145 failure_count: Optional[int] = None,
1146 start_time: Optional[float] = None,
1147 command_name: Optional[str] = None,
1148 ):
1149 """
1150 The supported exceptions are already checked in the
1151 retry object so we don't need to do it here.
1153 In this error handler we are trying to reconnect to the server.
1154 """
1155 if (
1156 error
1157 and failure_count is not None
1158 and failure_count <= conn.retry.get_retries()
1159 ):
1160 if command_name:
1161 await record_operation_duration(
1162 command_name=command_name,
1163 duration_seconds=time.monotonic() - start_time,
1164 server_address=getattr(conn, "host", None),
1165 server_port=getattr(conn, "port", None),
1166 db_namespace=str(conn.db),
1167 error=error,
1168 retry_attempts=failure_count,
1169 )
1170 await conn.disconnect(error=error, failure_count=failure_count)
1171 await conn.connect()
1173 async def _execute(self, conn, command, *args, **kwargs):
1174 """
1175 Connect manually upon disconnection. If the Redis server is down,
1176 this will fail and raise a ConnectionError as desired.
1177 After reconnection, the ``on_connect`` callback should have been
1178 called by the # connection to resubscribe us to any channels and
1179 patterns we were previously listening to
1180 """
1181 if not len(args) == 0:
1182 command_name = args[0]
1183 else:
1184 command_name = None
1186 # Start timing for observability
1187 start_time = time.monotonic()
1188 # Track actual retry attempts for error reporting
1189 actual_retry_attempts = 0
1191 def failure_callback(error, failure_count):
1192 nonlocal actual_retry_attempts
1193 actual_retry_attempts = failure_count
1194 return self._reconnect(conn, error, failure_count, start_time, command_name)
1196 try:
1197 response = await conn.retry.call_with_retry(
1198 lambda: command(*args, **kwargs),
1199 failure_callback,
1200 with_failure_count=True,
1201 )
1203 if command_name:
1204 await record_operation_duration(
1205 command_name=command_name,
1206 duration_seconds=time.monotonic() - start_time,
1207 server_address=getattr(conn, "host", None),
1208 server_port=getattr(conn, "port", None),
1209 db_namespace=str(conn.db),
1210 )
1212 return response
1213 except Exception as e:
1214 await record_error_count(
1215 server_address=getattr(conn, "host", None),
1216 server_port=getattr(conn, "port", None),
1217 network_peer_address=getattr(conn, "host", None),
1218 network_peer_port=getattr(conn, "port", None),
1219 error_type=e,
1220 retry_attempts=actual_retry_attempts,
1221 is_internal=False,
1222 )
1223 raise
1225 async def parse_response(self, block: bool = True, timeout: float = 0):
1226 """
1227 Parse the response from a publish/subscribe command.
1229 Args:
1230 block: If True, block indefinitely until a message is available.
1231 If False, return immediately if no message is available.
1232 Default: True
1233 timeout: The timeout in seconds for reading a response when block=False.
1234 This parameter is ignored when block=True.
1235 Default: 0 (return immediately if no data available)
1237 Returns:
1238 The parsed response from the server, or None if no message is available
1239 within the timeout period (when block=False).
1241 Important:
1242 The block and timeout parameters work together:
1243 - When block=True: timeout is IGNORED, method blocks indefinitely
1244 - When block=False: timeout is USED, method returns after timeout expires
1246 Typically, you should use get_message(timeout=X) instead of calling
1247 parse_response() directly. The get_message() method automatically sets
1248 block=False when a timeout is provided, and block=True when timeout=None.
1250 Example:
1251 # Block indefinitely (timeout is ignored)
1252 response = await pubsub.parse_response(block=True, timeout=0.1)
1254 # Non-blocking with 0.1 second timeout
1255 response = await pubsub.parse_response(block=False, timeout=0.1)
1257 # Non-blocking, return immediately
1258 response = await pubsub.parse_response(block=False, timeout=0)
1260 # Recommended: use get_message() instead
1261 msg = await pubsub.get_message(timeout=0.1) # automatically sets block=False
1262 msg = await pubsub.get_message(timeout=None) # automatically sets block=True
1263 """
1264 conn = self.connection
1265 if conn is None:
1266 raise RuntimeError(
1267 "pubsub connection not set: "
1268 "did you forget to call subscribe() or psubscribe()?"
1269 )
1271 await self.check_health()
1273 if not conn.is_connected:
1274 await conn.connect()
1276 read_timeout = None if block else timeout
1277 response = await self._execute(
1278 conn,
1279 conn.read_response,
1280 timeout=read_timeout,
1281 disconnect_on_error=False,
1282 push_request=True,
1283 )
1285 if conn.health_check_interval and response in self.health_check_response:
1286 # ignore the health check message as user might not expect it
1287 return None
1288 return response
1290 async def check_health(self):
1291 conn = self.connection
1292 if conn is None:
1293 raise RuntimeError(
1294 "pubsub connection not set: "
1295 "did you forget to call subscribe() or psubscribe()?"
1296 )
1298 if (
1299 conn.health_check_interval
1300 and asyncio.get_running_loop().time() > conn.next_health_check
1301 ):
1302 await conn.send_command(
1303 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1304 )
1306 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1307 """
1308 normalize channel/pattern names to be either bytes or strings
1309 based on whether responses are automatically decoded. this saves us
1310 from coercing the value for each message coming in.
1311 """
1312 encode = self.encoder.encode
1313 decode = self.encoder.decode
1314 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1316 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1317 """
1318 Subscribe to channel patterns. Patterns supplied as keyword arguments
1319 expect a pattern name as the key and a callable as the value. A
1320 pattern's callable will be invoked automatically when a message is
1321 received on that pattern rather than producing a message via
1322 ``listen()``.
1323 """
1324 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1325 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1326 # Mypy bug: https://github.com/python/mypy/issues/10970
1327 new_patterns.update(kwargs) # type: ignore[arg-type]
1328 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1329 # update the patterns dict AFTER we send the command. we don't want to
1330 # subscribe twice to these patterns, once for the command and again
1331 # for the reconnection.
1332 new_patterns = self._normalize_keys(new_patterns)
1333 self.patterns.update(new_patterns)
1334 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1335 return ret_val
1337 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1338 """
1339 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1340 all patterns.
1341 """
1342 patterns: Iterable[ChannelT]
1343 if args:
1344 parsed_args = list_or_args((args[0],), args[1:])
1345 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1346 else:
1347 parsed_args = []
1348 patterns = self.patterns
1349 self.pending_unsubscribe_patterns.update(patterns)
1350 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1352 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1353 """
1354 Subscribe to channels. Channels supplied as keyword arguments expect
1355 a channel name as the key and a callable as the value. A channel's
1356 callable will be invoked automatically when a message is received on
1357 that channel rather than producing a message via ``listen()`` or
1358 ``get_message()``.
1359 """
1360 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1361 new_channels = dict.fromkeys(parsed_args)
1362 # Mypy bug: https://github.com/python/mypy/issues/10970
1363 new_channels.update(kwargs) # type: ignore[arg-type]
1364 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1365 # update the channels dict AFTER we send the command. we don't want to
1366 # subscribe twice to these channels, once for the command and again
1367 # for the reconnection.
1368 new_channels = self._normalize_keys(new_channels)
1369 self.channels.update(new_channels)
1370 self.pending_unsubscribe_channels.difference_update(new_channels)
1371 return ret_val
1373 def unsubscribe(self, *args) -> Awaitable:
1374 """
1375 Unsubscribe from the supplied channels. If empty, unsubscribe from
1376 all channels
1377 """
1378 if args:
1379 parsed_args = list_or_args(args[0], args[1:])
1380 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1381 else:
1382 parsed_args = []
1383 channels = self.channels
1384 self.pending_unsubscribe_channels.update(channels)
1385 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1387 async def ssubscribe(self, *args, target_node=None, **kwargs):
1388 """
1389 Subscribes the client to the specified shard channels.
1390 Channels supplied as keyword arguments expect a channel name as the key
1391 and a callable as the value. A channel's callable will be invoked automatically
1392 when a message is received on that channel rather than producing a message via
1393 ``listen()`` or ``get_sharded_message()``.
1394 """
1395 if args:
1396 args = list_or_args(args[0], args[1:])
1397 new_s_channels = dict.fromkeys(args)
1398 new_s_channels.update(kwargs)
1399 ret_val = await self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1400 # update the s_channels dict AFTER we send the command. we don't want to
1401 # subscribe twice to these channels, once for the command and again
1402 # for the reconnection.
1403 new_s_channels = self._normalize_keys(new_s_channels)
1404 self.shard_channels.update(new_s_channels)
1405 self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1406 return ret_val
1408 def sunsubscribe(self, *args, target_node=None) -> Awaitable:
1409 """
1410 Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1411 all shard_channels
1412 """
1413 if args:
1414 args = list_or_args(args[0], args[1:])
1415 s_channels = self._normalize_keys(dict.fromkeys(args))
1416 else:
1417 s_channels = self.shard_channels
1418 self.pending_unsubscribe_shard_channels.update(s_channels)
1419 return self.execute_command("SUNSUBSCRIBE", *args)
1421 async def listen(self) -> AsyncIterator:
1422 """Listen for messages on channels this client has been subscribed to"""
1423 while self.subscribed:
1424 response = await self.handle_message(await self.parse_response(block=True))
1425 if response is not None:
1426 yield response
1428 async def get_message(
1429 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1430 ):
1431 """
1432 Get the next message if one is available, otherwise None.
1434 If timeout is specified, the system will wait for `timeout` seconds
1435 before returning. Timeout should be specified as a floating point
1436 number or None to wait indefinitely.
1437 """
1438 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1439 if response:
1440 return await self.handle_message(response, ignore_subscribe_messages)
1441 return None
1443 def ping(self, message=None) -> Awaitable[bool]:
1444 """
1445 Ping the Redis server to test connectivity.
1447 Sends a PING command to the Redis server and returns True if the server
1448 responds with "PONG".
1449 """
1450 args = ["PING", message] if message is not None else ["PING"]
1451 return self.execute_command(*args)
1453 async def handle_message(self, response, ignore_subscribe_messages=False):
1454 """
1455 Parses a pub/sub message. If the channel or pattern was subscribed to
1456 with a message handler, the handler is invoked instead of a parsed
1457 message being returned.
1458 """
1459 if response is None:
1460 return None
1461 if isinstance(response, bytes):
1462 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1463 message_type = str_if_bytes(response[0])
1464 if message_type == "pmessage":
1465 message = {
1466 "type": message_type,
1467 "pattern": response[1],
1468 "channel": response[2],
1469 "data": response[3],
1470 }
1471 elif message_type == "pong":
1472 message = {
1473 "type": message_type,
1474 "pattern": None,
1475 "channel": None,
1476 "data": response[1],
1477 }
1478 else:
1479 message = {
1480 "type": message_type,
1481 "pattern": None,
1482 "channel": response[1],
1483 "data": response[2],
1484 }
1486 if message_type in ["message", "pmessage"]:
1487 channel = str_if_bytes(message["channel"])
1488 await record_pubsub_message(
1489 direction=PubSubDirection.RECEIVE,
1490 channel=channel,
1491 )
1492 elif message_type == "smessage":
1493 channel = str_if_bytes(message["channel"])
1494 await record_pubsub_message(
1495 direction=PubSubDirection.RECEIVE,
1496 channel=channel,
1497 sharded=True,
1498 )
1500 # if this is an unsubscribe message, remove it from memory
1501 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1502 if message_type == "punsubscribe":
1503 pattern = response[1]
1504 if pattern in self.pending_unsubscribe_patterns:
1505 self.pending_unsubscribe_patterns.remove(pattern)
1506 self.patterns.pop(pattern, None)
1507 elif message_type == "sunsubscribe":
1508 s_channel = response[1]
1509 if s_channel in self.pending_unsubscribe_shard_channels:
1510 self.pending_unsubscribe_shard_channels.remove(s_channel)
1511 self.shard_channels.pop(s_channel, None)
1512 else:
1513 channel = response[1]
1514 if channel in self.pending_unsubscribe_channels:
1515 self.pending_unsubscribe_channels.remove(channel)
1516 self.channels.pop(channel, None)
1518 if message_type in self.PUBLISH_MESSAGE_TYPES:
1519 # if there's a message handler, invoke it
1520 if message_type == "pmessage":
1521 handler = self.patterns.get(message["pattern"], None)
1522 elif message_type == "smessage":
1523 handler = self.shard_channels.get(message["channel"], None)
1524 else:
1525 handler = self.channels.get(message["channel"], None)
1526 if handler:
1527 if inspect.iscoroutinefunction(handler):
1528 await handler(message)
1529 else:
1530 handler(message)
1531 return None
1532 elif message_type != "pong":
1533 # this is a subscribe/unsubscribe message. ignore if we don't
1534 # want them
1535 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1536 return None
1538 return message
1540 async def run(
1541 self,
1542 *,
1543 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1544 poll_timeout: float = 1.0,
1545 pubsub=None,
1546 ) -> None:
1547 """Process pub/sub messages using registered callbacks.
1549 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1550 redis-py, but it is a coroutine. To launch it as a separate task, use
1551 ``asyncio.create_task``:
1553 >>> task = asyncio.create_task(pubsub.run())
1555 To shut it down, use asyncio cancellation:
1557 >>> task.cancel()
1558 >>> await task
1559 """
1560 for channel, handler in self.channels.items():
1561 if handler is None:
1562 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1563 for pattern, handler in self.patterns.items():
1564 if handler is None:
1565 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1567 await self.connect()
1568 while True:
1569 try:
1570 if pubsub is None:
1571 await self.get_message(
1572 ignore_subscribe_messages=True, timeout=poll_timeout
1573 )
1574 else:
1575 await pubsub.get_message(
1576 ignore_subscribe_messages=True, timeout=poll_timeout
1577 )
1578 except asyncio.CancelledError:
1579 raise
1580 except BaseException as e:
1581 if exception_handler is None:
1582 raise
1583 res = exception_handler(e, self)
1584 if inspect.isawaitable(res):
1585 await res
1586 # Ensure that other tasks on the event loop get a chance to run
1587 # if we didn't have to block for I/O anywhere.
1588 await asyncio.sleep(0)
1591class PubsubWorkerExceptionHandler(Protocol):
1592 def __call__(self, e: BaseException, pubsub: PubSub): ...
1595class AsyncPubsubWorkerExceptionHandler(Protocol):
1596 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1599PSWorkerThreadExcHandlerT = Union[
1600 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1601]
1604CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1605CommandStackT = List[CommandT]
1608class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1609 """
1610 Pipelines provide a way to transmit multiple commands to the Redis server
1611 in one transmission. This is convenient for batch processing, such as
1612 saving all the values in a list to Redis.
1614 All commands executed within a pipeline(when running in transactional mode,
1615 which is the default behavior) are wrapped with MULTI and EXEC
1616 calls. This guarantees all commands executed in the pipeline will be
1617 executed atomically.
1619 Any command raising an exception does *not* halt the execution of
1620 subsequent commands in the pipeline. Instead, the exception is caught
1621 and its instance is placed into the response list returned by execute().
1622 Code iterating over the response list should be able to deal with an
1623 instance of an exception as a potential value. In general, these will be
1624 ResponseError exceptions, such as those raised when issuing a command
1625 on a key of a different datatype.
1626 """
1628 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1630 def __init__(
1631 self,
1632 connection_pool: ConnectionPool,
1633 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1634 transaction: bool,
1635 shard_hint: Optional[str],
1636 ):
1637 self.connection_pool = connection_pool
1638 self.connection = None
1639 self.response_callbacks = response_callbacks
1640 self.is_transaction = transaction
1641 self.shard_hint = shard_hint
1642 self.watching = False
1643 self.command_stack: CommandStackT = []
1644 self.scripts: Set[Script] = set()
1645 self.explicit_transaction = False
1647 async def __aenter__(self: _RedisT) -> _RedisT:
1648 return self
1650 async def __aexit__(self, exc_type, exc_value, traceback):
1651 await self.reset()
1653 def __await__(self):
1654 return self._async_self().__await__()
1656 _DEL_MESSAGE = "Unclosed Pipeline client"
1658 def __len__(self):
1659 return len(self.command_stack)
1661 def __bool__(self):
1662 """Pipeline instances should always evaluate to True"""
1663 return True
1665 async def _async_self(self):
1666 return self
1668 async def reset(self):
1669 self.command_stack = []
1670 self.scripts = set()
1671 # make sure to reset the connection state in the event that we were
1672 # watching something
1673 if self.watching and self.connection:
1674 try:
1675 # call this manually since our unwatch or
1676 # immediate_execute_command methods can call reset()
1677 await self.connection.send_command("UNWATCH")
1678 await self.connection.read_response()
1679 except ConnectionError:
1680 # disconnect will also remove any previous WATCHes
1681 if self.connection:
1682 await self.connection.disconnect()
1683 # clean up the other instance attributes
1684 self.watching = False
1685 self.explicit_transaction = False
1686 # we can safely return the connection to the pool here since we're
1687 # sure we're no longer WATCHing anything
1688 if self.connection:
1689 await self.connection_pool.release(self.connection)
1690 self.connection = None
1692 async def aclose(self) -> None:
1693 """Alias for reset(), a standard method name for cleanup"""
1694 await self.reset()
1696 def multi(self):
1697 """
1698 Start a transactional block of the pipeline after WATCH commands
1699 are issued. End the transactional block with `execute`.
1700 """
1701 if self.explicit_transaction:
1702 raise RedisError("Cannot issue nested calls to MULTI")
1703 if self.command_stack:
1704 raise RedisError(
1705 "Commands without an initial WATCH have already been issued"
1706 )
1707 self.explicit_transaction = True
1709 def execute_command(
1710 self, *args, **kwargs
1711 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1712 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1713 return self.immediate_execute_command(*args, **kwargs)
1714 return self.pipeline_execute_command(*args, **kwargs)
1716 async def _disconnect_reset_raise_on_watching(
1717 self,
1718 conn: Connection,
1719 error: Exception,
1720 failure_count: Optional[int] = None,
1721 start_time: Optional[float] = None,
1722 command_name: Optional[str] = None,
1723 ) -> None:
1724 """
1725 Close the connection reset watching state and
1726 raise an exception if we were watching.
1728 The supported exceptions are already checked in the
1729 retry object so we don't need to do it here.
1731 After we disconnect the connection, it will try to reconnect and
1732 do a health check as part of the send_command logic(on connection level).
1733 """
1734 if (
1735 error
1736 and failure_count is not None
1737 and failure_count <= conn.retry.get_retries()
1738 ):
1739 await record_operation_duration(
1740 command_name=command_name,
1741 duration_seconds=time.monotonic() - start_time,
1742 server_address=getattr(conn, "host", None),
1743 server_port=getattr(conn, "port", None),
1744 db_namespace=str(conn.db),
1745 error=error,
1746 retry_attempts=failure_count,
1747 )
1748 await conn.disconnect(error=error, failure_count=failure_count)
1749 # if we were already watching a variable, the watch is no longer
1750 # valid since this connection has died. raise a WatchError, which
1751 # indicates the user should retry this transaction.
1752 if self.watching:
1753 await self.reset()
1754 raise WatchError(
1755 f"A {type(error).__name__} occurred while watching one or more keys"
1756 )
1758 async def immediate_execute_command(self, *args, **options):
1759 """
1760 Execute a command immediately, but don't auto-retry on the supported
1761 errors for retry if we're already WATCHing a variable.
1762 Used when issuing WATCH or subsequent commands retrieving their values but before
1763 MULTI is called.
1764 """
1765 command_name = args[0]
1766 conn = self.connection
1767 # if this is the first call, we need a connection
1768 if not conn:
1769 conn = await self.connection_pool.get_connection()
1770 self.connection = conn
1772 # Start timing for observability
1773 start_time = time.monotonic()
1774 # Track actual retry attempts for error reporting
1775 actual_retry_attempts = 0
1777 def failure_callback(error, failure_count):
1778 nonlocal actual_retry_attempts
1779 actual_retry_attempts = failure_count
1780 return self._disconnect_reset_raise_on_watching(
1781 conn, error, failure_count, start_time, command_name
1782 )
1784 try:
1785 response = await conn.retry.call_with_retry(
1786 lambda: self._send_command_parse_response(
1787 conn, command_name, *args, **options
1788 ),
1789 failure_callback,
1790 with_failure_count=True,
1791 )
1793 await record_operation_duration(
1794 command_name=command_name,
1795 duration_seconds=time.monotonic() - start_time,
1796 server_address=getattr(conn, "host", None),
1797 server_port=getattr(conn, "port", None),
1798 db_namespace=str(conn.db),
1799 )
1801 return response
1802 except Exception as e:
1803 await record_error_count(
1804 server_address=getattr(conn, "host", None),
1805 server_port=getattr(conn, "port", None),
1806 network_peer_address=getattr(conn, "host", None),
1807 network_peer_port=getattr(conn, "port", None),
1808 error_type=e,
1809 retry_attempts=actual_retry_attempts,
1810 is_internal=False,
1811 )
1812 raise
1814 def pipeline_execute_command(self, *args, **options):
1815 """
1816 Stage a command to be executed when execute() is next called
1818 Returns the current Pipeline object back so commands can be
1819 chained together, such as:
1821 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1823 At some other point, you can then run: pipe.execute(),
1824 which will execute all commands queued in the pipe.
1825 """
1826 self.command_stack.append((args, options))
1827 return self
1829 async def _execute_transaction( # noqa: C901
1830 self, connection: Connection, commands: CommandStackT, raise_on_error
1831 ):
1832 pre: CommandT = (("MULTI",), {})
1833 post: CommandT = (("EXEC",), {})
1834 cmds = (pre, *commands, post)
1835 all_cmds = connection.pack_commands(
1836 args for args, options in cmds if EMPTY_RESPONSE not in options
1837 )
1838 await connection.send_packed_command(all_cmds)
1839 errors = []
1841 # parse off the response for MULTI
1842 # NOTE: we need to handle ResponseErrors here and continue
1843 # so that we read all the additional command messages from
1844 # the socket
1845 try:
1846 await self.parse_response(connection, "_")
1847 except ResponseError as err:
1848 errors.append((0, err))
1850 # and all the other commands
1851 for i, command in enumerate(commands):
1852 if EMPTY_RESPONSE in command[1]:
1853 errors.append((i, command[1][EMPTY_RESPONSE]))
1854 else:
1855 try:
1856 await self.parse_response(connection, "_")
1857 except ResponseError as err:
1858 self.annotate_exception(err, i + 1, command[0])
1859 errors.append((i, err))
1861 # parse the EXEC.
1862 try:
1863 response = await self.parse_response(connection, "_")
1864 except ExecAbortError as err:
1865 if errors:
1866 raise errors[0][1] from err
1867 raise
1869 # EXEC clears any watched keys
1870 self.watching = False
1872 if response is None:
1873 raise WatchError("Watched variable changed.") from None
1875 # put any parse errors into the response
1876 for i, e in errors:
1877 response.insert(i, e)
1879 if len(response) != len(commands):
1880 if self.connection:
1881 await self.connection.disconnect()
1882 raise ResponseError(
1883 "Wrong number of response items from pipeline execution"
1884 ) from None
1886 # find any errors in the response and raise if necessary
1887 if raise_on_error:
1888 self.raise_first_error(commands, response)
1890 # We have to run response callbacks manually
1891 data = []
1892 for r, cmd in zip(response, commands):
1893 if not isinstance(r, Exception):
1894 args, options = cmd
1895 command_name = args[0]
1897 # Remove keys entry, it needs only for cache.
1898 options.pop("keys", None)
1900 if command_name in self.response_callbacks:
1901 r = self.response_callbacks[command_name](r, **options)
1902 if inspect.isawaitable(r):
1903 r = await r
1904 data.append(r)
1905 return data
1907 async def _execute_pipeline(
1908 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1909 ):
1910 # build up all commands into a single request to increase network perf
1911 all_cmds = connection.pack_commands([args for args, _ in commands])
1912 await connection.send_packed_command(all_cmds)
1914 response = []
1915 for args, options in commands:
1916 try:
1917 response.append(
1918 await self.parse_response(connection, args[0], **options)
1919 )
1920 except ResponseError as e:
1921 response.append(e)
1923 if raise_on_error:
1924 self.raise_first_error(commands, response)
1925 return response
1927 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1928 for i, r in enumerate(response):
1929 if isinstance(r, ResponseError):
1930 self.annotate_exception(r, i + 1, commands[i][0])
1931 raise r
1933 def annotate_exception(
1934 self, exception: Exception, number: int, command: Iterable[object]
1935 ) -> None:
1936 cmd = " ".join(map(safe_str, command))
1937 msg = (
1938 f"Command # {number} ({truncate_text(cmd)}) "
1939 f"of pipeline caused error: {exception.args}"
1940 )
1941 exception.args = (msg,) + exception.args[1:]
1943 async def parse_response(
1944 self, connection: Connection, command_name: Union[str, bytes], **options
1945 ):
1946 result = await super().parse_response(connection, command_name, **options)
1947 if command_name in self.UNWATCH_COMMANDS:
1948 self.watching = False
1949 elif command_name == "WATCH":
1950 self.watching = True
1951 return result
1953 async def load_scripts(self):
1954 # make sure all scripts that are about to be run on this pipeline exist
1955 scripts = list(self.scripts)
1956 immediate = self.immediate_execute_command
1957 shas = [s.sha for s in scripts]
1958 # we can't use the normal script_* methods because they would just
1959 # get buffered in the pipeline.
1960 exists = await immediate("SCRIPT EXISTS", *shas)
1961 if not all(exists):
1962 for s, exist in zip(scripts, exists):
1963 if not exist:
1964 s.sha = await immediate("SCRIPT LOAD", s.script)
1966 async def _disconnect_raise_on_watching(
1967 self,
1968 conn: Connection,
1969 error: Exception,
1970 failure_count: Optional[int] = None,
1971 start_time: Optional[float] = None,
1972 command_name: Optional[str] = None,
1973 ):
1974 """
1975 Close the connection, raise an exception if we were watching.
1977 The supported exceptions are already checked in the
1978 retry object so we don't need to do it here.
1980 After we disconnect the connection, it will try to reconnect and
1981 do a health check as part of the send_command logic(on connection level).
1982 """
1983 if (
1984 error
1985 and failure_count is not None
1986 and failure_count <= conn.retry.get_retries()
1987 ):
1988 await record_operation_duration(
1989 command_name=command_name,
1990 duration_seconds=time.monotonic() - start_time,
1991 server_address=getattr(conn, "host", None),
1992 server_port=getattr(conn, "port", None),
1993 db_namespace=str(conn.db),
1994 error=error,
1995 retry_attempts=failure_count,
1996 )
1997 await conn.disconnect(error=error, failure_count=failure_count)
1998 # if we were watching a variable, the watch is no longer valid
1999 # since this connection has died. raise a WatchError, which
2000 # indicates the user should retry this transaction.
2001 if self.watching:
2002 raise WatchError(
2003 f"A {type(error).__name__} occurred while watching one or more keys"
2004 )
2006 async def execute(self, raise_on_error: bool = True) -> List[Any]:
2007 """Execute all the commands in the current pipeline"""
2008 stack = self.command_stack
2009 if not stack and not self.watching:
2010 return []
2011 if self.scripts:
2012 await self.load_scripts()
2013 if self.is_transaction or self.explicit_transaction:
2014 execute = self._execute_transaction
2015 operation_name = "MULTI"
2016 else:
2017 execute = self._execute_pipeline
2018 operation_name = "PIPELINE"
2020 conn = self.connection
2021 if not conn:
2022 conn = await self.connection_pool.get_connection()
2023 # assign to self.connection so reset() releases the connection
2024 # back to the pool after we're done
2025 self.connection = conn
2026 conn = cast(Connection, conn)
2028 # Start timing for observability
2029 start_time = time.monotonic()
2030 # Track actual retry attempts for error reporting
2031 actual_retry_attempts = 0
2033 def failure_callback(error, failure_count):
2034 nonlocal actual_retry_attempts
2035 actual_retry_attempts = failure_count
2036 return self._disconnect_raise_on_watching(
2037 conn, error, failure_count, start_time, operation_name
2038 )
2040 try:
2041 response = await conn.retry.call_with_retry(
2042 lambda: execute(conn, stack, raise_on_error),
2043 failure_callback,
2044 with_failure_count=True,
2045 )
2047 await record_operation_duration(
2048 command_name=operation_name,
2049 duration_seconds=time.monotonic() - start_time,
2050 server_address=getattr(conn, "host", None),
2051 server_port=getattr(conn, "port", None),
2052 db_namespace=str(conn.db),
2053 )
2054 return response
2055 except Exception as e:
2056 await record_error_count(
2057 server_address=getattr(conn, "host", None),
2058 server_port=getattr(conn, "port", None),
2059 network_peer_address=getattr(conn, "host", None),
2060 network_peer_port=getattr(conn, "port", None),
2061 error_type=e,
2062 retry_attempts=actual_retry_attempts,
2063 is_internal=False,
2064 )
2065 raise
2066 finally:
2067 await self.reset()
2069 async def discard(self):
2070 """Flushes all previously queued commands
2071 See: https://redis.io/commands/DISCARD
2072 """
2073 await self.execute_command("DISCARD")
2075 async def watch(self, *names: KeyT):
2076 """Watches the values at keys ``names``"""
2077 if self.explicit_transaction:
2078 raise RedisError("Cannot issue a WATCH after a MULTI")
2079 return await self.execute_command("WATCH", *names)
2081 async def unwatch(self):
2082 """Unwatches all previously specified keys"""
2083 return self.watching and await self.execute_command("UNWATCH") or True