1import asyncio
2import copy
3import inspect
4import re
5import warnings
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 AsyncIterator,
10 Awaitable,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 MutableMapping,
17 Optional,
18 Protocol,
19 Set,
20 Tuple,
21 Type,
22 TypedDict,
23 TypeVar,
24 Union,
25 cast,
26)
27
28from redis._parsers.helpers import (
29 _RedisCallbacks,
30 _RedisCallbacksRESP2,
31 _RedisCallbacksRESP3,
32 bool_ok,
33)
34from redis.asyncio.connection import (
35 Connection,
36 ConnectionPool,
37 SSLConnection,
38 UnixDomainSocketConnection,
39)
40from redis.asyncio.lock import Lock
41from redis.asyncio.retry import Retry
42from redis.backoff import ExponentialWithJitterBackoff
43from redis.client import (
44 EMPTY_RESPONSE,
45 NEVER_DECODE,
46 AbstractRedis,
47 CaseInsensitiveDict,
48)
49from redis.commands import (
50 AsyncCoreCommands,
51 AsyncRedisModuleCommands,
52 AsyncSentinelCommands,
53 list_or_args,
54)
55from redis.credentials import CredentialProvider
56from redis.event import (
57 AfterPooledConnectionsInstantiationEvent,
58 AfterPubSubConnectionInstantiationEvent,
59 AfterSingleConnectionInstantiationEvent,
60 ClientType,
61 EventDispatcher,
62)
63from redis.exceptions import (
64 ConnectionError,
65 ExecAbortError,
66 PubSubError,
67 RedisError,
68 ResponseError,
69 WatchError,
70)
71from redis.typing import ChannelT, EncodableT, KeyT
72from redis.utils import (
73 SSL_AVAILABLE,
74 _set_info_logger,
75 deprecated_args,
76 deprecated_function,
77 get_lib_version,
78 safe_str,
79 str_if_bytes,
80 truncate_text,
81)
82
83if TYPE_CHECKING and SSL_AVAILABLE:
84 from ssl import TLSVersion, VerifyFlags, VerifyMode
85else:
86 TLSVersion = None
87 VerifyMode = None
88 VerifyFlags = None
89
90PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
91_KeyT = TypeVar("_KeyT", bound=KeyT)
92_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
93_RedisT = TypeVar("_RedisT", bound="Redis")
94_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
95if TYPE_CHECKING:
96 from redis.commands.core import Script
97
98
99class ResponseCallbackProtocol(Protocol):
100 def __call__(self, response: Any, **kwargs): ...
101
102
103class AsyncResponseCallbackProtocol(Protocol):
104 async def __call__(self, response: Any, **kwargs): ...
105
106
107ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
108
109
110class Redis(
111 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
112):
113 """
114 Implementation of the Redis protocol.
115
116 This abstract class provides a Python interface to all Redis commands
117 and an implementation of the Redis protocol.
118
119 Pipelines derive from this, implementing how
120 the commands are sent and received to the Redis server. Based on
121 configuration, an instance will either use a ConnectionPool, or
122 Connection object to talk to redis.
123 """
124
125 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
126
127 @classmethod
128 def from_url(
129 cls,
130 url: str,
131 single_connection_client: bool = False,
132 auto_close_connection_pool: Optional[bool] = None,
133 **kwargs,
134 ):
135 """
136 Return a Redis client object configured from the given URL
137
138 For example::
139
140 redis://[[username]:[password]]@localhost:6379/0
141 rediss://[[username]:[password]]@localhost:6379/0
142 unix://[username@]/path/to/socket.sock?db=0[&password=password]
143
144 Three URL schemes are supported:
145
146 - `redis://` creates a TCP socket connection. See more at:
147 <https://www.iana.org/assignments/uri-schemes/prov/redis>
148 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
149 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
150 - ``unix://``: creates a Unix Domain Socket connection.
151
152 The username, password, hostname, path and all querystring values
153 are passed through urllib.parse.unquote in order to replace any
154 percent-encoded values with their corresponding characters.
155
156 There are several ways to specify a database number. The first value
157 found will be used:
158
159 1. A ``db`` querystring option, e.g. redis://localhost?db=0
160
161 2. If using the redis:// or rediss:// schemes, the path argument
162 of the url, e.g. redis://localhost/0
163
164 3. A ``db`` keyword argument to this function.
165
166 If none of these options are specified, the default db=0 is used.
167
168 All querystring options are cast to their appropriate Python types.
169 Boolean arguments can be specified with string values "True"/"False"
170 or "Yes"/"No". Values that cannot be properly cast cause a
171 ``ValueError`` to be raised. Once parsed, the querystring arguments
172 and keyword arguments are passed to the ``ConnectionPool``'s
173 class initializer. In the case of conflicting arguments, querystring
174 arguments always win.
175
176 """
177 connection_pool = ConnectionPool.from_url(url, **kwargs)
178 client = cls(
179 connection_pool=connection_pool,
180 single_connection_client=single_connection_client,
181 )
182 if auto_close_connection_pool is not None:
183 warnings.warn(
184 DeprecationWarning(
185 '"auto_close_connection_pool" is deprecated '
186 "since version 5.0.1. "
187 "Please create a ConnectionPool explicitly and "
188 "provide to the Redis() constructor instead."
189 )
190 )
191 else:
192 auto_close_connection_pool = True
193 client.auto_close_connection_pool = auto_close_connection_pool
194 return client
195
196 @classmethod
197 def from_pool(
198 cls: Type["Redis"],
199 connection_pool: ConnectionPool,
200 ) -> "Redis":
201 """
202 Return a Redis client from the given connection pool.
203 The Redis client will take ownership of the connection pool and
204 close it when the Redis client is closed.
205 """
206 client = cls(
207 connection_pool=connection_pool,
208 )
209 client.auto_close_connection_pool = True
210 return client
211
212 @deprecated_args(
213 args_to_warn=["retry_on_timeout"],
214 reason="TimeoutError is included by default.",
215 version="6.0.0",
216 )
217 def __init__(
218 self,
219 *,
220 host: str = "localhost",
221 port: int = 6379,
222 db: Union[str, int] = 0,
223 password: Optional[str] = None,
224 socket_timeout: Optional[float] = None,
225 socket_connect_timeout: Optional[float] = None,
226 socket_keepalive: Optional[bool] = None,
227 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
228 connection_pool: Optional[ConnectionPool] = None,
229 unix_socket_path: Optional[str] = None,
230 encoding: str = "utf-8",
231 encoding_errors: str = "strict",
232 decode_responses: bool = False,
233 retry_on_timeout: bool = False,
234 retry: Retry = Retry(
235 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
236 ),
237 retry_on_error: Optional[list] = None,
238 ssl: bool = False,
239 ssl_keyfile: Optional[str] = None,
240 ssl_certfile: Optional[str] = None,
241 ssl_cert_reqs: Union[str, VerifyMode] = "required",
242 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
243 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
244 ssl_ca_certs: Optional[str] = None,
245 ssl_ca_data: Optional[str] = None,
246 ssl_check_hostname: bool = True,
247 ssl_min_version: Optional[TLSVersion] = None,
248 ssl_ciphers: Optional[str] = None,
249 max_connections: Optional[int] = None,
250 single_connection_client: bool = False,
251 health_check_interval: int = 0,
252 client_name: Optional[str] = None,
253 lib_name: Optional[str] = "redis-py",
254 lib_version: Optional[str] = get_lib_version(),
255 username: Optional[str] = None,
256 auto_close_connection_pool: Optional[bool] = None,
257 redis_connect_func=None,
258 credential_provider: Optional[CredentialProvider] = None,
259 protocol: Optional[int] = 2,
260 event_dispatcher: Optional[EventDispatcher] = None,
261 ):
262 """
263 Initialize a new Redis client.
264
265 To specify a retry policy for specific errors, you have two options:
266
267 1. Set the `retry_on_error` to a list of the error/s to retry on, and
268 you can also set `retry` to a valid `Retry` object(in case the default
269 one is not appropriate) - with this approach the retries will be triggered
270 on the default errors specified in the Retry object enriched with the
271 errors specified in `retry_on_error`.
272
273 2. Define a `Retry` object with configured 'supported_errors' and set
274 it to the `retry` parameter - with this approach you completely redefine
275 the errors on which retries will happen.
276
277 `retry_on_timeout` is deprecated - please include the TimeoutError
278 either in the Retry object or in the `retry_on_error` list.
279
280 When 'connection_pool' is provided - the retry configuration of the
281 provided pool will be used.
282 """
283 kwargs: Dict[str, Any]
284 if event_dispatcher is None:
285 self._event_dispatcher = EventDispatcher()
286 else:
287 self._event_dispatcher = event_dispatcher
288 # auto_close_connection_pool only has an effect if connection_pool is
289 # None. It is assumed that if connection_pool is not None, the user
290 # wants to manage the connection pool themselves.
291 if auto_close_connection_pool is not None:
292 warnings.warn(
293 DeprecationWarning(
294 '"auto_close_connection_pool" is deprecated '
295 "since version 5.0.1. "
296 "Please create a ConnectionPool explicitly and "
297 "provide to the Redis() constructor instead."
298 )
299 )
300 else:
301 auto_close_connection_pool = True
302
303 if not connection_pool:
304 # Create internal connection pool, expected to be closed by Redis instance
305 if not retry_on_error:
306 retry_on_error = []
307 kwargs = {
308 "db": db,
309 "username": username,
310 "password": password,
311 "credential_provider": credential_provider,
312 "socket_timeout": socket_timeout,
313 "encoding": encoding,
314 "encoding_errors": encoding_errors,
315 "decode_responses": decode_responses,
316 "retry_on_error": retry_on_error,
317 "retry": copy.deepcopy(retry),
318 "max_connections": max_connections,
319 "health_check_interval": health_check_interval,
320 "client_name": client_name,
321 "lib_name": lib_name,
322 "lib_version": lib_version,
323 "redis_connect_func": redis_connect_func,
324 "protocol": protocol,
325 }
326 # based on input, setup appropriate connection args
327 if unix_socket_path is not None:
328 kwargs.update(
329 {
330 "path": unix_socket_path,
331 "connection_class": UnixDomainSocketConnection,
332 }
333 )
334 else:
335 # TCP specific options
336 kwargs.update(
337 {
338 "host": host,
339 "port": port,
340 "socket_connect_timeout": socket_connect_timeout,
341 "socket_keepalive": socket_keepalive,
342 "socket_keepalive_options": socket_keepalive_options,
343 }
344 )
345
346 if ssl:
347 kwargs.update(
348 {
349 "connection_class": SSLConnection,
350 "ssl_keyfile": ssl_keyfile,
351 "ssl_certfile": ssl_certfile,
352 "ssl_cert_reqs": ssl_cert_reqs,
353 "ssl_include_verify_flags": ssl_include_verify_flags,
354 "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
355 "ssl_ca_certs": ssl_ca_certs,
356 "ssl_ca_data": ssl_ca_data,
357 "ssl_check_hostname": ssl_check_hostname,
358 "ssl_min_version": ssl_min_version,
359 "ssl_ciphers": ssl_ciphers,
360 }
361 )
362 # This arg only used if no pool is passed in
363 self.auto_close_connection_pool = auto_close_connection_pool
364 connection_pool = ConnectionPool(**kwargs)
365 self._event_dispatcher.dispatch(
366 AfterPooledConnectionsInstantiationEvent(
367 [connection_pool], ClientType.ASYNC, credential_provider
368 )
369 )
370 else:
371 # If a pool is passed in, do not close it
372 self.auto_close_connection_pool = False
373 self._event_dispatcher.dispatch(
374 AfterPooledConnectionsInstantiationEvent(
375 [connection_pool], ClientType.ASYNC, credential_provider
376 )
377 )
378
379 self.connection_pool = connection_pool
380 self.single_connection_client = single_connection_client
381 self.connection: Optional[Connection] = None
382
383 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
384
385 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
386 self.response_callbacks.update(_RedisCallbacksRESP3)
387 else:
388 self.response_callbacks.update(_RedisCallbacksRESP2)
389
390 # If using a single connection client, we need to lock creation-of and use-of
391 # the client in order to avoid race conditions such as using asyncio.gather
392 # on a set of redis commands
393 self._single_conn_lock = asyncio.Lock()
394
395 # When used as an async context manager, we need to increment and decrement
396 # a usage counter so that we can close the connection pool when no one is
397 # using the client.
398 self._usage_counter = 0
399 self._usage_lock = asyncio.Lock()
400
401 def __repr__(self):
402 return (
403 f"<{self.__class__.__module__}.{self.__class__.__name__}"
404 f"({self.connection_pool!r})>"
405 )
406
407 def __await__(self):
408 return self.initialize().__await__()
409
410 async def initialize(self: _RedisT) -> _RedisT:
411 if self.single_connection_client:
412 async with self._single_conn_lock:
413 if self.connection is None:
414 self.connection = await self.connection_pool.get_connection()
415
416 self._event_dispatcher.dispatch(
417 AfterSingleConnectionInstantiationEvent(
418 self.connection, ClientType.ASYNC, self._single_conn_lock
419 )
420 )
421 return self
422
423 def set_response_callback(self, command: str, callback: ResponseCallbackT):
424 """Set a custom Response Callback"""
425 self.response_callbacks[command] = callback
426
427 def get_encoder(self):
428 """Get the connection pool's encoder"""
429 return self.connection_pool.get_encoder()
430
431 def get_connection_kwargs(self):
432 """Get the connection's key-word arguments"""
433 return self.connection_pool.connection_kwargs
434
435 def get_retry(self) -> Optional[Retry]:
436 return self.get_connection_kwargs().get("retry")
437
438 def set_retry(self, retry: Retry) -> None:
439 self.get_connection_kwargs().update({"retry": retry})
440 self.connection_pool.set_retry(retry)
441
442 def load_external_module(self, funcname, func):
443 """
444 This function can be used to add externally defined redis modules,
445 and their namespaces to the redis client.
446
447 funcname - A string containing the name of the function to create
448 func - The function, being added to this class.
449
450 ex: Assume that one has a custom redis module named foomod that
451 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
452 To load function functions into this namespace:
453
454 from redis import Redis
455 from foomodule import F
456 r = Redis()
457 r.load_external_module("foo", F)
458 r.foo().dothing('your', 'arguments')
459
460 For a concrete example see the reimport of the redisjson module in
461 tests/test_connection.py::test_loading_external_modules
462 """
463 setattr(self, funcname, func)
464
465 def pipeline(
466 self, transaction: bool = True, shard_hint: Optional[str] = None
467 ) -> "Pipeline":
468 """
469 Return a new pipeline object that can queue multiple commands for
470 later execution. ``transaction`` indicates whether all commands
471 should be executed atomically. Apart from making a group of operations
472 atomic, pipelines are useful for reducing the back-and-forth overhead
473 between the client and server.
474 """
475 return Pipeline(
476 self.connection_pool, self.response_callbacks, transaction, shard_hint
477 )
478
479 async def transaction(
480 self,
481 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
482 *watches: KeyT,
483 shard_hint: Optional[str] = None,
484 value_from_callable: bool = False,
485 watch_delay: Optional[float] = None,
486 ):
487 """
488 Convenience method for executing the callable `func` as a transaction
489 while watching all keys specified in `watches`. The 'func' callable
490 should expect a single argument which is a Pipeline object.
491 """
492 pipe: Pipeline
493 async with self.pipeline(True, shard_hint) as pipe:
494 while True:
495 try:
496 if watches:
497 await pipe.watch(*watches)
498 func_value = func(pipe)
499 if inspect.isawaitable(func_value):
500 func_value = await func_value
501 exec_value = await pipe.execute()
502 return func_value if value_from_callable else exec_value
503 except WatchError:
504 if watch_delay is not None and watch_delay > 0:
505 await asyncio.sleep(watch_delay)
506 continue
507
508 def lock(
509 self,
510 name: KeyT,
511 timeout: Optional[float] = None,
512 sleep: float = 0.1,
513 blocking: bool = True,
514 blocking_timeout: Optional[float] = None,
515 lock_class: Optional[Type[Lock]] = None,
516 thread_local: bool = True,
517 raise_on_release_error: bool = True,
518 ) -> Lock:
519 """
520 Return a new Lock object using key ``name`` that mimics
521 the behavior of threading.Lock.
522
523 If specified, ``timeout`` indicates a maximum life for the lock.
524 By default, it will remain locked until release() is called.
525
526 ``sleep`` indicates the amount of time to sleep per loop iteration
527 when the lock is in blocking mode and another client is currently
528 holding the lock.
529
530 ``blocking`` indicates whether calling ``acquire`` should block until
531 the lock has been acquired or to fail immediately, causing ``acquire``
532 to return False and the lock not being acquired. Defaults to True.
533 Note this value can be overridden by passing a ``blocking``
534 argument to ``acquire``.
535
536 ``blocking_timeout`` indicates the maximum amount of time in seconds to
537 spend trying to acquire the lock. A value of ``None`` indicates
538 continue trying forever. ``blocking_timeout`` can be specified as a
539 float or integer, both representing the number of seconds to wait.
540
541 ``lock_class`` forces the specified lock implementation. Note that as
542 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
543 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
544 you have created your own custom lock class.
545
546 ``thread_local`` indicates whether the lock token is placed in
547 thread-local storage. By default, the token is placed in thread local
548 storage so that a thread only sees its token, not a token set by
549 another thread. Consider the following timeline:
550
551 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
552 thread-1 sets the token to "abc"
553 time: 1, thread-2 blocks trying to acquire `my-lock` using the
554 Lock instance.
555 time: 5, thread-1 has not yet completed. redis expires the lock
556 key.
557 time: 5, thread-2 acquired `my-lock` now that it's available.
558 thread-2 sets the token to "xyz"
559 time: 6, thread-1 finishes its work and calls release(). if the
560 token is *not* stored in thread local storage, then
561 thread-1 would see the token value as "xyz" and would be
562 able to successfully release the thread-2's lock.
563
564 ``raise_on_release_error`` indicates whether to raise an exception when
565 the lock is no longer owned when exiting the context manager. By default,
566 this is True, meaning an exception will be raised. If False, the warning
567 will be logged and the exception will be suppressed.
568
569 In some use cases it's necessary to disable thread local storage. For
570 example, if you have code where one thread acquires a lock and passes
571 that lock instance to a worker thread to release later. If thread
572 local storage isn't disabled in this case, the worker thread won't see
573 the token set by the thread that acquired the lock. Our assumption
574 is that these cases aren't common and as such default to using
575 thread local storage."""
576 if lock_class is None:
577 lock_class = Lock
578 return lock_class(
579 self,
580 name,
581 timeout=timeout,
582 sleep=sleep,
583 blocking=blocking,
584 blocking_timeout=blocking_timeout,
585 thread_local=thread_local,
586 raise_on_release_error=raise_on_release_error,
587 )
588
589 def pubsub(self, **kwargs) -> "PubSub":
590 """
591 Return a Publish/Subscribe object. With this object, you can
592 subscribe to channels and listen for messages that get published to
593 them.
594 """
595 return PubSub(
596 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
597 )
598
599 def monitor(self) -> "Monitor":
600 return Monitor(self.connection_pool)
601
602 def client(self) -> "Redis":
603 return self.__class__(
604 connection_pool=self.connection_pool, single_connection_client=True
605 )
606
607 async def __aenter__(self: _RedisT) -> _RedisT:
608 """
609 Async context manager entry. Increments a usage counter so that the
610 connection pool is only closed (via aclose()) when no context is using
611 the client.
612 """
613 await self._increment_usage()
614 try:
615 # Initialize the client (i.e. establish connection, etc.)
616 return await self.initialize()
617 except Exception:
618 # If initialization fails, decrement the counter to keep it in sync
619 await self._decrement_usage()
620 raise
621
622 async def _increment_usage(self) -> int:
623 """
624 Helper coroutine to increment the usage counter while holding the lock.
625 Returns the new value of the usage counter.
626 """
627 async with self._usage_lock:
628 self._usage_counter += 1
629 return self._usage_counter
630
631 async def _decrement_usage(self) -> int:
632 """
633 Helper coroutine to decrement the usage counter while holding the lock.
634 Returns the new value of the usage counter.
635 """
636 async with self._usage_lock:
637 self._usage_counter -= 1
638 return self._usage_counter
639
640 async def __aexit__(self, exc_type, exc_value, traceback):
641 """
642 Async context manager exit. Decrements a usage counter. If this is the
643 last exit (counter becomes zero), the client closes its connection pool.
644 """
645 current_usage = await asyncio.shield(self._decrement_usage())
646 if current_usage == 0:
647 # This was the last active context, so disconnect the pool.
648 await asyncio.shield(self.aclose())
649
650 _DEL_MESSAGE = "Unclosed Redis client"
651
652 # passing _warnings and _grl as argument default since they may be gone
653 # by the time __del__ is called at shutdown
654 def __del__(
655 self,
656 _warn: Any = warnings.warn,
657 _grl: Any = asyncio.get_running_loop,
658 ) -> None:
659 if hasattr(self, "connection") and (self.connection is not None):
660 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
661 try:
662 context = {"client": self, "message": self._DEL_MESSAGE}
663 _grl().call_exception_handler(context)
664 except RuntimeError:
665 pass
666 self.connection._close()
667
668 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
669 """
670 Closes Redis client connection
671
672 Args:
673 close_connection_pool:
674 decides whether to close the connection pool used by this Redis client,
675 overriding Redis.auto_close_connection_pool.
676 By default, let Redis.auto_close_connection_pool decide
677 whether to close the connection pool.
678 """
679 conn = self.connection
680 if conn:
681 self.connection = None
682 await self.connection_pool.release(conn)
683 if close_connection_pool or (
684 close_connection_pool is None and self.auto_close_connection_pool
685 ):
686 await self.connection_pool.disconnect()
687
688 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
689 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
690 """
691 Alias for aclose(), for backwards compatibility
692 """
693 await self.aclose(close_connection_pool)
694
695 async def _send_command_parse_response(self, conn, command_name, *args, **options):
696 """
697 Send a command and parse the response
698 """
699 await conn.send_command(*args)
700 return await self.parse_response(conn, command_name, **options)
701
702 async def _close_connection(self, conn: Connection):
703 """
704 Close the connection before retrying.
705
706 The supported exceptions are already checked in the
707 retry object so we don't need to do it here.
708
709 After we disconnect the connection, it will try to reconnect and
710 do a health check as part of the send_command logic(on connection level).
711 """
712 await conn.disconnect()
713
714 # COMMAND EXECUTION AND PROTOCOL PARSING
715 async def execute_command(self, *args, **options):
716 """Execute a command and return a parsed response"""
717 await self.initialize()
718 pool = self.connection_pool
719 command_name = args[0]
720 conn = self.connection or await pool.get_connection()
721
722 if self.single_connection_client:
723 await self._single_conn_lock.acquire()
724 try:
725 return await conn.retry.call_with_retry(
726 lambda: self._send_command_parse_response(
727 conn, command_name, *args, **options
728 ),
729 lambda _: self._close_connection(conn),
730 )
731 finally:
732 if self.single_connection_client:
733 self._single_conn_lock.release()
734 if not self.connection:
735 await pool.release(conn)
736
737 async def parse_response(
738 self, connection: Connection, command_name: Union[str, bytes], **options
739 ):
740 """Parses a response from the Redis server"""
741 try:
742 if NEVER_DECODE in options:
743 response = await connection.read_response(disable_decoding=True)
744 options.pop(NEVER_DECODE)
745 else:
746 response = await connection.read_response()
747 except ResponseError:
748 if EMPTY_RESPONSE in options:
749 return options[EMPTY_RESPONSE]
750 raise
751
752 if EMPTY_RESPONSE in options:
753 options.pop(EMPTY_RESPONSE)
754
755 # Remove keys entry, it needs only for cache.
756 options.pop("keys", None)
757
758 if command_name in self.response_callbacks:
759 # Mypy bug: https://github.com/python/mypy/issues/10977
760 command_name = cast(str, command_name)
761 retval = self.response_callbacks[command_name](response, **options)
762 return await retval if inspect.isawaitable(retval) else retval
763 return response
764
765
766StrictRedis = Redis
767
768
769class MonitorCommandInfo(TypedDict):
770 time: float
771 db: int
772 client_address: str
773 client_port: str
774 client_type: str
775 command: str
776
777
778class Monitor:
779 """
780 Monitor is useful for handling the MONITOR command to the redis server.
781 next_command() method returns one command from monitor
782 listen() method yields commands from monitor.
783 """
784
785 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
786 command_re = re.compile(r'"(.*?)(?<!\\)"')
787
788 def __init__(self, connection_pool: ConnectionPool):
789 self.connection_pool = connection_pool
790 self.connection: Optional[Connection] = None
791
792 async def connect(self):
793 if self.connection is None:
794 self.connection = await self.connection_pool.get_connection()
795
796 async def __aenter__(self):
797 await self.connect()
798 await self.connection.send_command("MONITOR")
799 # check that monitor returns 'OK', but don't return it to user
800 response = await self.connection.read_response()
801 if not bool_ok(response):
802 raise RedisError(f"MONITOR failed: {response}")
803 return self
804
805 async def __aexit__(self, *args):
806 await self.connection.disconnect()
807 await self.connection_pool.release(self.connection)
808
809 async def next_command(self) -> MonitorCommandInfo:
810 """Parse the response from a monitor command"""
811 await self.connect()
812 response = await self.connection.read_response()
813 if isinstance(response, bytes):
814 response = self.connection.encoder.decode(response, force=True)
815 command_time, command_data = response.split(" ", 1)
816 m = self.monitor_re.match(command_data)
817 db_id, client_info, command = m.groups()
818 command = " ".join(self.command_re.findall(command))
819 # Redis escapes double quotes because each piece of the command
820 # string is surrounded by double quotes. We don't have that
821 # requirement so remove the escaping and leave the quote.
822 command = command.replace('\\"', '"')
823
824 if client_info == "lua":
825 client_address = "lua"
826 client_port = ""
827 client_type = "lua"
828 elif client_info.startswith("unix"):
829 client_address = "unix"
830 client_port = client_info[5:]
831 client_type = "unix"
832 else:
833 # use rsplit as ipv6 addresses contain colons
834 client_address, client_port = client_info.rsplit(":", 1)
835 client_type = "tcp"
836 return {
837 "time": float(command_time),
838 "db": int(db_id),
839 "client_address": client_address,
840 "client_port": client_port,
841 "client_type": client_type,
842 "command": command,
843 }
844
845 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
846 """Listen for commands coming to the server."""
847 while True:
848 yield await self.next_command()
849
850
851class PubSub:
852 """
853 PubSub provides publish, subscribe and listen support to Redis channels.
854
855 After subscribing to one or more channels, the listen() method will block
856 until a message arrives on one of the subscribed channels. That message
857 will be returned and it's safe to start listening again.
858 """
859
860 PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
861 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
862 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
863
864 def __init__(
865 self,
866 connection_pool: ConnectionPool,
867 shard_hint: Optional[str] = None,
868 ignore_subscribe_messages: bool = False,
869 encoder=None,
870 push_handler_func: Optional[Callable] = None,
871 event_dispatcher: Optional["EventDispatcher"] = None,
872 ):
873 if event_dispatcher is None:
874 self._event_dispatcher = EventDispatcher()
875 else:
876 self._event_dispatcher = event_dispatcher
877 self.connection_pool = connection_pool
878 self.shard_hint = shard_hint
879 self.ignore_subscribe_messages = ignore_subscribe_messages
880 self.connection = None
881 # we need to know the encoding options for this connection in order
882 # to lookup channel and pattern names for callback handlers.
883 self.encoder = encoder
884 self.push_handler_func = push_handler_func
885 if self.encoder is None:
886 self.encoder = self.connection_pool.get_encoder()
887 if self.encoder.decode_responses:
888 self.health_check_response = [
889 ["pong", self.HEALTH_CHECK_MESSAGE],
890 self.HEALTH_CHECK_MESSAGE,
891 ]
892 else:
893 self.health_check_response = [
894 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
895 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
896 ]
897 if self.push_handler_func is None:
898 _set_info_logger()
899 self.channels = {}
900 self.pending_unsubscribe_channels = set()
901 self.patterns = {}
902 self.pending_unsubscribe_patterns = set()
903 self._lock = asyncio.Lock()
904
905 async def __aenter__(self):
906 return self
907
908 async def __aexit__(self, exc_type, exc_value, traceback):
909 await self.aclose()
910
911 def __del__(self):
912 if self.connection:
913 self.connection.deregister_connect_callback(self.on_connect)
914
915 async def aclose(self):
916 # In case a connection property does not yet exist
917 # (due to a crash earlier in the Redis() constructor), return
918 # immediately as there is nothing to clean-up.
919 if not hasattr(self, "connection"):
920 return
921 async with self._lock:
922 if self.connection:
923 await self.connection.disconnect()
924 self.connection.deregister_connect_callback(self.on_connect)
925 await self.connection_pool.release(self.connection)
926 self.connection = None
927 self.channels = {}
928 self.pending_unsubscribe_channels = set()
929 self.patterns = {}
930 self.pending_unsubscribe_patterns = set()
931
932 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
933 async def close(self) -> None:
934 """Alias for aclose(), for backwards compatibility"""
935 await self.aclose()
936
937 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
938 async def reset(self) -> None:
939 """Alias for aclose(), for backwards compatibility"""
940 await self.aclose()
941
942 async def on_connect(self, connection: Connection):
943 """Re-subscribe to any channels and patterns previously subscribed to"""
944 # NOTE: for python3, we can't pass bytestrings as keyword arguments
945 # so we need to decode channel/pattern names back to unicode strings
946 # before passing them to [p]subscribe.
947 self.pending_unsubscribe_channels.clear()
948 self.pending_unsubscribe_patterns.clear()
949 if self.channels:
950 channels = {}
951 for k, v in self.channels.items():
952 channels[self.encoder.decode(k, force=True)] = v
953 await self.subscribe(**channels)
954 if self.patterns:
955 patterns = {}
956 for k, v in self.patterns.items():
957 patterns[self.encoder.decode(k, force=True)] = v
958 await self.psubscribe(**patterns)
959
960 @property
961 def subscribed(self):
962 """Indicates if there are subscriptions to any channels or patterns"""
963 return bool(self.channels or self.patterns)
964
965 async def execute_command(self, *args: EncodableT):
966 """Execute a publish/subscribe command"""
967
968 # NOTE: don't parse the response in this function -- it could pull a
969 # legitimate message off the stack if the connection is already
970 # subscribed to one or more channels
971
972 await self.connect()
973 connection = self.connection
974 kwargs = {"check_health": not self.subscribed}
975 await self._execute(connection, connection.send_command, *args, **kwargs)
976
977 async def connect(self):
978 """
979 Ensure that the PubSub is connected
980 """
981 if self.connection is None:
982 self.connection = await self.connection_pool.get_connection()
983 # register a callback that re-subscribes to any channels we
984 # were listening to when we were disconnected
985 self.connection.register_connect_callback(self.on_connect)
986 else:
987 await self.connection.connect()
988 if self.push_handler_func is not None:
989 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
990
991 self._event_dispatcher.dispatch(
992 AfterPubSubConnectionInstantiationEvent(
993 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
994 )
995 )
996
997 async def _reconnect(self, conn):
998 """
999 Try to reconnect
1000 """
1001 await conn.disconnect()
1002 await conn.connect()
1003
1004 async def _execute(self, conn, command, *args, **kwargs):
1005 """
1006 Connect manually upon disconnection. If the Redis server is down,
1007 this will fail and raise a ConnectionError as desired.
1008 After reconnection, the ``on_connect`` callback should have been
1009 called by the # connection to resubscribe us to any channels and
1010 patterns we were previously listening to
1011 """
1012 return await conn.retry.call_with_retry(
1013 lambda: command(*args, **kwargs),
1014 lambda _: self._reconnect(conn),
1015 )
1016
1017 async def parse_response(self, block: bool = True, timeout: float = 0):
1018 """Parse the response from a publish/subscribe command"""
1019 conn = self.connection
1020 if conn is None:
1021 raise RuntimeError(
1022 "pubsub connection not set: "
1023 "did you forget to call subscribe() or psubscribe()?"
1024 )
1025
1026 await self.check_health()
1027
1028 if not conn.is_connected:
1029 await conn.connect()
1030
1031 read_timeout = None if block else timeout
1032 response = await self._execute(
1033 conn,
1034 conn.read_response,
1035 timeout=read_timeout,
1036 disconnect_on_error=False,
1037 push_request=True,
1038 )
1039
1040 if conn.health_check_interval and response in self.health_check_response:
1041 # ignore the health check message as user might not expect it
1042 return None
1043 return response
1044
1045 async def check_health(self):
1046 conn = self.connection
1047 if conn is None:
1048 raise RuntimeError(
1049 "pubsub connection not set: "
1050 "did you forget to call subscribe() or psubscribe()?"
1051 )
1052
1053 if (
1054 conn.health_check_interval
1055 and asyncio.get_running_loop().time() > conn.next_health_check
1056 ):
1057 await conn.send_command(
1058 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1059 )
1060
1061 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1062 """
1063 normalize channel/pattern names to be either bytes or strings
1064 based on whether responses are automatically decoded. this saves us
1065 from coercing the value for each message coming in.
1066 """
1067 encode = self.encoder.encode
1068 decode = self.encoder.decode
1069 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1070
1071 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1072 """
1073 Subscribe to channel patterns. Patterns supplied as keyword arguments
1074 expect a pattern name as the key and a callable as the value. A
1075 pattern's callable will be invoked automatically when a message is
1076 received on that pattern rather than producing a message via
1077 ``listen()``.
1078 """
1079 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1080 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1081 # Mypy bug: https://github.com/python/mypy/issues/10970
1082 new_patterns.update(kwargs) # type: ignore[arg-type]
1083 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1084 # update the patterns dict AFTER we send the command. we don't want to
1085 # subscribe twice to these patterns, once for the command and again
1086 # for the reconnection.
1087 new_patterns = self._normalize_keys(new_patterns)
1088 self.patterns.update(new_patterns)
1089 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1090 return ret_val
1091
1092 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1093 """
1094 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1095 all patterns.
1096 """
1097 patterns: Iterable[ChannelT]
1098 if args:
1099 parsed_args = list_or_args((args[0],), args[1:])
1100 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1101 else:
1102 parsed_args = []
1103 patterns = self.patterns
1104 self.pending_unsubscribe_patterns.update(patterns)
1105 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1106
1107 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1108 """
1109 Subscribe to channels. Channels supplied as keyword arguments expect
1110 a channel name as the key and a callable as the value. A channel's
1111 callable will be invoked automatically when a message is received on
1112 that channel rather than producing a message via ``listen()`` or
1113 ``get_message()``.
1114 """
1115 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1116 new_channels = dict.fromkeys(parsed_args)
1117 # Mypy bug: https://github.com/python/mypy/issues/10970
1118 new_channels.update(kwargs) # type: ignore[arg-type]
1119 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1120 # update the channels dict AFTER we send the command. we don't want to
1121 # subscribe twice to these channels, once for the command and again
1122 # for the reconnection.
1123 new_channels = self._normalize_keys(new_channels)
1124 self.channels.update(new_channels)
1125 self.pending_unsubscribe_channels.difference_update(new_channels)
1126 return ret_val
1127
1128 def unsubscribe(self, *args) -> Awaitable:
1129 """
1130 Unsubscribe from the supplied channels. If empty, unsubscribe from
1131 all channels
1132 """
1133 if args:
1134 parsed_args = list_or_args(args[0], args[1:])
1135 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1136 else:
1137 parsed_args = []
1138 channels = self.channels
1139 self.pending_unsubscribe_channels.update(channels)
1140 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1141
1142 async def listen(self) -> AsyncIterator:
1143 """Listen for messages on channels this client has been subscribed to"""
1144 while self.subscribed:
1145 response = await self.handle_message(await self.parse_response(block=True))
1146 if response is not None:
1147 yield response
1148
1149 async def get_message(
1150 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1151 ):
1152 """
1153 Get the next message if one is available, otherwise None.
1154
1155 If timeout is specified, the system will wait for `timeout` seconds
1156 before returning. Timeout should be specified as a floating point
1157 number or None to wait indefinitely.
1158 """
1159 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1160 if response:
1161 return await self.handle_message(response, ignore_subscribe_messages)
1162 return None
1163
1164 def ping(self, message=None) -> Awaitable[bool]:
1165 """
1166 Ping the Redis server to test connectivity.
1167
1168 Sends a PING command to the Redis server and returns True if the server
1169 responds with "PONG".
1170 """
1171 args = ["PING", message] if message is not None else ["PING"]
1172 return self.execute_command(*args)
1173
1174 async def handle_message(self, response, ignore_subscribe_messages=False):
1175 """
1176 Parses a pub/sub message. If the channel or pattern was subscribed to
1177 with a message handler, the handler is invoked instead of a parsed
1178 message being returned.
1179 """
1180 if response is None:
1181 return None
1182 if isinstance(response, bytes):
1183 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1184 message_type = str_if_bytes(response[0])
1185 if message_type == "pmessage":
1186 message = {
1187 "type": message_type,
1188 "pattern": response[1],
1189 "channel": response[2],
1190 "data": response[3],
1191 }
1192 elif message_type == "pong":
1193 message = {
1194 "type": message_type,
1195 "pattern": None,
1196 "channel": None,
1197 "data": response[1],
1198 }
1199 else:
1200 message = {
1201 "type": message_type,
1202 "pattern": None,
1203 "channel": response[1],
1204 "data": response[2],
1205 }
1206
1207 # if this is an unsubscribe message, remove it from memory
1208 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1209 if message_type == "punsubscribe":
1210 pattern = response[1]
1211 if pattern in self.pending_unsubscribe_patterns:
1212 self.pending_unsubscribe_patterns.remove(pattern)
1213 self.patterns.pop(pattern, None)
1214 else:
1215 channel = response[1]
1216 if channel in self.pending_unsubscribe_channels:
1217 self.pending_unsubscribe_channels.remove(channel)
1218 self.channels.pop(channel, None)
1219
1220 if message_type in self.PUBLISH_MESSAGE_TYPES:
1221 # if there's a message handler, invoke it
1222 if message_type == "pmessage":
1223 handler = self.patterns.get(message["pattern"], None)
1224 else:
1225 handler = self.channels.get(message["channel"], None)
1226 if handler:
1227 if inspect.iscoroutinefunction(handler):
1228 await handler(message)
1229 else:
1230 handler(message)
1231 return None
1232 elif message_type != "pong":
1233 # this is a subscribe/unsubscribe message. ignore if we don't
1234 # want them
1235 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1236 return None
1237
1238 return message
1239
1240 async def run(
1241 self,
1242 *,
1243 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1244 poll_timeout: float = 1.0,
1245 pubsub=None,
1246 ) -> None:
1247 """Process pub/sub messages using registered callbacks.
1248
1249 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1250 redis-py, but it is a coroutine. To launch it as a separate task, use
1251 ``asyncio.create_task``:
1252
1253 >>> task = asyncio.create_task(pubsub.run())
1254
1255 To shut it down, use asyncio cancellation:
1256
1257 >>> task.cancel()
1258 >>> await task
1259 """
1260 for channel, handler in self.channels.items():
1261 if handler is None:
1262 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1263 for pattern, handler in self.patterns.items():
1264 if handler is None:
1265 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1266
1267 await self.connect()
1268 while True:
1269 try:
1270 if pubsub is None:
1271 await self.get_message(
1272 ignore_subscribe_messages=True, timeout=poll_timeout
1273 )
1274 else:
1275 await pubsub.get_message(
1276 ignore_subscribe_messages=True, timeout=poll_timeout
1277 )
1278 except asyncio.CancelledError:
1279 raise
1280 except BaseException as e:
1281 if exception_handler is None:
1282 raise
1283 res = exception_handler(e, self)
1284 if inspect.isawaitable(res):
1285 await res
1286 # Ensure that other tasks on the event loop get a chance to run
1287 # if we didn't have to block for I/O anywhere.
1288 await asyncio.sleep(0)
1289
1290
1291class PubsubWorkerExceptionHandler(Protocol):
1292 def __call__(self, e: BaseException, pubsub: PubSub): ...
1293
1294
1295class AsyncPubsubWorkerExceptionHandler(Protocol):
1296 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1297
1298
1299PSWorkerThreadExcHandlerT = Union[
1300 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1301]
1302
1303
1304CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1305CommandStackT = List[CommandT]
1306
1307
1308class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1309 """
1310 Pipelines provide a way to transmit multiple commands to the Redis server
1311 in one transmission. This is convenient for batch processing, such as
1312 saving all the values in a list to Redis.
1313
1314 All commands executed within a pipeline(when running in transactional mode,
1315 which is the default behavior) are wrapped with MULTI and EXEC
1316 calls. This guarantees all commands executed in the pipeline will be
1317 executed atomically.
1318
1319 Any command raising an exception does *not* halt the execution of
1320 subsequent commands in the pipeline. Instead, the exception is caught
1321 and its instance is placed into the response list returned by execute().
1322 Code iterating over the response list should be able to deal with an
1323 instance of an exception as a potential value. In general, these will be
1324 ResponseError exceptions, such as those raised when issuing a command
1325 on a key of a different datatype.
1326 """
1327
1328 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1329
1330 def __init__(
1331 self,
1332 connection_pool: ConnectionPool,
1333 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1334 transaction: bool,
1335 shard_hint: Optional[str],
1336 ):
1337 self.connection_pool = connection_pool
1338 self.connection = None
1339 self.response_callbacks = response_callbacks
1340 self.is_transaction = transaction
1341 self.shard_hint = shard_hint
1342 self.watching = False
1343 self.command_stack: CommandStackT = []
1344 self.scripts: Set[Script] = set()
1345 self.explicit_transaction = False
1346
1347 async def __aenter__(self: _RedisT) -> _RedisT:
1348 return self
1349
1350 async def __aexit__(self, exc_type, exc_value, traceback):
1351 await self.reset()
1352
1353 def __await__(self):
1354 return self._async_self().__await__()
1355
1356 _DEL_MESSAGE = "Unclosed Pipeline client"
1357
1358 def __len__(self):
1359 return len(self.command_stack)
1360
1361 def __bool__(self):
1362 """Pipeline instances should always evaluate to True"""
1363 return True
1364
1365 async def _async_self(self):
1366 return self
1367
1368 async def reset(self):
1369 self.command_stack = []
1370 self.scripts = set()
1371 # make sure to reset the connection state in the event that we were
1372 # watching something
1373 if self.watching and self.connection:
1374 try:
1375 # call this manually since our unwatch or
1376 # immediate_execute_command methods can call reset()
1377 await self.connection.send_command("UNWATCH")
1378 await self.connection.read_response()
1379 except ConnectionError:
1380 # disconnect will also remove any previous WATCHes
1381 if self.connection:
1382 await self.connection.disconnect()
1383 # clean up the other instance attributes
1384 self.watching = False
1385 self.explicit_transaction = False
1386 # we can safely return the connection to the pool here since we're
1387 # sure we're no longer WATCHing anything
1388 if self.connection:
1389 await self.connection_pool.release(self.connection)
1390 self.connection = None
1391
1392 async def aclose(self) -> None:
1393 """Alias for reset(), a standard method name for cleanup"""
1394 await self.reset()
1395
1396 def multi(self):
1397 """
1398 Start a transactional block of the pipeline after WATCH commands
1399 are issued. End the transactional block with `execute`.
1400 """
1401 if self.explicit_transaction:
1402 raise RedisError("Cannot issue nested calls to MULTI")
1403 if self.command_stack:
1404 raise RedisError(
1405 "Commands without an initial WATCH have already been issued"
1406 )
1407 self.explicit_transaction = True
1408
1409 def execute_command(
1410 self, *args, **kwargs
1411 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1412 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1413 return self.immediate_execute_command(*args, **kwargs)
1414 return self.pipeline_execute_command(*args, **kwargs)
1415
1416 async def _disconnect_reset_raise_on_watching(
1417 self,
1418 conn: Connection,
1419 error: Exception,
1420 ):
1421 """
1422 Close the connection reset watching state and
1423 raise an exception if we were watching.
1424
1425 The supported exceptions are already checked in the
1426 retry object so we don't need to do it here.
1427
1428 After we disconnect the connection, it will try to reconnect and
1429 do a health check as part of the send_command logic(on connection level).
1430 """
1431 await conn.disconnect()
1432 # if we were already watching a variable, the watch is no longer
1433 # valid since this connection has died. raise a WatchError, which
1434 # indicates the user should retry this transaction.
1435 if self.watching:
1436 await self.reset()
1437 raise WatchError(
1438 f"A {type(error).__name__} occurred while watching one or more keys"
1439 )
1440
1441 async def immediate_execute_command(self, *args, **options):
1442 """
1443 Execute a command immediately, but don't auto-retry on the supported
1444 errors for retry if we're already WATCHing a variable.
1445 Used when issuing WATCH or subsequent commands retrieving their values but before
1446 MULTI is called.
1447 """
1448 command_name = args[0]
1449 conn = self.connection
1450 # if this is the first call, we need a connection
1451 if not conn:
1452 conn = await self.connection_pool.get_connection()
1453 self.connection = conn
1454
1455 return await conn.retry.call_with_retry(
1456 lambda: self._send_command_parse_response(
1457 conn, command_name, *args, **options
1458 ),
1459 lambda error: self._disconnect_reset_raise_on_watching(conn, error),
1460 )
1461
1462 def pipeline_execute_command(self, *args, **options):
1463 """
1464 Stage a command to be executed when execute() is next called
1465
1466 Returns the current Pipeline object back so commands can be
1467 chained together, such as:
1468
1469 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1470
1471 At some other point, you can then run: pipe.execute(),
1472 which will execute all commands queued in the pipe.
1473 """
1474 self.command_stack.append((args, options))
1475 return self
1476
1477 async def _execute_transaction( # noqa: C901
1478 self, connection: Connection, commands: CommandStackT, raise_on_error
1479 ):
1480 pre: CommandT = (("MULTI",), {})
1481 post: CommandT = (("EXEC",), {})
1482 cmds = (pre, *commands, post)
1483 all_cmds = connection.pack_commands(
1484 args for args, options in cmds if EMPTY_RESPONSE not in options
1485 )
1486 await connection.send_packed_command(all_cmds)
1487 errors = []
1488
1489 # parse off the response for MULTI
1490 # NOTE: we need to handle ResponseErrors here and continue
1491 # so that we read all the additional command messages from
1492 # the socket
1493 try:
1494 await self.parse_response(connection, "_")
1495 except ResponseError as err:
1496 errors.append((0, err))
1497
1498 # and all the other commands
1499 for i, command in enumerate(commands):
1500 if EMPTY_RESPONSE in command[1]:
1501 errors.append((i, command[1][EMPTY_RESPONSE]))
1502 else:
1503 try:
1504 await self.parse_response(connection, "_")
1505 except ResponseError as err:
1506 self.annotate_exception(err, i + 1, command[0])
1507 errors.append((i, err))
1508
1509 # parse the EXEC.
1510 try:
1511 response = await self.parse_response(connection, "_")
1512 except ExecAbortError as err:
1513 if errors:
1514 raise errors[0][1] from err
1515 raise
1516
1517 # EXEC clears any watched keys
1518 self.watching = False
1519
1520 if response is None:
1521 raise WatchError("Watched variable changed.") from None
1522
1523 # put any parse errors into the response
1524 for i, e in errors:
1525 response.insert(i, e)
1526
1527 if len(response) != len(commands):
1528 if self.connection:
1529 await self.connection.disconnect()
1530 raise ResponseError(
1531 "Wrong number of response items from pipeline execution"
1532 ) from None
1533
1534 # find any errors in the response and raise if necessary
1535 if raise_on_error:
1536 self.raise_first_error(commands, response)
1537
1538 # We have to run response callbacks manually
1539 data = []
1540 for r, cmd in zip(response, commands):
1541 if not isinstance(r, Exception):
1542 args, options = cmd
1543 command_name = args[0]
1544
1545 # Remove keys entry, it needs only for cache.
1546 options.pop("keys", None)
1547
1548 if command_name in self.response_callbacks:
1549 r = self.response_callbacks[command_name](r, **options)
1550 if inspect.isawaitable(r):
1551 r = await r
1552 data.append(r)
1553 return data
1554
1555 async def _execute_pipeline(
1556 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1557 ):
1558 # build up all commands into a single request to increase network perf
1559 all_cmds = connection.pack_commands([args for args, _ in commands])
1560 await connection.send_packed_command(all_cmds)
1561
1562 response = []
1563 for args, options in commands:
1564 try:
1565 response.append(
1566 await self.parse_response(connection, args[0], **options)
1567 )
1568 except ResponseError as e:
1569 response.append(e)
1570
1571 if raise_on_error:
1572 self.raise_first_error(commands, response)
1573 return response
1574
1575 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1576 for i, r in enumerate(response):
1577 if isinstance(r, ResponseError):
1578 self.annotate_exception(r, i + 1, commands[i][0])
1579 raise r
1580
1581 def annotate_exception(
1582 self, exception: Exception, number: int, command: Iterable[object]
1583 ) -> None:
1584 cmd = " ".join(map(safe_str, command))
1585 msg = (
1586 f"Command # {number} ({truncate_text(cmd)}) "
1587 "of pipeline caused error: {exception.args}"
1588 )
1589 exception.args = (msg,) + exception.args[1:]
1590
1591 async def parse_response(
1592 self, connection: Connection, command_name: Union[str, bytes], **options
1593 ):
1594 result = await super().parse_response(connection, command_name, **options)
1595 if command_name in self.UNWATCH_COMMANDS:
1596 self.watching = False
1597 elif command_name == "WATCH":
1598 self.watching = True
1599 return result
1600
1601 async def load_scripts(self):
1602 # make sure all scripts that are about to be run on this pipeline exist
1603 scripts = list(self.scripts)
1604 immediate = self.immediate_execute_command
1605 shas = [s.sha for s in scripts]
1606 # we can't use the normal script_* methods because they would just
1607 # get buffered in the pipeline.
1608 exists = await immediate("SCRIPT EXISTS", *shas)
1609 if not all(exists):
1610 for s, exist in zip(scripts, exists):
1611 if not exist:
1612 s.sha = await immediate("SCRIPT LOAD", s.script)
1613
1614 async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
1615 """
1616 Close the connection, raise an exception if we were watching.
1617
1618 The supported exceptions are already checked in the
1619 retry object so we don't need to do it here.
1620
1621 After we disconnect the connection, it will try to reconnect and
1622 do a health check as part of the send_command logic(on connection level).
1623 """
1624 await conn.disconnect()
1625 # if we were watching a variable, the watch is no longer valid
1626 # since this connection has died. raise a WatchError, which
1627 # indicates the user should retry this transaction.
1628 if self.watching:
1629 raise WatchError(
1630 f"A {type(error).__name__} occurred while watching one or more keys"
1631 )
1632
1633 async def execute(self, raise_on_error: bool = True) -> List[Any]:
1634 """Execute all the commands in the current pipeline"""
1635 stack = self.command_stack
1636 if not stack and not self.watching:
1637 return []
1638 if self.scripts:
1639 await self.load_scripts()
1640 if self.is_transaction or self.explicit_transaction:
1641 execute = self._execute_transaction
1642 else:
1643 execute = self._execute_pipeline
1644
1645 conn = self.connection
1646 if not conn:
1647 conn = await self.connection_pool.get_connection()
1648 # assign to self.connection so reset() releases the connection
1649 # back to the pool after we're done
1650 self.connection = conn
1651 conn = cast(Connection, conn)
1652
1653 try:
1654 return await conn.retry.call_with_retry(
1655 lambda: execute(conn, stack, raise_on_error),
1656 lambda error: self._disconnect_raise_on_watching(conn, error),
1657 )
1658 finally:
1659 await self.reset()
1660
1661 async def discard(self):
1662 """Flushes all previously queued commands
1663 See: https://redis.io/commands/DISCARD
1664 """
1665 await self.execute_command("DISCARD")
1666
1667 async def watch(self, *names: KeyT):
1668 """Watches the values at keys ``names``"""
1669 if self.explicit_transaction:
1670 raise RedisError("Cannot issue a WATCH after a MULTI")
1671 return await self.execute_command("WATCH", *names)
1672
1673 async def unwatch(self):
1674 """Unwatches all previously specified keys"""
1675 return self.watching and await self.execute_command("UNWATCH") or True