Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/asyncio/client.py: 23%
660 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 06:16 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 06:16 +0000
1import asyncio
2import copy
3import inspect
4import re
5import ssl
6import warnings
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 AsyncIterator,
11 Awaitable,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Mapping,
17 MutableMapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26 cast,
27)
29from redis._cache import (
30 DEFAULT_BLACKLIST,
31 DEFAULT_EVICTION_POLICY,
32 DEFAULT_WHITELIST,
33 AbstractCache,
34)
35from redis._parsers.helpers import (
36 _RedisCallbacks,
37 _RedisCallbacksRESP2,
38 _RedisCallbacksRESP3,
39 bool_ok,
40)
41from redis.asyncio.connection import (
42 Connection,
43 ConnectionPool,
44 SSLConnection,
45 UnixDomainSocketConnection,
46)
47from redis.asyncio.lock import Lock
48from redis.asyncio.retry import Retry
49from redis.client import (
50 EMPTY_RESPONSE,
51 NEVER_DECODE,
52 AbstractRedis,
53 CaseInsensitiveDict,
54)
55from redis.commands import (
56 AsyncCoreCommands,
57 AsyncRedisModuleCommands,
58 AsyncSentinelCommands,
59 list_or_args,
60)
61from redis.credentials import CredentialProvider
62from redis.exceptions import (
63 ConnectionError,
64 ExecAbortError,
65 PubSubError,
66 RedisError,
67 ResponseError,
68 TimeoutError,
69 WatchError,
70)
71from redis.typing import ChannelT, EncodableT, KeyT
72from redis.utils import (
73 HIREDIS_AVAILABLE,
74 _set_info_logger,
75 deprecated_function,
76 get_lib_version,
77 safe_str,
78 str_if_bytes,
79)
81PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
82_KeyT = TypeVar("_KeyT", bound=KeyT)
83_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
84_RedisT = TypeVar("_RedisT", bound="Redis")
85_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
86if TYPE_CHECKING:
87 from redis.commands.core import Script
90class ResponseCallbackProtocol(Protocol):
91 def __call__(self, response: Any, **kwargs): ...
94class AsyncResponseCallbackProtocol(Protocol):
95 async def __call__(self, response: Any, **kwargs): ...
98ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
101class Redis(
102 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
103):
104 """
105 Implementation of the Redis protocol.
107 This abstract class provides a Python interface to all Redis commands
108 and an implementation of the Redis protocol.
110 Pipelines derive from this, implementing how
111 the commands are sent and received to the Redis server. Based on
112 configuration, an instance will either use a ConnectionPool, or
113 Connection object to talk to redis.
114 """
116 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
118 @classmethod
119 def from_url(
120 cls,
121 url: str,
122 single_connection_client: bool = False,
123 auto_close_connection_pool: Optional[bool] = None,
124 **kwargs,
125 ):
126 """
127 Return a Redis client object configured from the given URL
129 For example::
131 redis://[[username]:[password]]@localhost:6379/0
132 rediss://[[username]:[password]]@localhost:6379/0
133 unix://[username@]/path/to/socket.sock?db=0[&password=password]
135 Three URL schemes are supported:
137 - `redis://` creates a TCP socket connection. See more at:
138 <https://www.iana.org/assignments/uri-schemes/prov/redis>
139 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
140 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
141 - ``unix://``: creates a Unix Domain Socket connection.
143 The username, password, hostname, path and all querystring values
144 are passed through urllib.parse.unquote in order to replace any
145 percent-encoded values with their corresponding characters.
147 There are several ways to specify a database number. The first value
148 found will be used:
150 1. A ``db`` querystring option, e.g. redis://localhost?db=0
152 2. If using the redis:// or rediss:// schemes, the path argument
153 of the url, e.g. redis://localhost/0
155 3. A ``db`` keyword argument to this function.
157 If none of these options are specified, the default db=0 is used.
159 All querystring options are cast to their appropriate Python types.
160 Boolean arguments can be specified with string values "True"/"False"
161 or "Yes"/"No". Values that cannot be properly cast cause a
162 ``ValueError`` to be raised. Once parsed, the querystring arguments
163 and keyword arguments are passed to the ``ConnectionPool``'s
164 class initializer. In the case of conflicting arguments, querystring
165 arguments always win.
167 """
168 connection_pool = ConnectionPool.from_url(url, **kwargs)
169 client = cls(
170 connection_pool=connection_pool,
171 single_connection_client=single_connection_client,
172 )
173 if auto_close_connection_pool is not None:
174 warnings.warn(
175 DeprecationWarning(
176 '"auto_close_connection_pool" is deprecated '
177 "since version 5.0.1. "
178 "Please create a ConnectionPool explicitly and "
179 "provide to the Redis() constructor instead."
180 )
181 )
182 else:
183 auto_close_connection_pool = True
184 client.auto_close_connection_pool = auto_close_connection_pool
185 return client
187 @classmethod
188 def from_pool(
189 cls: Type["Redis"],
190 connection_pool: ConnectionPool,
191 ) -> "Redis":
192 """
193 Return a Redis client from the given connection pool.
194 The Redis client will take ownership of the connection pool and
195 close it when the Redis client is closed.
196 """
197 client = cls(
198 connection_pool=connection_pool,
199 )
200 client.auto_close_connection_pool = True
201 return client
203 def __init__(
204 self,
205 *,
206 host: str = "localhost",
207 port: int = 6379,
208 db: Union[str, int] = 0,
209 password: Optional[str] = None,
210 socket_timeout: Optional[float] = None,
211 socket_connect_timeout: Optional[float] = None,
212 socket_keepalive: Optional[bool] = None,
213 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
214 connection_pool: Optional[ConnectionPool] = None,
215 unix_socket_path: Optional[str] = None,
216 encoding: str = "utf-8",
217 encoding_errors: str = "strict",
218 decode_responses: bool = False,
219 retry_on_timeout: bool = False,
220 retry_on_error: Optional[list] = None,
221 ssl: bool = False,
222 ssl_keyfile: Optional[str] = None,
223 ssl_certfile: Optional[str] = None,
224 ssl_cert_reqs: str = "required",
225 ssl_ca_certs: Optional[str] = None,
226 ssl_ca_data: Optional[str] = None,
227 ssl_check_hostname: bool = False,
228 ssl_min_version: Optional[ssl.TLSVersion] = None,
229 max_connections: Optional[int] = None,
230 single_connection_client: bool = False,
231 health_check_interval: int = 0,
232 client_name: Optional[str] = None,
233 lib_name: Optional[str] = "redis-py",
234 lib_version: Optional[str] = get_lib_version(),
235 username: Optional[str] = None,
236 retry: Optional[Retry] = None,
237 auto_close_connection_pool: Optional[bool] = None,
238 redis_connect_func=None,
239 credential_provider: Optional[CredentialProvider] = None,
240 protocol: Optional[int] = 2,
241 cache_enabled: bool = False,
242 client_cache: Optional[AbstractCache] = None,
243 cache_max_size: int = 100,
244 cache_ttl: int = 0,
245 cache_policy: str = DEFAULT_EVICTION_POLICY,
246 cache_blacklist: List[str] = DEFAULT_BLACKLIST,
247 cache_whitelist: List[str] = DEFAULT_WHITELIST,
248 ):
249 """
250 Initialize a new Redis client.
251 To specify a retry policy for specific errors, first set
252 `retry_on_error` to a list of the error/s to retry on, then set
253 `retry` to a valid `Retry` object.
254 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
255 """
256 kwargs: Dict[str, Any]
257 # auto_close_connection_pool only has an effect if connection_pool is
258 # None. It is assumed that if connection_pool is not None, the user
259 # wants to manage the connection pool themselves.
260 if auto_close_connection_pool is not None:
261 warnings.warn(
262 DeprecationWarning(
263 '"auto_close_connection_pool" is deprecated '
264 "since version 5.0.1. "
265 "Please create a ConnectionPool explicitly and "
266 "provide to the Redis() constructor instead."
267 )
268 )
269 else:
270 auto_close_connection_pool = True
272 if not connection_pool:
273 # Create internal connection pool, expected to be closed by Redis instance
274 if not retry_on_error:
275 retry_on_error = []
276 if retry_on_timeout is True:
277 retry_on_error.append(TimeoutError)
278 kwargs = {
279 "db": db,
280 "username": username,
281 "password": password,
282 "credential_provider": credential_provider,
283 "socket_timeout": socket_timeout,
284 "encoding": encoding,
285 "encoding_errors": encoding_errors,
286 "decode_responses": decode_responses,
287 "retry_on_timeout": retry_on_timeout,
288 "retry_on_error": retry_on_error,
289 "retry": copy.deepcopy(retry),
290 "max_connections": max_connections,
291 "health_check_interval": health_check_interval,
292 "client_name": client_name,
293 "lib_name": lib_name,
294 "lib_version": lib_version,
295 "redis_connect_func": redis_connect_func,
296 "protocol": protocol,
297 "cache_enabled": cache_enabled,
298 "client_cache": client_cache,
299 "cache_max_size": cache_max_size,
300 "cache_ttl": cache_ttl,
301 "cache_policy": cache_policy,
302 "cache_blacklist": cache_blacklist,
303 "cache_whitelist": cache_whitelist,
304 }
305 # based on input, setup appropriate connection args
306 if unix_socket_path is not None:
307 kwargs.update(
308 {
309 "path": unix_socket_path,
310 "connection_class": UnixDomainSocketConnection,
311 }
312 )
313 else:
314 # TCP specific options
315 kwargs.update(
316 {
317 "host": host,
318 "port": port,
319 "socket_connect_timeout": socket_connect_timeout,
320 "socket_keepalive": socket_keepalive,
321 "socket_keepalive_options": socket_keepalive_options,
322 }
323 )
325 if ssl:
326 kwargs.update(
327 {
328 "connection_class": SSLConnection,
329 "ssl_keyfile": ssl_keyfile,
330 "ssl_certfile": ssl_certfile,
331 "ssl_cert_reqs": ssl_cert_reqs,
332 "ssl_ca_certs": ssl_ca_certs,
333 "ssl_ca_data": ssl_ca_data,
334 "ssl_check_hostname": ssl_check_hostname,
335 "ssl_min_version": ssl_min_version,
336 }
337 )
338 # This arg only used if no pool is passed in
339 self.auto_close_connection_pool = auto_close_connection_pool
340 connection_pool = ConnectionPool(**kwargs)
341 else:
342 # If a pool is passed in, do not close it
343 self.auto_close_connection_pool = False
345 self.connection_pool = connection_pool
346 self.single_connection_client = single_connection_client
347 self.connection: Optional[Connection] = None
349 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
351 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
352 self.response_callbacks.update(_RedisCallbacksRESP3)
353 else:
354 self.response_callbacks.update(_RedisCallbacksRESP2)
356 # If using a single connection client, we need to lock creation-of and use-of
357 # the client in order to avoid race conditions such as using asyncio.gather
358 # on a set of redis commands
359 self._single_conn_lock = asyncio.Lock()
361 def __repr__(self):
362 return (
363 f"<{self.__class__.__module__}.{self.__class__.__name__}"
364 f"({self.connection_pool!r})>"
365 )
367 def __await__(self):
368 return self.initialize().__await__()
370 async def initialize(self: _RedisT) -> _RedisT:
371 if self.single_connection_client:
372 async with self._single_conn_lock:
373 if self.connection is None:
374 self.connection = await self.connection_pool.get_connection("_")
375 return self
377 def set_response_callback(self, command: str, callback: ResponseCallbackT):
378 """Set a custom Response Callback"""
379 self.response_callbacks[command] = callback
381 def get_encoder(self):
382 """Get the connection pool's encoder"""
383 return self.connection_pool.get_encoder()
385 def get_connection_kwargs(self):
386 """Get the connection's key-word arguments"""
387 return self.connection_pool.connection_kwargs
389 def get_retry(self) -> Optional["Retry"]:
390 return self.get_connection_kwargs().get("retry")
392 def set_retry(self, retry: "Retry") -> None:
393 self.get_connection_kwargs().update({"retry": retry})
394 self.connection_pool.set_retry(retry)
396 def load_external_module(self, funcname, func):
397 """
398 This function can be used to add externally defined redis modules,
399 and their namespaces to the redis client.
401 funcname - A string containing the name of the function to create
402 func - The function, being added to this class.
404 ex: Assume that one has a custom redis module named foomod that
405 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
406 To load function functions into this namespace:
408 from redis import Redis
409 from foomodule import F
410 r = Redis()
411 r.load_external_module("foo", F)
412 r.foo().dothing('your', 'arguments')
414 For a concrete example see the reimport of the redisjson module in
415 tests/test_connection.py::test_loading_external_modules
416 """
417 setattr(self, funcname, func)
419 def pipeline(
420 self, transaction: bool = True, shard_hint: Optional[str] = None
421 ) -> "Pipeline":
422 """
423 Return a new pipeline object that can queue multiple commands for
424 later execution. ``transaction`` indicates whether all commands
425 should be executed atomically. Apart from making a group of operations
426 atomic, pipelines are useful for reducing the back-and-forth overhead
427 between the client and server.
428 """
429 return Pipeline(
430 self.connection_pool, self.response_callbacks, transaction, shard_hint
431 )
433 async def transaction(
434 self,
435 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
436 *watches: KeyT,
437 shard_hint: Optional[str] = None,
438 value_from_callable: bool = False,
439 watch_delay: Optional[float] = None,
440 ):
441 """
442 Convenience method for executing the callable `func` as a transaction
443 while watching all keys specified in `watches`. The 'func' callable
444 should expect a single argument which is a Pipeline object.
445 """
446 pipe: Pipeline
447 async with self.pipeline(True, shard_hint) as pipe:
448 while True:
449 try:
450 if watches:
451 await pipe.watch(*watches)
452 func_value = func(pipe)
453 if inspect.isawaitable(func_value):
454 func_value = await func_value
455 exec_value = await pipe.execute()
456 return func_value if value_from_callable else exec_value
457 except WatchError:
458 if watch_delay is not None and watch_delay > 0:
459 await asyncio.sleep(watch_delay)
460 continue
462 def lock(
463 self,
464 name: KeyT,
465 timeout: Optional[float] = None,
466 sleep: float = 0.1,
467 blocking: bool = True,
468 blocking_timeout: Optional[float] = None,
469 lock_class: Optional[Type[Lock]] = None,
470 thread_local: bool = True,
471 ) -> Lock:
472 """
473 Return a new Lock object using key ``name`` that mimics
474 the behavior of threading.Lock.
476 If specified, ``timeout`` indicates a maximum life for the lock.
477 By default, it will remain locked until release() is called.
479 ``sleep`` indicates the amount of time to sleep per loop iteration
480 when the lock is in blocking mode and another client is currently
481 holding the lock.
483 ``blocking`` indicates whether calling ``acquire`` should block until
484 the lock has been acquired or to fail immediately, causing ``acquire``
485 to return False and the lock not being acquired. Defaults to True.
486 Note this value can be overridden by passing a ``blocking``
487 argument to ``acquire``.
489 ``blocking_timeout`` indicates the maximum amount of time in seconds to
490 spend trying to acquire the lock. A value of ``None`` indicates
491 continue trying forever. ``blocking_timeout`` can be specified as a
492 float or integer, both representing the number of seconds to wait.
494 ``lock_class`` forces the specified lock implementation. Note that as
495 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
496 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
497 you have created your own custom lock class.
499 ``thread_local`` indicates whether the lock token is placed in
500 thread-local storage. By default, the token is placed in thread local
501 storage so that a thread only sees its token, not a token set by
502 another thread. Consider the following timeline:
504 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
505 thread-1 sets the token to "abc"
506 time: 1, thread-2 blocks trying to acquire `my-lock` using the
507 Lock instance.
508 time: 5, thread-1 has not yet completed. redis expires the lock
509 key.
510 time: 5, thread-2 acquired `my-lock` now that it's available.
511 thread-2 sets the token to "xyz"
512 time: 6, thread-1 finishes its work and calls release(). if the
513 token is *not* stored in thread local storage, then
514 thread-1 would see the token value as "xyz" and would be
515 able to successfully release the thread-2's lock.
517 In some use cases it's necessary to disable thread local storage. For
518 example, if you have code where one thread acquires a lock and passes
519 that lock instance to a worker thread to release later. If thread
520 local storage isn't disabled in this case, the worker thread won't see
521 the token set by the thread that acquired the lock. Our assumption
522 is that these cases aren't common and as such default to using
523 thread local storage."""
524 if lock_class is None:
525 lock_class = Lock
526 return lock_class(
527 self,
528 name,
529 timeout=timeout,
530 sleep=sleep,
531 blocking=blocking,
532 blocking_timeout=blocking_timeout,
533 thread_local=thread_local,
534 )
536 def pubsub(self, **kwargs) -> "PubSub":
537 """
538 Return a Publish/Subscribe object. With this object, you can
539 subscribe to channels and listen for messages that get published to
540 them.
541 """
542 return PubSub(self.connection_pool, **kwargs)
544 def monitor(self) -> "Monitor":
545 return Monitor(self.connection_pool)
547 def client(self) -> "Redis":
548 return self.__class__(
549 connection_pool=self.connection_pool, single_connection_client=True
550 )
552 async def __aenter__(self: _RedisT) -> _RedisT:
553 return await self.initialize()
555 async def __aexit__(self, exc_type, exc_value, traceback):
556 await self.aclose()
558 _DEL_MESSAGE = "Unclosed Redis client"
560 # passing _warnings and _grl as argument default since they may be gone
561 # by the time __del__ is called at shutdown
562 def __del__(
563 self,
564 _warn: Any = warnings.warn,
565 _grl: Any = asyncio.get_running_loop,
566 ) -> None:
567 if hasattr(self, "connection") and (self.connection is not None):
568 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
569 try:
570 context = {"client": self, "message": self._DEL_MESSAGE}
571 _grl().call_exception_handler(context)
572 except RuntimeError:
573 pass
574 self.connection._close()
576 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
577 """
578 Closes Redis client connection
580 :param close_connection_pool: decides whether to close the connection pool used
581 by this Redis client, overriding Redis.auto_close_connection_pool. By default,
582 let Redis.auto_close_connection_pool decide whether to close the connection
583 pool.
584 """
585 conn = self.connection
586 if conn:
587 self.connection = None
588 await self.connection_pool.release(conn)
589 if close_connection_pool or (
590 close_connection_pool is None and self.auto_close_connection_pool
591 ):
592 await self.connection_pool.disconnect()
594 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
595 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
596 """
597 Alias for aclose(), for backwards compatibility
598 """
599 await self.aclose(close_connection_pool)
601 async def _send_command_parse_response(self, conn, command_name, *args, **options):
602 """
603 Send a command and parse the response
604 """
605 await conn.send_command(*args)
606 return await self.parse_response(conn, command_name, **options)
608 async def _disconnect_raise(self, conn: Connection, error: Exception):
609 """
610 Close the connection and raise an exception
611 if retry_on_error is not set or the error
612 is not one of the specified error types
613 """
614 await conn.disconnect()
615 if (
616 conn.retry_on_error is None
617 or isinstance(error, tuple(conn.retry_on_error)) is False
618 ):
619 raise error
621 # COMMAND EXECUTION AND PROTOCOL PARSING
622 async def execute_command(self, *args, **options):
623 """Execute a command and return a parsed response"""
624 await self.initialize()
625 command_name = args[0]
626 keys = options.pop("keys", None) # keys are used only for client side caching
627 pool = self.connection_pool
628 conn = self.connection or await pool.get_connection(command_name, **options)
629 response_from_cache = await conn._get_from_local_cache(args)
630 try:
631 if response_from_cache is not None:
632 return response_from_cache
633 else:
634 try:
635 if self.single_connection_client:
636 await self._single_conn_lock.acquire()
637 response = await conn.retry.call_with_retry(
638 lambda: self._send_command_parse_response(
639 conn, command_name, *args, **options
640 ),
641 lambda error: self._disconnect_raise(conn, error),
642 )
643 conn._add_to_local_cache(args, response, keys)
644 return response
645 finally:
646 if self.single_connection_client:
647 self._single_conn_lock.release()
648 finally:
649 if not self.connection:
650 await pool.release(conn)
652 async def parse_response(
653 self, connection: Connection, command_name: Union[str, bytes], **options
654 ):
655 """Parses a response from the Redis server"""
656 try:
657 if NEVER_DECODE in options:
658 response = await connection.read_response(disable_decoding=True)
659 options.pop(NEVER_DECODE)
660 else:
661 response = await connection.read_response()
662 except ResponseError:
663 if EMPTY_RESPONSE in options:
664 return options[EMPTY_RESPONSE]
665 raise
667 if EMPTY_RESPONSE in options:
668 options.pop(EMPTY_RESPONSE)
670 if command_name in self.response_callbacks:
671 # Mypy bug: https://github.com/python/mypy/issues/10977
672 command_name = cast(str, command_name)
673 retval = self.response_callbacks[command_name](response, **options)
674 return await retval if inspect.isawaitable(retval) else retval
675 return response
677 def flush_cache(self):
678 try:
679 if self.connection:
680 self.connection.client_cache.flush()
681 else:
682 self.connection_pool.flush_cache()
683 except AttributeError:
684 pass
686 def delete_command_from_cache(self, command):
687 try:
688 if self.connection:
689 self.connection.client_cache.delete_command(command)
690 else:
691 self.connection_pool.delete_command_from_cache(command)
692 except AttributeError:
693 pass
695 def invalidate_key_from_cache(self, key):
696 try:
697 if self.connection:
698 self.connection.client_cache.invalidate_key(key)
699 else:
700 self.connection_pool.invalidate_key_from_cache(key)
701 except AttributeError:
702 pass
705StrictRedis = Redis
708class MonitorCommandInfo(TypedDict):
709 time: float
710 db: int
711 client_address: str
712 client_port: str
713 client_type: str
714 command: str
717class Monitor:
718 """
719 Monitor is useful for handling the MONITOR command to the redis server.
720 next_command() method returns one command from monitor
721 listen() method yields commands from monitor.
722 """
724 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
725 command_re = re.compile(r'"(.*?)(?<!\\)"')
727 def __init__(self, connection_pool: ConnectionPool):
728 self.connection_pool = connection_pool
729 self.connection: Optional[Connection] = None
731 async def connect(self):
732 if self.connection is None:
733 self.connection = await self.connection_pool.get_connection("MONITOR")
735 async def __aenter__(self):
736 await self.connect()
737 await self.connection.send_command("MONITOR")
738 # check that monitor returns 'OK', but don't return it to user
739 response = await self.connection.read_response()
740 if not bool_ok(response):
741 raise RedisError(f"MONITOR failed: {response}")
742 return self
744 async def __aexit__(self, *args):
745 await self.connection.disconnect()
746 await self.connection_pool.release(self.connection)
748 async def next_command(self) -> MonitorCommandInfo:
749 """Parse the response from a monitor command"""
750 await self.connect()
751 response = await self.connection.read_response()
752 if isinstance(response, bytes):
753 response = self.connection.encoder.decode(response, force=True)
754 command_time, command_data = response.split(" ", 1)
755 m = self.monitor_re.match(command_data)
756 db_id, client_info, command = m.groups()
757 command = " ".join(self.command_re.findall(command))
758 # Redis escapes double quotes because each piece of the command
759 # string is surrounded by double quotes. We don't have that
760 # requirement so remove the escaping and leave the quote.
761 command = command.replace('\\"', '"')
763 if client_info == "lua":
764 client_address = "lua"
765 client_port = ""
766 client_type = "lua"
767 elif client_info.startswith("unix"):
768 client_address = "unix"
769 client_port = client_info[5:]
770 client_type = "unix"
771 else:
772 # use rsplit as ipv6 addresses contain colons
773 client_address, client_port = client_info.rsplit(":", 1)
774 client_type = "tcp"
775 return {
776 "time": float(command_time),
777 "db": int(db_id),
778 "client_address": client_address,
779 "client_port": client_port,
780 "client_type": client_type,
781 "command": command,
782 }
784 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
785 """Listen for commands coming to the server."""
786 while True:
787 yield await self.next_command()
790class PubSub:
791 """
792 PubSub provides publish, subscribe and listen support to Redis channels.
794 After subscribing to one or more channels, the listen() method will block
795 until a message arrives on one of the subscribed channels. That message
796 will be returned and it's safe to start listening again.
797 """
799 PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
800 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
801 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
803 def __init__(
804 self,
805 connection_pool: ConnectionPool,
806 shard_hint: Optional[str] = None,
807 ignore_subscribe_messages: bool = False,
808 encoder=None,
809 push_handler_func: Optional[Callable] = None,
810 ):
811 self.connection_pool = connection_pool
812 self.shard_hint = shard_hint
813 self.ignore_subscribe_messages = ignore_subscribe_messages
814 self.connection = None
815 # we need to know the encoding options for this connection in order
816 # to lookup channel and pattern names for callback handlers.
817 self.encoder = encoder
818 self.push_handler_func = push_handler_func
819 if self.encoder is None:
820 self.encoder = self.connection_pool.get_encoder()
821 if self.encoder.decode_responses:
822 self.health_check_response = [
823 ["pong", self.HEALTH_CHECK_MESSAGE],
824 self.HEALTH_CHECK_MESSAGE,
825 ]
826 else:
827 self.health_check_response = [
828 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
829 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
830 ]
831 if self.push_handler_func is None:
832 _set_info_logger()
833 self.channels = {}
834 self.pending_unsubscribe_channels = set()
835 self.patterns = {}
836 self.pending_unsubscribe_patterns = set()
837 self._lock = asyncio.Lock()
839 async def __aenter__(self):
840 return self
842 async def __aexit__(self, exc_type, exc_value, traceback):
843 await self.aclose()
845 def __del__(self):
846 if self.connection:
847 self.connection.deregister_connect_callback(self.on_connect)
849 async def aclose(self):
850 # In case a connection property does not yet exist
851 # (due to a crash earlier in the Redis() constructor), return
852 # immediately as there is nothing to clean-up.
853 if not hasattr(self, "connection"):
854 return
855 async with self._lock:
856 if self.connection:
857 await self.connection.disconnect()
858 self.connection.deregister_connect_callback(self.on_connect)
859 await self.connection_pool.release(self.connection)
860 self.connection = None
861 self.channels = {}
862 self.pending_unsubscribe_channels = set()
863 self.patterns = {}
864 self.pending_unsubscribe_patterns = set()
866 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
867 async def close(self) -> None:
868 """Alias for aclose(), for backwards compatibility"""
869 await self.aclose()
871 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
872 async def reset(self) -> None:
873 """Alias for aclose(), for backwards compatibility"""
874 await self.aclose()
876 async def on_connect(self, connection: Connection):
877 """Re-subscribe to any channels and patterns previously subscribed to"""
878 # NOTE: for python3, we can't pass bytestrings as keyword arguments
879 # so we need to decode channel/pattern names back to unicode strings
880 # before passing them to [p]subscribe.
881 self.pending_unsubscribe_channels.clear()
882 self.pending_unsubscribe_patterns.clear()
883 if self.channels:
884 channels = {}
885 for k, v in self.channels.items():
886 channels[self.encoder.decode(k, force=True)] = v
887 await self.subscribe(**channels)
888 if self.patterns:
889 patterns = {}
890 for k, v in self.patterns.items():
891 patterns[self.encoder.decode(k, force=True)] = v
892 await self.psubscribe(**patterns)
894 @property
895 def subscribed(self):
896 """Indicates if there are subscriptions to any channels or patterns"""
897 return bool(self.channels or self.patterns)
899 async def execute_command(self, *args: EncodableT):
900 """Execute a publish/subscribe command"""
902 # NOTE: don't parse the response in this function -- it could pull a
903 # legitimate message off the stack if the connection is already
904 # subscribed to one or more channels
906 await self.connect()
907 connection = self.connection
908 kwargs = {"check_health": not self.subscribed}
909 await self._execute(connection, connection.send_command, *args, **kwargs)
911 async def connect(self):
912 """
913 Ensure that the PubSub is connected
914 """
915 if self.connection is None:
916 self.connection = await self.connection_pool.get_connection(
917 "pubsub", self.shard_hint
918 )
919 # register a callback that re-subscribes to any channels we
920 # were listening to when we were disconnected
921 self.connection.register_connect_callback(self.on_connect)
922 else:
923 await self.connection.connect()
924 if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
925 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
927 async def _disconnect_raise_connect(self, conn, error):
928 """
929 Close the connection and raise an exception
930 if retry_on_error is not set or the error is not one
931 of the specified error types. Otherwise, try to
932 reconnect
933 """
934 await conn.disconnect()
935 if (
936 conn.retry_on_error is None
937 or isinstance(error, tuple(conn.retry_on_error)) is False
938 ):
939 raise error
940 await conn.connect()
942 async def _execute(self, conn, command, *args, **kwargs):
943 """
944 Connect manually upon disconnection. If the Redis server is down,
945 this will fail and raise a ConnectionError as desired.
946 After reconnection, the ``on_connect`` callback should have been
947 called by the # connection to resubscribe us to any channels and
948 patterns we were previously listening to
949 """
950 return await conn.retry.call_with_retry(
951 lambda: command(*args, **kwargs),
952 lambda error: self._disconnect_raise_connect(conn, error),
953 )
955 async def parse_response(self, block: bool = True, timeout: float = 0):
956 """Parse the response from a publish/subscribe command"""
957 conn = self.connection
958 if conn is None:
959 raise RuntimeError(
960 "pubsub connection not set: "
961 "did you forget to call subscribe() or psubscribe()?"
962 )
964 await self.check_health()
966 if not conn.is_connected:
967 await conn.connect()
969 read_timeout = None if block else timeout
970 response = await self._execute(
971 conn,
972 conn.read_response,
973 timeout=read_timeout,
974 disconnect_on_error=False,
975 push_request=True,
976 )
978 if conn.health_check_interval and response in self.health_check_response:
979 # ignore the health check message as user might not expect it
980 return None
981 return response
983 async def check_health(self):
984 conn = self.connection
985 if conn is None:
986 raise RuntimeError(
987 "pubsub connection not set: "
988 "did you forget to call subscribe() or psubscribe()?"
989 )
991 if (
992 conn.health_check_interval
993 and asyncio.get_running_loop().time() > conn.next_health_check
994 ):
995 await conn.send_command(
996 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
997 )
999 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1000 """
1001 normalize channel/pattern names to be either bytes or strings
1002 based on whether responses are automatically decoded. this saves us
1003 from coercing the value for each message coming in.
1004 """
1005 encode = self.encoder.encode
1006 decode = self.encoder.decode
1007 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1009 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1010 """
1011 Subscribe to channel patterns. Patterns supplied as keyword arguments
1012 expect a pattern name as the key and a callable as the value. A
1013 pattern's callable will be invoked automatically when a message is
1014 received on that pattern rather than producing a message via
1015 ``listen()``.
1016 """
1017 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1018 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1019 # Mypy bug: https://github.com/python/mypy/issues/10970
1020 new_patterns.update(kwargs) # type: ignore[arg-type]
1021 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1022 # update the patterns dict AFTER we send the command. we don't want to
1023 # subscribe twice to these patterns, once for the command and again
1024 # for the reconnection.
1025 new_patterns = self._normalize_keys(new_patterns)
1026 self.patterns.update(new_patterns)
1027 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1028 return ret_val
1030 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1031 """
1032 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1033 all patterns.
1034 """
1035 patterns: Iterable[ChannelT]
1036 if args:
1037 parsed_args = list_or_args((args[0],), args[1:])
1038 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1039 else:
1040 parsed_args = []
1041 patterns = self.patterns
1042 self.pending_unsubscribe_patterns.update(patterns)
1043 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1045 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1046 """
1047 Subscribe to channels. Channels supplied as keyword arguments expect
1048 a channel name as the key and a callable as the value. A channel's
1049 callable will be invoked automatically when a message is received on
1050 that channel rather than producing a message via ``listen()`` or
1051 ``get_message()``.
1052 """
1053 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1054 new_channels = dict.fromkeys(parsed_args)
1055 # Mypy bug: https://github.com/python/mypy/issues/10970
1056 new_channels.update(kwargs) # type: ignore[arg-type]
1057 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1058 # update the channels dict AFTER we send the command. we don't want to
1059 # subscribe twice to these channels, once for the command and again
1060 # for the reconnection.
1061 new_channels = self._normalize_keys(new_channels)
1062 self.channels.update(new_channels)
1063 self.pending_unsubscribe_channels.difference_update(new_channels)
1064 return ret_val
1066 def unsubscribe(self, *args) -> Awaitable:
1067 """
1068 Unsubscribe from the supplied channels. If empty, unsubscribe from
1069 all channels
1070 """
1071 if args:
1072 parsed_args = list_or_args(args[0], args[1:])
1073 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1074 else:
1075 parsed_args = []
1076 channels = self.channels
1077 self.pending_unsubscribe_channels.update(channels)
1078 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1080 async def listen(self) -> AsyncIterator:
1081 """Listen for messages on channels this client has been subscribed to"""
1082 while self.subscribed:
1083 response = await self.handle_message(await self.parse_response(block=True))
1084 if response is not None:
1085 yield response
1087 async def get_message(
1088 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1089 ):
1090 """
1091 Get the next message if one is available, otherwise None.
1093 If timeout is specified, the system will wait for `timeout` seconds
1094 before returning. Timeout should be specified as a floating point
1095 number or None to wait indefinitely.
1096 """
1097 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1098 if response:
1099 return await self.handle_message(response, ignore_subscribe_messages)
1100 return None
1102 def ping(self, message=None) -> Awaitable:
1103 """
1104 Ping the Redis server
1105 """
1106 args = ["PING", message] if message is not None else ["PING"]
1107 return self.execute_command(*args)
1109 async def handle_message(self, response, ignore_subscribe_messages=False):
1110 """
1111 Parses a pub/sub message. If the channel or pattern was subscribed to
1112 with a message handler, the handler is invoked instead of a parsed
1113 message being returned.
1114 """
1115 if response is None:
1116 return None
1117 if isinstance(response, bytes):
1118 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1119 message_type = str_if_bytes(response[0])
1120 if message_type == "pmessage":
1121 message = {
1122 "type": message_type,
1123 "pattern": response[1],
1124 "channel": response[2],
1125 "data": response[3],
1126 }
1127 elif message_type == "pong":
1128 message = {
1129 "type": message_type,
1130 "pattern": None,
1131 "channel": None,
1132 "data": response[1],
1133 }
1134 else:
1135 message = {
1136 "type": message_type,
1137 "pattern": None,
1138 "channel": response[1],
1139 "data": response[2],
1140 }
1142 # if this is an unsubscribe message, remove it from memory
1143 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1144 if message_type == "punsubscribe":
1145 pattern = response[1]
1146 if pattern in self.pending_unsubscribe_patterns:
1147 self.pending_unsubscribe_patterns.remove(pattern)
1148 self.patterns.pop(pattern, None)
1149 else:
1150 channel = response[1]
1151 if channel in self.pending_unsubscribe_channels:
1152 self.pending_unsubscribe_channels.remove(channel)
1153 self.channels.pop(channel, None)
1155 if message_type in self.PUBLISH_MESSAGE_TYPES:
1156 # if there's a message handler, invoke it
1157 if message_type == "pmessage":
1158 handler = self.patterns.get(message["pattern"], None)
1159 else:
1160 handler = self.channels.get(message["channel"], None)
1161 if handler:
1162 if inspect.iscoroutinefunction(handler):
1163 await handler(message)
1164 else:
1165 handler(message)
1166 return None
1167 elif message_type != "pong":
1168 # this is a subscribe/unsubscribe message. ignore if we don't
1169 # want them
1170 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1171 return None
1173 return message
1175 async def run(
1176 self,
1177 *,
1178 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1179 poll_timeout: float = 1.0,
1180 ) -> None:
1181 """Process pub/sub messages using registered callbacks.
1183 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1184 redis-py, but it is a coroutine. To launch it as a separate task, use
1185 ``asyncio.create_task``:
1187 >>> task = asyncio.create_task(pubsub.run())
1189 To shut it down, use asyncio cancellation:
1191 >>> task.cancel()
1192 >>> await task
1193 """
1194 for channel, handler in self.channels.items():
1195 if handler is None:
1196 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1197 for pattern, handler in self.patterns.items():
1198 if handler is None:
1199 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1201 await self.connect()
1202 while True:
1203 try:
1204 await self.get_message(
1205 ignore_subscribe_messages=True, timeout=poll_timeout
1206 )
1207 except asyncio.CancelledError:
1208 raise
1209 except BaseException as e:
1210 if exception_handler is None:
1211 raise
1212 res = exception_handler(e, self)
1213 if inspect.isawaitable(res):
1214 await res
1215 # Ensure that other tasks on the event loop get a chance to run
1216 # if we didn't have to block for I/O anywhere.
1217 await asyncio.sleep(0)
1220class PubsubWorkerExceptionHandler(Protocol):
1221 def __call__(self, e: BaseException, pubsub: PubSub): ...
1224class AsyncPubsubWorkerExceptionHandler(Protocol):
1225 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1228PSWorkerThreadExcHandlerT = Union[
1229 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1230]
1233CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1234CommandStackT = List[CommandT]
1237class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1238 """
1239 Pipelines provide a way to transmit multiple commands to the Redis server
1240 in one transmission. This is convenient for batch processing, such as
1241 saving all the values in a list to Redis.
1243 All commands executed within a pipeline are wrapped with MULTI and EXEC
1244 calls. This guarantees all commands executed in the pipeline will be
1245 executed atomically.
1247 Any command raising an exception does *not* halt the execution of
1248 subsequent commands in the pipeline. Instead, the exception is caught
1249 and its instance is placed into the response list returned by execute().
1250 Code iterating over the response list should be able to deal with an
1251 instance of an exception as a potential value. In general, these will be
1252 ResponseError exceptions, such as those raised when issuing a command
1253 on a key of a different datatype.
1254 """
1256 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1258 def __init__(
1259 self,
1260 connection_pool: ConnectionPool,
1261 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1262 transaction: bool,
1263 shard_hint: Optional[str],
1264 ):
1265 self.connection_pool = connection_pool
1266 self.connection = None
1267 self.response_callbacks = response_callbacks
1268 self.is_transaction = transaction
1269 self.shard_hint = shard_hint
1270 self.watching = False
1271 self.command_stack: CommandStackT = []
1272 self.scripts: Set["Script"] = set()
1273 self.explicit_transaction = False
1275 async def __aenter__(self: _RedisT) -> _RedisT:
1276 return self
1278 async def __aexit__(self, exc_type, exc_value, traceback):
1279 await self.reset()
1281 def __await__(self):
1282 return self._async_self().__await__()
1284 _DEL_MESSAGE = "Unclosed Pipeline client"
1286 def __len__(self):
1287 return len(self.command_stack)
1289 def __bool__(self):
1290 """Pipeline instances should always evaluate to True"""
1291 return True
1293 async def _async_self(self):
1294 return self
1296 async def reset(self):
1297 self.command_stack = []
1298 self.scripts = set()
1299 # make sure to reset the connection state in the event that we were
1300 # watching something
1301 if self.watching and self.connection:
1302 try:
1303 # call this manually since our unwatch or
1304 # immediate_execute_command methods can call reset()
1305 await self.connection.send_command("UNWATCH")
1306 await self.connection.read_response()
1307 except ConnectionError:
1308 # disconnect will also remove any previous WATCHes
1309 if self.connection:
1310 await self.connection.disconnect()
1311 # clean up the other instance attributes
1312 self.watching = False
1313 self.explicit_transaction = False
1314 # we can safely return the connection to the pool here since we're
1315 # sure we're no longer WATCHing anything
1316 if self.connection:
1317 await self.connection_pool.release(self.connection)
1318 self.connection = None
1320 async def aclose(self) -> None:
1321 """Alias for reset(), a standard method name for cleanup"""
1322 await self.reset()
1324 def multi(self):
1325 """
1326 Start a transactional block of the pipeline after WATCH commands
1327 are issued. End the transactional block with `execute`.
1328 """
1329 if self.explicit_transaction:
1330 raise RedisError("Cannot issue nested calls to MULTI")
1331 if self.command_stack:
1332 raise RedisError(
1333 "Commands without an initial WATCH have already been issued"
1334 )
1335 self.explicit_transaction = True
1337 def execute_command(
1338 self, *args, **kwargs
1339 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1340 kwargs.pop("keys", None) # the keys are used only for client side caching
1341 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1342 return self.immediate_execute_command(*args, **kwargs)
1343 return self.pipeline_execute_command(*args, **kwargs)
1345 async def _disconnect_reset_raise(self, conn, error):
1346 """
1347 Close the connection, reset watching state and
1348 raise an exception if we were watching,
1349 if retry_on_error is not set or the error is not one
1350 of the specified error types.
1351 """
1352 await conn.disconnect()
1353 # if we were already watching a variable, the watch is no longer
1354 # valid since this connection has died. raise a WatchError, which
1355 # indicates the user should retry this transaction.
1356 if self.watching:
1357 await self.aclose()
1358 raise WatchError(
1359 "A ConnectionError occurred on while watching one or more keys"
1360 )
1361 # if retry_on_error is not set or the error is not one
1362 # of the specified error types, raise it
1363 if (
1364 conn.retry_on_error is None
1365 or isinstance(error, tuple(conn.retry_on_error)) is False
1366 ):
1367 await self.aclose()
1368 raise
1370 async def immediate_execute_command(self, *args, **options):
1371 """
1372 Execute a command immediately, but don't auto-retry on a
1373 ConnectionError if we're already WATCHing a variable. Used when
1374 issuing WATCH or subsequent commands retrieving their values but before
1375 MULTI is called.
1376 """
1377 command_name = args[0]
1378 conn = self.connection
1379 # if this is the first call, we need a connection
1380 if not conn:
1381 conn = await self.connection_pool.get_connection(
1382 command_name, self.shard_hint
1383 )
1384 self.connection = conn
1386 return await conn.retry.call_with_retry(
1387 lambda: self._send_command_parse_response(
1388 conn, command_name, *args, **options
1389 ),
1390 lambda error: self._disconnect_reset_raise(conn, error),
1391 )
1393 def pipeline_execute_command(self, *args, **options):
1394 """
1395 Stage a command to be executed when execute() is next called
1397 Returns the current Pipeline object back so commands can be
1398 chained together, such as:
1400 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1402 At some other point, you can then run: pipe.execute(),
1403 which will execute all commands queued in the pipe.
1404 """
1405 self.command_stack.append((args, options))
1406 return self
1408 async def _execute_transaction( # noqa: C901
1409 self, connection: Connection, commands: CommandStackT, raise_on_error
1410 ):
1411 pre: CommandT = (("MULTI",), {})
1412 post: CommandT = (("EXEC",), {})
1413 cmds = (pre, *commands, post)
1414 all_cmds = connection.pack_commands(
1415 args for args, options in cmds if EMPTY_RESPONSE not in options
1416 )
1417 await connection.send_packed_command(all_cmds)
1418 errors = []
1420 # parse off the response for MULTI
1421 # NOTE: we need to handle ResponseErrors here and continue
1422 # so that we read all the additional command messages from
1423 # the socket
1424 try:
1425 await self.parse_response(connection, "_")
1426 except ResponseError as err:
1427 errors.append((0, err))
1429 # and all the other commands
1430 for i, command in enumerate(commands):
1431 if EMPTY_RESPONSE in command[1]:
1432 errors.append((i, command[1][EMPTY_RESPONSE]))
1433 else:
1434 try:
1435 await self.parse_response(connection, "_")
1436 except ResponseError as err:
1437 self.annotate_exception(err, i + 1, command[0])
1438 errors.append((i, err))
1440 # parse the EXEC.
1441 try:
1442 response = await self.parse_response(connection, "_")
1443 except ExecAbortError as err:
1444 if errors:
1445 raise errors[0][1] from err
1446 raise
1448 # EXEC clears any watched keys
1449 self.watching = False
1451 if response is None:
1452 raise WatchError("Watched variable changed.") from None
1454 # put any parse errors into the response
1455 for i, e in errors:
1456 response.insert(i, e)
1458 if len(response) != len(commands):
1459 if self.connection:
1460 await self.connection.disconnect()
1461 raise ResponseError(
1462 "Wrong number of response items from pipeline execution"
1463 ) from None
1465 # find any errors in the response and raise if necessary
1466 if raise_on_error:
1467 self.raise_first_error(commands, response)
1469 # We have to run response callbacks manually
1470 data = []
1471 for r, cmd in zip(response, commands):
1472 if not isinstance(r, Exception):
1473 args, options = cmd
1474 command_name = args[0]
1475 if command_name in self.response_callbacks:
1476 r = self.response_callbacks[command_name](r, **options)
1477 if inspect.isawaitable(r):
1478 r = await r
1479 data.append(r)
1480 return data
1482 async def _execute_pipeline(
1483 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1484 ):
1485 # build up all commands into a single request to increase network perf
1486 all_cmds = connection.pack_commands([args for args, _ in commands])
1487 await connection.send_packed_command(all_cmds)
1489 response = []
1490 for args, options in commands:
1491 try:
1492 response.append(
1493 await self.parse_response(connection, args[0], **options)
1494 )
1495 except ResponseError as e:
1496 response.append(e)
1498 if raise_on_error:
1499 self.raise_first_error(commands, response)
1500 return response
1502 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1503 for i, r in enumerate(response):
1504 if isinstance(r, ResponseError):
1505 self.annotate_exception(r, i + 1, commands[i][0])
1506 raise r
1508 def annotate_exception(
1509 self, exception: Exception, number: int, command: Iterable[object]
1510 ) -> None:
1511 cmd = " ".join(map(safe_str, command))
1512 msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}"
1513 exception.args = (msg,) + exception.args[1:]
1515 async def parse_response(
1516 self, connection: Connection, command_name: Union[str, bytes], **options
1517 ):
1518 result = await super().parse_response(connection, command_name, **options)
1519 if command_name in self.UNWATCH_COMMANDS:
1520 self.watching = False
1521 elif command_name == "WATCH":
1522 self.watching = True
1523 return result
1525 async def load_scripts(self):
1526 # make sure all scripts that are about to be run on this pipeline exist
1527 scripts = list(self.scripts)
1528 immediate = self.immediate_execute_command
1529 shas = [s.sha for s in scripts]
1530 # we can't use the normal script_* methods because they would just
1531 # get buffered in the pipeline.
1532 exists = await immediate("SCRIPT EXISTS", *shas)
1533 if not all(exists):
1534 for s, exist in zip(scripts, exists):
1535 if not exist:
1536 s.sha = await immediate("SCRIPT LOAD", s.script)
1538 async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
1539 """
1540 Close the connection, raise an exception if we were watching,
1541 and raise an exception if retry_on_error is not set or the
1542 error is not one of the specified error types.
1543 """
1544 await conn.disconnect()
1545 # if we were watching a variable, the watch is no longer valid
1546 # since this connection has died. raise a WatchError, which
1547 # indicates the user should retry this transaction.
1548 if self.watching:
1549 raise WatchError(
1550 "A ConnectionError occurred on while watching one or more keys"
1551 )
1552 # if retry_on_error is not set or the error is not one
1553 # of the specified error types, raise it
1554 if (
1555 conn.retry_on_error is None
1556 or isinstance(error, tuple(conn.retry_on_error)) is False
1557 ):
1558 await self.reset()
1559 raise
1561 async def execute(self, raise_on_error: bool = True):
1562 """Execute all the commands in the current pipeline"""
1563 stack = self.command_stack
1564 if not stack and not self.watching:
1565 return []
1566 if self.scripts:
1567 await self.load_scripts()
1568 if self.is_transaction or self.explicit_transaction:
1569 execute = self._execute_transaction
1570 else:
1571 execute = self._execute_pipeline
1573 conn = self.connection
1574 if not conn:
1575 conn = await self.connection_pool.get_connection("MULTI", self.shard_hint)
1576 # assign to self.connection so reset() releases the connection
1577 # back to the pool after we're done
1578 self.connection = conn
1579 conn = cast(Connection, conn)
1581 try:
1582 return await conn.retry.call_with_retry(
1583 lambda: execute(conn, stack, raise_on_error),
1584 lambda error: self._disconnect_raise_reset(conn, error),
1585 )
1586 finally:
1587 await self.reset()
1589 async def discard(self):
1590 """Flushes all previously queued commands
1591 See: https://redis.io/commands/DISCARD
1592 """
1593 await self.execute_command("DISCARD")
1595 async def watch(self, *names: KeyT):
1596 """Watches the values at keys ``names``"""
1597 if self.explicit_transaction:
1598 raise RedisError("Cannot issue a WATCH after a MULTI")
1599 return await self.execute_command("WATCH", *names)
1601 async def unwatch(self):
1602 """Unwatches all previously specified keys"""
1603 return self.watching and await self.execute_command("UNWATCH") or True