Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/client.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import copy
3import inspect
4import re
5import warnings
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 AsyncIterator,
10 Awaitable,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 MutableMapping,
17 Optional,
18 Protocol,
19 Set,
20 Tuple,
21 Type,
22 TypedDict,
23 TypeVar,
24 Union,
25 cast,
26)
28from redis._parsers.helpers import (
29 _RedisCallbacks,
30 _RedisCallbacksRESP2,
31 _RedisCallbacksRESP3,
32 bool_ok,
33)
34from redis.asyncio.connection import (
35 Connection,
36 ConnectionPool,
37 SSLConnection,
38 UnixDomainSocketConnection,
39)
40from redis.asyncio.lock import Lock
41from redis.asyncio.retry import Retry
42from redis.backoff import ExponentialWithJitterBackoff
43from redis.client import (
44 EMPTY_RESPONSE,
45 NEVER_DECODE,
46 AbstractRedis,
47 CaseInsensitiveDict,
48)
49from redis.commands import (
50 AsyncCoreCommands,
51 AsyncRedisModuleCommands,
52 AsyncSentinelCommands,
53 list_or_args,
54)
55from redis.credentials import CredentialProvider
56from redis.event import (
57 AfterPooledConnectionsInstantiationEvent,
58 AfterPubSubConnectionInstantiationEvent,
59 AfterSingleConnectionInstantiationEvent,
60 ClientType,
61 EventDispatcher,
62)
63from redis.exceptions import (
64 ConnectionError,
65 ExecAbortError,
66 PubSubError,
67 RedisError,
68 ResponseError,
69 WatchError,
70)
71from redis.typing import ChannelT, EncodableT, KeyT
72from redis.utils import (
73 SSL_AVAILABLE,
74 _set_info_logger,
75 deprecated_args,
76 deprecated_function,
77 get_lib_version,
78 safe_str,
79 str_if_bytes,
80 truncate_text,
81)
83if TYPE_CHECKING and SSL_AVAILABLE:
84 from ssl import TLSVersion, VerifyFlags, VerifyMode
85else:
86 TLSVersion = None
87 VerifyMode = None
88 VerifyFlags = None
90PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
91_KeyT = TypeVar("_KeyT", bound=KeyT)
92_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
93_RedisT = TypeVar("_RedisT", bound="Redis")
94_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
95if TYPE_CHECKING:
96 from redis.commands.core import Script
99class ResponseCallbackProtocol(Protocol):
100 def __call__(self, response: Any, **kwargs): ...
103class AsyncResponseCallbackProtocol(Protocol):
104 async def __call__(self, response: Any, **kwargs): ...
107ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
110class Redis(
111 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
112):
113 """
114 Implementation of the Redis protocol.
116 This abstract class provides a Python interface to all Redis commands
117 and an implementation of the Redis protocol.
119 Pipelines derive from this, implementing how
120 the commands are sent and received to the Redis server. Based on
121 configuration, an instance will either use a ConnectionPool, or
122 Connection object to talk to redis.
123 """
125 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
127 @classmethod
128 def from_url(
129 cls,
130 url: str,
131 single_connection_client: bool = False,
132 auto_close_connection_pool: Optional[bool] = None,
133 **kwargs,
134 ):
135 """
136 Return a Redis client object configured from the given URL
138 For example::
140 redis://[[username]:[password]]@localhost:6379/0
141 rediss://[[username]:[password]]@localhost:6379/0
142 unix://[username@]/path/to/socket.sock?db=0[&password=password]
144 Three URL schemes are supported:
146 - `redis://` creates a TCP socket connection. See more at:
147 <https://www.iana.org/assignments/uri-schemes/prov/redis>
148 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
149 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
150 - ``unix://``: creates a Unix Domain Socket connection.
152 The username, password, hostname, path and all querystring values
153 are passed through urllib.parse.unquote in order to replace any
154 percent-encoded values with their corresponding characters.
156 There are several ways to specify a database number. The first value
157 found will be used:
159 1. A ``db`` querystring option, e.g. redis://localhost?db=0
161 2. If using the redis:// or rediss:// schemes, the path argument
162 of the url, e.g. redis://localhost/0
164 3. A ``db`` keyword argument to this function.
166 If none of these options are specified, the default db=0 is used.
168 All querystring options are cast to their appropriate Python types.
169 Boolean arguments can be specified with string values "True"/"False"
170 or "Yes"/"No". Values that cannot be properly cast cause a
171 ``ValueError`` to be raised. Once parsed, the querystring arguments
172 and keyword arguments are passed to the ``ConnectionPool``'s
173 class initializer. In the case of conflicting arguments, querystring
174 arguments always win.
176 """
177 connection_pool = ConnectionPool.from_url(url, **kwargs)
178 client = cls(
179 connection_pool=connection_pool,
180 single_connection_client=single_connection_client,
181 )
182 if auto_close_connection_pool is not None:
183 warnings.warn(
184 DeprecationWarning(
185 '"auto_close_connection_pool" is deprecated '
186 "since version 5.0.1. "
187 "Please create a ConnectionPool explicitly and "
188 "provide to the Redis() constructor instead."
189 )
190 )
191 else:
192 auto_close_connection_pool = True
193 client.auto_close_connection_pool = auto_close_connection_pool
194 return client
196 @classmethod
197 def from_pool(
198 cls: Type["Redis"],
199 connection_pool: ConnectionPool,
200 ) -> "Redis":
201 """
202 Return a Redis client from the given connection pool.
203 The Redis client will take ownership of the connection pool and
204 close it when the Redis client is closed.
205 """
206 client = cls(
207 connection_pool=connection_pool,
208 )
209 client.auto_close_connection_pool = True
210 return client
212 @deprecated_args(
213 args_to_warn=["retry_on_timeout"],
214 reason="TimeoutError is included by default.",
215 version="6.0.0",
216 )
217 def __init__(
218 self,
219 *,
220 host: str = "localhost",
221 port: int = 6379,
222 db: Union[str, int] = 0,
223 password: Optional[str] = None,
224 socket_timeout: Optional[float] = None,
225 socket_connect_timeout: Optional[float] = None,
226 socket_keepalive: Optional[bool] = None,
227 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
228 connection_pool: Optional[ConnectionPool] = None,
229 unix_socket_path: Optional[str] = None,
230 encoding: str = "utf-8",
231 encoding_errors: str = "strict",
232 decode_responses: bool = False,
233 retry_on_timeout: bool = False,
234 retry: Retry = Retry(
235 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
236 ),
237 retry_on_error: Optional[list] = None,
238 ssl: bool = False,
239 ssl_keyfile: Optional[str] = None,
240 ssl_certfile: Optional[str] = None,
241 ssl_cert_reqs: Union[str, VerifyMode] = "required",
242 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
243 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
244 ssl_ca_certs: Optional[str] = None,
245 ssl_ca_data: Optional[str] = None,
246 ssl_check_hostname: bool = True,
247 ssl_min_version: Optional[TLSVersion] = None,
248 ssl_ciphers: Optional[str] = None,
249 max_connections: Optional[int] = None,
250 single_connection_client: bool = False,
251 health_check_interval: int = 0,
252 client_name: Optional[str] = None,
253 lib_name: Optional[str] = "redis-py",
254 lib_version: Optional[str] = get_lib_version(),
255 username: Optional[str] = None,
256 auto_close_connection_pool: Optional[bool] = None,
257 redis_connect_func=None,
258 credential_provider: Optional[CredentialProvider] = None,
259 protocol: Optional[int] = 2,
260 event_dispatcher: Optional[EventDispatcher] = None,
261 ):
262 """
263 Initialize a new Redis client.
265 To specify a retry policy for specific errors, you have two options:
267 1. Set the `retry_on_error` to a list of the error/s to retry on, and
268 you can also set `retry` to a valid `Retry` object(in case the default
269 one is not appropriate) - with this approach the retries will be triggered
270 on the default errors specified in the Retry object enriched with the
271 errors specified in `retry_on_error`.
273 2. Define a `Retry` object with configured 'supported_errors' and set
274 it to the `retry` parameter - with this approach you completely redefine
275 the errors on which retries will happen.
277 `retry_on_timeout` is deprecated - please include the TimeoutError
278 either in the Retry object or in the `retry_on_error` list.
280 When 'connection_pool' is provided - the retry configuration of the
281 provided pool will be used.
282 """
283 kwargs: Dict[str, Any]
284 if event_dispatcher is None:
285 self._event_dispatcher = EventDispatcher()
286 else:
287 self._event_dispatcher = event_dispatcher
288 # auto_close_connection_pool only has an effect if connection_pool is
289 # None. It is assumed that if connection_pool is not None, the user
290 # wants to manage the connection pool themselves.
291 if auto_close_connection_pool is not None:
292 warnings.warn(
293 DeprecationWarning(
294 '"auto_close_connection_pool" is deprecated '
295 "since version 5.0.1. "
296 "Please create a ConnectionPool explicitly and "
297 "provide to the Redis() constructor instead."
298 )
299 )
300 else:
301 auto_close_connection_pool = True
303 if not connection_pool:
304 # Create internal connection pool, expected to be closed by Redis instance
305 if not retry_on_error:
306 retry_on_error = []
307 kwargs = {
308 "db": db,
309 "username": username,
310 "password": password,
311 "credential_provider": credential_provider,
312 "socket_timeout": socket_timeout,
313 "encoding": encoding,
314 "encoding_errors": encoding_errors,
315 "decode_responses": decode_responses,
316 "retry_on_error": retry_on_error,
317 "retry": copy.deepcopy(retry),
318 "max_connections": max_connections,
319 "health_check_interval": health_check_interval,
320 "client_name": client_name,
321 "lib_name": lib_name,
322 "lib_version": lib_version,
323 "redis_connect_func": redis_connect_func,
324 "protocol": protocol,
325 }
326 # based on input, setup appropriate connection args
327 if unix_socket_path is not None:
328 kwargs.update(
329 {
330 "path": unix_socket_path,
331 "connection_class": UnixDomainSocketConnection,
332 }
333 )
334 else:
335 # TCP specific options
336 kwargs.update(
337 {
338 "host": host,
339 "port": port,
340 "socket_connect_timeout": socket_connect_timeout,
341 "socket_keepalive": socket_keepalive,
342 "socket_keepalive_options": socket_keepalive_options,
343 }
344 )
346 if ssl:
347 kwargs.update(
348 {
349 "connection_class": SSLConnection,
350 "ssl_keyfile": ssl_keyfile,
351 "ssl_certfile": ssl_certfile,
352 "ssl_cert_reqs": ssl_cert_reqs,
353 "ssl_include_verify_flags": ssl_include_verify_flags,
354 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
355 "ssl_ca_certs": ssl_ca_certs,
356 "ssl_ca_data": ssl_ca_data,
357 "ssl_check_hostname": ssl_check_hostname,
358 "ssl_min_version": ssl_min_version,
359 "ssl_ciphers": ssl_ciphers,
360 }
361 )
362 # This arg only used if no pool is passed in
363 self.auto_close_connection_pool = auto_close_connection_pool
364 connection_pool = ConnectionPool(**kwargs)
365 self._event_dispatcher.dispatch(
366 AfterPooledConnectionsInstantiationEvent(
367 [connection_pool], ClientType.ASYNC, credential_provider
368 )
369 )
370 else:
371 # If a pool is passed in, do not close it
372 self.auto_close_connection_pool = False
373 self._event_dispatcher.dispatch(
374 AfterPooledConnectionsInstantiationEvent(
375 [connection_pool], ClientType.ASYNC, credential_provider
376 )
377 )
379 self.connection_pool = connection_pool
380 self.single_connection_client = single_connection_client
381 self.connection: Optional[Connection] = None
383 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
385 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
386 self.response_callbacks.update(_RedisCallbacksRESP3)
387 else:
388 self.response_callbacks.update(_RedisCallbacksRESP2)
390 # If using a single connection client, we need to lock creation-of and use-of
391 # the client in order to avoid race conditions such as using asyncio.gather
392 # on a set of redis commands
393 self._single_conn_lock = asyncio.Lock()
395 # When used as an async context manager, we need to increment and decrement
396 # a usage counter so that we can close the connection pool when no one is
397 # using the client.
398 self._usage_counter = 0
399 self._usage_lock = asyncio.Lock()
401 def __repr__(self):
402 return (
403 f"<{self.__class__.__module__}.{self.__class__.__name__}"
404 f"({self.connection_pool!r})>"
405 )
407 def __await__(self):
408 return self.initialize().__await__()
410 async def initialize(self: _RedisT) -> _RedisT:
411 if self.single_connection_client:
412 async with self._single_conn_lock:
413 if self.connection is None:
414 self.connection = await self.connection_pool.get_connection()
416 self._event_dispatcher.dispatch(
417 AfterSingleConnectionInstantiationEvent(
418 self.connection, ClientType.ASYNC, self._single_conn_lock
419 )
420 )
421 return self
423 def set_response_callback(self, command: str, callback: ResponseCallbackT):
424 """Set a custom Response Callback"""
425 self.response_callbacks[command] = callback
427 def get_encoder(self):
428 """Get the connection pool's encoder"""
429 return self.connection_pool.get_encoder()
431 def get_connection_kwargs(self):
432 """Get the connection's key-word arguments"""
433 return self.connection_pool.connection_kwargs
435 def get_retry(self) -> Optional[Retry]:
436 return self.get_connection_kwargs().get("retry")
438 def set_retry(self, retry: Retry) -> None:
439 self.get_connection_kwargs().update({"retry": retry})
440 self.connection_pool.set_retry(retry)
442 def load_external_module(self, funcname, func):
443 """
444 This function can be used to add externally defined redis modules,
445 and their namespaces to the redis client.
447 funcname - A string containing the name of the function to create
448 func - The function, being added to this class.
450 ex: Assume that one has a custom redis module named foomod that
451 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
452 To load function functions into this namespace:
454 from redis import Redis
455 from foomodule import F
456 r = Redis()
457 r.load_external_module("foo", F)
458 r.foo().dothing('your', 'arguments')
460 For a concrete example see the reimport of the redisjson module in
461 tests/test_connection.py::test_loading_external_modules
462 """
463 setattr(self, funcname, func)
465 def pipeline(
466 self, transaction: bool = True, shard_hint: Optional[str] = None
467 ) -> "Pipeline":
468 """
469 Return a new pipeline object that can queue multiple commands for
470 later execution. ``transaction`` indicates whether all commands
471 should be executed atomically. Apart from making a group of operations
472 atomic, pipelines are useful for reducing the back-and-forth overhead
473 between the client and server.
474 """
475 return Pipeline(
476 self.connection_pool, self.response_callbacks, transaction, shard_hint
477 )
479 async def transaction(
480 self,
481 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
482 *watches: KeyT,
483 shard_hint: Optional[str] = None,
484 value_from_callable: bool = False,
485 watch_delay: Optional[float] = None,
486 ):
487 """
488 Convenience method for executing the callable `func` as a transaction
489 while watching all keys specified in `watches`. The 'func' callable
490 should expect a single argument which is a Pipeline object.
491 """
492 pipe: Pipeline
493 async with self.pipeline(True, shard_hint) as pipe:
494 while True:
495 try:
496 if watches:
497 await pipe.watch(*watches)
498 func_value = func(pipe)
499 if inspect.isawaitable(func_value):
500 func_value = await func_value
501 exec_value = await pipe.execute()
502 return func_value if value_from_callable else exec_value
503 except WatchError:
504 if watch_delay is not None and watch_delay > 0:
505 await asyncio.sleep(watch_delay)
506 continue
508 def lock(
509 self,
510 name: KeyT,
511 timeout: Optional[float] = None,
512 sleep: float = 0.1,
513 blocking: bool = True,
514 blocking_timeout: Optional[float] = None,
515 lock_class: Optional[Type[Lock]] = None,
516 thread_local: bool = True,
517 raise_on_release_error: bool = True,
518 ) -> Lock:
519 """
520 Return a new Lock object using key ``name`` that mimics
521 the behavior of threading.Lock.
523 If specified, ``timeout`` indicates a maximum life for the lock.
524 By default, it will remain locked until release() is called.
526 ``sleep`` indicates the amount of time to sleep per loop iteration
527 when the lock is in blocking mode and another client is currently
528 holding the lock.
530 ``blocking`` indicates whether calling ``acquire`` should block until
531 the lock has been acquired or to fail immediately, causing ``acquire``
532 to return False and the lock not being acquired. Defaults to True.
533 Note this value can be overridden by passing a ``blocking``
534 argument to ``acquire``.
536 ``blocking_timeout`` indicates the maximum amount of time in seconds to
537 spend trying to acquire the lock. A value of ``None`` indicates
538 continue trying forever. ``blocking_timeout`` can be specified as a
539 float or integer, both representing the number of seconds to wait.
541 ``lock_class`` forces the specified lock implementation. Note that as
542 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
543 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
544 you have created your own custom lock class.
546 ``thread_local`` indicates whether the lock token is placed in
547 thread-local storage. By default, the token is placed in thread local
548 storage so that a thread only sees its token, not a token set by
549 another thread. Consider the following timeline:
551 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
552 thread-1 sets the token to "abc"
553 time: 1, thread-2 blocks trying to acquire `my-lock` using the
554 Lock instance.
555 time: 5, thread-1 has not yet completed. redis expires the lock
556 key.
557 time: 5, thread-2 acquired `my-lock` now that it's available.
558 thread-2 sets the token to "xyz"
559 time: 6, thread-1 finishes its work and calls release(). if the
560 token is *not* stored in thread local storage, then
561 thread-1 would see the token value as "xyz" and would be
562 able to successfully release the thread-2's lock.
564 ``raise_on_release_error`` indicates whether to raise an exception when
565 the lock is no longer owned when exiting the context manager. By default,
566 this is True, meaning an exception will be raised. If False, the warning
567 will be logged and the exception will be suppressed.
569 In some use cases it's necessary to disable thread local storage. For
570 example, if you have code where one thread acquires a lock and passes
571 that lock instance to a worker thread to release later. If thread
572 local storage isn't disabled in this case, the worker thread won't see
573 the token set by the thread that acquired the lock. Our assumption
574 is that these cases aren't common and as such default to using
575 thread local storage."""
576 if lock_class is None:
577 lock_class = Lock
578 return lock_class(
579 self,
580 name,
581 timeout=timeout,
582 sleep=sleep,
583 blocking=blocking,
584 blocking_timeout=blocking_timeout,
585 thread_local=thread_local,
586 raise_on_release_error=raise_on_release_error,
587 )
589 def pubsub(self, **kwargs) -> "PubSub":
590 """
591 Return a Publish/Subscribe object. With this object, you can
592 subscribe to channels and listen for messages that get published to
593 them.
594 """
595 return PubSub(
596 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
597 )
599 def monitor(self) -> "Monitor":
600 return Monitor(self.connection_pool)
602 def client(self) -> "Redis":
603 return self.__class__(
604 connection_pool=self.connection_pool, single_connection_client=True
605 )
607 async def __aenter__(self: _RedisT) -> _RedisT:
608 """
609 Async context manager entry. Increments a usage counter so that the
610 connection pool is only closed (via aclose()) when no context is using
611 the client.
612 """
613 await self._increment_usage()
614 try:
615 # Initialize the client (i.e. establish connection, etc.)
616 return await self.initialize()
617 except Exception:
618 # If initialization fails, decrement the counter to keep it in sync
619 await self._decrement_usage()
620 raise
622 async def _increment_usage(self) -> int:
623 """
624 Helper coroutine to increment the usage counter while holding the lock.
625 Returns the new value of the usage counter.
626 """
627 async with self._usage_lock:
628 self._usage_counter += 1
629 return self._usage_counter
631 async def _decrement_usage(self) -> int:
632 """
633 Helper coroutine to decrement the usage counter while holding the lock.
634 Returns the new value of the usage counter.
635 """
636 async with self._usage_lock:
637 self._usage_counter -= 1
638 return self._usage_counter
640 async def __aexit__(self, exc_type, exc_value, traceback):
641 """
642 Async context manager exit. Decrements a usage counter. If this is the
643 last exit (counter becomes zero), the client closes its connection pool.
644 """
645 current_usage = await asyncio.shield(self._decrement_usage())
646 if current_usage == 0:
647 # This was the last active context, so disconnect the pool.
648 await asyncio.shield(self.aclose())
650 _DEL_MESSAGE = "Unclosed Redis client"
652 # passing _warnings and _grl as argument default since they may be gone
653 # by the time __del__ is called at shutdown
654 def __del__(
655 self,
656 _warn: Any = warnings.warn,
657 _grl: Any = asyncio.get_running_loop,
658 ) -> None:
659 if hasattr(self, "connection") and (self.connection is not None):
660 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
661 try:
662 context = {"client": self, "message": self._DEL_MESSAGE}
663 _grl().call_exception_handler(context)
664 except RuntimeError:
665 pass
666 self.connection._close()
668 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
669 """
670 Closes Redis client connection
672 Args:
673 close_connection_pool:
674 decides whether to close the connection pool used by this Redis client,
675 overriding Redis.auto_close_connection_pool.
676 By default, let Redis.auto_close_connection_pool decide
677 whether to close the connection pool.
678 """
679 conn = self.connection
680 if conn:
681 self.connection = None
682 await self.connection_pool.release(conn)
683 if close_connection_pool or (
684 close_connection_pool is None and self.auto_close_connection_pool
685 ):
686 await self.connection_pool.disconnect()
688 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
689 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
690 """
691 Alias for aclose(), for backwards compatibility
692 """
693 await self.aclose(close_connection_pool)
695 async def _send_command_parse_response(self, conn, command_name, *args, **options):
696 """
697 Send a command and parse the response
698 """
699 await conn.send_command(*args)
700 return await self.parse_response(conn, command_name, **options)
702 async def _close_connection(self, conn: Connection):
703 """
704 Close the connection before retrying.
706 The supported exceptions are already checked in the
707 retry object so we don't need to do it here.
709 After we disconnect the connection, it will try to reconnect and
710 do a health check as part of the send_command logic(on connection level).
711 """
712 await conn.disconnect()
714 # COMMAND EXECUTION AND PROTOCOL PARSING
715 async def execute_command(self, *args, **options):
716 """Execute a command and return a parsed response"""
717 await self.initialize()
718 pool = self.connection_pool
719 command_name = args[0]
720 conn = self.connection or await pool.get_connection()
722 if self.single_connection_client:
723 await self._single_conn_lock.acquire()
724 try:
725 return await conn.retry.call_with_retry(
726 lambda: self._send_command_parse_response(
727 conn, command_name, *args, **options
728 ),
729 lambda _: self._close_connection(conn),
730 )
731 finally:
732 if self.single_connection_client:
733 self._single_conn_lock.release()
734 if not self.connection:
735 await pool.release(conn)
737 async def parse_response(
738 self, connection: Connection, command_name: Union[str, bytes], **options
739 ):
740 """Parses a response from the Redis server"""
741 try:
742 if NEVER_DECODE in options:
743 response = await connection.read_response(disable_decoding=True)
744 options.pop(NEVER_DECODE)
745 else:
746 response = await connection.read_response()
747 except ResponseError:
748 if EMPTY_RESPONSE in options:
749 return options[EMPTY_RESPONSE]
750 raise
752 if EMPTY_RESPONSE in options:
753 options.pop(EMPTY_RESPONSE)
755 # Remove keys entry, it needs only for cache.
756 options.pop("keys", None)
758 if command_name in self.response_callbacks:
759 # Mypy bug: https://github.com/python/mypy/issues/10977
760 command_name = cast(str, command_name)
761 retval = self.response_callbacks[command_name](response, **options)
762 return await retval if inspect.isawaitable(retval) else retval
763 return response
766StrictRedis = Redis
769class MonitorCommandInfo(TypedDict):
770 time: float
771 db: int
772 client_address: str
773 client_port: str
774 client_type: str
775 command: str
778class Monitor:
779 """
780 Monitor is useful for handling the MONITOR command to the redis server.
781 next_command() method returns one command from monitor
782 listen() method yields commands from monitor.
783 """
785 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
786 command_re = re.compile(r'"(.*?)(?<!\\)"')
788 def __init__(self, connection_pool: ConnectionPool):
789 self.connection_pool = connection_pool
790 self.connection: Optional[Connection] = None
792 async def connect(self):
793 if self.connection is None:
794 self.connection = await self.connection_pool.get_connection()
796 async def __aenter__(self):
797 await self.connect()
798 await self.connection.send_command("MONITOR")
799 # check that monitor returns 'OK', but don't return it to user
800 response = await self.connection.read_response()
801 if not bool_ok(response):
802 raise RedisError(f"MONITOR failed: {response}")
803 return self
805 async def __aexit__(self, *args):
806 await self.connection.disconnect()
807 await self.connection_pool.release(self.connection)
809 async def next_command(self) -> MonitorCommandInfo:
810 """Parse the response from a monitor command"""
811 await self.connect()
812 response = await self.connection.read_response()
813 if isinstance(response, bytes):
814 response = self.connection.encoder.decode(response, force=True)
815 command_time, command_data = response.split(" ", 1)
816 m = self.monitor_re.match(command_data)
817 db_id, client_info, command = m.groups()
818 command = " ".join(self.command_re.findall(command))
819 # Redis escapes double quotes because each piece of the command
820 # string is surrounded by double quotes. We don't have that
821 # requirement so remove the escaping and leave the quote.
822 command = command.replace('\\"', '"')
824 if client_info == "lua":
825 client_address = "lua"
826 client_port = ""
827 client_type = "lua"
828 elif client_info.startswith("unix"):
829 client_address = "unix"
830 client_port = client_info[5:]
831 client_type = "unix"
832 else:
833 # use rsplit as ipv6 addresses contain colons
834 client_address, client_port = client_info.rsplit(":", 1)
835 client_type = "tcp"
836 return {
837 "time": float(command_time),
838 "db": int(db_id),
839 "client_address": client_address,
840 "client_port": client_port,
841 "client_type": client_type,
842 "command": command,
843 }
845 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
846 """Listen for commands coming to the server."""
847 while True:
848 yield await self.next_command()
851class PubSub:
852 """
853 PubSub provides publish, subscribe and listen support to Redis channels.
855 After subscribing to one or more channels, the listen() method will block
856 until a message arrives on one of the subscribed channels. That message
857 will be returned and it's safe to start listening again.
858 """
860 PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
861 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
862 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
864 def __init__(
865 self,
866 connection_pool: ConnectionPool,
867 shard_hint: Optional[str] = None,
868 ignore_subscribe_messages: bool = False,
869 encoder=None,
870 push_handler_func: Optional[Callable] = None,
871 event_dispatcher: Optional["EventDispatcher"] = None,
872 ):
873 if event_dispatcher is None:
874 self._event_dispatcher = EventDispatcher()
875 else:
876 self._event_dispatcher = event_dispatcher
877 self.connection_pool = connection_pool
878 self.shard_hint = shard_hint
879 self.ignore_subscribe_messages = ignore_subscribe_messages
880 self.connection = None
881 # we need to know the encoding options for this connection in order
882 # to lookup channel and pattern names for callback handlers.
883 self.encoder = encoder
884 self.push_handler_func = push_handler_func
885 if self.encoder is None:
886 self.encoder = self.connection_pool.get_encoder()
887 if self.encoder.decode_responses:
888 self.health_check_response = [
889 ["pong", self.HEALTH_CHECK_MESSAGE],
890 self.HEALTH_CHECK_MESSAGE,
891 ]
892 else:
893 self.health_check_response = [
894 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
895 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
896 ]
897 if self.push_handler_func is None:
898 _set_info_logger()
899 self.channels = {}
900 self.pending_unsubscribe_channels = set()
901 self.patterns = {}
902 self.pending_unsubscribe_patterns = set()
903 self._lock = asyncio.Lock()
905 async def __aenter__(self):
906 return self
908 async def __aexit__(self, exc_type, exc_value, traceback):
909 await self.aclose()
911 def __del__(self):
912 if self.connection:
913 self.connection.deregister_connect_callback(self.on_connect)
915 async def aclose(self):
916 # In case a connection property does not yet exist
917 # (due to a crash earlier in the Redis() constructor), return
918 # immediately as there is nothing to clean-up.
919 if not hasattr(self, "connection"):
920 return
921 async with self._lock:
922 if self.connection:
923 await self.connection.disconnect()
924 self.connection.deregister_connect_callback(self.on_connect)
925 await self.connection_pool.release(self.connection)
926 self.connection = None
927 self.channels = {}
928 self.pending_unsubscribe_channels = set()
929 self.patterns = {}
930 self.pending_unsubscribe_patterns = set()
932 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
933 async def close(self) -> None:
934 """Alias for aclose(), for backwards compatibility"""
935 await self.aclose()
937 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
938 async def reset(self) -> None:
939 """Alias for aclose(), for backwards compatibility"""
940 await self.aclose()
942 async def on_connect(self, connection: Connection):
943 """Re-subscribe to any channels and patterns previously subscribed to"""
944 # NOTE: for python3, we can't pass bytestrings as keyword arguments
945 # so we need to decode channel/pattern names back to unicode strings
946 # before passing them to [p]subscribe.
947 self.pending_unsubscribe_channels.clear()
948 self.pending_unsubscribe_patterns.clear()
949 if self.channels:
950 channels = {}
951 for k, v in self.channels.items():
952 channels[self.encoder.decode(k, force=True)] = v
953 await self.subscribe(**channels)
954 if self.patterns:
955 patterns = {}
956 for k, v in self.patterns.items():
957 patterns[self.encoder.decode(k, force=True)] = v
958 await self.psubscribe(**patterns)
960 @property
961 def subscribed(self):
962 """Indicates if there are subscriptions to any channels or patterns"""
963 return bool(self.channels or self.patterns)
965 async def execute_command(self, *args: EncodableT):
966 """Execute a publish/subscribe command"""
968 # NOTE: don't parse the response in this function -- it could pull a
969 # legitimate message off the stack if the connection is already
970 # subscribed to one or more channels
972 await self.connect()
973 connection = self.connection
974 kwargs = {"check_health": not self.subscribed}
975 await self._execute(connection, connection.send_command, *args, **kwargs)
977 async def connect(self):
978 """
979 Ensure that the PubSub is connected
980 """
981 if self.connection is None:
982 self.connection = await self.connection_pool.get_connection()
983 # register a callback that re-subscribes to any channels we
984 # were listening to when we were disconnected
985 self.connection.register_connect_callback(self.on_connect)
986 else:
987 await self.connection.connect()
988 if self.push_handler_func is not None:
989 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
991 self._event_dispatcher.dispatch(
992 AfterPubSubConnectionInstantiationEvent(
993 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
994 )
995 )
997 async def _reconnect(self, conn):
998 """
999 Try to reconnect
1000 """
1001 await conn.disconnect()
1002 await conn.connect()
1004 async def _execute(self, conn, command, *args, **kwargs):
1005 """
1006 Connect manually upon disconnection. If the Redis server is down,
1007 this will fail and raise a ConnectionError as desired.
1008 After reconnection, the ``on_connect`` callback should have been
1009 called by the # connection to resubscribe us to any channels and
1010 patterns we were previously listening to
1011 """
1012 return await conn.retry.call_with_retry(
1013 lambda: command(*args, **kwargs),
1014 lambda _: self._reconnect(conn),
1015 )
1017 async def parse_response(self, block: bool = True, timeout: float = 0):
1018 """Parse the response from a publish/subscribe command"""
1019 conn = self.connection
1020 if conn is None:
1021 raise RuntimeError(
1022 "pubsub connection not set: "
1023 "did you forget to call subscribe() or psubscribe()?"
1024 )
1026 await self.check_health()
1028 if not conn.is_connected:
1029 await conn.connect()
1031 read_timeout = None if block else timeout
1032 response = await self._execute(
1033 conn,
1034 conn.read_response,
1035 timeout=read_timeout,
1036 disconnect_on_error=False,
1037 push_request=True,
1038 )
1040 if conn.health_check_interval and response in self.health_check_response:
1041 # ignore the health check message as user might not expect it
1042 return None
1043 return response
1045 async def check_health(self):
1046 conn = self.connection
1047 if conn is None:
1048 raise RuntimeError(
1049 "pubsub connection not set: "
1050 "did you forget to call subscribe() or psubscribe()?"
1051 )
1053 if (
1054 conn.health_check_interval
1055 and asyncio.get_running_loop().time() > conn.next_health_check
1056 ):
1057 await conn.send_command(
1058 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1059 )
1061 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1062 """
1063 normalize channel/pattern names to be either bytes or strings
1064 based on whether responses are automatically decoded. this saves us
1065 from coercing the value for each message coming in.
1066 """
1067 encode = self.encoder.encode
1068 decode = self.encoder.decode
1069 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1071 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1072 """
1073 Subscribe to channel patterns. Patterns supplied as keyword arguments
1074 expect a pattern name as the key and a callable as the value. A
1075 pattern's callable will be invoked automatically when a message is
1076 received on that pattern rather than producing a message via
1077 ``listen()``.
1078 """
1079 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1080 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1081 # Mypy bug: https://github.com/python/mypy/issues/10970
1082 new_patterns.update(kwargs) # type: ignore[arg-type]
1083 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1084 # update the patterns dict AFTER we send the command. we don't want to
1085 # subscribe twice to these patterns, once for the command and again
1086 # for the reconnection.
1087 new_patterns = self._normalize_keys(new_patterns)
1088 self.patterns.update(new_patterns)
1089 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1090 return ret_val
1092 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1093 """
1094 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1095 all patterns.
1096 """
1097 patterns: Iterable[ChannelT]
1098 if args:
1099 parsed_args = list_or_args((args[0],), args[1:])
1100 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1101 else:
1102 parsed_args = []
1103 patterns = self.patterns
1104 self.pending_unsubscribe_patterns.update(patterns)
1105 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1107 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1108 """
1109 Subscribe to channels. Channels supplied as keyword arguments expect
1110 a channel name as the key and a callable as the value. A channel's
1111 callable will be invoked automatically when a message is received on
1112 that channel rather than producing a message via ``listen()`` or
1113 ``get_message()``.
1114 """
1115 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1116 new_channels = dict.fromkeys(parsed_args)
1117 # Mypy bug: https://github.com/python/mypy/issues/10970
1118 new_channels.update(kwargs) # type: ignore[arg-type]
1119 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1120 # update the channels dict AFTER we send the command. we don't want to
1121 # subscribe twice to these channels, once for the command and again
1122 # for the reconnection.
1123 new_channels = self._normalize_keys(new_channels)
1124 self.channels.update(new_channels)
1125 self.pending_unsubscribe_channels.difference_update(new_channels)
1126 return ret_val
1128 def unsubscribe(self, *args) -> Awaitable:
1129 """
1130 Unsubscribe from the supplied channels. If empty, unsubscribe from
1131 all channels
1132 """
1133 if args:
1134 parsed_args = list_or_args(args[0], args[1:])
1135 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1136 else:
1137 parsed_args = []
1138 channels = self.channels
1139 self.pending_unsubscribe_channels.update(channels)
1140 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1142 async def listen(self) -> AsyncIterator:
1143 """Listen for messages on channels this client has been subscribed to"""
1144 while self.subscribed:
1145 response = await self.handle_message(await self.parse_response(block=True))
1146 if response is not None:
1147 yield response
1149 async def get_message(
1150 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1151 ):
1152 """
1153 Get the next message if one is available, otherwise None.
1155 If timeout is specified, the system will wait for `timeout` seconds
1156 before returning. Timeout should be specified as a floating point
1157 number or None to wait indefinitely.
1158 """
1159 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1160 if response:
1161 return await self.handle_message(response, ignore_subscribe_messages)
1162 return None
1164 def ping(self, message=None) -> Awaitable[bool]:
1165 """
1166 Ping the Redis server to test connectivity.
1168 Sends a PING command to the Redis server and returns True if the server
1169 responds with "PONG".
1170 """
1171 args = ["PING", message] if message is not None else ["PING"]
1172 return self.execute_command(*args)
1174 async def handle_message(self, response, ignore_subscribe_messages=False):
1175 """
1176 Parses a pub/sub message. If the channel or pattern was subscribed to
1177 with a message handler, the handler is invoked instead of a parsed
1178 message being returned.
1179 """
1180 if response is None:
1181 return None
1182 if isinstance(response, bytes):
1183 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1184 message_type = str_if_bytes(response[0])
1185 if message_type == "pmessage":
1186 message = {
1187 "type": message_type,
1188 "pattern": response[1],
1189 "channel": response[2],
1190 "data": response[3],
1191 }
1192 elif message_type == "pong":
1193 message = {
1194 "type": message_type,
1195 "pattern": None,
1196 "channel": None,
1197 "data": response[1],
1198 }
1199 else:
1200 message = {
1201 "type": message_type,
1202 "pattern": None,
1203 "channel": response[1],
1204 "data": response[2],
1205 }
1207 # if this is an unsubscribe message, remove it from memory
1208 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1209 if message_type == "punsubscribe":
1210 pattern = response[1]
1211 if pattern in self.pending_unsubscribe_patterns:
1212 self.pending_unsubscribe_patterns.remove(pattern)
1213 self.patterns.pop(pattern, None)
1214 else:
1215 channel = response[1]
1216 if channel in self.pending_unsubscribe_channels:
1217 self.pending_unsubscribe_channels.remove(channel)
1218 self.channels.pop(channel, None)
1220 if message_type in self.PUBLISH_MESSAGE_TYPES:
1221 # if there's a message handler, invoke it
1222 if message_type == "pmessage":
1223 handler = self.patterns.get(message["pattern"], None)
1224 else:
1225 handler = self.channels.get(message["channel"], None)
1226 if handler:
1227 if inspect.iscoroutinefunction(handler):
1228 await handler(message)
1229 else:
1230 handler(message)
1231 return None
1232 elif message_type != "pong":
1233 # this is a subscribe/unsubscribe message. ignore if we don't
1234 # want them
1235 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1236 return None
1238 return message
1240 async def run(
1241 self,
1242 *,
1243 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1244 poll_timeout: float = 1.0,
1245 pubsub=None,
1246 ) -> None:
1247 """Process pub/sub messages using registered callbacks.
1249 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1250 redis-py, but it is a coroutine. To launch it as a separate task, use
1251 ``asyncio.create_task``:
1253 >>> task = asyncio.create_task(pubsub.run())
1255 To shut it down, use asyncio cancellation:
1257 >>> task.cancel()
1258 >>> await task
1259 """
1260 for channel, handler in self.channels.items():
1261 if handler is None:
1262 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1263 for pattern, handler in self.patterns.items():
1264 if handler is None:
1265 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1267 await self.connect()
1268 while True:
1269 try:
1270 if pubsub is None:
1271 await self.get_message(
1272 ignore_subscribe_messages=True, timeout=poll_timeout
1273 )
1274 else:
1275 await pubsub.get_message(
1276 ignore_subscribe_messages=True, timeout=poll_timeout
1277 )
1278 except asyncio.CancelledError:
1279 raise
1280 except BaseException as e:
1281 if exception_handler is None:
1282 raise
1283 res = exception_handler(e, self)
1284 if inspect.isawaitable(res):
1285 await res
1286 # Ensure that other tasks on the event loop get a chance to run
1287 # if we didn't have to block for I/O anywhere.
1288 await asyncio.sleep(0)
1291class PubsubWorkerExceptionHandler(Protocol):
1292 def __call__(self, e: BaseException, pubsub: PubSub): ...
1295class AsyncPubsubWorkerExceptionHandler(Protocol):
1296 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1299PSWorkerThreadExcHandlerT = Union[
1300 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1301]
1304CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1305CommandStackT = List[CommandT]
1308class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1309 """
1310 Pipelines provide a way to transmit multiple commands to the Redis server
1311 in one transmission. This is convenient for batch processing, such as
1312 saving all the values in a list to Redis.
1314 All commands executed within a pipeline(when running in transactional mode,
1315 which is the default behavior) are wrapped with MULTI and EXEC
1316 calls. This guarantees all commands executed in the pipeline will be
1317 executed atomically.
1319 Any command raising an exception does *not* halt the execution of
1320 subsequent commands in the pipeline. Instead, the exception is caught
1321 and its instance is placed into the response list returned by execute().
1322 Code iterating over the response list should be able to deal with an
1323 instance of an exception as a potential value. In general, these will be
1324 ResponseError exceptions, such as those raised when issuing a command
1325 on a key of a different datatype.
1326 """
1328 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1330 def __init__(
1331 self,
1332 connection_pool: ConnectionPool,
1333 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1334 transaction: bool,
1335 shard_hint: Optional[str],
1336 ):
1337 self.connection_pool = connection_pool
1338 self.connection = None
1339 self.response_callbacks = response_callbacks
1340 self.is_transaction = transaction
1341 self.shard_hint = shard_hint
1342 self.watching = False
1343 self.command_stack: CommandStackT = []
1344 self.scripts: Set[Script] = set()
1345 self.explicit_transaction = False
1347 async def __aenter__(self: _RedisT) -> _RedisT:
1348 return self
1350 async def __aexit__(self, exc_type, exc_value, traceback):
1351 await self.reset()
1353 def __await__(self):
1354 return self._async_self().__await__()
1356 _DEL_MESSAGE = "Unclosed Pipeline client"
1358 def __len__(self):
1359 return len(self.command_stack)
1361 def __bool__(self):
1362 """Pipeline instances should always evaluate to True"""
1363 return True
1365 async def _async_self(self):
1366 return self
1368 async def reset(self):
1369 self.command_stack = []
1370 self.scripts = set()
1371 # make sure to reset the connection state in the event that we were
1372 # watching something
1373 if self.watching and self.connection:
1374 try:
1375 # call this manually since our unwatch or
1376 # immediate_execute_command methods can call reset()
1377 await self.connection.send_command("UNWATCH")
1378 await self.connection.read_response()
1379 except ConnectionError:
1380 # disconnect will also remove any previous WATCHes
1381 if self.connection:
1382 await self.connection.disconnect()
1383 # clean up the other instance attributes
1384 self.watching = False
1385 self.explicit_transaction = False
1386 # we can safely return the connection to the pool here since we're
1387 # sure we're no longer WATCHing anything
1388 if self.connection:
1389 await self.connection_pool.release(self.connection)
1390 self.connection = None
1392 async def aclose(self) -> None:
1393 """Alias for reset(), a standard method name for cleanup"""
1394 await self.reset()
1396 def multi(self):
1397 """
1398 Start a transactional block of the pipeline after WATCH commands
1399 are issued. End the transactional block with `execute`.
1400 """
1401 if self.explicit_transaction:
1402 raise RedisError("Cannot issue nested calls to MULTI")
1403 if self.command_stack:
1404 raise RedisError(
1405 "Commands without an initial WATCH have already been issued"
1406 )
1407 self.explicit_transaction = True
1409 def execute_command(
1410 self, *args, **kwargs
1411 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1412 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1413 return self.immediate_execute_command(*args, **kwargs)
1414 return self.pipeline_execute_command(*args, **kwargs)
1416 async def _disconnect_reset_raise_on_watching(
1417 self,
1418 conn: Connection,
1419 error: Exception,
1420 ):
1421 """
1422 Close the connection reset watching state and
1423 raise an exception if we were watching.
1425 The supported exceptions are already checked in the
1426 retry object so we don't need to do it here.
1428 After we disconnect the connection, it will try to reconnect and
1429 do a health check as part of the send_command logic(on connection level).
1430 """
1431 await conn.disconnect()
1432 # if we were already watching a variable, the watch is no longer
1433 # valid since this connection has died. raise a WatchError, which
1434 # indicates the user should retry this transaction.
1435 if self.watching:
1436 await self.reset()
1437 raise WatchError(
1438 f"A {type(error).__name__} occurred while watching one or more keys"
1439 )
1441 async def immediate_execute_command(self, *args, **options):
1442 """
1443 Execute a command immediately, but don't auto-retry on the supported
1444 errors for retry if we're already WATCHing a variable.
1445 Used when issuing WATCH or subsequent commands retrieving their values but before
1446 MULTI is called.
1447 """
1448 command_name = args[0]
1449 conn = self.connection
1450 # if this is the first call, we need a connection
1451 if not conn:
1452 conn = await self.connection_pool.get_connection()
1453 self.connection = conn
1455 return await conn.retry.call_with_retry(
1456 lambda: self._send_command_parse_response(
1457 conn, command_name, *args, **options
1458 ),
1459 lambda error: self._disconnect_reset_raise_on_watching(conn, error),
1460 )
1462 def pipeline_execute_command(self, *args, **options):
1463 """
1464 Stage a command to be executed when execute() is next called
1466 Returns the current Pipeline object back so commands can be
1467 chained together, such as:
1469 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1471 At some other point, you can then run: pipe.execute(),
1472 which will execute all commands queued in the pipe.
1473 """
1474 self.command_stack.append((args, options))
1475 return self
1477 async def _execute_transaction( # noqa: C901
1478 self, connection: Connection, commands: CommandStackT, raise_on_error
1479 ):
1480 pre: CommandT = (("MULTI",), {})
1481 post: CommandT = (("EXEC",), {})
1482 cmds = (pre, *commands, post)
1483 all_cmds = connection.pack_commands(
1484 args for args, options in cmds if EMPTY_RESPONSE not in options
1485 )
1486 await connection.send_packed_command(all_cmds)
1487 errors = []
1489 # parse off the response for MULTI
1490 # NOTE: we need to handle ResponseErrors here and continue
1491 # so that we read all the additional command messages from
1492 # the socket
1493 try:
1494 await self.parse_response(connection, "_")
1495 except ResponseError as err:
1496 errors.append((0, err))
1498 # and all the other commands
1499 for i, command in enumerate(commands):
1500 if EMPTY_RESPONSE in command[1]:
1501 errors.append((i, command[1][EMPTY_RESPONSE]))
1502 else:
1503 try:
1504 await self.parse_response(connection, "_")
1505 except ResponseError as err:
1506 self.annotate_exception(err, i + 1, command[0])
1507 errors.append((i, err))
1509 # parse the EXEC.
1510 try:
1511 response = await self.parse_response(connection, "_")
1512 except ExecAbortError as err:
1513 if errors:
1514 raise errors[0][1] from err
1515 raise
1517 # EXEC clears any watched keys
1518 self.watching = False
1520 if response is None:
1521 raise WatchError("Watched variable changed.") from None
1523 # put any parse errors into the response
1524 for i, e in errors:
1525 response.insert(i, e)
1527 if len(response) != len(commands):
1528 if self.connection:
1529 await self.connection.disconnect()
1530 raise ResponseError(
1531 "Wrong number of response items from pipeline execution"
1532 ) from None
1534 # find any errors in the response and raise if necessary
1535 if raise_on_error:
1536 self.raise_first_error(commands, response)
1538 # We have to run response callbacks manually
1539 data = []
1540 for r, cmd in zip(response, commands):
1541 if not isinstance(r, Exception):
1542 args, options = cmd
1543 command_name = args[0]
1545 # Remove keys entry, it needs only for cache.
1546 options.pop("keys", None)
1548 if command_name in self.response_callbacks:
1549 r = self.response_callbacks[command_name](r, **options)
1550 if inspect.isawaitable(r):
1551 r = await r
1552 data.append(r)
1553 return data
1555 async def _execute_pipeline(
1556 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1557 ):
1558 # build up all commands into a single request to increase network perf
1559 all_cmds = connection.pack_commands([args for args, _ in commands])
1560 await connection.send_packed_command(all_cmds)
1562 response = []
1563 for args, options in commands:
1564 try:
1565 response.append(
1566 await self.parse_response(connection, args[0], **options)
1567 )
1568 except ResponseError as e:
1569 response.append(e)
1571 if raise_on_error:
1572 self.raise_first_error(commands, response)
1573 return response
1575 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1576 for i, r in enumerate(response):
1577 if isinstance(r, ResponseError):
1578 self.annotate_exception(r, i + 1, commands[i][0])
1579 raise r
1581 def annotate_exception(
1582 self, exception: Exception, number: int, command: Iterable[object]
1583 ) -> None:
1584 cmd = " ".join(map(safe_str, command))
1585 msg = (
1586 f"Command # {number} ({truncate_text(cmd)}) "
1587 "of pipeline caused error: {exception.args}"
1588 )
1589 exception.args = (msg,) + exception.args[1:]
1591 async def parse_response(
1592 self, connection: Connection, command_name: Union[str, bytes], **options
1593 ):
1594 result = await super().parse_response(connection, command_name, **options)
1595 if command_name in self.UNWATCH_COMMANDS:
1596 self.watching = False
1597 elif command_name == "WATCH":
1598 self.watching = True
1599 return result
1601 async def load_scripts(self):
1602 # make sure all scripts that are about to be run on this pipeline exist
1603 scripts = list(self.scripts)
1604 immediate = self.immediate_execute_command
1605 shas = [s.sha for s in scripts]
1606 # we can't use the normal script_* methods because they would just
1607 # get buffered in the pipeline.
1608 exists = await immediate("SCRIPT EXISTS", *shas)
1609 if not all(exists):
1610 for s, exist in zip(scripts, exists):
1611 if not exist:
1612 s.sha = await immediate("SCRIPT LOAD", s.script)
1614 async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
1615 """
1616 Close the connection, raise an exception if we were watching.
1618 The supported exceptions are already checked in the
1619 retry object so we don't need to do it here.
1621 After we disconnect the connection, it will try to reconnect and
1622 do a health check as part of the send_command logic(on connection level).
1623 """
1624 await conn.disconnect()
1625 # if we were watching a variable, the watch is no longer valid
1626 # since this connection has died. raise a WatchError, which
1627 # indicates the user should retry this transaction.
1628 if self.watching:
1629 raise WatchError(
1630 f"A {type(error).__name__} occurred while watching one or more keys"
1631 )
1633 async def execute(self, raise_on_error: bool = True) -> List[Any]:
1634 """Execute all the commands in the current pipeline"""
1635 stack = self.command_stack
1636 if not stack and not self.watching:
1637 return []
1638 if self.scripts:
1639 await self.load_scripts()
1640 if self.is_transaction or self.explicit_transaction:
1641 execute = self._execute_transaction
1642 else:
1643 execute = self._execute_pipeline
1645 conn = self.connection
1646 if not conn:
1647 conn = await self.connection_pool.get_connection()
1648 # assign to self.connection so reset() releases the connection
1649 # back to the pool after we're done
1650 self.connection = conn
1651 conn = cast(Connection, conn)
1653 try:
1654 return await conn.retry.call_with_retry(
1655 lambda: execute(conn, stack, raise_on_error),
1656 lambda error: self._disconnect_raise_on_watching(conn, error),
1657 )
1658 finally:
1659 await self.reset()
1661 async def discard(self):
1662 """Flushes all previously queued commands
1663 See: https://redis.io/commands/DISCARD
1664 """
1665 await self.execute_command("DISCARD")
1667 async def watch(self, *names: KeyT):
1668 """Watches the values at keys ``names``"""
1669 if self.explicit_transaction:
1670 raise RedisError("Cannot issue a WATCH after a MULTI")
1671 return await self.execute_command("WATCH", *names)
1673 async def unwatch(self):
1674 """Unwatches all previously specified keys"""
1675 return self.watching and await self.execute_command("UNWATCH") or True