Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/client.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import copy
3import inspect
4import re
5import warnings
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 AsyncIterator,
10 Awaitable,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 MutableMapping,
17 Optional,
18 Protocol,
19 Set,
20 Tuple,
21 Type,
22 TypedDict,
23 TypeVar,
24 Union,
25 cast,
26)
28from redis._parsers.helpers import (
29 _RedisCallbacks,
30 _RedisCallbacksRESP2,
31 _RedisCallbacksRESP3,
32 bool_ok,
33)
34from redis.asyncio.connection import (
35 Connection,
36 ConnectionPool,
37 SSLConnection,
38 UnixDomainSocketConnection,
39)
40from redis.asyncio.lock import Lock
41from redis.asyncio.retry import Retry
42from redis.backoff import ExponentialWithJitterBackoff
43from redis.client import (
44 EMPTY_RESPONSE,
45 NEVER_DECODE,
46 AbstractRedis,
47 CaseInsensitiveDict,
48)
49from redis.commands import (
50 AsyncCoreCommands,
51 AsyncRedisModuleCommands,
52 AsyncSentinelCommands,
53 list_or_args,
54)
55from redis.credentials import CredentialProvider
56from redis.driver_info import DriverInfo, resolve_driver_info
57from redis.event import (
58 AfterPooledConnectionsInstantiationEvent,
59 AfterPubSubConnectionInstantiationEvent,
60 AfterSingleConnectionInstantiationEvent,
61 ClientType,
62 EventDispatcher,
63)
64from redis.exceptions import (
65 ConnectionError,
66 ExecAbortError,
67 PubSubError,
68 RedisError,
69 ResponseError,
70 WatchError,
71)
72from redis.typing import ChannelT, EncodableT, KeyT
73from redis.utils import (
74 SSL_AVAILABLE,
75 _set_info_logger,
76 deprecated_args,
77 deprecated_function,
78 safe_str,
79 str_if_bytes,
80 truncate_text,
81)
83if TYPE_CHECKING and SSL_AVAILABLE:
84 from ssl import TLSVersion, VerifyFlags, VerifyMode
85else:
86 TLSVersion = None
87 VerifyMode = None
88 VerifyFlags = None
90PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
91_KeyT = TypeVar("_KeyT", bound=KeyT)
92_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
93_RedisT = TypeVar("_RedisT", bound="Redis")
94_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
95if TYPE_CHECKING:
96 from redis.commands.core import Script
99class ResponseCallbackProtocol(Protocol):
100 def __call__(self, response: Any, **kwargs): ...
103class AsyncResponseCallbackProtocol(Protocol):
104 async def __call__(self, response: Any, **kwargs): ...
107ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
110class Redis(
111 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
112):
113 """
114 Implementation of the Redis protocol.
116 This abstract class provides a Python interface to all Redis commands
117 and an implementation of the Redis protocol.
119 Pipelines derive from this, implementing how
120 the commands are sent and received to the Redis server. Based on
121 configuration, an instance will either use a ConnectionPool, or
122 Connection object to talk to redis.
123 """
125 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
127 @classmethod
128 def from_url(
129 cls: Type["Redis"],
130 url: str,
131 single_connection_client: bool = False,
132 auto_close_connection_pool: Optional[bool] = None,
133 **kwargs,
134 ) -> "Redis":
135 """
136 Return a Redis client object configured from the given URL
138 For example::
140 redis://[[username]:[password]]@localhost:6379/0
141 rediss://[[username]:[password]]@localhost:6379/0
142 unix://[username@]/path/to/socket.sock?db=0[&password=password]
144 Three URL schemes are supported:
146 - `redis://` creates a TCP socket connection. See more at:
147 <https://www.iana.org/assignments/uri-schemes/prov/redis>
148 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
149 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
150 - ``unix://``: creates a Unix Domain Socket connection.
152 The username, password, hostname, path and all querystring values
153 are passed through urllib.parse.unquote in order to replace any
154 percent-encoded values with their corresponding characters.
156 There are several ways to specify a database number. The first value
157 found will be used:
159 1. A ``db`` querystring option, e.g. redis://localhost?db=0
161 2. If using the redis:// or rediss:// schemes, the path argument
162 of the url, e.g. redis://localhost/0
164 3. A ``db`` keyword argument to this function.
166 If none of these options are specified, the default db=0 is used.
168 All querystring options are cast to their appropriate Python types.
169 Boolean arguments can be specified with string values "True"/"False"
170 or "Yes"/"No". Values that cannot be properly cast cause a
171 ``ValueError`` to be raised. Once parsed, the querystring arguments
172 and keyword arguments are passed to the ``ConnectionPool``'s
173 class initializer. In the case of conflicting arguments, querystring
174 arguments always win.
176 """
177 connection_pool = ConnectionPool.from_url(url, **kwargs)
178 client = cls(
179 connection_pool=connection_pool,
180 single_connection_client=single_connection_client,
181 )
182 if auto_close_connection_pool is not None:
183 warnings.warn(
184 DeprecationWarning(
185 '"auto_close_connection_pool" is deprecated '
186 "since version 5.0.1. "
187 "Please create a ConnectionPool explicitly and "
188 "provide to the Redis() constructor instead."
189 )
190 )
191 else:
192 auto_close_connection_pool = True
193 client.auto_close_connection_pool = auto_close_connection_pool
194 return client
196 @classmethod
197 def from_pool(
198 cls: Type["Redis"],
199 connection_pool: ConnectionPool,
200 ) -> "Redis":
201 """
202 Return a Redis client from the given connection pool.
203 The Redis client will take ownership of the connection pool and
204 close it when the Redis client is closed.
205 """
206 client = cls(
207 connection_pool=connection_pool,
208 )
209 client.auto_close_connection_pool = True
210 return client
212 @deprecated_args(
213 args_to_warn=["retry_on_timeout"],
214 reason="TimeoutError is included by default.",
215 version="6.0.0",
216 )
217 @deprecated_args(
218 args_to_warn=["lib_name", "lib_version"],
219 reason="Use 'driver_info' parameter instead. "
220 "lib_name and lib_version will be removed in a future version.",
221 )
222 def __init__(
223 self,
224 *,
225 host: str = "localhost",
226 port: int = 6379,
227 db: Union[str, int] = 0,
228 password: Optional[str] = None,
229 socket_timeout: Optional[float] = None,
230 socket_connect_timeout: Optional[float] = None,
231 socket_keepalive: Optional[bool] = None,
232 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
233 connection_pool: Optional[ConnectionPool] = None,
234 unix_socket_path: Optional[str] = None,
235 encoding: str = "utf-8",
236 encoding_errors: str = "strict",
237 decode_responses: bool = False,
238 retry_on_timeout: bool = False,
239 retry: Retry = Retry(
240 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
241 ),
242 retry_on_error: Optional[list] = None,
243 ssl: bool = False,
244 ssl_keyfile: Optional[str] = None,
245 ssl_certfile: Optional[str] = None,
246 ssl_cert_reqs: Union[str, VerifyMode] = "required",
247 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
248 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
249 ssl_ca_certs: Optional[str] = None,
250 ssl_ca_data: Optional[str] = None,
251 ssl_ca_path: Optional[str] = None,
252 ssl_check_hostname: bool = True,
253 ssl_min_version: Optional[TLSVersion] = None,
254 ssl_ciphers: Optional[str] = None,
255 ssl_password: Optional[str] = None,
256 max_connections: Optional[int] = None,
257 single_connection_client: bool = False,
258 health_check_interval: int = 0,
259 client_name: Optional[str] = None,
260 lib_name: Optional[str] = None,
261 lib_version: Optional[str] = None,
262 driver_info: Optional["DriverInfo"] = None,
263 username: Optional[str] = None,
264 auto_close_connection_pool: Optional[bool] = None,
265 redis_connect_func=None,
266 credential_provider: Optional[CredentialProvider] = None,
267 protocol: Optional[int] = 2,
268 event_dispatcher: Optional[EventDispatcher] = None,
269 ):
270 """
271 Initialize a new Redis client.
273 To specify a retry policy for specific errors, you have two options:
275 1. Set the `retry_on_error` to a list of the error/s to retry on, and
276 you can also set `retry` to a valid `Retry` object(in case the default
277 one is not appropriate) - with this approach the retries will be triggered
278 on the default errors specified in the Retry object enriched with the
279 errors specified in `retry_on_error`.
281 2. Define a `Retry` object with configured 'supported_errors' and set
282 it to the `retry` parameter - with this approach you completely redefine
283 the errors on which retries will happen.
285 `retry_on_timeout` is deprecated - please include the TimeoutError
286 either in the Retry object or in the `retry_on_error` list.
288 When 'connection_pool' is provided - the retry configuration of the
289 provided pool will be used.
290 """
291 kwargs: Dict[str, Any]
292 if event_dispatcher is None:
293 self._event_dispatcher = EventDispatcher()
294 else:
295 self._event_dispatcher = event_dispatcher
296 # auto_close_connection_pool only has an effect if connection_pool is
297 # None. It is assumed that if connection_pool is not None, the user
298 # wants to manage the connection pool themselves.
299 if auto_close_connection_pool is not None:
300 warnings.warn(
301 DeprecationWarning(
302 '"auto_close_connection_pool" is deprecated '
303 "since version 5.0.1. "
304 "Please create a ConnectionPool explicitly and "
305 "provide to the Redis() constructor instead."
306 )
307 )
308 else:
309 auto_close_connection_pool = True
311 if not connection_pool:
312 # Create internal connection pool, expected to be closed by Redis instance
313 if not retry_on_error:
314 retry_on_error = []
316 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
317 computed_driver_info = resolve_driver_info(
318 driver_info, lib_name, lib_version
319 )
321 kwargs = {
322 "db": db,
323 "username": username,
324 "password": password,
325 "credential_provider": credential_provider,
326 "socket_timeout": socket_timeout,
327 "encoding": encoding,
328 "encoding_errors": encoding_errors,
329 "decode_responses": decode_responses,
330 "retry_on_error": retry_on_error,
331 "retry": copy.deepcopy(retry),
332 "max_connections": max_connections,
333 "health_check_interval": health_check_interval,
334 "client_name": client_name,
335 "driver_info": computed_driver_info,
336 "redis_connect_func": redis_connect_func,
337 "protocol": protocol,
338 }
339 # based on input, setup appropriate connection args
340 if unix_socket_path is not None:
341 kwargs.update(
342 {
343 "path": unix_socket_path,
344 "connection_class": UnixDomainSocketConnection,
345 }
346 )
347 else:
348 # TCP specific options
349 kwargs.update(
350 {
351 "host": host,
352 "port": port,
353 "socket_connect_timeout": socket_connect_timeout,
354 "socket_keepalive": socket_keepalive,
355 "socket_keepalive_options": socket_keepalive_options,
356 }
357 )
359 if ssl:
360 kwargs.update(
361 {
362 "connection_class": SSLConnection,
363 "ssl_keyfile": ssl_keyfile,
364 "ssl_certfile": ssl_certfile,
365 "ssl_cert_reqs": ssl_cert_reqs,
366 "ssl_include_verify_flags": ssl_include_verify_flags,
367 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
368 "ssl_ca_certs": ssl_ca_certs,
369 "ssl_ca_data": ssl_ca_data,
370 "ssl_ca_path": ssl_ca_path,
371 "ssl_check_hostname": ssl_check_hostname,
372 "ssl_min_version": ssl_min_version,
373 "ssl_ciphers": ssl_ciphers,
374 "ssl_password": ssl_password,
375 }
376 )
377 # This arg only used if no pool is passed in
378 self.auto_close_connection_pool = auto_close_connection_pool
379 connection_pool = ConnectionPool(**kwargs)
380 self._event_dispatcher.dispatch(
381 AfterPooledConnectionsInstantiationEvent(
382 [connection_pool], ClientType.ASYNC, credential_provider
383 )
384 )
385 else:
386 # If a pool is passed in, do not close it
387 self.auto_close_connection_pool = False
388 self._event_dispatcher.dispatch(
389 AfterPooledConnectionsInstantiationEvent(
390 [connection_pool], ClientType.ASYNC, credential_provider
391 )
392 )
394 self.connection_pool = connection_pool
395 self.single_connection_client = single_connection_client
396 self.connection: Optional[Connection] = None
398 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
400 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
401 self.response_callbacks.update(_RedisCallbacksRESP3)
402 else:
403 self.response_callbacks.update(_RedisCallbacksRESP2)
405 # If using a single connection client, we need to lock creation-of and use-of
406 # the client in order to avoid race conditions such as using asyncio.gather
407 # on a set of redis commands
408 self._single_conn_lock = asyncio.Lock()
410 # When used as an async context manager, we need to increment and decrement
411 # a usage counter so that we can close the connection pool when no one is
412 # using the client.
413 self._usage_counter = 0
414 self._usage_lock = asyncio.Lock()
416 def __repr__(self):
417 return (
418 f"<{self.__class__.__module__}.{self.__class__.__name__}"
419 f"({self.connection_pool!r})>"
420 )
422 def __await__(self):
423 return self.initialize().__await__()
425 async def initialize(self: _RedisT) -> _RedisT:
426 if self.single_connection_client:
427 async with self._single_conn_lock:
428 if self.connection is None:
429 self.connection = await self.connection_pool.get_connection()
431 self._event_dispatcher.dispatch(
432 AfterSingleConnectionInstantiationEvent(
433 self.connection, ClientType.ASYNC, self._single_conn_lock
434 )
435 )
436 return self
438 def set_response_callback(self, command: str, callback: ResponseCallbackT):
439 """Set a custom Response Callback"""
440 self.response_callbacks[command] = callback
442 def get_encoder(self):
443 """Get the connection pool's encoder"""
444 return self.connection_pool.get_encoder()
446 def get_connection_kwargs(self):
447 """Get the connection's key-word arguments"""
448 return self.connection_pool.connection_kwargs
450 def get_retry(self) -> Optional[Retry]:
451 return self.get_connection_kwargs().get("retry")
453 def set_retry(self, retry: Retry) -> None:
454 self.get_connection_kwargs().update({"retry": retry})
455 self.connection_pool.set_retry(retry)
457 def load_external_module(self, funcname, func):
458 """
459 This function can be used to add externally defined redis modules,
460 and their namespaces to the redis client.
462 funcname - A string containing the name of the function to create
463 func - The function, being added to this class.
465 ex: Assume that one has a custom redis module named foomod that
466 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
467 To load function functions into this namespace:
469 from redis import Redis
470 from foomodule import F
471 r = Redis()
472 r.load_external_module("foo", F)
473 r.foo().dothing('your', 'arguments')
475 For a concrete example see the reimport of the redisjson module in
476 tests/test_connection.py::test_loading_external_modules
477 """
478 setattr(self, funcname, func)
480 def pipeline(
481 self, transaction: bool = True, shard_hint: Optional[str] = None
482 ) -> "Pipeline":
483 """
484 Return a new pipeline object that can queue multiple commands for
485 later execution. ``transaction`` indicates whether all commands
486 should be executed atomically. Apart from making a group of operations
487 atomic, pipelines are useful for reducing the back-and-forth overhead
488 between the client and server.
489 """
490 return Pipeline(
491 self.connection_pool, self.response_callbacks, transaction, shard_hint
492 )
494 async def transaction(
495 self,
496 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
497 *watches: KeyT,
498 shard_hint: Optional[str] = None,
499 value_from_callable: bool = False,
500 watch_delay: Optional[float] = None,
501 ):
502 """
503 Convenience method for executing the callable `func` as a transaction
504 while watching all keys specified in `watches`. The 'func' callable
505 should expect a single argument which is a Pipeline object.
506 """
507 pipe: Pipeline
508 async with self.pipeline(True, shard_hint) as pipe:
509 while True:
510 try:
511 if watches:
512 await pipe.watch(*watches)
513 func_value = func(pipe)
514 if inspect.isawaitable(func_value):
515 func_value = await func_value
516 exec_value = await pipe.execute()
517 return func_value if value_from_callable else exec_value
518 except WatchError:
519 if watch_delay is not None and watch_delay > 0:
520 await asyncio.sleep(watch_delay)
521 continue
523 def lock(
524 self,
525 name: KeyT,
526 timeout: Optional[float] = None,
527 sleep: float = 0.1,
528 blocking: bool = True,
529 blocking_timeout: Optional[float] = None,
530 lock_class: Optional[Type[Lock]] = None,
531 thread_local: bool = True,
532 raise_on_release_error: bool = True,
533 ) -> Lock:
534 """
535 Return a new Lock object using key ``name`` that mimics
536 the behavior of threading.Lock.
538 If specified, ``timeout`` indicates a maximum life for the lock.
539 By default, it will remain locked until release() is called.
541 ``sleep`` indicates the amount of time to sleep per loop iteration
542 when the lock is in blocking mode and another client is currently
543 holding the lock.
545 ``blocking`` indicates whether calling ``acquire`` should block until
546 the lock has been acquired or to fail immediately, causing ``acquire``
547 to return False and the lock not being acquired. Defaults to True.
548 Note this value can be overridden by passing a ``blocking``
549 argument to ``acquire``.
551 ``blocking_timeout`` indicates the maximum amount of time in seconds to
552 spend trying to acquire the lock. A value of ``None`` indicates
553 continue trying forever. ``blocking_timeout`` can be specified as a
554 float or integer, both representing the number of seconds to wait.
556 ``lock_class`` forces the specified lock implementation. Note that as
557 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
558 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
559 you have created your own custom lock class.
561 ``thread_local`` indicates whether the lock token is placed in
562 thread-local storage. By default, the token is placed in thread local
563 storage so that a thread only sees its token, not a token set by
564 another thread. Consider the following timeline:
566 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
567 thread-1 sets the token to "abc"
568 time: 1, thread-2 blocks trying to acquire `my-lock` using the
569 Lock instance.
570 time: 5, thread-1 has not yet completed. redis expires the lock
571 key.
572 time: 5, thread-2 acquired `my-lock` now that it's available.
573 thread-2 sets the token to "xyz"
574 time: 6, thread-1 finishes its work and calls release(). if the
575 token is *not* stored in thread local storage, then
576 thread-1 would see the token value as "xyz" and would be
577 able to successfully release the thread-2's lock.
579 ``raise_on_release_error`` indicates whether to raise an exception when
580 the lock is no longer owned when exiting the context manager. By default,
581 this is True, meaning an exception will be raised. If False, the warning
582 will be logged and the exception will be suppressed.
584 In some use cases it's necessary to disable thread local storage. For
585 example, if you have code where one thread acquires a lock and passes
586 that lock instance to a worker thread to release later. If thread
587 local storage isn't disabled in this case, the worker thread won't see
588 the token set by the thread that acquired the lock. Our assumption
589 is that these cases aren't common and as such default to using
590 thread local storage."""
591 if lock_class is None:
592 lock_class = Lock
593 return lock_class(
594 self,
595 name,
596 timeout=timeout,
597 sleep=sleep,
598 blocking=blocking,
599 blocking_timeout=blocking_timeout,
600 thread_local=thread_local,
601 raise_on_release_error=raise_on_release_error,
602 )
604 def pubsub(self, **kwargs) -> "PubSub":
605 """
606 Return a Publish/Subscribe object. With this object, you can
607 subscribe to channels and listen for messages that get published to
608 them.
609 """
610 return PubSub(
611 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
612 )
614 def monitor(self) -> "Monitor":
615 return Monitor(self.connection_pool)
617 def client(self) -> "Redis":
618 return self.__class__(
619 connection_pool=self.connection_pool, single_connection_client=True
620 )
622 async def __aenter__(self: _RedisT) -> _RedisT:
623 """
624 Async context manager entry. Increments a usage counter so that the
625 connection pool is only closed (via aclose()) when no context is using
626 the client.
627 """
628 await self._increment_usage()
629 try:
630 # Initialize the client (i.e. establish connection, etc.)
631 return await self.initialize()
632 except Exception:
633 # If initialization fails, decrement the counter to keep it in sync
634 await self._decrement_usage()
635 raise
637 async def _increment_usage(self) -> int:
638 """
639 Helper coroutine to increment the usage counter while holding the lock.
640 Returns the new value of the usage counter.
641 """
642 async with self._usage_lock:
643 self._usage_counter += 1
644 return self._usage_counter
646 async def _decrement_usage(self) -> int:
647 """
648 Helper coroutine to decrement the usage counter while holding the lock.
649 Returns the new value of the usage counter.
650 """
651 async with self._usage_lock:
652 self._usage_counter -= 1
653 return self._usage_counter
655 async def __aexit__(self, exc_type, exc_value, traceback):
656 """
657 Async context manager exit. Decrements a usage counter. If this is the
658 last exit (counter becomes zero), the client closes its connection pool.
659 """
660 current_usage = await asyncio.shield(self._decrement_usage())
661 if current_usage == 0:
662 # This was the last active context, so disconnect the pool.
663 await asyncio.shield(self.aclose())
665 _DEL_MESSAGE = "Unclosed Redis client"
667 # passing _warnings and _grl as argument default since they may be gone
668 # by the time __del__ is called at shutdown
669 def __del__(
670 self,
671 _warn: Any = warnings.warn,
672 _grl: Any = asyncio.get_running_loop,
673 ) -> None:
674 if hasattr(self, "connection") and (self.connection is not None):
675 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
676 try:
677 context = {"client": self, "message": self._DEL_MESSAGE}
678 _grl().call_exception_handler(context)
679 except RuntimeError:
680 pass
681 self.connection._close()
683 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
684 """
685 Closes Redis client connection
687 Args:
688 close_connection_pool:
689 decides whether to close the connection pool used by this Redis client,
690 overriding Redis.auto_close_connection_pool.
691 By default, let Redis.auto_close_connection_pool decide
692 whether to close the connection pool.
693 """
694 conn = self.connection
695 if conn:
696 self.connection = None
697 await self.connection_pool.release(conn)
698 if close_connection_pool or (
699 close_connection_pool is None and self.auto_close_connection_pool
700 ):
701 await self.connection_pool.disconnect()
703 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
704 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
705 """
706 Alias for aclose(), for backwards compatibility
707 """
708 await self.aclose(close_connection_pool)
710 async def _send_command_parse_response(self, conn, command_name, *args, **options):
711 """
712 Send a command and parse the response
713 """
714 await conn.send_command(*args)
715 return await self.parse_response(conn, command_name, **options)
717 async def _close_connection(self, conn: Connection):
718 """
719 Close the connection before retrying.
721 The supported exceptions are already checked in the
722 retry object so we don't need to do it here.
724 After we disconnect the connection, it will try to reconnect and
725 do a health check as part of the send_command logic(on connection level).
726 """
727 await conn.disconnect()
729 # COMMAND EXECUTION AND PROTOCOL PARSING
730 async def execute_command(self, *args, **options):
731 """Execute a command and return a parsed response"""
732 await self.initialize()
733 pool = self.connection_pool
734 command_name = args[0]
735 conn = self.connection or await pool.get_connection()
737 if self.single_connection_client:
738 await self._single_conn_lock.acquire()
739 try:
740 return await conn.retry.call_with_retry(
741 lambda: self._send_command_parse_response(
742 conn, command_name, *args, **options
743 ),
744 lambda _: self._close_connection(conn),
745 )
746 finally:
747 if self.single_connection_client:
748 self._single_conn_lock.release()
749 if not self.connection:
750 await pool.release(conn)
752 async def parse_response(
753 self, connection: Connection, command_name: Union[str, bytes], **options
754 ):
755 """Parses a response from the Redis server"""
756 try:
757 if NEVER_DECODE in options:
758 response = await connection.read_response(disable_decoding=True)
759 options.pop(NEVER_DECODE)
760 else:
761 response = await connection.read_response()
762 except ResponseError:
763 if EMPTY_RESPONSE in options:
764 return options[EMPTY_RESPONSE]
765 raise
767 if EMPTY_RESPONSE in options:
768 options.pop(EMPTY_RESPONSE)
770 # Remove keys entry, it needs only for cache.
771 options.pop("keys", None)
773 if command_name in self.response_callbacks:
774 # Mypy bug: https://github.com/python/mypy/issues/10977
775 command_name = cast(str, command_name)
776 retval = self.response_callbacks[command_name](response, **options)
777 return await retval if inspect.isawaitable(retval) else retval
778 return response
781StrictRedis = Redis
784class MonitorCommandInfo(TypedDict):
785 time: float
786 db: int
787 client_address: str
788 client_port: str
789 client_type: str
790 command: str
793class Monitor:
794 """
795 Monitor is useful for handling the MONITOR command to the redis server.
796 next_command() method returns one command from monitor
797 listen() method yields commands from monitor.
798 """
800 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
801 command_re = re.compile(r'"(.*?)(?<!\\)"')
803 def __init__(self, connection_pool: ConnectionPool):
804 self.connection_pool = connection_pool
805 self.connection: Optional[Connection] = None
807 async def connect(self):
808 if self.connection is None:
809 self.connection = await self.connection_pool.get_connection()
811 async def __aenter__(self):
812 await self.connect()
813 await self.connection.send_command("MONITOR")
814 # check that monitor returns 'OK', but don't return it to user
815 response = await self.connection.read_response()
816 if not bool_ok(response):
817 raise RedisError(f"MONITOR failed: {response}")
818 return self
820 async def __aexit__(self, *args):
821 await self.connection.disconnect()
822 await self.connection_pool.release(self.connection)
824 async def next_command(self) -> MonitorCommandInfo:
825 """Parse the response from a monitor command"""
826 await self.connect()
827 response = await self.connection.read_response()
828 if isinstance(response, bytes):
829 response = self.connection.encoder.decode(response, force=True)
830 command_time, command_data = response.split(" ", 1)
831 m = self.monitor_re.match(command_data)
832 db_id, client_info, command = m.groups()
833 command = " ".join(self.command_re.findall(command))
834 # Redis escapes double quotes because each piece of the command
835 # string is surrounded by double quotes. We don't have that
836 # requirement so remove the escaping and leave the quote.
837 command = command.replace('\\"', '"')
839 if client_info == "lua":
840 client_address = "lua"
841 client_port = ""
842 client_type = "lua"
843 elif client_info.startswith("unix"):
844 client_address = "unix"
845 client_port = client_info[5:]
846 client_type = "unix"
847 else:
848 # use rsplit as ipv6 addresses contain colons
849 client_address, client_port = client_info.rsplit(":", 1)
850 client_type = "tcp"
851 return {
852 "time": float(command_time),
853 "db": int(db_id),
854 "client_address": client_address,
855 "client_port": client_port,
856 "client_type": client_type,
857 "command": command,
858 }
860 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
861 """Listen for commands coming to the server."""
862 while True:
863 yield await self.next_command()
866class PubSub:
867 """
868 PubSub provides publish, subscribe and listen support to Redis channels.
870 After subscribing to one or more channels, the listen() method will block
871 until a message arrives on one of the subscribed channels. That message
872 will be returned and it's safe to start listening again.
873 """
875 PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
876 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
877 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
879 def __init__(
880 self,
881 connection_pool: ConnectionPool,
882 shard_hint: Optional[str] = None,
883 ignore_subscribe_messages: bool = False,
884 encoder=None,
885 push_handler_func: Optional[Callable] = None,
886 event_dispatcher: Optional["EventDispatcher"] = None,
887 ):
888 if event_dispatcher is None:
889 self._event_dispatcher = EventDispatcher()
890 else:
891 self._event_dispatcher = event_dispatcher
892 self.connection_pool = connection_pool
893 self.shard_hint = shard_hint
894 self.ignore_subscribe_messages = ignore_subscribe_messages
895 self.connection = None
896 # we need to know the encoding options for this connection in order
897 # to lookup channel and pattern names for callback handlers.
898 self.encoder = encoder
899 self.push_handler_func = push_handler_func
900 if self.encoder is None:
901 self.encoder = self.connection_pool.get_encoder()
902 if self.encoder.decode_responses:
903 self.health_check_response = [
904 ["pong", self.HEALTH_CHECK_MESSAGE],
905 self.HEALTH_CHECK_MESSAGE,
906 ]
907 else:
908 self.health_check_response = [
909 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
910 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
911 ]
912 if self.push_handler_func is None:
913 _set_info_logger()
914 self.channels = {}
915 self.pending_unsubscribe_channels = set()
916 self.patterns = {}
917 self.pending_unsubscribe_patterns = set()
918 self._lock = asyncio.Lock()
920 async def __aenter__(self):
921 return self
923 async def __aexit__(self, exc_type, exc_value, traceback):
924 await self.aclose()
926 def __del__(self):
927 if self.connection:
928 self.connection.deregister_connect_callback(self.on_connect)
930 async def aclose(self):
931 # In case a connection property does not yet exist
932 # (due to a crash earlier in the Redis() constructor), return
933 # immediately as there is nothing to clean-up.
934 if not hasattr(self, "connection"):
935 return
936 async with self._lock:
937 if self.connection:
938 await self.connection.disconnect()
939 self.connection.deregister_connect_callback(self.on_connect)
940 await self.connection_pool.release(self.connection)
941 self.connection = None
942 self.channels = {}
943 self.pending_unsubscribe_channels = set()
944 self.patterns = {}
945 self.pending_unsubscribe_patterns = set()
947 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
948 async def close(self) -> None:
949 """Alias for aclose(), for backwards compatibility"""
950 await self.aclose()
952 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
953 async def reset(self) -> None:
954 """Alias for aclose(), for backwards compatibility"""
955 await self.aclose()
957 async def on_connect(self, connection: Connection):
958 """Re-subscribe to any channels and patterns previously subscribed to"""
959 # NOTE: for python3, we can't pass bytestrings as keyword arguments
960 # so we need to decode channel/pattern names back to unicode strings
961 # before passing them to [p]subscribe.
962 self.pending_unsubscribe_channels.clear()
963 self.pending_unsubscribe_patterns.clear()
964 if self.channels:
965 channels = {}
966 for k, v in self.channels.items():
967 channels[self.encoder.decode(k, force=True)] = v
968 await self.subscribe(**channels)
969 if self.patterns:
970 patterns = {}
971 for k, v in self.patterns.items():
972 patterns[self.encoder.decode(k, force=True)] = v
973 await self.psubscribe(**patterns)
975 @property
976 def subscribed(self):
977 """Indicates if there are subscriptions to any channels or patterns"""
978 return bool(self.channels or self.patterns)
980 async def execute_command(self, *args: EncodableT):
981 """Execute a publish/subscribe command"""
983 # NOTE: don't parse the response in this function -- it could pull a
984 # legitimate message off the stack if the connection is already
985 # subscribed to one or more channels
987 await self.connect()
988 connection = self.connection
989 kwargs = {"check_health": not self.subscribed}
990 await self._execute(connection, connection.send_command, *args, **kwargs)
992 async def connect(self):
993 """
994 Ensure that the PubSub is connected
995 """
996 if self.connection is None:
997 self.connection = await self.connection_pool.get_connection()
998 # register a callback that re-subscribes to any channels we
999 # were listening to when we were disconnected
1000 self.connection.register_connect_callback(self.on_connect)
1001 else:
1002 await self.connection.connect()
1003 if self.push_handler_func is not None:
1004 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
1006 self._event_dispatcher.dispatch(
1007 AfterPubSubConnectionInstantiationEvent(
1008 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
1009 )
1010 )
1012 async def _reconnect(self, conn):
1013 """
1014 Try to reconnect
1015 """
1016 await conn.disconnect()
1017 await conn.connect()
1019 async def _execute(self, conn, command, *args, **kwargs):
1020 """
1021 Connect manually upon disconnection. If the Redis server is down,
1022 this will fail and raise a ConnectionError as desired.
1023 After reconnection, the ``on_connect`` callback should have been
1024 called by the # connection to resubscribe us to any channels and
1025 patterns we were previously listening to
1026 """
1027 return await conn.retry.call_with_retry(
1028 lambda: command(*args, **kwargs),
1029 lambda _: self._reconnect(conn),
1030 )
1032 async def parse_response(self, block: bool = True, timeout: float = 0):
1033 """Parse the response from a publish/subscribe command"""
1034 conn = self.connection
1035 if conn is None:
1036 raise RuntimeError(
1037 "pubsub connection not set: "
1038 "did you forget to call subscribe() or psubscribe()?"
1039 )
1041 await self.check_health()
1043 if not conn.is_connected:
1044 await conn.connect()
1046 read_timeout = None if block else timeout
1047 response = await self._execute(
1048 conn,
1049 conn.read_response,
1050 timeout=read_timeout,
1051 disconnect_on_error=False,
1052 push_request=True,
1053 )
1055 if conn.health_check_interval and response in self.health_check_response:
1056 # ignore the health check message as user might not expect it
1057 return None
1058 return response
1060 async def check_health(self):
1061 conn = self.connection
1062 if conn is None:
1063 raise RuntimeError(
1064 "pubsub connection not set: "
1065 "did you forget to call subscribe() or psubscribe()?"
1066 )
1068 if (
1069 conn.health_check_interval
1070 and asyncio.get_running_loop().time() > conn.next_health_check
1071 ):
1072 await conn.send_command(
1073 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1074 )
1076 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1077 """
1078 normalize channel/pattern names to be either bytes or strings
1079 based on whether responses are automatically decoded. this saves us
1080 from coercing the value for each message coming in.
1081 """
1082 encode = self.encoder.encode
1083 decode = self.encoder.decode
1084 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1086 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1087 """
1088 Subscribe to channel patterns. Patterns supplied as keyword arguments
1089 expect a pattern name as the key and a callable as the value. A
1090 pattern's callable will be invoked automatically when a message is
1091 received on that pattern rather than producing a message via
1092 ``listen()``.
1093 """
1094 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1095 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1096 # Mypy bug: https://github.com/python/mypy/issues/10970
1097 new_patterns.update(kwargs) # type: ignore[arg-type]
1098 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1099 # update the patterns dict AFTER we send the command. we don't want to
1100 # subscribe twice to these patterns, once for the command and again
1101 # for the reconnection.
1102 new_patterns = self._normalize_keys(new_patterns)
1103 self.patterns.update(new_patterns)
1104 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1105 return ret_val
1107 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1108 """
1109 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1110 all patterns.
1111 """
1112 patterns: Iterable[ChannelT]
1113 if args:
1114 parsed_args = list_or_args((args[0],), args[1:])
1115 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1116 else:
1117 parsed_args = []
1118 patterns = self.patterns
1119 self.pending_unsubscribe_patterns.update(patterns)
1120 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1122 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1123 """
1124 Subscribe to channels. Channels supplied as keyword arguments expect
1125 a channel name as the key and a callable as the value. A channel's
1126 callable will be invoked automatically when a message is received on
1127 that channel rather than producing a message via ``listen()`` or
1128 ``get_message()``.
1129 """
1130 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1131 new_channels = dict.fromkeys(parsed_args)
1132 # Mypy bug: https://github.com/python/mypy/issues/10970
1133 new_channels.update(kwargs) # type: ignore[arg-type]
1134 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1135 # update the channels dict AFTER we send the command. we don't want to
1136 # subscribe twice to these channels, once for the command and again
1137 # for the reconnection.
1138 new_channels = self._normalize_keys(new_channels)
1139 self.channels.update(new_channels)
1140 self.pending_unsubscribe_channels.difference_update(new_channels)
1141 return ret_val
1143 def unsubscribe(self, *args) -> Awaitable:
1144 """
1145 Unsubscribe from the supplied channels. If empty, unsubscribe from
1146 all channels
1147 """
1148 if args:
1149 parsed_args = list_or_args(args[0], args[1:])
1150 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1151 else:
1152 parsed_args = []
1153 channels = self.channels
1154 self.pending_unsubscribe_channels.update(channels)
1155 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1157 async def listen(self) -> AsyncIterator:
1158 """Listen for messages on channels this client has been subscribed to"""
1159 while self.subscribed:
1160 response = await self.handle_message(await self.parse_response(block=True))
1161 if response is not None:
1162 yield response
1164 async def get_message(
1165 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1166 ):
1167 """
1168 Get the next message if one is available, otherwise None.
1170 If timeout is specified, the system will wait for `timeout` seconds
1171 before returning. Timeout should be specified as a floating point
1172 number or None to wait indefinitely.
1173 """
1174 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1175 if response:
1176 return await self.handle_message(response, ignore_subscribe_messages)
1177 return None
1179 def ping(self, message=None) -> Awaitable[bool]:
1180 """
1181 Ping the Redis server to test connectivity.
1183 Sends a PING command to the Redis server and returns True if the server
1184 responds with "PONG".
1185 """
1186 args = ["PING", message] if message is not None else ["PING"]
1187 return self.execute_command(*args)
1189 async def handle_message(self, response, ignore_subscribe_messages=False):
1190 """
1191 Parses a pub/sub message. If the channel or pattern was subscribed to
1192 with a message handler, the handler is invoked instead of a parsed
1193 message being returned.
1194 """
1195 if response is None:
1196 return None
1197 if isinstance(response, bytes):
1198 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1199 message_type = str_if_bytes(response[0])
1200 if message_type == "pmessage":
1201 message = {
1202 "type": message_type,
1203 "pattern": response[1],
1204 "channel": response[2],
1205 "data": response[3],
1206 }
1207 elif message_type == "pong":
1208 message = {
1209 "type": message_type,
1210 "pattern": None,
1211 "channel": None,
1212 "data": response[1],
1213 }
1214 else:
1215 message = {
1216 "type": message_type,
1217 "pattern": None,
1218 "channel": response[1],
1219 "data": response[2],
1220 }
1222 # if this is an unsubscribe message, remove it from memory
1223 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1224 if message_type == "punsubscribe":
1225 pattern = response[1]
1226 if pattern in self.pending_unsubscribe_patterns:
1227 self.pending_unsubscribe_patterns.remove(pattern)
1228 self.patterns.pop(pattern, None)
1229 else:
1230 channel = response[1]
1231 if channel in self.pending_unsubscribe_channels:
1232 self.pending_unsubscribe_channels.remove(channel)
1233 self.channels.pop(channel, None)
1235 if message_type in self.PUBLISH_MESSAGE_TYPES:
1236 # if there's a message handler, invoke it
1237 if message_type == "pmessage":
1238 handler = self.patterns.get(message["pattern"], None)
1239 else:
1240 handler = self.channels.get(message["channel"], None)
1241 if handler:
1242 if inspect.iscoroutinefunction(handler):
1243 await handler(message)
1244 else:
1245 handler(message)
1246 return None
1247 elif message_type != "pong":
1248 # this is a subscribe/unsubscribe message. ignore if we don't
1249 # want them
1250 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1251 return None
1253 return message
1255 async def run(
1256 self,
1257 *,
1258 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1259 poll_timeout: float = 1.0,
1260 pubsub=None,
1261 ) -> None:
1262 """Process pub/sub messages using registered callbacks.
1264 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1265 redis-py, but it is a coroutine. To launch it as a separate task, use
1266 ``asyncio.create_task``:
1268 >>> task = asyncio.create_task(pubsub.run())
1270 To shut it down, use asyncio cancellation:
1272 >>> task.cancel()
1273 >>> await task
1274 """
1275 for channel, handler in self.channels.items():
1276 if handler is None:
1277 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1278 for pattern, handler in self.patterns.items():
1279 if handler is None:
1280 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1282 await self.connect()
1283 while True:
1284 try:
1285 if pubsub is None:
1286 await self.get_message(
1287 ignore_subscribe_messages=True, timeout=poll_timeout
1288 )
1289 else:
1290 await pubsub.get_message(
1291 ignore_subscribe_messages=True, timeout=poll_timeout
1292 )
1293 except asyncio.CancelledError:
1294 raise
1295 except BaseException as e:
1296 if exception_handler is None:
1297 raise
1298 res = exception_handler(e, self)
1299 if inspect.isawaitable(res):
1300 await res
1301 # Ensure that other tasks on the event loop get a chance to run
1302 # if we didn't have to block for I/O anywhere.
1303 await asyncio.sleep(0)
1306class PubsubWorkerExceptionHandler(Protocol):
1307 def __call__(self, e: BaseException, pubsub: PubSub): ...
1310class AsyncPubsubWorkerExceptionHandler(Protocol):
1311 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1314PSWorkerThreadExcHandlerT = Union[
1315 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1316]
1319CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1320CommandStackT = List[CommandT]
1323class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1324 """
1325 Pipelines provide a way to transmit multiple commands to the Redis server
1326 in one transmission. This is convenient for batch processing, such as
1327 saving all the values in a list to Redis.
1329 All commands executed within a pipeline(when running in transactional mode,
1330 which is the default behavior) are wrapped with MULTI and EXEC
1331 calls. This guarantees all commands executed in the pipeline will be
1332 executed atomically.
1334 Any command raising an exception does *not* halt the execution of
1335 subsequent commands in the pipeline. Instead, the exception is caught
1336 and its instance is placed into the response list returned by execute().
1337 Code iterating over the response list should be able to deal with an
1338 instance of an exception as a potential value. In general, these will be
1339 ResponseError exceptions, such as those raised when issuing a command
1340 on a key of a different datatype.
1341 """
1343 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1345 def __init__(
1346 self,
1347 connection_pool: ConnectionPool,
1348 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1349 transaction: bool,
1350 shard_hint: Optional[str],
1351 ):
1352 self.connection_pool = connection_pool
1353 self.connection = None
1354 self.response_callbacks = response_callbacks
1355 self.is_transaction = transaction
1356 self.shard_hint = shard_hint
1357 self.watching = False
1358 self.command_stack: CommandStackT = []
1359 self.scripts: Set[Script] = set()
1360 self.explicit_transaction = False
1362 async def __aenter__(self: _RedisT) -> _RedisT:
1363 return self
1365 async def __aexit__(self, exc_type, exc_value, traceback):
1366 await self.reset()
1368 def __await__(self):
1369 return self._async_self().__await__()
1371 _DEL_MESSAGE = "Unclosed Pipeline client"
1373 def __len__(self):
1374 return len(self.command_stack)
1376 def __bool__(self):
1377 """Pipeline instances should always evaluate to True"""
1378 return True
1380 async def _async_self(self):
1381 return self
1383 async def reset(self):
1384 self.command_stack = []
1385 self.scripts = set()
1386 # make sure to reset the connection state in the event that we were
1387 # watching something
1388 if self.watching and self.connection:
1389 try:
1390 # call this manually since our unwatch or
1391 # immediate_execute_command methods can call reset()
1392 await self.connection.send_command("UNWATCH")
1393 await self.connection.read_response()
1394 except ConnectionError:
1395 # disconnect will also remove any previous WATCHes
1396 if self.connection:
1397 await self.connection.disconnect()
1398 # clean up the other instance attributes
1399 self.watching = False
1400 self.explicit_transaction = False
1401 # we can safely return the connection to the pool here since we're
1402 # sure we're no longer WATCHing anything
1403 if self.connection:
1404 await self.connection_pool.release(self.connection)
1405 self.connection = None
1407 async def aclose(self) -> None:
1408 """Alias for reset(), a standard method name for cleanup"""
1409 await self.reset()
1411 def multi(self):
1412 """
1413 Start a transactional block of the pipeline after WATCH commands
1414 are issued. End the transactional block with `execute`.
1415 """
1416 if self.explicit_transaction:
1417 raise RedisError("Cannot issue nested calls to MULTI")
1418 if self.command_stack:
1419 raise RedisError(
1420 "Commands without an initial WATCH have already been issued"
1421 )
1422 self.explicit_transaction = True
1424 def execute_command(
1425 self, *args, **kwargs
1426 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1427 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1428 return self.immediate_execute_command(*args, **kwargs)
1429 return self.pipeline_execute_command(*args, **kwargs)
1431 async def _disconnect_reset_raise_on_watching(
1432 self,
1433 conn: Connection,
1434 error: Exception,
1435 ):
1436 """
1437 Close the connection reset watching state and
1438 raise an exception if we were watching.
1440 The supported exceptions are already checked in the
1441 retry object so we don't need to do it here.
1443 After we disconnect the connection, it will try to reconnect and
1444 do a health check as part of the send_command logic(on connection level).
1445 """
1446 await conn.disconnect()
1447 # if we were already watching a variable, the watch is no longer
1448 # valid since this connection has died. raise a WatchError, which
1449 # indicates the user should retry this transaction.
1450 if self.watching:
1451 await self.reset()
1452 raise WatchError(
1453 f"A {type(error).__name__} occurred while watching one or more keys"
1454 )
1456 async def immediate_execute_command(self, *args, **options):
1457 """
1458 Execute a command immediately, but don't auto-retry on the supported
1459 errors for retry if we're already WATCHing a variable.
1460 Used when issuing WATCH or subsequent commands retrieving their values but before
1461 MULTI is called.
1462 """
1463 command_name = args[0]
1464 conn = self.connection
1465 # if this is the first call, we need a connection
1466 if not conn:
1467 conn = await self.connection_pool.get_connection()
1468 self.connection = conn
1470 return await conn.retry.call_with_retry(
1471 lambda: self._send_command_parse_response(
1472 conn, command_name, *args, **options
1473 ),
1474 lambda error: self._disconnect_reset_raise_on_watching(conn, error),
1475 )
1477 def pipeline_execute_command(self, *args, **options):
1478 """
1479 Stage a command to be executed when execute() is next called
1481 Returns the current Pipeline object back so commands can be
1482 chained together, such as:
1484 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1486 At some other point, you can then run: pipe.execute(),
1487 which will execute all commands queued in the pipe.
1488 """
1489 self.command_stack.append((args, options))
1490 return self
1492 async def _execute_transaction( # noqa: C901
1493 self, connection: Connection, commands: CommandStackT, raise_on_error
1494 ):
1495 pre: CommandT = (("MULTI",), {})
1496 post: CommandT = (("EXEC",), {})
1497 cmds = (pre, *commands, post)
1498 all_cmds = connection.pack_commands(
1499 args for args, options in cmds if EMPTY_RESPONSE not in options
1500 )
1501 await connection.send_packed_command(all_cmds)
1502 errors = []
1504 # parse off the response for MULTI
1505 # NOTE: we need to handle ResponseErrors here and continue
1506 # so that we read all the additional command messages from
1507 # the socket
1508 try:
1509 await self.parse_response(connection, "_")
1510 except ResponseError as err:
1511 errors.append((0, err))
1513 # and all the other commands
1514 for i, command in enumerate(commands):
1515 if EMPTY_RESPONSE in command[1]:
1516 errors.append((i, command[1][EMPTY_RESPONSE]))
1517 else:
1518 try:
1519 await self.parse_response(connection, "_")
1520 except ResponseError as err:
1521 self.annotate_exception(err, i + 1, command[0])
1522 errors.append((i, err))
1524 # parse the EXEC.
1525 try:
1526 response = await self.parse_response(connection, "_")
1527 except ExecAbortError as err:
1528 if errors:
1529 raise errors[0][1] from err
1530 raise
1532 # EXEC clears any watched keys
1533 self.watching = False
1535 if response is None:
1536 raise WatchError("Watched variable changed.") from None
1538 # put any parse errors into the response
1539 for i, e in errors:
1540 response.insert(i, e)
1542 if len(response) != len(commands):
1543 if self.connection:
1544 await self.connection.disconnect()
1545 raise ResponseError(
1546 "Wrong number of response items from pipeline execution"
1547 ) from None
1549 # find any errors in the response and raise if necessary
1550 if raise_on_error:
1551 self.raise_first_error(commands, response)
1553 # We have to run response callbacks manually
1554 data = []
1555 for r, cmd in zip(response, commands):
1556 if not isinstance(r, Exception):
1557 args, options = cmd
1558 command_name = args[0]
1560 # Remove keys entry, it needs only for cache.
1561 options.pop("keys", None)
1563 if command_name in self.response_callbacks:
1564 r = self.response_callbacks[command_name](r, **options)
1565 if inspect.isawaitable(r):
1566 r = await r
1567 data.append(r)
1568 return data
1570 async def _execute_pipeline(
1571 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1572 ):
1573 # build up all commands into a single request to increase network perf
1574 all_cmds = connection.pack_commands([args for args, _ in commands])
1575 await connection.send_packed_command(all_cmds)
1577 response = []
1578 for args, options in commands:
1579 try:
1580 response.append(
1581 await self.parse_response(connection, args[0], **options)
1582 )
1583 except ResponseError as e:
1584 response.append(e)
1586 if raise_on_error:
1587 self.raise_first_error(commands, response)
1588 return response
1590 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1591 for i, r in enumerate(response):
1592 if isinstance(r, ResponseError):
1593 self.annotate_exception(r, i + 1, commands[i][0])
1594 raise r
1596 def annotate_exception(
1597 self, exception: Exception, number: int, command: Iterable[object]
1598 ) -> None:
1599 cmd = " ".join(map(safe_str, command))
1600 msg = (
1601 f"Command # {number} ({truncate_text(cmd)}) "
1602 f"of pipeline caused error: {exception.args}"
1603 )
1604 exception.args = (msg,) + exception.args[1:]
1606 async def parse_response(
1607 self, connection: Connection, command_name: Union[str, bytes], **options
1608 ):
1609 result = await super().parse_response(connection, command_name, **options)
1610 if command_name in self.UNWATCH_COMMANDS:
1611 self.watching = False
1612 elif command_name == "WATCH":
1613 self.watching = True
1614 return result
1616 async def load_scripts(self):
1617 # make sure all scripts that are about to be run on this pipeline exist
1618 scripts = list(self.scripts)
1619 immediate = self.immediate_execute_command
1620 shas = [s.sha for s in scripts]
1621 # we can't use the normal script_* methods because they would just
1622 # get buffered in the pipeline.
1623 exists = await immediate("SCRIPT EXISTS", *shas)
1624 if not all(exists):
1625 for s, exist in zip(scripts, exists):
1626 if not exist:
1627 s.sha = await immediate("SCRIPT LOAD", s.script)
1629 async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
1630 """
1631 Close the connection, raise an exception if we were watching.
1633 The supported exceptions are already checked in the
1634 retry object so we don't need to do it here.
1636 After we disconnect the connection, it will try to reconnect and
1637 do a health check as part of the send_command logic(on connection level).
1638 """
1639 await conn.disconnect()
1640 # if we were watching a variable, the watch is no longer valid
1641 # since this connection has died. raise a WatchError, which
1642 # indicates the user should retry this transaction.
1643 if self.watching:
1644 raise WatchError(
1645 f"A {type(error).__name__} occurred while watching one or more keys"
1646 )
1648 async def execute(self, raise_on_error: bool = True) -> List[Any]:
1649 """Execute all the commands in the current pipeline"""
1650 stack = self.command_stack
1651 if not stack and not self.watching:
1652 return []
1653 if self.scripts:
1654 await self.load_scripts()
1655 if self.is_transaction or self.explicit_transaction:
1656 execute = self._execute_transaction
1657 else:
1658 execute = self._execute_pipeline
1660 conn = self.connection
1661 if not conn:
1662 conn = await self.connection_pool.get_connection()
1663 # assign to self.connection so reset() releases the connection
1664 # back to the pool after we're done
1665 self.connection = conn
1666 conn = cast(Connection, conn)
1668 try:
1669 return await conn.retry.call_with_retry(
1670 lambda: execute(conn, stack, raise_on_error),
1671 lambda error: self._disconnect_raise_on_watching(conn, error),
1672 )
1673 finally:
1674 await self.reset()
1676 async def discard(self):
1677 """Flushes all previously queued commands
1678 See: https://redis.io/commands/DISCARD
1679 """
1680 await self.execute_command("DISCARD")
1682 async def watch(self, *names: KeyT):
1683 """Watches the values at keys ``names``"""
1684 if self.explicit_transaction:
1685 raise RedisError("Cannot issue a WATCH after a MULTI")
1686 return await self.execute_command("WATCH", *names)
1688 async def unwatch(self):
1689 """Unwatches all previously specified keys"""
1690 return self.watching and await self.execute_command("UNWATCH") or True