Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/client.py: 22%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import copy
3import inspect
4import re
5import time
6import warnings
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 AsyncIterator,
11 Awaitable,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Mapping,
17 MutableMapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26 cast,
27)
29from redis._parsers.helpers import (
30 _RedisCallbacks,
31 _RedisCallbacksRESP2,
32 _RedisCallbacksRESP3,
33 bool_ok,
34)
35from redis.asyncio.connection import (
36 Connection,
37 ConnectionPool,
38 SSLConnection,
39 UnixDomainSocketConnection,
40)
41from redis.asyncio.lock import Lock
42from redis.asyncio.observability.recorder import (
43 record_error_count,
44 record_operation_duration,
45 record_pubsub_message,
46)
47from redis.asyncio.retry import Retry
48from redis.backoff import ExponentialWithJitterBackoff
49from redis.client import (
50 EMPTY_RESPONSE,
51 NEVER_DECODE,
52 AbstractRedis,
53 CaseInsensitiveDict,
54)
55from redis.commands import (
56 AsyncCoreCommands,
57 AsyncRedisModuleCommands,
58 AsyncSentinelCommands,
59 list_or_args,
60)
61from redis.credentials import CredentialProvider
62from redis.driver_info import DriverInfo, resolve_driver_info
63from redis.event import (
64 AfterPooledConnectionsInstantiationEvent,
65 AfterPubSubConnectionInstantiationEvent,
66 AfterSingleConnectionInstantiationEvent,
67 ClientType,
68 EventDispatcher,
69)
70from redis.exceptions import (
71 ConnectionError,
72 ExecAbortError,
73 PubSubError,
74 RedisError,
75 ResponseError,
76 WatchError,
77)
78from redis.observability.attributes import PubSubDirection
79from redis.typing import ChannelT, EncodableT, KeyT
80from redis.utils import (
81 SSL_AVAILABLE,
82 _set_info_logger,
83 deprecated_args,
84 deprecated_function,
85 safe_str,
86 str_if_bytes,
87 truncate_text,
88)
90if TYPE_CHECKING and SSL_AVAILABLE:
91 from ssl import TLSVersion, VerifyFlags, VerifyMode
92else:
93 TLSVersion = None
94 VerifyMode = None
95 VerifyFlags = None
97PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
98_KeyT = TypeVar("_KeyT", bound=KeyT)
99_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
100_RedisT = TypeVar("_RedisT", bound="Redis")
101_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
102if TYPE_CHECKING:
103 from redis.commands.core import Script
106class ResponseCallbackProtocol(Protocol):
107 def __call__(self, response: Any, **kwargs): ...
110class AsyncResponseCallbackProtocol(Protocol):
111 async def __call__(self, response: Any, **kwargs): ...
114ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
117class Redis(
118 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
119):
120 """
121 Implementation of the Redis protocol.
123 This abstract class provides a Python interface to all Redis commands
124 and an implementation of the Redis protocol.
126 Pipelines derive from this, implementing how
127 the commands are sent and received to the Redis server. Based on
128 configuration, an instance will either use a ConnectionPool, or
129 Connection object to talk to redis.
130 """
132 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
134 @classmethod
135 def from_url(
136 cls: Type["Redis"],
137 url: str,
138 single_connection_client: bool = False,
139 auto_close_connection_pool: Optional[bool] = None,
140 **kwargs,
141 ) -> "Redis":
142 """
143 Return a Redis client object configured from the given URL
145 For example::
147 redis://[[username]:[password]]@localhost:6379/0
148 rediss://[[username]:[password]]@localhost:6379/0
149 unix://[username@]/path/to/socket.sock?db=0[&password=password]
151 Three URL schemes are supported:
153 - `redis://` creates a TCP socket connection. See more at:
154 <https://www.iana.org/assignments/uri-schemes/prov/redis>
155 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
156 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
157 - ``unix://``: creates a Unix Domain Socket connection.
159 The username, password, hostname, path and all querystring values
160 are passed through urllib.parse.unquote in order to replace any
161 percent-encoded values with their corresponding characters.
163 There are several ways to specify a database number. The first value
164 found will be used:
166 1. A ``db`` querystring option, e.g. redis://localhost?db=0
168 2. If using the redis:// or rediss:// schemes, the path argument
169 of the url, e.g. redis://localhost/0
171 3. A ``db`` keyword argument to this function.
173 If none of these options are specified, the default db=0 is used.
175 All querystring options are cast to their appropriate Python types.
176 Boolean arguments can be specified with string values "True"/"False"
177 or "Yes"/"No". Values that cannot be properly cast cause a
178 ``ValueError`` to be raised. Once parsed, the querystring arguments
179 and keyword arguments are passed to the ``ConnectionPool``'s
180 class initializer. In the case of conflicting arguments, querystring
181 arguments always win.
183 """
184 connection_pool = ConnectionPool.from_url(url, **kwargs)
185 client = cls(
186 connection_pool=connection_pool,
187 single_connection_client=single_connection_client,
188 )
189 if auto_close_connection_pool is not None:
190 warnings.warn(
191 DeprecationWarning(
192 '"auto_close_connection_pool" is deprecated '
193 "since version 5.0.1. "
194 "Please create a ConnectionPool explicitly and "
195 "provide to the Redis() constructor instead."
196 )
197 )
198 else:
199 auto_close_connection_pool = True
200 client.auto_close_connection_pool = auto_close_connection_pool
201 return client
203 @classmethod
204 def from_pool(
205 cls: Type["Redis"],
206 connection_pool: ConnectionPool,
207 ) -> "Redis":
208 """
209 Return a Redis client from the given connection pool.
210 The Redis client will take ownership of the connection pool and
211 close it when the Redis client is closed.
212 """
213 client = cls(
214 connection_pool=connection_pool,
215 )
216 client.auto_close_connection_pool = True
217 return client
219 @deprecated_args(
220 args_to_warn=["retry_on_timeout"],
221 reason="TimeoutError is included by default.",
222 version="6.0.0",
223 )
224 @deprecated_args(
225 args_to_warn=["lib_name", "lib_version"],
226 reason="Use 'driver_info' parameter instead. "
227 "lib_name and lib_version will be removed in a future version.",
228 )
229 def __init__(
230 self,
231 *,
232 host: str = "localhost",
233 port: int = 6379,
234 db: Union[str, int] = 0,
235 password: Optional[str] = None,
236 socket_timeout: Optional[float] = None,
237 socket_connect_timeout: Optional[float] = None,
238 socket_keepalive: Optional[bool] = None,
239 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
240 connection_pool: Optional[ConnectionPool] = None,
241 unix_socket_path: Optional[str] = None,
242 encoding: str = "utf-8",
243 encoding_errors: str = "strict",
244 decode_responses: bool = False,
245 retry_on_timeout: bool = False,
246 retry: Retry = Retry(
247 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
248 ),
249 retry_on_error: Optional[list] = None,
250 ssl: bool = False,
251 ssl_keyfile: Optional[str] = None,
252 ssl_certfile: Optional[str] = None,
253 ssl_cert_reqs: Union[str, VerifyMode] = "required",
254 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
255 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
256 ssl_ca_certs: Optional[str] = None,
257 ssl_ca_data: Optional[str] = None,
258 ssl_ca_path: Optional[str] = None,
259 ssl_check_hostname: bool = True,
260 ssl_min_version: Optional[TLSVersion] = None,
261 ssl_ciphers: Optional[str] = None,
262 ssl_password: Optional[str] = None,
263 max_connections: Optional[int] = None,
264 single_connection_client: bool = False,
265 health_check_interval: int = 0,
266 client_name: Optional[str] = None,
267 lib_name: Optional[str] = None,
268 lib_version: Optional[str] = None,
269 driver_info: Optional["DriverInfo"] = None,
270 username: Optional[str] = None,
271 auto_close_connection_pool: Optional[bool] = None,
272 redis_connect_func=None,
273 credential_provider: Optional[CredentialProvider] = None,
274 protocol: Optional[int] = 2,
275 event_dispatcher: Optional[EventDispatcher] = None,
276 ):
277 """
278 Initialize a new Redis client.
280 To specify a retry policy for specific errors, you have two options:
282 1. Set the `retry_on_error` to a list of the error/s to retry on, and
283 you can also set `retry` to a valid `Retry` object(in case the default
284 one is not appropriate) - with this approach the retries will be triggered
285 on the default errors specified in the Retry object enriched with the
286 errors specified in `retry_on_error`.
288 2. Define a `Retry` object with configured 'supported_errors' and set
289 it to the `retry` parameter - with this approach you completely redefine
290 the errors on which retries will happen.
292 `retry_on_timeout` is deprecated - please include the TimeoutError
293 either in the Retry object or in the `retry_on_error` list.
295 When 'connection_pool' is provided - the retry configuration of the
296 provided pool will be used.
297 """
298 kwargs: Dict[str, Any]
299 if event_dispatcher is None:
300 self._event_dispatcher = EventDispatcher()
301 else:
302 self._event_dispatcher = event_dispatcher
303 # auto_close_connection_pool only has an effect if connection_pool is
304 # None. It is assumed that if connection_pool is not None, the user
305 # wants to manage the connection pool themselves.
306 if auto_close_connection_pool is not None:
307 warnings.warn(
308 DeprecationWarning(
309 '"auto_close_connection_pool" is deprecated '
310 "since version 5.0.1. "
311 "Please create a ConnectionPool explicitly and "
312 "provide to the Redis() constructor instead."
313 )
314 )
315 else:
316 auto_close_connection_pool = True
318 if not connection_pool:
319 # Create internal connection pool, expected to be closed by Redis instance
320 if not retry_on_error:
321 retry_on_error = []
323 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
324 computed_driver_info = resolve_driver_info(
325 driver_info, lib_name, lib_version
326 )
328 kwargs = {
329 "db": db,
330 "username": username,
331 "password": password,
332 "credential_provider": credential_provider,
333 "socket_timeout": socket_timeout,
334 "encoding": encoding,
335 "encoding_errors": encoding_errors,
336 "decode_responses": decode_responses,
337 "retry_on_error": retry_on_error,
338 "retry": copy.deepcopy(retry),
339 "max_connections": max_connections,
340 "health_check_interval": health_check_interval,
341 "client_name": client_name,
342 "driver_info": computed_driver_info,
343 "redis_connect_func": redis_connect_func,
344 "protocol": protocol,
345 }
346 # based on input, setup appropriate connection args
347 if unix_socket_path is not None:
348 kwargs.update(
349 {
350 "path": unix_socket_path,
351 "connection_class": UnixDomainSocketConnection,
352 }
353 )
354 else:
355 # TCP specific options
356 kwargs.update(
357 {
358 "host": host,
359 "port": port,
360 "socket_connect_timeout": socket_connect_timeout,
361 "socket_keepalive": socket_keepalive,
362 "socket_keepalive_options": socket_keepalive_options,
363 }
364 )
366 if ssl:
367 kwargs.update(
368 {
369 "connection_class": SSLConnection,
370 "ssl_keyfile": ssl_keyfile,
371 "ssl_certfile": ssl_certfile,
372 "ssl_cert_reqs": ssl_cert_reqs,
373 "ssl_include_verify_flags": ssl_include_verify_flags,
374 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
375 "ssl_ca_certs": ssl_ca_certs,
376 "ssl_ca_data": ssl_ca_data,
377 "ssl_ca_path": ssl_ca_path,
378 "ssl_check_hostname": ssl_check_hostname,
379 "ssl_min_version": ssl_min_version,
380 "ssl_ciphers": ssl_ciphers,
381 "ssl_password": ssl_password,
382 }
383 )
384 # This arg only used if no pool is passed in
385 self.auto_close_connection_pool = auto_close_connection_pool
386 connection_pool = ConnectionPool(**kwargs)
387 self._event_dispatcher.dispatch(
388 AfterPooledConnectionsInstantiationEvent(
389 [connection_pool], ClientType.ASYNC, credential_provider
390 )
391 )
392 else:
393 # If a pool is passed in, do not close it
394 self.auto_close_connection_pool = False
395 self._event_dispatcher.dispatch(
396 AfterPooledConnectionsInstantiationEvent(
397 [connection_pool], ClientType.ASYNC, credential_provider
398 )
399 )
401 self.connection_pool = connection_pool
402 self.single_connection_client = single_connection_client
403 self.connection: Optional[Connection] = None
405 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
407 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
408 self.response_callbacks.update(_RedisCallbacksRESP3)
409 else:
410 self.response_callbacks.update(_RedisCallbacksRESP2)
412 # If using a single connection client, we need to lock creation-of and use-of
413 # the client in order to avoid race conditions such as using asyncio.gather
414 # on a set of redis commands
415 self._single_conn_lock = asyncio.Lock()
417 # When used as an async context manager, we need to increment and decrement
418 # a usage counter so that we can close the connection pool when no one is
419 # using the client.
420 self._usage_counter = 0
421 self._usage_lock = asyncio.Lock()
423 def __repr__(self):
424 return (
425 f"<{self.__class__.__module__}.{self.__class__.__name__}"
426 f"({self.connection_pool!r})>"
427 )
429 def __await__(self):
430 return self.initialize().__await__()
432 async def initialize(self: _RedisT) -> _RedisT:
433 if self.single_connection_client:
434 async with self._single_conn_lock:
435 if self.connection is None:
436 self.connection = await self.connection_pool.get_connection()
438 self._event_dispatcher.dispatch(
439 AfterSingleConnectionInstantiationEvent(
440 self.connection, ClientType.ASYNC, self._single_conn_lock
441 )
442 )
443 return self
445 def set_response_callback(self, command: str, callback: ResponseCallbackT):
446 """Set a custom Response Callback"""
447 self.response_callbacks[command] = callback
449 def get_encoder(self):
450 """Get the connection pool's encoder"""
451 return self.connection_pool.get_encoder()
453 def get_connection_kwargs(self):
454 """Get the connection's key-word arguments"""
455 return self.connection_pool.connection_kwargs
457 def get_retry(self) -> Optional[Retry]:
458 return self.get_connection_kwargs().get("retry")
460 def set_retry(self, retry: Retry) -> None:
461 self.get_connection_kwargs().update({"retry": retry})
462 self.connection_pool.set_retry(retry)
464 def load_external_module(self, funcname, func):
465 """
466 This function can be used to add externally defined redis modules,
467 and their namespaces to the redis client.
469 funcname - A string containing the name of the function to create
470 func - The function, being added to this class.
472 ex: Assume that one has a custom redis module named foomod that
473 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
474 To load function functions into this namespace:
476 from redis import Redis
477 from foomodule import F
478 r = Redis()
479 r.load_external_module("foo", F)
480 r.foo().dothing('your', 'arguments')
482 For a concrete example see the reimport of the redisjson module in
483 tests/test_connection.py::test_loading_external_modules
484 """
485 setattr(self, funcname, func)
487 def pipeline(
488 self, transaction: bool = True, shard_hint: Optional[str] = None
489 ) -> "Pipeline":
490 """
491 Return a new pipeline object that can queue multiple commands for
492 later execution. ``transaction`` indicates whether all commands
493 should be executed atomically. Apart from making a group of operations
494 atomic, pipelines are useful for reducing the back-and-forth overhead
495 between the client and server.
496 """
497 return Pipeline(
498 self.connection_pool, self.response_callbacks, transaction, shard_hint
499 )
501 async def transaction(
502 self,
503 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
504 *watches: KeyT,
505 shard_hint: Optional[str] = None,
506 value_from_callable: bool = False,
507 watch_delay: Optional[float] = None,
508 ):
509 """
510 Convenience method for executing the callable `func` as a transaction
511 while watching all keys specified in `watches`. The 'func' callable
512 should expect a single argument which is a Pipeline object.
513 """
514 pipe: Pipeline
515 async with self.pipeline(True, shard_hint) as pipe:
516 while True:
517 try:
518 if watches:
519 await pipe.watch(*watches)
520 func_value = func(pipe)
521 if inspect.isawaitable(func_value):
522 func_value = await func_value
523 exec_value = await pipe.execute()
524 return func_value if value_from_callable else exec_value
525 except WatchError:
526 if watch_delay is not None and watch_delay > 0:
527 await asyncio.sleep(watch_delay)
528 continue
530 def lock(
531 self,
532 name: KeyT,
533 timeout: Optional[float] = None,
534 sleep: float = 0.1,
535 blocking: bool = True,
536 blocking_timeout: Optional[float] = None,
537 lock_class: Optional[Type[Lock]] = None,
538 thread_local: bool = True,
539 raise_on_release_error: bool = True,
540 ) -> Lock:
541 """
542 Return a new Lock object using key ``name`` that mimics
543 the behavior of threading.Lock.
545 If specified, ``timeout`` indicates a maximum life for the lock.
546 By default, it will remain locked until release() is called.
548 ``sleep`` indicates the amount of time to sleep per loop iteration
549 when the lock is in blocking mode and another client is currently
550 holding the lock.
552 ``blocking`` indicates whether calling ``acquire`` should block until
553 the lock has been acquired or to fail immediately, causing ``acquire``
554 to return False and the lock not being acquired. Defaults to True.
555 Note this value can be overridden by passing a ``blocking``
556 argument to ``acquire``.
558 ``blocking_timeout`` indicates the maximum amount of time in seconds to
559 spend trying to acquire the lock. A value of ``None`` indicates
560 continue trying forever. ``blocking_timeout`` can be specified as a
561 float or integer, both representing the number of seconds to wait.
563 ``lock_class`` forces the specified lock implementation. Note that as
564 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
565 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
566 you have created your own custom lock class.
568 ``thread_local`` indicates whether the lock token is placed in
569 thread-local storage. By default, the token is placed in thread local
570 storage so that a thread only sees its token, not a token set by
571 another thread. Consider the following timeline:
573 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
574 thread-1 sets the token to "abc"
575 time: 1, thread-2 blocks trying to acquire `my-lock` using the
576 Lock instance.
577 time: 5, thread-1 has not yet completed. redis expires the lock
578 key.
579 time: 5, thread-2 acquired `my-lock` now that it's available.
580 thread-2 sets the token to "xyz"
581 time: 6, thread-1 finishes its work and calls release(). if the
582 token is *not* stored in thread local storage, then
583 thread-1 would see the token value as "xyz" and would be
584 able to successfully release the thread-2's lock.
586 ``raise_on_release_error`` indicates whether to raise an exception when
587 the lock is no longer owned when exiting the context manager. By default,
588 this is True, meaning an exception will be raised. If False, the warning
589 will be logged and the exception will be suppressed.
591 In some use cases it's necessary to disable thread local storage. For
592 example, if you have code where one thread acquires a lock and passes
593 that lock instance to a worker thread to release later. If thread
594 local storage isn't disabled in this case, the worker thread won't see
595 the token set by the thread that acquired the lock. Our assumption
596 is that these cases aren't common and as such default to using
597 thread local storage."""
598 if lock_class is None:
599 lock_class = Lock
600 return lock_class(
601 self,
602 name,
603 timeout=timeout,
604 sleep=sleep,
605 blocking=blocking,
606 blocking_timeout=blocking_timeout,
607 thread_local=thread_local,
608 raise_on_release_error=raise_on_release_error,
609 )
611 def pubsub(self, **kwargs) -> "PubSub":
612 """
613 Return a Publish/Subscribe object. With this object, you can
614 subscribe to channels and listen for messages that get published to
615 them.
616 """
617 return PubSub(
618 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
619 )
621 def monitor(self) -> "Monitor":
622 return Monitor(self.connection_pool)
624 def client(self) -> "Redis":
625 return self.__class__(
626 connection_pool=self.connection_pool, single_connection_client=True
627 )
629 async def __aenter__(self: _RedisT) -> _RedisT:
630 """
631 Async context manager entry. Increments a usage counter so that the
632 connection pool is only closed (via aclose()) when no context is using
633 the client.
634 """
635 await self._increment_usage()
636 try:
637 # Initialize the client (i.e. establish connection, etc.)
638 return await self.initialize()
639 except Exception:
640 # If initialization fails, decrement the counter to keep it in sync
641 await self._decrement_usage()
642 raise
644 async def _increment_usage(self) -> int:
645 """
646 Helper coroutine to increment the usage counter while holding the lock.
647 Returns the new value of the usage counter.
648 """
649 async with self._usage_lock:
650 self._usage_counter += 1
651 return self._usage_counter
653 async def _decrement_usage(self) -> int:
654 """
655 Helper coroutine to decrement the usage counter while holding the lock.
656 Returns the new value of the usage counter.
657 """
658 async with self._usage_lock:
659 self._usage_counter -= 1
660 return self._usage_counter
662 async def __aexit__(self, exc_type, exc_value, traceback):
663 """
664 Async context manager exit. Decrements a usage counter. If this is the
665 last exit (counter becomes zero), the client closes its connection pool.
666 """
667 current_usage = await asyncio.shield(self._decrement_usage())
668 if current_usage == 0:
669 # This was the last active context, so disconnect the pool.
670 await asyncio.shield(self.aclose())
672 _DEL_MESSAGE = "Unclosed Redis client"
674 # passing _warnings and _grl as argument default since they may be gone
675 # by the time __del__ is called at shutdown
676 def __del__(
677 self,
678 _warn: Any = warnings.warn,
679 _grl: Any = asyncio.get_running_loop,
680 ) -> None:
681 if hasattr(self, "connection") and (self.connection is not None):
682 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
683 try:
684 context = {"client": self, "message": self._DEL_MESSAGE}
685 _grl().call_exception_handler(context)
686 except RuntimeError:
687 pass
688 self.connection._close()
690 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
691 """
692 Closes Redis client connection
694 Args:
695 close_connection_pool:
696 decides whether to close the connection pool used by this Redis client,
697 overriding Redis.auto_close_connection_pool.
698 By default, let Redis.auto_close_connection_pool decide
699 whether to close the connection pool.
700 """
701 conn = self.connection
702 if conn:
703 self.connection = None
704 await self.connection_pool.release(conn)
705 if close_connection_pool or (
706 close_connection_pool is None and self.auto_close_connection_pool
707 ):
708 await self.connection_pool.disconnect()
710 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
711 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
712 """
713 Alias for aclose(), for backwards compatibility
714 """
715 await self.aclose(close_connection_pool)
717 async def _send_command_parse_response(self, conn, command_name, *args, **options):
718 """
719 Send a command and parse the response
720 """
721 await conn.send_command(*args)
722 return await self.parse_response(conn, command_name, **options)
724 async def _close_connection(
725 self,
726 conn: Connection,
727 error: Optional[BaseException] = None,
728 failure_count: Optional[int] = None,
729 start_time: Optional[float] = None,
730 command_name: Optional[str] = None,
731 ):
732 """
733 Close the connection before retrying.
735 The supported exceptions are already checked in the
736 retry object so we don't need to do it here.
738 After we disconnect the connection, it will try to reconnect and
739 do a health check as part of the send_command logic(on connection level).
740 """
741 if (
742 error
743 and failure_count is not None
744 and failure_count <= conn.retry.get_retries()
745 ):
746 await record_operation_duration(
747 command_name=command_name,
748 duration_seconds=time.monotonic() - start_time,
749 server_address=getattr(conn, "host", None),
750 server_port=getattr(conn, "port", None),
751 db_namespace=str(conn.db),
752 error=error,
753 retry_attempts=failure_count,
754 )
756 await conn.disconnect(error=error, failure_count=failure_count)
758 # COMMAND EXECUTION AND PROTOCOL PARSING
759 async def execute_command(self, *args, **options):
760 """Execute a command and return a parsed response"""
761 await self.initialize()
762 pool = self.connection_pool
763 command_name = args[0]
764 conn = self.connection or await pool.get_connection()
766 # Start timing for observability
767 start_time = time.monotonic()
768 # Track actual retry attempts for error reporting
769 actual_retry_attempts = 0
771 def failure_callback(error, failure_count):
772 nonlocal actual_retry_attempts
773 actual_retry_attempts = failure_count
774 return self._close_connection(
775 conn, error, failure_count, start_time, command_name
776 )
778 if self.single_connection_client:
779 await self._single_conn_lock.acquire()
780 try:
781 result = await conn.retry.call_with_retry(
782 lambda: self._send_command_parse_response(
783 conn, command_name, *args, **options
784 ),
785 failure_callback,
786 with_failure_count=True,
787 )
789 await record_operation_duration(
790 command_name=command_name,
791 duration_seconds=time.monotonic() - start_time,
792 server_address=getattr(conn, "host", None),
793 server_port=getattr(conn, "port", None),
794 db_namespace=str(conn.db),
795 )
796 return result
797 except Exception as e:
798 await record_error_count(
799 server_address=getattr(conn, "host", None),
800 server_port=getattr(conn, "port", None),
801 network_peer_address=getattr(conn, "host", None),
802 network_peer_port=getattr(conn, "port", None),
803 error_type=e,
804 retry_attempts=actual_retry_attempts,
805 is_internal=False,
806 )
807 raise
808 finally:
809 if self.single_connection_client:
810 self._single_conn_lock.release()
811 if not self.connection:
812 await pool.release(conn)
814 async def parse_response(
815 self, connection: Connection, command_name: Union[str, bytes], **options
816 ):
817 """Parses a response from the Redis server"""
818 try:
819 if NEVER_DECODE in options:
820 response = await connection.read_response(disable_decoding=True)
821 options.pop(NEVER_DECODE)
822 else:
823 response = await connection.read_response()
824 except ResponseError:
825 if EMPTY_RESPONSE in options:
826 return options[EMPTY_RESPONSE]
827 raise
829 if EMPTY_RESPONSE in options:
830 options.pop(EMPTY_RESPONSE)
832 # Remove keys entry, it needs only for cache.
833 options.pop("keys", None)
835 if command_name in self.response_callbacks:
836 # Mypy bug: https://github.com/python/mypy/issues/10977
837 command_name = cast(str, command_name)
838 retval = self.response_callbacks[command_name](response, **options)
839 return await retval if inspect.isawaitable(retval) else retval
840 return response
843StrictRedis = Redis
846class MonitorCommandInfo(TypedDict):
847 time: float
848 db: int
849 client_address: str
850 client_port: str
851 client_type: str
852 command: str
855class Monitor:
856 """
857 Monitor is useful for handling the MONITOR command to the redis server.
858 next_command() method returns one command from monitor
859 listen() method yields commands from monitor.
860 """
862 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
863 command_re = re.compile(r'"(.*?)(?<!\\)"')
865 def __init__(self, connection_pool: ConnectionPool):
866 self.connection_pool = connection_pool
867 self.connection: Optional[Connection] = None
869 async def connect(self):
870 if self.connection is None:
871 self.connection = await self.connection_pool.get_connection()
873 async def __aenter__(self):
874 await self.connect()
875 await self.connection.send_command("MONITOR")
876 # check that monitor returns 'OK', but don't return it to user
877 response = await self.connection.read_response()
878 if not bool_ok(response):
879 raise RedisError(f"MONITOR failed: {response}")
880 return self
882 async def __aexit__(self, *args):
883 await self.connection.disconnect()
884 await self.connection_pool.release(self.connection)
886 async def next_command(self) -> MonitorCommandInfo:
887 """Parse the response from a monitor command"""
888 await self.connect()
889 response = await self.connection.read_response()
890 if isinstance(response, bytes):
891 response = self.connection.encoder.decode(response, force=True)
892 command_time, command_data = response.split(" ", 1)
893 m = self.monitor_re.match(command_data)
894 db_id, client_info, command = m.groups()
895 command = " ".join(self.command_re.findall(command))
896 # Redis escapes double quotes because each piece of the command
897 # string is surrounded by double quotes. We don't have that
898 # requirement so remove the escaping and leave the quote.
899 command = command.replace('\\"', '"')
901 if client_info == "lua":
902 client_address = "lua"
903 client_port = ""
904 client_type = "lua"
905 elif client_info.startswith("unix"):
906 client_address = "unix"
907 client_port = client_info[5:]
908 client_type = "unix"
909 else:
910 # use rsplit as ipv6 addresses contain colons
911 client_address, client_port = client_info.rsplit(":", 1)
912 client_type = "tcp"
913 return {
914 "time": float(command_time),
915 "db": int(db_id),
916 "client_address": client_address,
917 "client_port": client_port,
918 "client_type": client_type,
919 "command": command,
920 }
922 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
923 """Listen for commands coming to the server."""
924 while True:
925 yield await self.next_command()
928class PubSub:
929 """
930 PubSub provides publish, subscribe and listen support to Redis channels.
932 After subscribing to one or more channels, the listen() method will block
933 until a message arrives on one of the subscribed channels. That message
934 will be returned and it's safe to start listening again.
935 """
937 PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
938 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
939 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
941 def __init__(
942 self,
943 connection_pool: ConnectionPool,
944 shard_hint: Optional[str] = None,
945 ignore_subscribe_messages: bool = False,
946 encoder=None,
947 push_handler_func: Optional[Callable] = None,
948 event_dispatcher: Optional["EventDispatcher"] = None,
949 ):
950 if event_dispatcher is None:
951 self._event_dispatcher = EventDispatcher()
952 else:
953 self._event_dispatcher = event_dispatcher
954 self.connection_pool = connection_pool
955 self.shard_hint = shard_hint
956 self.ignore_subscribe_messages = ignore_subscribe_messages
957 self.connection = None
958 # we need to know the encoding options for this connection in order
959 # to lookup channel and pattern names for callback handlers.
960 self.encoder = encoder
961 self.push_handler_func = push_handler_func
962 if self.encoder is None:
963 self.encoder = self.connection_pool.get_encoder()
964 if self.encoder.decode_responses:
965 self.health_check_response = [
966 ["pong", self.HEALTH_CHECK_MESSAGE],
967 self.HEALTH_CHECK_MESSAGE,
968 ]
969 else:
970 self.health_check_response = [
971 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
972 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
973 ]
974 if self.push_handler_func is None:
975 _set_info_logger()
976 self.channels = {}
977 self.pending_unsubscribe_channels = set()
978 self.patterns = {}
979 self.pending_unsubscribe_patterns = set()
980 self._lock = asyncio.Lock()
982 async def __aenter__(self):
983 return self
985 async def __aexit__(self, exc_type, exc_value, traceback):
986 await self.aclose()
988 def __del__(self):
989 if self.connection:
990 self.connection.deregister_connect_callback(self.on_connect)
992 async def aclose(self):
993 # In case a connection property does not yet exist
994 # (due to a crash earlier in the Redis() constructor), return
995 # immediately as there is nothing to clean-up.
996 if not hasattr(self, "connection"):
997 return
998 async with self._lock:
999 if self.connection:
1000 await self.connection.disconnect()
1001 self.connection.deregister_connect_callback(self.on_connect)
1002 await self.connection_pool.release(self.connection)
1003 self.connection = None
1004 self.channels = {}
1005 self.pending_unsubscribe_channels = set()
1006 self.patterns = {}
1007 self.pending_unsubscribe_patterns = set()
1009 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
1010 async def close(self) -> None:
1011 """Alias for aclose(), for backwards compatibility"""
1012 await self.aclose()
1014 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
1015 async def reset(self) -> None:
1016 """Alias for aclose(), for backwards compatibility"""
1017 await self.aclose()
1019 async def on_connect(self, connection: Connection):
1020 """Re-subscribe to any channels and patterns previously subscribed to"""
1021 # NOTE: for python3, we can't pass bytestrings as keyword arguments
1022 # so we need to decode channel/pattern names back to unicode strings
1023 # before passing them to [p]subscribe.
1024 #
1025 # However, channels subscribed without a callback (positional args) may
1026 # have binary names that are not valid in the current encoding (e.g.
1027 # arbitrary bytes that are not valid UTF-8). These channels are stored
1028 # with a ``None`` handler. We re-subscribe them as positional args so
1029 # that no decoding is required.
1030 self.pending_unsubscribe_channels.clear()
1031 self.pending_unsubscribe_patterns.clear()
1032 if self.channels:
1033 channels_with_handlers = {}
1034 channels_without_handlers = []
1035 for k, v in self.channels.items():
1036 if v is not None:
1037 channels_with_handlers[self.encoder.decode(k, force=True)] = v
1038 else:
1039 channels_without_handlers.append(k)
1040 if channels_with_handlers or channels_without_handlers:
1041 await self.subscribe(
1042 *channels_without_handlers, **channels_with_handlers
1043 )
1044 if self.patterns:
1045 patterns_with_handlers = {}
1046 patterns_without_handlers = []
1047 for k, v in self.patterns.items():
1048 if v is not None:
1049 patterns_with_handlers[self.encoder.decode(k, force=True)] = v
1050 else:
1051 patterns_without_handlers.append(k)
1052 if patterns_with_handlers or patterns_without_handlers:
1053 await self.psubscribe(
1054 *patterns_without_handlers, **patterns_with_handlers
1055 )
1057 @property
1058 def subscribed(self):
1059 """Indicates if there are subscriptions to any channels or patterns"""
1060 return bool(self.channels or self.patterns)
1062 async def execute_command(self, *args: EncodableT):
1063 """Execute a publish/subscribe command"""
1065 # NOTE: don't parse the response in this function -- it could pull a
1066 # legitimate message off the stack if the connection is already
1067 # subscribed to one or more channels
1069 await self.connect()
1070 connection = self.connection
1071 kwargs = {"check_health": not self.subscribed}
1072 await self._execute(connection, connection.send_command, *args, **kwargs)
1074 async def connect(self):
1075 """
1076 Ensure that the PubSub is connected
1077 """
1078 if self.connection is None:
1079 self.connection = await self.connection_pool.get_connection()
1080 # register a callback that re-subscribes to any channels we
1081 # were listening to when we were disconnected
1082 self.connection.register_connect_callback(self.on_connect)
1083 else:
1084 await self.connection.connect()
1085 if self.push_handler_func is not None:
1086 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
1088 self._event_dispatcher.dispatch(
1089 AfterPubSubConnectionInstantiationEvent(
1090 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
1091 )
1092 )
1094 async def _reconnect(
1095 self,
1096 conn,
1097 error: Optional[BaseException] = None,
1098 failure_count: Optional[int] = None,
1099 start_time: Optional[float] = None,
1100 command_name: Optional[str] = None,
1101 ):
1102 """
1103 The supported exceptions are already checked in the
1104 retry object so we don't need to do it here.
1106 In this error handler we are trying to reconnect to the server.
1107 """
1108 if (
1109 error
1110 and failure_count is not None
1111 and failure_count <= conn.retry.get_retries()
1112 ):
1113 if command_name:
1114 await record_operation_duration(
1115 command_name=command_name,
1116 duration_seconds=time.monotonic() - start_time,
1117 server_address=getattr(conn, "host", None),
1118 server_port=getattr(conn, "port", None),
1119 db_namespace=str(conn.db),
1120 error=error,
1121 retry_attempts=failure_count,
1122 )
1123 await conn.disconnect(error=error, failure_count=failure_count)
1124 await conn.connect()
1126 async def _execute(self, conn, command, *args, **kwargs):
1127 """
1128 Connect manually upon disconnection. If the Redis server is down,
1129 this will fail and raise a ConnectionError as desired.
1130 After reconnection, the ``on_connect`` callback should have been
1131 called by the # connection to resubscribe us to any channels and
1132 patterns we were previously listening to
1133 """
1134 if not len(args) == 0:
1135 command_name = args[0]
1136 else:
1137 command_name = None
1139 # Start timing for observability
1140 start_time = time.monotonic()
1141 # Track actual retry attempts for error reporting
1142 actual_retry_attempts = 0
1144 def failure_callback(error, failure_count):
1145 nonlocal actual_retry_attempts
1146 actual_retry_attempts = failure_count
1147 return self._reconnect(conn, error, failure_count, start_time, command_name)
1149 try:
1150 response = await conn.retry.call_with_retry(
1151 lambda: command(*args, **kwargs),
1152 failure_callback,
1153 with_failure_count=True,
1154 )
1156 if command_name:
1157 await record_operation_duration(
1158 command_name=command_name,
1159 duration_seconds=time.monotonic() - start_time,
1160 server_address=getattr(conn, "host", None),
1161 server_port=getattr(conn, "port", None),
1162 db_namespace=str(conn.db),
1163 )
1165 return response
1166 except Exception as e:
1167 await record_error_count(
1168 server_address=getattr(conn, "host", None),
1169 server_port=getattr(conn, "port", None),
1170 network_peer_address=getattr(conn, "host", None),
1171 network_peer_port=getattr(conn, "port", None),
1172 error_type=e,
1173 retry_attempts=actual_retry_attempts,
1174 is_internal=False,
1175 )
1176 raise
1178 async def parse_response(self, block: bool = True, timeout: float = 0):
1179 """Parse the response from a publish/subscribe command"""
1180 conn = self.connection
1181 if conn is None:
1182 raise RuntimeError(
1183 "pubsub connection not set: "
1184 "did you forget to call subscribe() or psubscribe()?"
1185 )
1187 await self.check_health()
1189 if not conn.is_connected:
1190 await conn.connect()
1192 read_timeout = None if block else timeout
1193 response = await self._execute(
1194 conn,
1195 conn.read_response,
1196 timeout=read_timeout,
1197 disconnect_on_error=False,
1198 push_request=True,
1199 )
1201 if conn.health_check_interval and response in self.health_check_response:
1202 # ignore the health check message as user might not expect it
1203 return None
1204 return response
1206 async def check_health(self):
1207 conn = self.connection
1208 if conn is None:
1209 raise RuntimeError(
1210 "pubsub connection not set: "
1211 "did you forget to call subscribe() or psubscribe()?"
1212 )
1214 if (
1215 conn.health_check_interval
1216 and asyncio.get_running_loop().time() > conn.next_health_check
1217 ):
1218 await conn.send_command(
1219 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1220 )
1222 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1223 """
1224 normalize channel/pattern names to be either bytes or strings
1225 based on whether responses are automatically decoded. this saves us
1226 from coercing the value for each message coming in.
1227 """
1228 encode = self.encoder.encode
1229 decode = self.encoder.decode
1230 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1232 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1233 """
1234 Subscribe to channel patterns. Patterns supplied as keyword arguments
1235 expect a pattern name as the key and a callable as the value. A
1236 pattern's callable will be invoked automatically when a message is
1237 received on that pattern rather than producing a message via
1238 ``listen()``.
1239 """
1240 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1241 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1242 # Mypy bug: https://github.com/python/mypy/issues/10970
1243 new_patterns.update(kwargs) # type: ignore[arg-type]
1244 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1245 # update the patterns dict AFTER we send the command. we don't want to
1246 # subscribe twice to these patterns, once for the command and again
1247 # for the reconnection.
1248 new_patterns = self._normalize_keys(new_patterns)
1249 self.patterns.update(new_patterns)
1250 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1251 return ret_val
1253 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1254 """
1255 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1256 all patterns.
1257 """
1258 patterns: Iterable[ChannelT]
1259 if args:
1260 parsed_args = list_or_args((args[0],), args[1:])
1261 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1262 else:
1263 parsed_args = []
1264 patterns = self.patterns
1265 self.pending_unsubscribe_patterns.update(patterns)
1266 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1268 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1269 """
1270 Subscribe to channels. Channels supplied as keyword arguments expect
1271 a channel name as the key and a callable as the value. A channel's
1272 callable will be invoked automatically when a message is received on
1273 that channel rather than producing a message via ``listen()`` or
1274 ``get_message()``.
1275 """
1276 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1277 new_channels = dict.fromkeys(parsed_args)
1278 # Mypy bug: https://github.com/python/mypy/issues/10970
1279 new_channels.update(kwargs) # type: ignore[arg-type]
1280 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1281 # update the channels dict AFTER we send the command. we don't want to
1282 # subscribe twice to these channels, once for the command and again
1283 # for the reconnection.
1284 new_channels = self._normalize_keys(new_channels)
1285 self.channels.update(new_channels)
1286 self.pending_unsubscribe_channels.difference_update(new_channels)
1287 return ret_val
1289 def unsubscribe(self, *args) -> Awaitable:
1290 """
1291 Unsubscribe from the supplied channels. If empty, unsubscribe from
1292 all channels
1293 """
1294 if args:
1295 parsed_args = list_or_args(args[0], args[1:])
1296 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1297 else:
1298 parsed_args = []
1299 channels = self.channels
1300 self.pending_unsubscribe_channels.update(channels)
1301 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1303 async def listen(self) -> AsyncIterator:
1304 """Listen for messages on channels this client has been subscribed to"""
1305 while self.subscribed:
1306 response = await self.handle_message(await self.parse_response(block=True))
1307 if response is not None:
1308 yield response
1310 async def get_message(
1311 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1312 ):
1313 """
1314 Get the next message if one is available, otherwise None.
1316 If timeout is specified, the system will wait for `timeout` seconds
1317 before returning. Timeout should be specified as a floating point
1318 number or None to wait indefinitely.
1319 """
1320 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1321 if response:
1322 return await self.handle_message(response, ignore_subscribe_messages)
1323 return None
1325 def ping(self, message=None) -> Awaitable[bool]:
1326 """
1327 Ping the Redis server to test connectivity.
1329 Sends a PING command to the Redis server and returns True if the server
1330 responds with "PONG".
1331 """
1332 args = ["PING", message] if message is not None else ["PING"]
1333 return self.execute_command(*args)
1335 async def handle_message(self, response, ignore_subscribe_messages=False):
1336 """
1337 Parses a pub/sub message. If the channel or pattern was subscribed to
1338 with a message handler, the handler is invoked instead of a parsed
1339 message being returned.
1340 """
1341 if response is None:
1342 return None
1343 if isinstance(response, bytes):
1344 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1345 message_type = str_if_bytes(response[0])
1346 if message_type == "pmessage":
1347 message = {
1348 "type": message_type,
1349 "pattern": response[1],
1350 "channel": response[2],
1351 "data": response[3],
1352 }
1353 elif message_type == "pong":
1354 message = {
1355 "type": message_type,
1356 "pattern": None,
1357 "channel": None,
1358 "data": response[1],
1359 }
1360 else:
1361 message = {
1362 "type": message_type,
1363 "pattern": None,
1364 "channel": response[1],
1365 "data": response[2],
1366 }
1368 if message_type in ["message", "pmessage"]:
1369 channel = str_if_bytes(message["channel"])
1370 await record_pubsub_message(
1371 direction=PubSubDirection.RECEIVE,
1372 channel=channel,
1373 )
1375 # if this is an unsubscribe message, remove it from memory
1376 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1377 if message_type == "punsubscribe":
1378 pattern = response[1]
1379 if pattern in self.pending_unsubscribe_patterns:
1380 self.pending_unsubscribe_patterns.remove(pattern)
1381 self.patterns.pop(pattern, None)
1382 else:
1383 channel = response[1]
1384 if channel in self.pending_unsubscribe_channels:
1385 self.pending_unsubscribe_channels.remove(channel)
1386 self.channels.pop(channel, None)
1388 if message_type in self.PUBLISH_MESSAGE_TYPES:
1389 # if there's a message handler, invoke it
1390 if message_type == "pmessage":
1391 handler = self.patterns.get(message["pattern"], None)
1392 else:
1393 handler = self.channels.get(message["channel"], None)
1394 if handler:
1395 if inspect.iscoroutinefunction(handler):
1396 await handler(message)
1397 else:
1398 handler(message)
1399 return None
1400 elif message_type != "pong":
1401 # this is a subscribe/unsubscribe message. ignore if we don't
1402 # want them
1403 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1404 return None
1406 return message
1408 async def run(
1409 self,
1410 *,
1411 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1412 poll_timeout: float = 1.0,
1413 pubsub=None,
1414 ) -> None:
1415 """Process pub/sub messages using registered callbacks.
1417 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1418 redis-py, but it is a coroutine. To launch it as a separate task, use
1419 ``asyncio.create_task``:
1421 >>> task = asyncio.create_task(pubsub.run())
1423 To shut it down, use asyncio cancellation:
1425 >>> task.cancel()
1426 >>> await task
1427 """
1428 for channel, handler in self.channels.items():
1429 if handler is None:
1430 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1431 for pattern, handler in self.patterns.items():
1432 if handler is None:
1433 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1435 await self.connect()
1436 while True:
1437 try:
1438 if pubsub is None:
1439 await self.get_message(
1440 ignore_subscribe_messages=True, timeout=poll_timeout
1441 )
1442 else:
1443 await pubsub.get_message(
1444 ignore_subscribe_messages=True, timeout=poll_timeout
1445 )
1446 except asyncio.CancelledError:
1447 raise
1448 except BaseException as e:
1449 if exception_handler is None:
1450 raise
1451 res = exception_handler(e, self)
1452 if inspect.isawaitable(res):
1453 await res
1454 # Ensure that other tasks on the event loop get a chance to run
1455 # if we didn't have to block for I/O anywhere.
1456 await asyncio.sleep(0)
1459class PubsubWorkerExceptionHandler(Protocol):
1460 def __call__(self, e: BaseException, pubsub: PubSub): ...
1463class AsyncPubsubWorkerExceptionHandler(Protocol):
1464 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1467PSWorkerThreadExcHandlerT = Union[
1468 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1469]
1472CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1473CommandStackT = List[CommandT]
1476class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1477 """
1478 Pipelines provide a way to transmit multiple commands to the Redis server
1479 in one transmission. This is convenient for batch processing, such as
1480 saving all the values in a list to Redis.
1482 All commands executed within a pipeline(when running in transactional mode,
1483 which is the default behavior) are wrapped with MULTI and EXEC
1484 calls. This guarantees all commands executed in the pipeline will be
1485 executed atomically.
1487 Any command raising an exception does *not* halt the execution of
1488 subsequent commands in the pipeline. Instead, the exception is caught
1489 and its instance is placed into the response list returned by execute().
1490 Code iterating over the response list should be able to deal with an
1491 instance of an exception as a potential value. In general, these will be
1492 ResponseError exceptions, such as those raised when issuing a command
1493 on a key of a different datatype.
1494 """
1496 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1498 def __init__(
1499 self,
1500 connection_pool: ConnectionPool,
1501 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1502 transaction: bool,
1503 shard_hint: Optional[str],
1504 ):
1505 self.connection_pool = connection_pool
1506 self.connection = None
1507 self.response_callbacks = response_callbacks
1508 self.is_transaction = transaction
1509 self.shard_hint = shard_hint
1510 self.watching = False
1511 self.command_stack: CommandStackT = []
1512 self.scripts: Set[Script] = set()
1513 self.explicit_transaction = False
1515 async def __aenter__(self: _RedisT) -> _RedisT:
1516 return self
1518 async def __aexit__(self, exc_type, exc_value, traceback):
1519 await self.reset()
1521 def __await__(self):
1522 return self._async_self().__await__()
1524 _DEL_MESSAGE = "Unclosed Pipeline client"
1526 def __len__(self):
1527 return len(self.command_stack)
1529 def __bool__(self):
1530 """Pipeline instances should always evaluate to True"""
1531 return True
1533 async def _async_self(self):
1534 return self
1536 async def reset(self):
1537 self.command_stack = []
1538 self.scripts = set()
1539 # make sure to reset the connection state in the event that we were
1540 # watching something
1541 if self.watching and self.connection:
1542 try:
1543 # call this manually since our unwatch or
1544 # immediate_execute_command methods can call reset()
1545 await self.connection.send_command("UNWATCH")
1546 await self.connection.read_response()
1547 except ConnectionError:
1548 # disconnect will also remove any previous WATCHes
1549 if self.connection:
1550 await self.connection.disconnect()
1551 # clean up the other instance attributes
1552 self.watching = False
1553 self.explicit_transaction = False
1554 # we can safely return the connection to the pool here since we're
1555 # sure we're no longer WATCHing anything
1556 if self.connection:
1557 await self.connection_pool.release(self.connection)
1558 self.connection = None
1560 async def aclose(self) -> None:
1561 """Alias for reset(), a standard method name for cleanup"""
1562 await self.reset()
1564 def multi(self):
1565 """
1566 Start a transactional block of the pipeline after WATCH commands
1567 are issued. End the transactional block with `execute`.
1568 """
1569 if self.explicit_transaction:
1570 raise RedisError("Cannot issue nested calls to MULTI")
1571 if self.command_stack:
1572 raise RedisError(
1573 "Commands without an initial WATCH have already been issued"
1574 )
1575 self.explicit_transaction = True
1577 def execute_command(
1578 self, *args, **kwargs
1579 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1580 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1581 return self.immediate_execute_command(*args, **kwargs)
1582 return self.pipeline_execute_command(*args, **kwargs)
1584 async def _disconnect_reset_raise_on_watching(
1585 self,
1586 conn: Connection,
1587 error: Exception,
1588 failure_count: Optional[int] = None,
1589 start_time: Optional[float] = None,
1590 command_name: Optional[str] = None,
1591 ) -> None:
1592 """
1593 Close the connection reset watching state and
1594 raise an exception if we were watching.
1596 The supported exceptions are already checked in the
1597 retry object so we don't need to do it here.
1599 After we disconnect the connection, it will try to reconnect and
1600 do a health check as part of the send_command logic(on connection level).
1601 """
1602 if (
1603 error
1604 and failure_count is not None
1605 and failure_count <= conn.retry.get_retries()
1606 ):
1607 await record_operation_duration(
1608 command_name=command_name,
1609 duration_seconds=time.monotonic() - start_time,
1610 server_address=getattr(conn, "host", None),
1611 server_port=getattr(conn, "port", None),
1612 db_namespace=str(conn.db),
1613 error=error,
1614 retry_attempts=failure_count,
1615 )
1616 await conn.disconnect(error=error, failure_count=failure_count)
1617 # if we were already watching a variable, the watch is no longer
1618 # valid since this connection has died. raise a WatchError, which
1619 # indicates the user should retry this transaction.
1620 if self.watching:
1621 await self.reset()
1622 raise WatchError(
1623 f"A {type(error).__name__} occurred while watching one or more keys"
1624 )
1626 async def immediate_execute_command(self, *args, **options):
1627 """
1628 Execute a command immediately, but don't auto-retry on the supported
1629 errors for retry if we're already WATCHing a variable.
1630 Used when issuing WATCH or subsequent commands retrieving their values but before
1631 MULTI is called.
1632 """
1633 command_name = args[0]
1634 conn = self.connection
1635 # if this is the first call, we need a connection
1636 if not conn:
1637 conn = await self.connection_pool.get_connection()
1638 self.connection = conn
1640 # Start timing for observability
1641 start_time = time.monotonic()
1642 # Track actual retry attempts for error reporting
1643 actual_retry_attempts = 0
1645 def failure_callback(error, failure_count):
1646 nonlocal actual_retry_attempts
1647 actual_retry_attempts = failure_count
1648 return self._disconnect_reset_raise_on_watching(
1649 conn, error, failure_count, start_time, command_name
1650 )
1652 try:
1653 response = await conn.retry.call_with_retry(
1654 lambda: self._send_command_parse_response(
1655 conn, command_name, *args, **options
1656 ),
1657 failure_callback,
1658 with_failure_count=True,
1659 )
1661 await record_operation_duration(
1662 command_name=command_name,
1663 duration_seconds=time.monotonic() - start_time,
1664 server_address=getattr(conn, "host", None),
1665 server_port=getattr(conn, "port", None),
1666 db_namespace=str(conn.db),
1667 )
1669 return response
1670 except Exception as e:
1671 await record_error_count(
1672 server_address=getattr(conn, "host", None),
1673 server_port=getattr(conn, "port", None),
1674 network_peer_address=getattr(conn, "host", None),
1675 network_peer_port=getattr(conn, "port", None),
1676 error_type=e,
1677 retry_attempts=actual_retry_attempts,
1678 is_internal=False,
1679 )
1680 raise
1682 def pipeline_execute_command(self, *args, **options):
1683 """
1684 Stage a command to be executed when execute() is next called
1686 Returns the current Pipeline object back so commands can be
1687 chained together, such as:
1689 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1691 At some other point, you can then run: pipe.execute(),
1692 which will execute all commands queued in the pipe.
1693 """
1694 self.command_stack.append((args, options))
1695 return self
1697 async def _execute_transaction( # noqa: C901
1698 self, connection: Connection, commands: CommandStackT, raise_on_error
1699 ):
1700 pre: CommandT = (("MULTI",), {})
1701 post: CommandT = (("EXEC",), {})
1702 cmds = (pre, *commands, post)
1703 all_cmds = connection.pack_commands(
1704 args for args, options in cmds if EMPTY_RESPONSE not in options
1705 )
1706 await connection.send_packed_command(all_cmds)
1707 errors = []
1709 # parse off the response for MULTI
1710 # NOTE: we need to handle ResponseErrors here and continue
1711 # so that we read all the additional command messages from
1712 # the socket
1713 try:
1714 await self.parse_response(connection, "_")
1715 except ResponseError as err:
1716 errors.append((0, err))
1718 # and all the other commands
1719 for i, command in enumerate(commands):
1720 if EMPTY_RESPONSE in command[1]:
1721 errors.append((i, command[1][EMPTY_RESPONSE]))
1722 else:
1723 try:
1724 await self.parse_response(connection, "_")
1725 except ResponseError as err:
1726 self.annotate_exception(err, i + 1, command[0])
1727 errors.append((i, err))
1729 # parse the EXEC.
1730 try:
1731 response = await self.parse_response(connection, "_")
1732 except ExecAbortError as err:
1733 if errors:
1734 raise errors[0][1] from err
1735 raise
1737 # EXEC clears any watched keys
1738 self.watching = False
1740 if response is None:
1741 raise WatchError("Watched variable changed.") from None
1743 # put any parse errors into the response
1744 for i, e in errors:
1745 response.insert(i, e)
1747 if len(response) != len(commands):
1748 if self.connection:
1749 await self.connection.disconnect()
1750 raise ResponseError(
1751 "Wrong number of response items from pipeline execution"
1752 ) from None
1754 # find any errors in the response and raise if necessary
1755 if raise_on_error:
1756 self.raise_first_error(commands, response)
1758 # We have to run response callbacks manually
1759 data = []
1760 for r, cmd in zip(response, commands):
1761 if not isinstance(r, Exception):
1762 args, options = cmd
1763 command_name = args[0]
1765 # Remove keys entry, it needs only for cache.
1766 options.pop("keys", None)
1768 if command_name in self.response_callbacks:
1769 r = self.response_callbacks[command_name](r, **options)
1770 if inspect.isawaitable(r):
1771 r = await r
1772 data.append(r)
1773 return data
1775 async def _execute_pipeline(
1776 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1777 ):
1778 # build up all commands into a single request to increase network perf
1779 all_cmds = connection.pack_commands([args for args, _ in commands])
1780 await connection.send_packed_command(all_cmds)
1782 response = []
1783 for args, options in commands:
1784 try:
1785 response.append(
1786 await self.parse_response(connection, args[0], **options)
1787 )
1788 except ResponseError as e:
1789 response.append(e)
1791 if raise_on_error:
1792 self.raise_first_error(commands, response)
1793 return response
1795 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1796 for i, r in enumerate(response):
1797 if isinstance(r, ResponseError):
1798 self.annotate_exception(r, i + 1, commands[i][0])
1799 raise r
1801 def annotate_exception(
1802 self, exception: Exception, number: int, command: Iterable[object]
1803 ) -> None:
1804 cmd = " ".join(map(safe_str, command))
1805 msg = (
1806 f"Command # {number} ({truncate_text(cmd)}) "
1807 f"of pipeline caused error: {exception.args}"
1808 )
1809 exception.args = (msg,) + exception.args[1:]
1811 async def parse_response(
1812 self, connection: Connection, command_name: Union[str, bytes], **options
1813 ):
1814 result = await super().parse_response(connection, command_name, **options)
1815 if command_name in self.UNWATCH_COMMANDS:
1816 self.watching = False
1817 elif command_name == "WATCH":
1818 self.watching = True
1819 return result
1821 async def load_scripts(self):
1822 # make sure all scripts that are about to be run on this pipeline exist
1823 scripts = list(self.scripts)
1824 immediate = self.immediate_execute_command
1825 shas = [s.sha for s in scripts]
1826 # we can't use the normal script_* methods because they would just
1827 # get buffered in the pipeline.
1828 exists = await immediate("SCRIPT EXISTS", *shas)
1829 if not all(exists):
1830 for s, exist in zip(scripts, exists):
1831 if not exist:
1832 s.sha = await immediate("SCRIPT LOAD", s.script)
1834 async def _disconnect_raise_on_watching(
1835 self,
1836 conn: Connection,
1837 error: Exception,
1838 failure_count: Optional[int] = None,
1839 start_time: Optional[float] = None,
1840 command_name: Optional[str] = None,
1841 ):
1842 """
1843 Close the connection, raise an exception if we were watching.
1845 The supported exceptions are already checked in the
1846 retry object so we don't need to do it here.
1848 After we disconnect the connection, it will try to reconnect and
1849 do a health check as part of the send_command logic(on connection level).
1850 """
1851 if (
1852 error
1853 and failure_count is not None
1854 and failure_count <= conn.retry.get_retries()
1855 ):
1856 await record_operation_duration(
1857 command_name=command_name,
1858 duration_seconds=time.monotonic() - start_time,
1859 server_address=getattr(conn, "host", None),
1860 server_port=getattr(conn, "port", None),
1861 db_namespace=str(conn.db),
1862 error=error,
1863 retry_attempts=failure_count,
1864 )
1865 await conn.disconnect(error=error, failure_count=failure_count)
1866 # if we were watching a variable, the watch is no longer valid
1867 # since this connection has died. raise a WatchError, which
1868 # indicates the user should retry this transaction.
1869 if self.watching:
1870 raise WatchError(
1871 f"A {type(error).__name__} occurred while watching one or more keys"
1872 )
1874 async def execute(self, raise_on_error: bool = True) -> List[Any]:
1875 """Execute all the commands in the current pipeline"""
1876 stack = self.command_stack
1877 if not stack and not self.watching:
1878 return []
1879 if self.scripts:
1880 await self.load_scripts()
1881 if self.is_transaction or self.explicit_transaction:
1882 execute = self._execute_transaction
1883 operation_name = "MULTI"
1884 else:
1885 execute = self._execute_pipeline
1886 operation_name = "PIPELINE"
1888 conn = self.connection
1889 if not conn:
1890 conn = await self.connection_pool.get_connection()
1891 # assign to self.connection so reset() releases the connection
1892 # back to the pool after we're done
1893 self.connection = conn
1894 conn = cast(Connection, conn)
1896 # Start timing for observability
1897 start_time = time.monotonic()
1898 # Track actual retry attempts for error reporting
1899 actual_retry_attempts = 0
1901 def failure_callback(error, failure_count):
1902 nonlocal actual_retry_attempts
1903 actual_retry_attempts = failure_count
1904 return self._disconnect_raise_on_watching(
1905 conn, error, failure_count, start_time, operation_name
1906 )
1908 try:
1909 response = await conn.retry.call_with_retry(
1910 lambda: execute(conn, stack, raise_on_error),
1911 failure_callback,
1912 with_failure_count=True,
1913 )
1915 await record_operation_duration(
1916 command_name=operation_name,
1917 duration_seconds=time.monotonic() - start_time,
1918 server_address=getattr(conn, "host", None),
1919 server_port=getattr(conn, "port", None),
1920 db_namespace=str(conn.db),
1921 )
1922 return response
1923 except Exception as e:
1924 await record_error_count(
1925 server_address=getattr(conn, "host", None),
1926 server_port=getattr(conn, "port", None),
1927 network_peer_address=getattr(conn, "host", None),
1928 network_peer_port=getattr(conn, "port", None),
1929 error_type=e,
1930 retry_attempts=actual_retry_attempts,
1931 is_internal=False,
1932 )
1933 raise
1934 finally:
1935 await self.reset()
1937 async def discard(self):
1938 """Flushes all previously queued commands
1939 See: https://redis.io/commands/DISCARD
1940 """
1941 await self.execute_command("DISCARD")
1943 async def watch(self, *names: KeyT):
1944 """Watches the values at keys ``names``"""
1945 if self.explicit_transaction:
1946 raise RedisError("Cannot issue a WATCH after a MULTI")
1947 return await self.execute_command("WATCH", *names)
1949 async def unwatch(self):
1950 """Unwatches all previously specified keys"""
1951 return self.watching and await self.execute_command("UNWATCH") or True