1import asyncio
2import copy
3import inspect
4import re
5import warnings
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 AsyncIterator,
10 Awaitable,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 MutableMapping,
17 Optional,
18 Protocol,
19 Set,
20 Tuple,
21 Type,
22 TypedDict,
23 TypeVar,
24 Union,
25 cast,
26)
27
28from redis._parsers.helpers import (
29 _RedisCallbacks,
30 _RedisCallbacksRESP2,
31 _RedisCallbacksRESP3,
32 bool_ok,
33)
34from redis.asyncio.connection import (
35 Connection,
36 ConnectionPool,
37 SSLConnection,
38 UnixDomainSocketConnection,
39)
40from redis.asyncio.lock import Lock
41from redis.asyncio.retry import Retry
42from redis.backoff import ExponentialWithJitterBackoff
43from redis.client import (
44 EMPTY_RESPONSE,
45 NEVER_DECODE,
46 AbstractRedis,
47 CaseInsensitiveDict,
48)
49from redis.commands import (
50 AsyncCoreCommands,
51 AsyncRedisModuleCommands,
52 AsyncSentinelCommands,
53 list_or_args,
54)
55from redis.credentials import CredentialProvider
56from redis.event import (
57 AfterPooledConnectionsInstantiationEvent,
58 AfterPubSubConnectionInstantiationEvent,
59 AfterSingleConnectionInstantiationEvent,
60 ClientType,
61 EventDispatcher,
62)
63from redis.exceptions import (
64 ConnectionError,
65 ExecAbortError,
66 PubSubError,
67 RedisError,
68 ResponseError,
69 WatchError,
70)
71from redis.typing import ChannelT, EncodableT, KeyT
72from redis.utils import (
73 SSL_AVAILABLE,
74 _set_info_logger,
75 deprecated_args,
76 deprecated_function,
77 get_lib_version,
78 safe_str,
79 str_if_bytes,
80 truncate_text,
81)
82
83if TYPE_CHECKING and SSL_AVAILABLE:
84 from ssl import TLSVersion, VerifyMode
85else:
86 TLSVersion = None
87 VerifyMode = None
88
89PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
90_KeyT = TypeVar("_KeyT", bound=KeyT)
91_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
92_RedisT = TypeVar("_RedisT", bound="Redis")
93_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
94if TYPE_CHECKING:
95 from redis.commands.core import Script
96
97
98class ResponseCallbackProtocol(Protocol):
99 def __call__(self, response: Any, **kwargs): ...
100
101
102class AsyncResponseCallbackProtocol(Protocol):
103 async def __call__(self, response: Any, **kwargs): ...
104
105
106ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
107
108
109class Redis(
110 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
111):
112 """
113 Implementation of the Redis protocol.
114
115 This abstract class provides a Python interface to all Redis commands
116 and an implementation of the Redis protocol.
117
118 Pipelines derive from this, implementing how
119 the commands are sent and received to the Redis server. Based on
120 configuration, an instance will either use a ConnectionPool, or
121 Connection object to talk to redis.
122 """
123
124 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
125
126 @classmethod
127 def from_url(
128 cls,
129 url: str,
130 single_connection_client: bool = False,
131 auto_close_connection_pool: Optional[bool] = None,
132 **kwargs,
133 ):
134 """
135 Return a Redis client object configured from the given URL
136
137 For example::
138
139 redis://[[username]:[password]]@localhost:6379/0
140 rediss://[[username]:[password]]@localhost:6379/0
141 unix://[username@]/path/to/socket.sock?db=0[&password=password]
142
143 Three URL schemes are supported:
144
145 - `redis://` creates a TCP socket connection. See more at:
146 <https://www.iana.org/assignments/uri-schemes/prov/redis>
147 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
148 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
149 - ``unix://``: creates a Unix Domain Socket connection.
150
151 The username, password, hostname, path and all querystring values
152 are passed through urllib.parse.unquote in order to replace any
153 percent-encoded values with their corresponding characters.
154
155 There are several ways to specify a database number. The first value
156 found will be used:
157
158 1. A ``db`` querystring option, e.g. redis://localhost?db=0
159
160 2. If using the redis:// or rediss:// schemes, the path argument
161 of the url, e.g. redis://localhost/0
162
163 3. A ``db`` keyword argument to this function.
164
165 If none of these options are specified, the default db=0 is used.
166
167 All querystring options are cast to their appropriate Python types.
168 Boolean arguments can be specified with string values "True"/"False"
169 or "Yes"/"No". Values that cannot be properly cast cause a
170 ``ValueError`` to be raised. Once parsed, the querystring arguments
171 and keyword arguments are passed to the ``ConnectionPool``'s
172 class initializer. In the case of conflicting arguments, querystring
173 arguments always win.
174
175 """
176 connection_pool = ConnectionPool.from_url(url, **kwargs)
177 client = cls(
178 connection_pool=connection_pool,
179 single_connection_client=single_connection_client,
180 )
181 if auto_close_connection_pool is not None:
182 warnings.warn(
183 DeprecationWarning(
184 '"auto_close_connection_pool" is deprecated '
185 "since version 5.0.1. "
186 "Please create a ConnectionPool explicitly and "
187 "provide to the Redis() constructor instead."
188 )
189 )
190 else:
191 auto_close_connection_pool = True
192 client.auto_close_connection_pool = auto_close_connection_pool
193 return client
194
195 @classmethod
196 def from_pool(
197 cls: Type["Redis"],
198 connection_pool: ConnectionPool,
199 ) -> "Redis":
200 """
201 Return a Redis client from the given connection pool.
202 The Redis client will take ownership of the connection pool and
203 close it when the Redis client is closed.
204 """
205 client = cls(
206 connection_pool=connection_pool,
207 )
208 client.auto_close_connection_pool = True
209 return client
210
211 @deprecated_args(
212 args_to_warn=["retry_on_timeout"],
213 reason="TimeoutError is included by default.",
214 version="6.0.0",
215 )
216 def __init__(
217 self,
218 *,
219 host: str = "localhost",
220 port: int = 6379,
221 db: Union[str, int] = 0,
222 password: Optional[str] = None,
223 socket_timeout: Optional[float] = None,
224 socket_connect_timeout: Optional[float] = None,
225 socket_keepalive: Optional[bool] = None,
226 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
227 connection_pool: Optional[ConnectionPool] = None,
228 unix_socket_path: Optional[str] = None,
229 encoding: str = "utf-8",
230 encoding_errors: str = "strict",
231 decode_responses: bool = False,
232 retry_on_timeout: bool = False,
233 retry: Retry = Retry(
234 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
235 ),
236 retry_on_error: Optional[list] = None,
237 ssl: bool = False,
238 ssl_keyfile: Optional[str] = None,
239 ssl_certfile: Optional[str] = None,
240 ssl_cert_reqs: Union[str, VerifyMode] = "required",
241 ssl_ca_certs: Optional[str] = None,
242 ssl_ca_data: Optional[str] = None,
243 ssl_check_hostname: bool = True,
244 ssl_min_version: Optional[TLSVersion] = None,
245 ssl_ciphers: Optional[str] = None,
246 max_connections: Optional[int] = None,
247 single_connection_client: bool = False,
248 health_check_interval: int = 0,
249 client_name: Optional[str] = None,
250 lib_name: Optional[str] = "redis-py",
251 lib_version: Optional[str] = get_lib_version(),
252 username: Optional[str] = None,
253 auto_close_connection_pool: Optional[bool] = None,
254 redis_connect_func=None,
255 credential_provider: Optional[CredentialProvider] = None,
256 protocol: Optional[int] = 2,
257 event_dispatcher: Optional[EventDispatcher] = None,
258 ):
259 """
260 Initialize a new Redis client.
261
262 To specify a retry policy for specific errors, you have two options:
263
264 1. Set the `retry_on_error` to a list of the error/s to retry on, and
265 you can also set `retry` to a valid `Retry` object(in case the default
266 one is not appropriate) - with this approach the retries will be triggered
267 on the default errors specified in the Retry object enriched with the
268 errors specified in `retry_on_error`.
269
270 2. Define a `Retry` object with configured 'supported_errors' and set
271 it to the `retry` parameter - with this approach you completely redefine
272 the errors on which retries will happen.
273
274 `retry_on_timeout` is deprecated - please include the TimeoutError
275 either in the Retry object or in the `retry_on_error` list.
276
277 When 'connection_pool' is provided - the retry configuration of the
278 provided pool will be used.
279 """
280 kwargs: Dict[str, Any]
281 if event_dispatcher is None:
282 self._event_dispatcher = EventDispatcher()
283 else:
284 self._event_dispatcher = event_dispatcher
285 # auto_close_connection_pool only has an effect if connection_pool is
286 # None. It is assumed that if connection_pool is not None, the user
287 # wants to manage the connection pool themselves.
288 if auto_close_connection_pool is not None:
289 warnings.warn(
290 DeprecationWarning(
291 '"auto_close_connection_pool" is deprecated '
292 "since version 5.0.1. "
293 "Please create a ConnectionPool explicitly and "
294 "provide to the Redis() constructor instead."
295 )
296 )
297 else:
298 auto_close_connection_pool = True
299
300 if not connection_pool:
301 # Create internal connection pool, expected to be closed by Redis instance
302 if not retry_on_error:
303 retry_on_error = []
304 kwargs = {
305 "db": db,
306 "username": username,
307 "password": password,
308 "credential_provider": credential_provider,
309 "socket_timeout": socket_timeout,
310 "encoding": encoding,
311 "encoding_errors": encoding_errors,
312 "decode_responses": decode_responses,
313 "retry_on_error": retry_on_error,
314 "retry": copy.deepcopy(retry),
315 "max_connections": max_connections,
316 "health_check_interval": health_check_interval,
317 "client_name": client_name,
318 "lib_name": lib_name,
319 "lib_version": lib_version,
320 "redis_connect_func": redis_connect_func,
321 "protocol": protocol,
322 }
323 # based on input, setup appropriate connection args
324 if unix_socket_path is not None:
325 kwargs.update(
326 {
327 "path": unix_socket_path,
328 "connection_class": UnixDomainSocketConnection,
329 }
330 )
331 else:
332 # TCP specific options
333 kwargs.update(
334 {
335 "host": host,
336 "port": port,
337 "socket_connect_timeout": socket_connect_timeout,
338 "socket_keepalive": socket_keepalive,
339 "socket_keepalive_options": socket_keepalive_options,
340 }
341 )
342
343 if ssl:
344 kwargs.update(
345 {
346 "connection_class": SSLConnection,
347 "ssl_keyfile": ssl_keyfile,
348 "ssl_certfile": ssl_certfile,
349 "ssl_cert_reqs": ssl_cert_reqs,
350 "ssl_ca_certs": ssl_ca_certs,
351 "ssl_ca_data": ssl_ca_data,
352 "ssl_check_hostname": ssl_check_hostname,
353 "ssl_min_version": ssl_min_version,
354 "ssl_ciphers": ssl_ciphers,
355 }
356 )
357 # This arg only used if no pool is passed in
358 self.auto_close_connection_pool = auto_close_connection_pool
359 connection_pool = ConnectionPool(**kwargs)
360 self._event_dispatcher.dispatch(
361 AfterPooledConnectionsInstantiationEvent(
362 [connection_pool], ClientType.ASYNC, credential_provider
363 )
364 )
365 else:
366 # If a pool is passed in, do not close it
367 self.auto_close_connection_pool = False
368 self._event_dispatcher.dispatch(
369 AfterPooledConnectionsInstantiationEvent(
370 [connection_pool], ClientType.ASYNC, credential_provider
371 )
372 )
373
374 self.connection_pool = connection_pool
375 self.single_connection_client = single_connection_client
376 self.connection: Optional[Connection] = None
377
378 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
379
380 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
381 self.response_callbacks.update(_RedisCallbacksRESP3)
382 else:
383 self.response_callbacks.update(_RedisCallbacksRESP2)
384
385 # If using a single connection client, we need to lock creation-of and use-of
386 # the client in order to avoid race conditions such as using asyncio.gather
387 # on a set of redis commands
388 self._single_conn_lock = asyncio.Lock()
389
390 def __repr__(self):
391 return (
392 f"<{self.__class__.__module__}.{self.__class__.__name__}"
393 f"({self.connection_pool!r})>"
394 )
395
396 def __await__(self):
397 return self.initialize().__await__()
398
399 async def initialize(self: _RedisT) -> _RedisT:
400 if self.single_connection_client:
401 async with self._single_conn_lock:
402 if self.connection is None:
403 self.connection = await self.connection_pool.get_connection()
404
405 self._event_dispatcher.dispatch(
406 AfterSingleConnectionInstantiationEvent(
407 self.connection, ClientType.ASYNC, self._single_conn_lock
408 )
409 )
410 return self
411
412 def set_response_callback(self, command: str, callback: ResponseCallbackT):
413 """Set a custom Response Callback"""
414 self.response_callbacks[command] = callback
415
416 def get_encoder(self):
417 """Get the connection pool's encoder"""
418 return self.connection_pool.get_encoder()
419
420 def get_connection_kwargs(self):
421 """Get the connection's key-word arguments"""
422 return self.connection_pool.connection_kwargs
423
424 def get_retry(self) -> Optional[Retry]:
425 return self.get_connection_kwargs().get("retry")
426
427 def set_retry(self, retry: Retry) -> None:
428 self.get_connection_kwargs().update({"retry": retry})
429 self.connection_pool.set_retry(retry)
430
431 def load_external_module(self, funcname, func):
432 """
433 This function can be used to add externally defined redis modules,
434 and their namespaces to the redis client.
435
436 funcname - A string containing the name of the function to create
437 func - The function, being added to this class.
438
439 ex: Assume that one has a custom redis module named foomod that
440 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
441 To load function functions into this namespace:
442
443 from redis import Redis
444 from foomodule import F
445 r = Redis()
446 r.load_external_module("foo", F)
447 r.foo().dothing('your', 'arguments')
448
449 For a concrete example see the reimport of the redisjson module in
450 tests/test_connection.py::test_loading_external_modules
451 """
452 setattr(self, funcname, func)
453
454 def pipeline(
455 self, transaction: bool = True, shard_hint: Optional[str] = None
456 ) -> "Pipeline":
457 """
458 Return a new pipeline object that can queue multiple commands for
459 later execution. ``transaction`` indicates whether all commands
460 should be executed atomically. Apart from making a group of operations
461 atomic, pipelines are useful for reducing the back-and-forth overhead
462 between the client and server.
463 """
464 return Pipeline(
465 self.connection_pool, self.response_callbacks, transaction, shard_hint
466 )
467
468 async def transaction(
469 self,
470 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
471 *watches: KeyT,
472 shard_hint: Optional[str] = None,
473 value_from_callable: bool = False,
474 watch_delay: Optional[float] = None,
475 ):
476 """
477 Convenience method for executing the callable `func` as a transaction
478 while watching all keys specified in `watches`. The 'func' callable
479 should expect a single argument which is a Pipeline object.
480 """
481 pipe: Pipeline
482 async with self.pipeline(True, shard_hint) as pipe:
483 while True:
484 try:
485 if watches:
486 await pipe.watch(*watches)
487 func_value = func(pipe)
488 if inspect.isawaitable(func_value):
489 func_value = await func_value
490 exec_value = await pipe.execute()
491 return func_value if value_from_callable else exec_value
492 except WatchError:
493 if watch_delay is not None and watch_delay > 0:
494 await asyncio.sleep(watch_delay)
495 continue
496
497 def lock(
498 self,
499 name: KeyT,
500 timeout: Optional[float] = None,
501 sleep: float = 0.1,
502 blocking: bool = True,
503 blocking_timeout: Optional[float] = None,
504 lock_class: Optional[Type[Lock]] = None,
505 thread_local: bool = True,
506 raise_on_release_error: bool = True,
507 ) -> Lock:
508 """
509 Return a new Lock object using key ``name`` that mimics
510 the behavior of threading.Lock.
511
512 If specified, ``timeout`` indicates a maximum life for the lock.
513 By default, it will remain locked until release() is called.
514
515 ``sleep`` indicates the amount of time to sleep per loop iteration
516 when the lock is in blocking mode and another client is currently
517 holding the lock.
518
519 ``blocking`` indicates whether calling ``acquire`` should block until
520 the lock has been acquired or to fail immediately, causing ``acquire``
521 to return False and the lock not being acquired. Defaults to True.
522 Note this value can be overridden by passing a ``blocking``
523 argument to ``acquire``.
524
525 ``blocking_timeout`` indicates the maximum amount of time in seconds to
526 spend trying to acquire the lock. A value of ``None`` indicates
527 continue trying forever. ``blocking_timeout`` can be specified as a
528 float or integer, both representing the number of seconds to wait.
529
530 ``lock_class`` forces the specified lock implementation. Note that as
531 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
532 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
533 you have created your own custom lock class.
534
535 ``thread_local`` indicates whether the lock token is placed in
536 thread-local storage. By default, the token is placed in thread local
537 storage so that a thread only sees its token, not a token set by
538 another thread. Consider the following timeline:
539
540 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
541 thread-1 sets the token to "abc"
542 time: 1, thread-2 blocks trying to acquire `my-lock` using the
543 Lock instance.
544 time: 5, thread-1 has not yet completed. redis expires the lock
545 key.
546 time: 5, thread-2 acquired `my-lock` now that it's available.
547 thread-2 sets the token to "xyz"
548 time: 6, thread-1 finishes its work and calls release(). if the
549 token is *not* stored in thread local storage, then
550 thread-1 would see the token value as "xyz" and would be
551 able to successfully release the thread-2's lock.
552
553 ``raise_on_release_error`` indicates whether to raise an exception when
554 the lock is no longer owned when exiting the context manager. By default,
555 this is True, meaning an exception will be raised. If False, the warning
556 will be logged and the exception will be suppressed.
557
558 In some use cases it's necessary to disable thread local storage. For
559 example, if you have code where one thread acquires a lock and passes
560 that lock instance to a worker thread to release later. If thread
561 local storage isn't disabled in this case, the worker thread won't see
562 the token set by the thread that acquired the lock. Our assumption
563 is that these cases aren't common and as such default to using
564 thread local storage."""
565 if lock_class is None:
566 lock_class = Lock
567 return lock_class(
568 self,
569 name,
570 timeout=timeout,
571 sleep=sleep,
572 blocking=blocking,
573 blocking_timeout=blocking_timeout,
574 thread_local=thread_local,
575 raise_on_release_error=raise_on_release_error,
576 )
577
578 def pubsub(self, **kwargs) -> "PubSub":
579 """
580 Return a Publish/Subscribe object. With this object, you can
581 subscribe to channels and listen for messages that get published to
582 them.
583 """
584 return PubSub(
585 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
586 )
587
588 def monitor(self) -> "Monitor":
589 return Monitor(self.connection_pool)
590
591 def client(self) -> "Redis":
592 return self.__class__(
593 connection_pool=self.connection_pool, single_connection_client=True
594 )
595
596 async def __aenter__(self: _RedisT) -> _RedisT:
597 return await self.initialize()
598
599 async def __aexit__(self, exc_type, exc_value, traceback):
600 await self.aclose()
601
602 _DEL_MESSAGE = "Unclosed Redis client"
603
604 # passing _warnings and _grl as argument default since they may be gone
605 # by the time __del__ is called at shutdown
606 def __del__(
607 self,
608 _warn: Any = warnings.warn,
609 _grl: Any = asyncio.get_running_loop,
610 ) -> None:
611 if hasattr(self, "connection") and (self.connection is not None):
612 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
613 try:
614 context = {"client": self, "message": self._DEL_MESSAGE}
615 _grl().call_exception_handler(context)
616 except RuntimeError:
617 pass
618 self.connection._close()
619
620 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
621 """
622 Closes Redis client connection
623
624 Args:
625 close_connection_pool:
626 decides whether to close the connection pool used by this Redis client,
627 overriding Redis.auto_close_connection_pool.
628 By default, let Redis.auto_close_connection_pool decide
629 whether to close the connection pool.
630 """
631 conn = self.connection
632 if conn:
633 self.connection = None
634 await self.connection_pool.release(conn)
635 if close_connection_pool or (
636 close_connection_pool is None and self.auto_close_connection_pool
637 ):
638 await self.connection_pool.disconnect()
639
640 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
641 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
642 """
643 Alias for aclose(), for backwards compatibility
644 """
645 await self.aclose(close_connection_pool)
646
647 async def _send_command_parse_response(self, conn, command_name, *args, **options):
648 """
649 Send a command and parse the response
650 """
651 await conn.send_command(*args)
652 return await self.parse_response(conn, command_name, **options)
653
654 async def _close_connection(self, conn: Connection):
655 """
656 Close the connection before retrying.
657
658 The supported exceptions are already checked in the
659 retry object so we don't need to do it here.
660
661 After we disconnect the connection, it will try to reconnect and
662 do a health check as part of the send_command logic(on connection level).
663 """
664 await conn.disconnect()
665
666 # COMMAND EXECUTION AND PROTOCOL PARSING
667 async def execute_command(self, *args, **options):
668 """Execute a command and return a parsed response"""
669 await self.initialize()
670 pool = self.connection_pool
671 command_name = args[0]
672 conn = self.connection or await pool.get_connection()
673
674 if self.single_connection_client:
675 await self._single_conn_lock.acquire()
676 try:
677 return await conn.retry.call_with_retry(
678 lambda: self._send_command_parse_response(
679 conn, command_name, *args, **options
680 ),
681 lambda _: self._close_connection(conn),
682 )
683 finally:
684 if self.single_connection_client:
685 self._single_conn_lock.release()
686 if not self.connection:
687 await pool.release(conn)
688
689 async def parse_response(
690 self, connection: Connection, command_name: Union[str, bytes], **options
691 ):
692 """Parses a response from the Redis server"""
693 try:
694 if NEVER_DECODE in options:
695 response = await connection.read_response(disable_decoding=True)
696 options.pop(NEVER_DECODE)
697 else:
698 response = await connection.read_response()
699 except ResponseError:
700 if EMPTY_RESPONSE in options:
701 return options[EMPTY_RESPONSE]
702 raise
703
704 if EMPTY_RESPONSE in options:
705 options.pop(EMPTY_RESPONSE)
706
707 # Remove keys entry, it needs only for cache.
708 options.pop("keys", None)
709
710 if command_name in self.response_callbacks:
711 # Mypy bug: https://github.com/python/mypy/issues/10977
712 command_name = cast(str, command_name)
713 retval = self.response_callbacks[command_name](response, **options)
714 return await retval if inspect.isawaitable(retval) else retval
715 return response
716
717
718StrictRedis = Redis
719
720
721class MonitorCommandInfo(TypedDict):
722 time: float
723 db: int
724 client_address: str
725 client_port: str
726 client_type: str
727 command: str
728
729
730class Monitor:
731 """
732 Monitor is useful for handling the MONITOR command to the redis server.
733 next_command() method returns one command from monitor
734 listen() method yields commands from monitor.
735 """
736
737 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
738 command_re = re.compile(r'"(.*?)(?<!\\)"')
739
740 def __init__(self, connection_pool: ConnectionPool):
741 self.connection_pool = connection_pool
742 self.connection: Optional[Connection] = None
743
744 async def connect(self):
745 if self.connection is None:
746 self.connection = await self.connection_pool.get_connection()
747
748 async def __aenter__(self):
749 await self.connect()
750 await self.connection.send_command("MONITOR")
751 # check that monitor returns 'OK', but don't return it to user
752 response = await self.connection.read_response()
753 if not bool_ok(response):
754 raise RedisError(f"MONITOR failed: {response}")
755 return self
756
757 async def __aexit__(self, *args):
758 await self.connection.disconnect()
759 await self.connection_pool.release(self.connection)
760
761 async def next_command(self) -> MonitorCommandInfo:
762 """Parse the response from a monitor command"""
763 await self.connect()
764 response = await self.connection.read_response()
765 if isinstance(response, bytes):
766 response = self.connection.encoder.decode(response, force=True)
767 command_time, command_data = response.split(" ", 1)
768 m = self.monitor_re.match(command_data)
769 db_id, client_info, command = m.groups()
770 command = " ".join(self.command_re.findall(command))
771 # Redis escapes double quotes because each piece of the command
772 # string is surrounded by double quotes. We don't have that
773 # requirement so remove the escaping and leave the quote.
774 command = command.replace('\\"', '"')
775
776 if client_info == "lua":
777 client_address = "lua"
778 client_port = ""
779 client_type = "lua"
780 elif client_info.startswith("unix"):
781 client_address = "unix"
782 client_port = client_info[5:]
783 client_type = "unix"
784 else:
785 # use rsplit as ipv6 addresses contain colons
786 client_address, client_port = client_info.rsplit(":", 1)
787 client_type = "tcp"
788 return {
789 "time": float(command_time),
790 "db": int(db_id),
791 "client_address": client_address,
792 "client_port": client_port,
793 "client_type": client_type,
794 "command": command,
795 }
796
797 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
798 """Listen for commands coming to the server."""
799 while True:
800 yield await self.next_command()
801
802
803class PubSub:
804 """
805 PubSub provides publish, subscribe and listen support to Redis channels.
806
807 After subscribing to one or more channels, the listen() method will block
808 until a message arrives on one of the subscribed channels. That message
809 will be returned and it's safe to start listening again.
810 """
811
812 PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
813 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
814 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
815
816 def __init__(
817 self,
818 connection_pool: ConnectionPool,
819 shard_hint: Optional[str] = None,
820 ignore_subscribe_messages: bool = False,
821 encoder=None,
822 push_handler_func: Optional[Callable] = None,
823 event_dispatcher: Optional["EventDispatcher"] = None,
824 ):
825 if event_dispatcher is None:
826 self._event_dispatcher = EventDispatcher()
827 else:
828 self._event_dispatcher = event_dispatcher
829 self.connection_pool = connection_pool
830 self.shard_hint = shard_hint
831 self.ignore_subscribe_messages = ignore_subscribe_messages
832 self.connection = None
833 # we need to know the encoding options for this connection in order
834 # to lookup channel and pattern names for callback handlers.
835 self.encoder = encoder
836 self.push_handler_func = push_handler_func
837 if self.encoder is None:
838 self.encoder = self.connection_pool.get_encoder()
839 if self.encoder.decode_responses:
840 self.health_check_response = [
841 ["pong", self.HEALTH_CHECK_MESSAGE],
842 self.HEALTH_CHECK_MESSAGE,
843 ]
844 else:
845 self.health_check_response = [
846 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
847 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
848 ]
849 if self.push_handler_func is None:
850 _set_info_logger()
851 self.channels = {}
852 self.pending_unsubscribe_channels = set()
853 self.patterns = {}
854 self.pending_unsubscribe_patterns = set()
855 self._lock = asyncio.Lock()
856
857 async def __aenter__(self):
858 return self
859
860 async def __aexit__(self, exc_type, exc_value, traceback):
861 await self.aclose()
862
863 def __del__(self):
864 if self.connection:
865 self.connection.deregister_connect_callback(self.on_connect)
866
867 async def aclose(self):
868 # In case a connection property does not yet exist
869 # (due to a crash earlier in the Redis() constructor), return
870 # immediately as there is nothing to clean-up.
871 if not hasattr(self, "connection"):
872 return
873 async with self._lock:
874 if self.connection:
875 await self.connection.disconnect()
876 self.connection.deregister_connect_callback(self.on_connect)
877 await self.connection_pool.release(self.connection)
878 self.connection = None
879 self.channels = {}
880 self.pending_unsubscribe_channels = set()
881 self.patterns = {}
882 self.pending_unsubscribe_patterns = set()
883
884 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
885 async def close(self) -> None:
886 """Alias for aclose(), for backwards compatibility"""
887 await self.aclose()
888
889 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
890 async def reset(self) -> None:
891 """Alias for aclose(), for backwards compatibility"""
892 await self.aclose()
893
894 async def on_connect(self, connection: Connection):
895 """Re-subscribe to any channels and patterns previously subscribed to"""
896 # NOTE: for python3, we can't pass bytestrings as keyword arguments
897 # so we need to decode channel/pattern names back to unicode strings
898 # before passing them to [p]subscribe.
899 self.pending_unsubscribe_channels.clear()
900 self.pending_unsubscribe_patterns.clear()
901 if self.channels:
902 channels = {}
903 for k, v in self.channels.items():
904 channels[self.encoder.decode(k, force=True)] = v
905 await self.subscribe(**channels)
906 if self.patterns:
907 patterns = {}
908 for k, v in self.patterns.items():
909 patterns[self.encoder.decode(k, force=True)] = v
910 await self.psubscribe(**patterns)
911
912 @property
913 def subscribed(self):
914 """Indicates if there are subscriptions to any channels or patterns"""
915 return bool(self.channels or self.patterns)
916
917 async def execute_command(self, *args: EncodableT):
918 """Execute a publish/subscribe command"""
919
920 # NOTE: don't parse the response in this function -- it could pull a
921 # legitimate message off the stack if the connection is already
922 # subscribed to one or more channels
923
924 await self.connect()
925 connection = self.connection
926 kwargs = {"check_health": not self.subscribed}
927 await self._execute(connection, connection.send_command, *args, **kwargs)
928
929 async def connect(self):
930 """
931 Ensure that the PubSub is connected
932 """
933 if self.connection is None:
934 self.connection = await self.connection_pool.get_connection()
935 # register a callback that re-subscribes to any channels we
936 # were listening to when we were disconnected
937 self.connection.register_connect_callback(self.on_connect)
938 else:
939 await self.connection.connect()
940 if self.push_handler_func is not None:
941 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
942
943 self._event_dispatcher.dispatch(
944 AfterPubSubConnectionInstantiationEvent(
945 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
946 )
947 )
948
949 async def _reconnect(self, conn):
950 """
951 Try to reconnect
952 """
953 await conn.disconnect()
954 await conn.connect()
955
956 async def _execute(self, conn, command, *args, **kwargs):
957 """
958 Connect manually upon disconnection. If the Redis server is down,
959 this will fail and raise a ConnectionError as desired.
960 After reconnection, the ``on_connect`` callback should have been
961 called by the # connection to resubscribe us to any channels and
962 patterns we were previously listening to
963 """
964 return await conn.retry.call_with_retry(
965 lambda: command(*args, **kwargs),
966 lambda _: self._reconnect(conn),
967 )
968
969 async def parse_response(self, block: bool = True, timeout: float = 0):
970 """Parse the response from a publish/subscribe command"""
971 conn = self.connection
972 if conn is None:
973 raise RuntimeError(
974 "pubsub connection not set: "
975 "did you forget to call subscribe() or psubscribe()?"
976 )
977
978 await self.check_health()
979
980 if not conn.is_connected:
981 await conn.connect()
982
983 read_timeout = None if block else timeout
984 response = await self._execute(
985 conn,
986 conn.read_response,
987 timeout=read_timeout,
988 disconnect_on_error=False,
989 push_request=True,
990 )
991
992 if conn.health_check_interval and response in self.health_check_response:
993 # ignore the health check message as user might not expect it
994 return None
995 return response
996
997 async def check_health(self):
998 conn = self.connection
999 if conn is None:
1000 raise RuntimeError(
1001 "pubsub connection not set: "
1002 "did you forget to call subscribe() or psubscribe()?"
1003 )
1004
1005 if (
1006 conn.health_check_interval
1007 and asyncio.get_running_loop().time() > conn.next_health_check
1008 ):
1009 await conn.send_command(
1010 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1011 )
1012
1013 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1014 """
1015 normalize channel/pattern names to be either bytes or strings
1016 based on whether responses are automatically decoded. this saves us
1017 from coercing the value for each message coming in.
1018 """
1019 encode = self.encoder.encode
1020 decode = self.encoder.decode
1021 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1022
1023 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1024 """
1025 Subscribe to channel patterns. Patterns supplied as keyword arguments
1026 expect a pattern name as the key and a callable as the value. A
1027 pattern's callable will be invoked automatically when a message is
1028 received on that pattern rather than producing a message via
1029 ``listen()``.
1030 """
1031 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1032 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1033 # Mypy bug: https://github.com/python/mypy/issues/10970
1034 new_patterns.update(kwargs) # type: ignore[arg-type]
1035 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1036 # update the patterns dict AFTER we send the command. we don't want to
1037 # subscribe twice to these patterns, once for the command and again
1038 # for the reconnection.
1039 new_patterns = self._normalize_keys(new_patterns)
1040 self.patterns.update(new_patterns)
1041 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1042 return ret_val
1043
1044 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1045 """
1046 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1047 all patterns.
1048 """
1049 patterns: Iterable[ChannelT]
1050 if args:
1051 parsed_args = list_or_args((args[0],), args[1:])
1052 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1053 else:
1054 parsed_args = []
1055 patterns = self.patterns
1056 self.pending_unsubscribe_patterns.update(patterns)
1057 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1058
1059 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1060 """
1061 Subscribe to channels. Channels supplied as keyword arguments expect
1062 a channel name as the key and a callable as the value. A channel's
1063 callable will be invoked automatically when a message is received on
1064 that channel rather than producing a message via ``listen()`` or
1065 ``get_message()``.
1066 """
1067 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1068 new_channels = dict.fromkeys(parsed_args)
1069 # Mypy bug: https://github.com/python/mypy/issues/10970
1070 new_channels.update(kwargs) # type: ignore[arg-type]
1071 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1072 # update the channels dict AFTER we send the command. we don't want to
1073 # subscribe twice to these channels, once for the command and again
1074 # for the reconnection.
1075 new_channels = self._normalize_keys(new_channels)
1076 self.channels.update(new_channels)
1077 self.pending_unsubscribe_channels.difference_update(new_channels)
1078 return ret_val
1079
1080 def unsubscribe(self, *args) -> Awaitable:
1081 """
1082 Unsubscribe from the supplied channels. If empty, unsubscribe from
1083 all channels
1084 """
1085 if args:
1086 parsed_args = list_or_args(args[0], args[1:])
1087 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1088 else:
1089 parsed_args = []
1090 channels = self.channels
1091 self.pending_unsubscribe_channels.update(channels)
1092 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1093
1094 async def listen(self) -> AsyncIterator:
1095 """Listen for messages on channels this client has been subscribed to"""
1096 while self.subscribed:
1097 response = await self.handle_message(await self.parse_response(block=True))
1098 if response is not None:
1099 yield response
1100
1101 async def get_message(
1102 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1103 ):
1104 """
1105 Get the next message if one is available, otherwise None.
1106
1107 If timeout is specified, the system will wait for `timeout` seconds
1108 before returning. Timeout should be specified as a floating point
1109 number or None to wait indefinitely.
1110 """
1111 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1112 if response:
1113 return await self.handle_message(response, ignore_subscribe_messages)
1114 return None
1115
1116 def ping(self, message=None) -> Awaitable:
1117 """
1118 Ping the Redis server
1119 """
1120 args = ["PING", message] if message is not None else ["PING"]
1121 return self.execute_command(*args)
1122
1123 async def handle_message(self, response, ignore_subscribe_messages=False):
1124 """
1125 Parses a pub/sub message. If the channel or pattern was subscribed to
1126 with a message handler, the handler is invoked instead of a parsed
1127 message being returned.
1128 """
1129 if response is None:
1130 return None
1131 if isinstance(response, bytes):
1132 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1133 message_type = str_if_bytes(response[0])
1134 if message_type == "pmessage":
1135 message = {
1136 "type": message_type,
1137 "pattern": response[1],
1138 "channel": response[2],
1139 "data": response[3],
1140 }
1141 elif message_type == "pong":
1142 message = {
1143 "type": message_type,
1144 "pattern": None,
1145 "channel": None,
1146 "data": response[1],
1147 }
1148 else:
1149 message = {
1150 "type": message_type,
1151 "pattern": None,
1152 "channel": response[1],
1153 "data": response[2],
1154 }
1155
1156 # if this is an unsubscribe message, remove it from memory
1157 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1158 if message_type == "punsubscribe":
1159 pattern = response[1]
1160 if pattern in self.pending_unsubscribe_patterns:
1161 self.pending_unsubscribe_patterns.remove(pattern)
1162 self.patterns.pop(pattern, None)
1163 else:
1164 channel = response[1]
1165 if channel in self.pending_unsubscribe_channels:
1166 self.pending_unsubscribe_channels.remove(channel)
1167 self.channels.pop(channel, None)
1168
1169 if message_type in self.PUBLISH_MESSAGE_TYPES:
1170 # if there's a message handler, invoke it
1171 if message_type == "pmessage":
1172 handler = self.patterns.get(message["pattern"], None)
1173 else:
1174 handler = self.channels.get(message["channel"], None)
1175 if handler:
1176 if inspect.iscoroutinefunction(handler):
1177 await handler(message)
1178 else:
1179 handler(message)
1180 return None
1181 elif message_type != "pong":
1182 # this is a subscribe/unsubscribe message. ignore if we don't
1183 # want them
1184 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1185 return None
1186
1187 return message
1188
1189 async def run(
1190 self,
1191 *,
1192 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1193 poll_timeout: float = 1.0,
1194 ) -> None:
1195 """Process pub/sub messages using registered callbacks.
1196
1197 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1198 redis-py, but it is a coroutine. To launch it as a separate task, use
1199 ``asyncio.create_task``:
1200
1201 >>> task = asyncio.create_task(pubsub.run())
1202
1203 To shut it down, use asyncio cancellation:
1204
1205 >>> task.cancel()
1206 >>> await task
1207 """
1208 for channel, handler in self.channels.items():
1209 if handler is None:
1210 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1211 for pattern, handler in self.patterns.items():
1212 if handler is None:
1213 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1214
1215 await self.connect()
1216 while True:
1217 try:
1218 await self.get_message(
1219 ignore_subscribe_messages=True, timeout=poll_timeout
1220 )
1221 except asyncio.CancelledError:
1222 raise
1223 except BaseException as e:
1224 if exception_handler is None:
1225 raise
1226 res = exception_handler(e, self)
1227 if inspect.isawaitable(res):
1228 await res
1229 # Ensure that other tasks on the event loop get a chance to run
1230 # if we didn't have to block for I/O anywhere.
1231 await asyncio.sleep(0)
1232
1233
1234class PubsubWorkerExceptionHandler(Protocol):
1235 def __call__(self, e: BaseException, pubsub: PubSub): ...
1236
1237
1238class AsyncPubsubWorkerExceptionHandler(Protocol):
1239 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1240
1241
1242PSWorkerThreadExcHandlerT = Union[
1243 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1244]
1245
1246
1247CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1248CommandStackT = List[CommandT]
1249
1250
1251class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1252 """
1253 Pipelines provide a way to transmit multiple commands to the Redis server
1254 in one transmission. This is convenient for batch processing, such as
1255 saving all the values in a list to Redis.
1256
1257 All commands executed within a pipeline(when running in transactional mode,
1258 which is the default behavior) are wrapped with MULTI and EXEC
1259 calls. This guarantees all commands executed in the pipeline will be
1260 executed atomically.
1261
1262 Any command raising an exception does *not* halt the execution of
1263 subsequent commands in the pipeline. Instead, the exception is caught
1264 and its instance is placed into the response list returned by execute().
1265 Code iterating over the response list should be able to deal with an
1266 instance of an exception as a potential value. In general, these will be
1267 ResponseError exceptions, such as those raised when issuing a command
1268 on a key of a different datatype.
1269 """
1270
1271 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1272
1273 def __init__(
1274 self,
1275 connection_pool: ConnectionPool,
1276 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1277 transaction: bool,
1278 shard_hint: Optional[str],
1279 ):
1280 self.connection_pool = connection_pool
1281 self.connection = None
1282 self.response_callbacks = response_callbacks
1283 self.is_transaction = transaction
1284 self.shard_hint = shard_hint
1285 self.watching = False
1286 self.command_stack: CommandStackT = []
1287 self.scripts: Set[Script] = set()
1288 self.explicit_transaction = False
1289
1290 async def __aenter__(self: _RedisT) -> _RedisT:
1291 return self
1292
1293 async def __aexit__(self, exc_type, exc_value, traceback):
1294 await self.reset()
1295
1296 def __await__(self):
1297 return self._async_self().__await__()
1298
1299 _DEL_MESSAGE = "Unclosed Pipeline client"
1300
1301 def __len__(self):
1302 return len(self.command_stack)
1303
1304 def __bool__(self):
1305 """Pipeline instances should always evaluate to True"""
1306 return True
1307
1308 async def _async_self(self):
1309 return self
1310
1311 async def reset(self):
1312 self.command_stack = []
1313 self.scripts = set()
1314 # make sure to reset the connection state in the event that we were
1315 # watching something
1316 if self.watching and self.connection:
1317 try:
1318 # call this manually since our unwatch or
1319 # immediate_execute_command methods can call reset()
1320 await self.connection.send_command("UNWATCH")
1321 await self.connection.read_response()
1322 except ConnectionError:
1323 # disconnect will also remove any previous WATCHes
1324 if self.connection:
1325 await self.connection.disconnect()
1326 # clean up the other instance attributes
1327 self.watching = False
1328 self.explicit_transaction = False
1329 # we can safely return the connection to the pool here since we're
1330 # sure we're no longer WATCHing anything
1331 if self.connection:
1332 await self.connection_pool.release(self.connection)
1333 self.connection = None
1334
1335 async def aclose(self) -> None:
1336 """Alias for reset(), a standard method name for cleanup"""
1337 await self.reset()
1338
1339 def multi(self):
1340 """
1341 Start a transactional block of the pipeline after WATCH commands
1342 are issued. End the transactional block with `execute`.
1343 """
1344 if self.explicit_transaction:
1345 raise RedisError("Cannot issue nested calls to MULTI")
1346 if self.command_stack:
1347 raise RedisError(
1348 "Commands without an initial WATCH have already been issued"
1349 )
1350 self.explicit_transaction = True
1351
1352 def execute_command(
1353 self, *args, **kwargs
1354 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1355 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1356 return self.immediate_execute_command(*args, **kwargs)
1357 return self.pipeline_execute_command(*args, **kwargs)
1358
1359 async def _disconnect_reset_raise_on_watching(
1360 self,
1361 conn: Connection,
1362 error: Exception,
1363 ):
1364 """
1365 Close the connection reset watching state and
1366 raise an exception if we were watching.
1367
1368 The supported exceptions are already checked in the
1369 retry object so we don't need to do it here.
1370
1371 After we disconnect the connection, it will try to reconnect and
1372 do a health check as part of the send_command logic(on connection level).
1373 """
1374 await conn.disconnect()
1375 # if we were already watching a variable, the watch is no longer
1376 # valid since this connection has died. raise a WatchError, which
1377 # indicates the user should retry this transaction.
1378 if self.watching:
1379 await self.reset()
1380 raise WatchError(
1381 f"A {type(error).__name__} occurred while watching one or more keys"
1382 )
1383
1384 async def immediate_execute_command(self, *args, **options):
1385 """
1386 Execute a command immediately, but don't auto-retry on the supported
1387 errors for retry if we're already WATCHing a variable.
1388 Used when issuing WATCH or subsequent commands retrieving their values but before
1389 MULTI is called.
1390 """
1391 command_name = args[0]
1392 conn = self.connection
1393 # if this is the first call, we need a connection
1394 if not conn:
1395 conn = await self.connection_pool.get_connection()
1396 self.connection = conn
1397
1398 return await conn.retry.call_with_retry(
1399 lambda: self._send_command_parse_response(
1400 conn, command_name, *args, **options
1401 ),
1402 lambda error: self._disconnect_reset_raise_on_watching(conn, error),
1403 )
1404
1405 def pipeline_execute_command(self, *args, **options):
1406 """
1407 Stage a command to be executed when execute() is next called
1408
1409 Returns the current Pipeline object back so commands can be
1410 chained together, such as:
1411
1412 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1413
1414 At some other point, you can then run: pipe.execute(),
1415 which will execute all commands queued in the pipe.
1416 """
1417 self.command_stack.append((args, options))
1418 return self
1419
1420 async def _execute_transaction( # noqa: C901
1421 self, connection: Connection, commands: CommandStackT, raise_on_error
1422 ):
1423 pre: CommandT = (("MULTI",), {})
1424 post: CommandT = (("EXEC",), {})
1425 cmds = (pre, *commands, post)
1426 all_cmds = connection.pack_commands(
1427 args for args, options in cmds if EMPTY_RESPONSE not in options
1428 )
1429 await connection.send_packed_command(all_cmds)
1430 errors = []
1431
1432 # parse off the response for MULTI
1433 # NOTE: we need to handle ResponseErrors here and continue
1434 # so that we read all the additional command messages from
1435 # the socket
1436 try:
1437 await self.parse_response(connection, "_")
1438 except ResponseError as err:
1439 errors.append((0, err))
1440
1441 # and all the other commands
1442 for i, command in enumerate(commands):
1443 if EMPTY_RESPONSE in command[1]:
1444 errors.append((i, command[1][EMPTY_RESPONSE]))
1445 else:
1446 try:
1447 await self.parse_response(connection, "_")
1448 except ResponseError as err:
1449 self.annotate_exception(err, i + 1, command[0])
1450 errors.append((i, err))
1451
1452 # parse the EXEC.
1453 try:
1454 response = await self.parse_response(connection, "_")
1455 except ExecAbortError as err:
1456 if errors:
1457 raise errors[0][1] from err
1458 raise
1459
1460 # EXEC clears any watched keys
1461 self.watching = False
1462
1463 if response is None:
1464 raise WatchError("Watched variable changed.") from None
1465
1466 # put any parse errors into the response
1467 for i, e in errors:
1468 response.insert(i, e)
1469
1470 if len(response) != len(commands):
1471 if self.connection:
1472 await self.connection.disconnect()
1473 raise ResponseError(
1474 "Wrong number of response items from pipeline execution"
1475 ) from None
1476
1477 # find any errors in the response and raise if necessary
1478 if raise_on_error:
1479 self.raise_first_error(commands, response)
1480
1481 # We have to run response callbacks manually
1482 data = []
1483 for r, cmd in zip(response, commands):
1484 if not isinstance(r, Exception):
1485 args, options = cmd
1486 command_name = args[0]
1487
1488 # Remove keys entry, it needs only for cache.
1489 options.pop("keys", None)
1490
1491 if command_name in self.response_callbacks:
1492 r = self.response_callbacks[command_name](r, **options)
1493 if inspect.isawaitable(r):
1494 r = await r
1495 data.append(r)
1496 return data
1497
1498 async def _execute_pipeline(
1499 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1500 ):
1501 # build up all commands into a single request to increase network perf
1502 all_cmds = connection.pack_commands([args for args, _ in commands])
1503 await connection.send_packed_command(all_cmds)
1504
1505 response = []
1506 for args, options in commands:
1507 try:
1508 response.append(
1509 await self.parse_response(connection, args[0], **options)
1510 )
1511 except ResponseError as e:
1512 response.append(e)
1513
1514 if raise_on_error:
1515 self.raise_first_error(commands, response)
1516 return response
1517
1518 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1519 for i, r in enumerate(response):
1520 if isinstance(r, ResponseError):
1521 self.annotate_exception(r, i + 1, commands[i][0])
1522 raise r
1523
1524 def annotate_exception(
1525 self, exception: Exception, number: int, command: Iterable[object]
1526 ) -> None:
1527 cmd = " ".join(map(safe_str, command))
1528 msg = (
1529 f"Command # {number} ({truncate_text(cmd)}) "
1530 "of pipeline caused error: {exception.args}"
1531 )
1532 exception.args = (msg,) + exception.args[1:]
1533
1534 async def parse_response(
1535 self, connection: Connection, command_name: Union[str, bytes], **options
1536 ):
1537 result = await super().parse_response(connection, command_name, **options)
1538 if command_name in self.UNWATCH_COMMANDS:
1539 self.watching = False
1540 elif command_name == "WATCH":
1541 self.watching = True
1542 return result
1543
1544 async def load_scripts(self):
1545 # make sure all scripts that are about to be run on this pipeline exist
1546 scripts = list(self.scripts)
1547 immediate = self.immediate_execute_command
1548 shas = [s.sha for s in scripts]
1549 # we can't use the normal script_* methods because they would just
1550 # get buffered in the pipeline.
1551 exists = await immediate("SCRIPT EXISTS", *shas)
1552 if not all(exists):
1553 for s, exist in zip(scripts, exists):
1554 if not exist:
1555 s.sha = await immediate("SCRIPT LOAD", s.script)
1556
1557 async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
1558 """
1559 Close the connection, raise an exception if we were watching.
1560
1561 The supported exceptions are already checked in the
1562 retry object so we don't need to do it here.
1563
1564 After we disconnect the connection, it will try to reconnect and
1565 do a health check as part of the send_command logic(on connection level).
1566 """
1567 await conn.disconnect()
1568 # if we were watching a variable, the watch is no longer valid
1569 # since this connection has died. raise a WatchError, which
1570 # indicates the user should retry this transaction.
1571 if self.watching:
1572 raise WatchError(
1573 f"A {type(error).__name__} occurred while watching one or more keys"
1574 )
1575
1576 async def execute(self, raise_on_error: bool = True) -> List[Any]:
1577 """Execute all the commands in the current pipeline"""
1578 stack = self.command_stack
1579 if not stack and not self.watching:
1580 return []
1581 if self.scripts:
1582 await self.load_scripts()
1583 if self.is_transaction or self.explicit_transaction:
1584 execute = self._execute_transaction
1585 else:
1586 execute = self._execute_pipeline
1587
1588 conn = self.connection
1589 if not conn:
1590 conn = await self.connection_pool.get_connection()
1591 # assign to self.connection so reset() releases the connection
1592 # back to the pool after we're done
1593 self.connection = conn
1594 conn = cast(Connection, conn)
1595
1596 try:
1597 return await conn.retry.call_with_retry(
1598 lambda: execute(conn, stack, raise_on_error),
1599 lambda error: self._disconnect_raise_on_watching(conn, error),
1600 )
1601 finally:
1602 await self.reset()
1603
1604 async def discard(self):
1605 """Flushes all previously queued commands
1606 See: https://redis.io/commands/DISCARD
1607 """
1608 await self.execute_command("DISCARD")
1609
1610 async def watch(self, *names: KeyT):
1611 """Watches the values at keys ``names``"""
1612 if self.explicit_transaction:
1613 raise RedisError("Cannot issue a WATCH after a MULTI")
1614 return await self.execute_command("WATCH", *names)
1615
1616 async def unwatch(self):
1617 """Unwatches all previously specified keys"""
1618 return self.watching and await self.execute_command("UNWATCH") or True