1import asyncio
2import copy
3import inspect
4import re
5import warnings
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 AsyncIterator,
10 Awaitable,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 MutableMapping,
17 Optional,
18 Protocol,
19 Set,
20 Tuple,
21 Type,
22 TypedDict,
23 TypeVar,
24 Union,
25 cast,
26)
27
28from redis._parsers.helpers import (
29 _RedisCallbacks,
30 _RedisCallbacksRESP2,
31 _RedisCallbacksRESP3,
32 bool_ok,
33)
34from redis.asyncio.connection import (
35 Connection,
36 ConnectionPool,
37 SSLConnection,
38 UnixDomainSocketConnection,
39)
40from redis.asyncio.lock import Lock
41from redis.asyncio.retry import Retry
42from redis.backoff import ExponentialWithJitterBackoff
43from redis.client import (
44 EMPTY_RESPONSE,
45 NEVER_DECODE,
46 AbstractRedis,
47 CaseInsensitiveDict,
48)
49from redis.commands import (
50 AsyncCoreCommands,
51 AsyncRedisModuleCommands,
52 AsyncSentinelCommands,
53 list_or_args,
54)
55from redis.credentials import CredentialProvider
56from redis.event import (
57 AfterPooledConnectionsInstantiationEvent,
58 AfterPubSubConnectionInstantiationEvent,
59 AfterSingleConnectionInstantiationEvent,
60 ClientType,
61 EventDispatcher,
62)
63from redis.exceptions import (
64 ConnectionError,
65 ExecAbortError,
66 PubSubError,
67 RedisError,
68 ResponseError,
69 WatchError,
70)
71from redis.typing import ChannelT, EncodableT, KeyT
72from redis.utils import (
73 SSL_AVAILABLE,
74 _set_info_logger,
75 deprecated_args,
76 deprecated_function,
77 get_lib_version,
78 safe_str,
79 str_if_bytes,
80 truncate_text,
81)
82
83if TYPE_CHECKING and SSL_AVAILABLE:
84 from ssl import TLSVersion, VerifyMode
85else:
86 TLSVersion = None
87 VerifyMode = None
88
89PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
90_KeyT = TypeVar("_KeyT", bound=KeyT)
91_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
92_RedisT = TypeVar("_RedisT", bound="Redis")
93_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
94if TYPE_CHECKING:
95 from redis.commands.core import Script
96
97
98class ResponseCallbackProtocol(Protocol):
99 def __call__(self, response: Any, **kwargs): ...
100
101
102class AsyncResponseCallbackProtocol(Protocol):
103 async def __call__(self, response: Any, **kwargs): ...
104
105
106ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
107
108
109class Redis(
110 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
111):
112 """
113 Implementation of the Redis protocol.
114
115 This abstract class provides a Python interface to all Redis commands
116 and an implementation of the Redis protocol.
117
118 Pipelines derive from this, implementing how
119 the commands are sent and received to the Redis server. Based on
120 configuration, an instance will either use a ConnectionPool, or
121 Connection object to talk to redis.
122 """
123
124 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
125
126 @classmethod
127 def from_url(
128 cls,
129 url: str,
130 single_connection_client: bool = False,
131 auto_close_connection_pool: Optional[bool] = None,
132 **kwargs,
133 ):
134 """
135 Return a Redis client object configured from the given URL
136
137 For example::
138
139 redis://[[username]:[password]]@localhost:6379/0
140 rediss://[[username]:[password]]@localhost:6379/0
141 unix://[username@]/path/to/socket.sock?db=0[&password=password]
142
143 Three URL schemes are supported:
144
145 - `redis://` creates a TCP socket connection. See more at:
146 <https://www.iana.org/assignments/uri-schemes/prov/redis>
147 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
148 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
149 - ``unix://``: creates a Unix Domain Socket connection.
150
151 The username, password, hostname, path and all querystring values
152 are passed through urllib.parse.unquote in order to replace any
153 percent-encoded values with their corresponding characters.
154
155 There are several ways to specify a database number. The first value
156 found will be used:
157
158 1. A ``db`` querystring option, e.g. redis://localhost?db=0
159
160 2. If using the redis:// or rediss:// schemes, the path argument
161 of the url, e.g. redis://localhost/0
162
163 3. A ``db`` keyword argument to this function.
164
165 If none of these options are specified, the default db=0 is used.
166
167 All querystring options are cast to their appropriate Python types.
168 Boolean arguments can be specified with string values "True"/"False"
169 or "Yes"/"No". Values that cannot be properly cast cause a
170 ``ValueError`` to be raised. Once parsed, the querystring arguments
171 and keyword arguments are passed to the ``ConnectionPool``'s
172 class initializer. In the case of conflicting arguments, querystring
173 arguments always win.
174
175 """
176 connection_pool = ConnectionPool.from_url(url, **kwargs)
177 client = cls(
178 connection_pool=connection_pool,
179 single_connection_client=single_connection_client,
180 )
181 if auto_close_connection_pool is not None:
182 warnings.warn(
183 DeprecationWarning(
184 '"auto_close_connection_pool" is deprecated '
185 "since version 5.0.1. "
186 "Please create a ConnectionPool explicitly and "
187 "provide to the Redis() constructor instead."
188 )
189 )
190 else:
191 auto_close_connection_pool = True
192 client.auto_close_connection_pool = auto_close_connection_pool
193 return client
194
195 @classmethod
196 def from_pool(
197 cls: Type["Redis"],
198 connection_pool: ConnectionPool,
199 ) -> "Redis":
200 """
201 Return a Redis client from the given connection pool.
202 The Redis client will take ownership of the connection pool and
203 close it when the Redis client is closed.
204 """
205 client = cls(
206 connection_pool=connection_pool,
207 )
208 client.auto_close_connection_pool = True
209 return client
210
211 @deprecated_args(
212 args_to_warn=["retry_on_timeout"],
213 reason="TimeoutError is included by default.",
214 version="6.0.0",
215 )
216 def __init__(
217 self,
218 *,
219 host: str = "localhost",
220 port: int = 6379,
221 db: Union[str, int] = 0,
222 password: Optional[str] = None,
223 socket_timeout: Optional[float] = None,
224 socket_connect_timeout: Optional[float] = None,
225 socket_keepalive: Optional[bool] = None,
226 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
227 connection_pool: Optional[ConnectionPool] = None,
228 unix_socket_path: Optional[str] = None,
229 encoding: str = "utf-8",
230 encoding_errors: str = "strict",
231 decode_responses: bool = False,
232 retry_on_timeout: bool = False,
233 retry: Retry = Retry(
234 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
235 ),
236 retry_on_error: Optional[list] = None,
237 ssl: bool = False,
238 ssl_keyfile: Optional[str] = None,
239 ssl_certfile: Optional[str] = None,
240 ssl_cert_reqs: Union[str, VerifyMode] = "required",
241 ssl_ca_certs: Optional[str] = None,
242 ssl_ca_data: Optional[str] = None,
243 ssl_check_hostname: bool = True,
244 ssl_min_version: Optional[TLSVersion] = None,
245 ssl_ciphers: Optional[str] = None,
246 max_connections: Optional[int] = None,
247 single_connection_client: bool = False,
248 health_check_interval: int = 0,
249 client_name: Optional[str] = None,
250 lib_name: Optional[str] = "redis-py",
251 lib_version: Optional[str] = get_lib_version(),
252 username: Optional[str] = None,
253 auto_close_connection_pool: Optional[bool] = None,
254 redis_connect_func=None,
255 credential_provider: Optional[CredentialProvider] = None,
256 protocol: Optional[int] = 2,
257 event_dispatcher: Optional[EventDispatcher] = None,
258 ):
259 """
260 Initialize a new Redis client.
261
262 To specify a retry policy for specific errors, you have two options:
263
264 1. Set the `retry_on_error` to a list of the error/s to retry on, and
265 you can also set `retry` to a valid `Retry` object(in case the default
266 one is not appropriate) - with this approach the retries will be triggered
267 on the default errors specified in the Retry object enriched with the
268 errors specified in `retry_on_error`.
269
270 2. Define a `Retry` object with configured 'supported_errors' and set
271 it to the `retry` parameter - with this approach you completely redefine
272 the errors on which retries will happen.
273
274 `retry_on_timeout` is deprecated - please include the TimeoutError
275 either in the Retry object or in the `retry_on_error` list.
276
277 When 'connection_pool' is provided - the retry configuration of the
278 provided pool will be used.
279 """
280 kwargs: Dict[str, Any]
281 if event_dispatcher is None:
282 self._event_dispatcher = EventDispatcher()
283 else:
284 self._event_dispatcher = event_dispatcher
285 # auto_close_connection_pool only has an effect if connection_pool is
286 # None. It is assumed that if connection_pool is not None, the user
287 # wants to manage the connection pool themselves.
288 if auto_close_connection_pool is not None:
289 warnings.warn(
290 DeprecationWarning(
291 '"auto_close_connection_pool" is deprecated '
292 "since version 5.0.1. "
293 "Please create a ConnectionPool explicitly and "
294 "provide to the Redis() constructor instead."
295 )
296 )
297 else:
298 auto_close_connection_pool = True
299
300 if not connection_pool:
301 # Create internal connection pool, expected to be closed by Redis instance
302 if not retry_on_error:
303 retry_on_error = []
304 kwargs = {
305 "db": db,
306 "username": username,
307 "password": password,
308 "credential_provider": credential_provider,
309 "socket_timeout": socket_timeout,
310 "encoding": encoding,
311 "encoding_errors": encoding_errors,
312 "decode_responses": decode_responses,
313 "retry_on_error": retry_on_error,
314 "retry": copy.deepcopy(retry),
315 "max_connections": max_connections,
316 "health_check_interval": health_check_interval,
317 "client_name": client_name,
318 "lib_name": lib_name,
319 "lib_version": lib_version,
320 "redis_connect_func": redis_connect_func,
321 "protocol": protocol,
322 }
323 # based on input, setup appropriate connection args
324 if unix_socket_path is not None:
325 kwargs.update(
326 {
327 "path": unix_socket_path,
328 "connection_class": UnixDomainSocketConnection,
329 }
330 )
331 else:
332 # TCP specific options
333 kwargs.update(
334 {
335 "host": host,
336 "port": port,
337 "socket_connect_timeout": socket_connect_timeout,
338 "socket_keepalive": socket_keepalive,
339 "socket_keepalive_options": socket_keepalive_options,
340 }
341 )
342
343 if ssl:
344 kwargs.update(
345 {
346 "connection_class": SSLConnection,
347 "ssl_keyfile": ssl_keyfile,
348 "ssl_certfile": ssl_certfile,
349 "ssl_cert_reqs": ssl_cert_reqs,
350 "ssl_ca_certs": ssl_ca_certs,
351 "ssl_ca_data": ssl_ca_data,
352 "ssl_check_hostname": ssl_check_hostname,
353 "ssl_min_version": ssl_min_version,
354 "ssl_ciphers": ssl_ciphers,
355 }
356 )
357 # This arg only used if no pool is passed in
358 self.auto_close_connection_pool = auto_close_connection_pool
359 connection_pool = ConnectionPool(**kwargs)
360 self._event_dispatcher.dispatch(
361 AfterPooledConnectionsInstantiationEvent(
362 [connection_pool], ClientType.ASYNC, credential_provider
363 )
364 )
365 else:
366 # If a pool is passed in, do not close it
367 self.auto_close_connection_pool = False
368 self._event_dispatcher.dispatch(
369 AfterPooledConnectionsInstantiationEvent(
370 [connection_pool], ClientType.ASYNC, credential_provider
371 )
372 )
373
374 self.connection_pool = connection_pool
375 self.single_connection_client = single_connection_client
376 self.connection: Optional[Connection] = None
377
378 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
379
380 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
381 self.response_callbacks.update(_RedisCallbacksRESP3)
382 else:
383 self.response_callbacks.update(_RedisCallbacksRESP2)
384
385 # If using a single connection client, we need to lock creation-of and use-of
386 # the client in order to avoid race conditions such as using asyncio.gather
387 # on a set of redis commands
388 self._single_conn_lock = asyncio.Lock()
389
390 # When used as an async context manager, we need to increment and decrement
391 # a usage counter so that we can close the connection pool when no one is
392 # using the client.
393 self._usage_counter = 0
394 self._usage_lock = asyncio.Lock()
395
396 def __repr__(self):
397 return (
398 f"<{self.__class__.__module__}.{self.__class__.__name__}"
399 f"({self.connection_pool!r})>"
400 )
401
402 def __await__(self):
403 return self.initialize().__await__()
404
405 async def initialize(self: _RedisT) -> _RedisT:
406 if self.single_connection_client:
407 async with self._single_conn_lock:
408 if self.connection is None:
409 self.connection = await self.connection_pool.get_connection()
410
411 self._event_dispatcher.dispatch(
412 AfterSingleConnectionInstantiationEvent(
413 self.connection, ClientType.ASYNC, self._single_conn_lock
414 )
415 )
416 return self
417
418 def set_response_callback(self, command: str, callback: ResponseCallbackT):
419 """Set a custom Response Callback"""
420 self.response_callbacks[command] = callback
421
422 def get_encoder(self):
423 """Get the connection pool's encoder"""
424 return self.connection_pool.get_encoder()
425
426 def get_connection_kwargs(self):
427 """Get the connection's key-word arguments"""
428 return self.connection_pool.connection_kwargs
429
430 def get_retry(self) -> Optional[Retry]:
431 return self.get_connection_kwargs().get("retry")
432
433 def set_retry(self, retry: Retry) -> None:
434 self.get_connection_kwargs().update({"retry": retry})
435 self.connection_pool.set_retry(retry)
436
437 def load_external_module(self, funcname, func):
438 """
439 This function can be used to add externally defined redis modules,
440 and their namespaces to the redis client.
441
442 funcname - A string containing the name of the function to create
443 func - The function, being added to this class.
444
445 ex: Assume that one has a custom redis module named foomod that
446 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
447 To load function functions into this namespace:
448
449 from redis import Redis
450 from foomodule import F
451 r = Redis()
452 r.load_external_module("foo", F)
453 r.foo().dothing('your', 'arguments')
454
455 For a concrete example see the reimport of the redisjson module in
456 tests/test_connection.py::test_loading_external_modules
457 """
458 setattr(self, funcname, func)
459
460 def pipeline(
461 self, transaction: bool = True, shard_hint: Optional[str] = None
462 ) -> "Pipeline":
463 """
464 Return a new pipeline object that can queue multiple commands for
465 later execution. ``transaction`` indicates whether all commands
466 should be executed atomically. Apart from making a group of operations
467 atomic, pipelines are useful for reducing the back-and-forth overhead
468 between the client and server.
469 """
470 return Pipeline(
471 self.connection_pool, self.response_callbacks, transaction, shard_hint
472 )
473
474 async def transaction(
475 self,
476 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
477 *watches: KeyT,
478 shard_hint: Optional[str] = None,
479 value_from_callable: bool = False,
480 watch_delay: Optional[float] = None,
481 ):
482 """
483 Convenience method for executing the callable `func` as a transaction
484 while watching all keys specified in `watches`. The 'func' callable
485 should expect a single argument which is a Pipeline object.
486 """
487 pipe: Pipeline
488 async with self.pipeline(True, shard_hint) as pipe:
489 while True:
490 try:
491 if watches:
492 await pipe.watch(*watches)
493 func_value = func(pipe)
494 if inspect.isawaitable(func_value):
495 func_value = await func_value
496 exec_value = await pipe.execute()
497 return func_value if value_from_callable else exec_value
498 except WatchError:
499 if watch_delay is not None and watch_delay > 0:
500 await asyncio.sleep(watch_delay)
501 continue
502
503 def lock(
504 self,
505 name: KeyT,
506 timeout: Optional[float] = None,
507 sleep: float = 0.1,
508 blocking: bool = True,
509 blocking_timeout: Optional[float] = None,
510 lock_class: Optional[Type[Lock]] = None,
511 thread_local: bool = True,
512 raise_on_release_error: bool = True,
513 ) -> Lock:
514 """
515 Return a new Lock object using key ``name`` that mimics
516 the behavior of threading.Lock.
517
518 If specified, ``timeout`` indicates a maximum life for the lock.
519 By default, it will remain locked until release() is called.
520
521 ``sleep`` indicates the amount of time to sleep per loop iteration
522 when the lock is in blocking mode and another client is currently
523 holding the lock.
524
525 ``blocking`` indicates whether calling ``acquire`` should block until
526 the lock has been acquired or to fail immediately, causing ``acquire``
527 to return False and the lock not being acquired. Defaults to True.
528 Note this value can be overridden by passing a ``blocking``
529 argument to ``acquire``.
530
531 ``blocking_timeout`` indicates the maximum amount of time in seconds to
532 spend trying to acquire the lock. A value of ``None`` indicates
533 continue trying forever. ``blocking_timeout`` can be specified as a
534 float or integer, both representing the number of seconds to wait.
535
536 ``lock_class`` forces the specified lock implementation. Note that as
537 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
538 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
539 you have created your own custom lock class.
540
541 ``thread_local`` indicates whether the lock token is placed in
542 thread-local storage. By default, the token is placed in thread local
543 storage so that a thread only sees its token, not a token set by
544 another thread. Consider the following timeline:
545
546 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
547 thread-1 sets the token to "abc"
548 time: 1, thread-2 blocks trying to acquire `my-lock` using the
549 Lock instance.
550 time: 5, thread-1 has not yet completed. redis expires the lock
551 key.
552 time: 5, thread-2 acquired `my-lock` now that it's available.
553 thread-2 sets the token to "xyz"
554 time: 6, thread-1 finishes its work and calls release(). if the
555 token is *not* stored in thread local storage, then
556 thread-1 would see the token value as "xyz" and would be
557 able to successfully release the thread-2's lock.
558
559 ``raise_on_release_error`` indicates whether to raise an exception when
560 the lock is no longer owned when exiting the context manager. By default,
561 this is True, meaning an exception will be raised. If False, the warning
562 will be logged and the exception will be suppressed.
563
564 In some use cases it's necessary to disable thread local storage. For
565 example, if you have code where one thread acquires a lock and passes
566 that lock instance to a worker thread to release later. If thread
567 local storage isn't disabled in this case, the worker thread won't see
568 the token set by the thread that acquired the lock. Our assumption
569 is that these cases aren't common and as such default to using
570 thread local storage."""
571 if lock_class is None:
572 lock_class = Lock
573 return lock_class(
574 self,
575 name,
576 timeout=timeout,
577 sleep=sleep,
578 blocking=blocking,
579 blocking_timeout=blocking_timeout,
580 thread_local=thread_local,
581 raise_on_release_error=raise_on_release_error,
582 )
583
584 def pubsub(self, **kwargs) -> "PubSub":
585 """
586 Return a Publish/Subscribe object. With this object, you can
587 subscribe to channels and listen for messages that get published to
588 them.
589 """
590 return PubSub(
591 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
592 )
593
594 def monitor(self) -> "Monitor":
595 return Monitor(self.connection_pool)
596
597 def client(self) -> "Redis":
598 return self.__class__(
599 connection_pool=self.connection_pool, single_connection_client=True
600 )
601
602 async def __aenter__(self: _RedisT) -> _RedisT:
603 """
604 Async context manager entry. Increments a usage counter so that the
605 connection pool is only closed (via aclose()) when no context is using
606 the client.
607 """
608 await self._increment_usage()
609 try:
610 # Initialize the client (i.e. establish connection, etc.)
611 return await self.initialize()
612 except Exception:
613 # If initialization fails, decrement the counter to keep it in sync
614 await self._decrement_usage()
615 raise
616
617 async def _increment_usage(self) -> int:
618 """
619 Helper coroutine to increment the usage counter while holding the lock.
620 Returns the new value of the usage counter.
621 """
622 async with self._usage_lock:
623 self._usage_counter += 1
624 return self._usage_counter
625
626 async def _decrement_usage(self) -> int:
627 """
628 Helper coroutine to decrement the usage counter while holding the lock.
629 Returns the new value of the usage counter.
630 """
631 async with self._usage_lock:
632 self._usage_counter -= 1
633 return self._usage_counter
634
635 async def __aexit__(self, exc_type, exc_value, traceback):
636 """
637 Async context manager exit. Decrements a usage counter. If this is the
638 last exit (counter becomes zero), the client closes its connection pool.
639 """
640 current_usage = await asyncio.shield(self._decrement_usage())
641 if current_usage == 0:
642 # This was the last active context, so disconnect the pool.
643 await asyncio.shield(self.aclose())
644
645 _DEL_MESSAGE = "Unclosed Redis client"
646
647 # passing _warnings and _grl as argument default since they may be gone
648 # by the time __del__ is called at shutdown
649 def __del__(
650 self,
651 _warn: Any = warnings.warn,
652 _grl: Any = asyncio.get_running_loop,
653 ) -> None:
654 if hasattr(self, "connection") and (self.connection is not None):
655 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
656 try:
657 context = {"client": self, "message": self._DEL_MESSAGE}
658 _grl().call_exception_handler(context)
659 except RuntimeError:
660 pass
661 self.connection._close()
662
663 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
664 """
665 Closes Redis client connection
666
667 Args:
668 close_connection_pool:
669 decides whether to close the connection pool used by this Redis client,
670 overriding Redis.auto_close_connection_pool.
671 By default, let Redis.auto_close_connection_pool decide
672 whether to close the connection pool.
673 """
674 conn = self.connection
675 if conn:
676 self.connection = None
677 await self.connection_pool.release(conn)
678 if close_connection_pool or (
679 close_connection_pool is None and self.auto_close_connection_pool
680 ):
681 await self.connection_pool.disconnect()
682
683 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
684 async def close(self, close_connection_pool: Optional[bool] = None) -> None:
685 """
686 Alias for aclose(), for backwards compatibility
687 """
688 await self.aclose(close_connection_pool)
689
690 async def _send_command_parse_response(self, conn, command_name, *args, **options):
691 """
692 Send a command and parse the response
693 """
694 await conn.send_command(*args)
695 return await self.parse_response(conn, command_name, **options)
696
697 async def _close_connection(self, conn: Connection):
698 """
699 Close the connection before retrying.
700
701 The supported exceptions are already checked in the
702 retry object so we don't need to do it here.
703
704 After we disconnect the connection, it will try to reconnect and
705 do a health check as part of the send_command logic(on connection level).
706 """
707 await conn.disconnect()
708
709 # COMMAND EXECUTION AND PROTOCOL PARSING
710 async def execute_command(self, *args, **options):
711 """Execute a command and return a parsed response"""
712 await self.initialize()
713 pool = self.connection_pool
714 command_name = args[0]
715 conn = self.connection or await pool.get_connection()
716
717 if self.single_connection_client:
718 await self._single_conn_lock.acquire()
719 try:
720 return await conn.retry.call_with_retry(
721 lambda: self._send_command_parse_response(
722 conn, command_name, *args, **options
723 ),
724 lambda _: self._close_connection(conn),
725 )
726 finally:
727 if self.single_connection_client:
728 self._single_conn_lock.release()
729 if not self.connection:
730 await pool.release(conn)
731
732 async def parse_response(
733 self, connection: Connection, command_name: Union[str, bytes], **options
734 ):
735 """Parses a response from the Redis server"""
736 try:
737 if NEVER_DECODE in options:
738 response = await connection.read_response(disable_decoding=True)
739 options.pop(NEVER_DECODE)
740 else:
741 response = await connection.read_response()
742 except ResponseError:
743 if EMPTY_RESPONSE in options:
744 return options[EMPTY_RESPONSE]
745 raise
746
747 if EMPTY_RESPONSE in options:
748 options.pop(EMPTY_RESPONSE)
749
750 # Remove keys entry, it needs only for cache.
751 options.pop("keys", None)
752
753 if command_name in self.response_callbacks:
754 # Mypy bug: https://github.com/python/mypy/issues/10977
755 command_name = cast(str, command_name)
756 retval = self.response_callbacks[command_name](response, **options)
757 return await retval if inspect.isawaitable(retval) else retval
758 return response
759
760
761StrictRedis = Redis
762
763
764class MonitorCommandInfo(TypedDict):
765 time: float
766 db: int
767 client_address: str
768 client_port: str
769 client_type: str
770 command: str
771
772
773class Monitor:
774 """
775 Monitor is useful for handling the MONITOR command to the redis server.
776 next_command() method returns one command from monitor
777 listen() method yields commands from monitor.
778 """
779
780 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
781 command_re = re.compile(r'"(.*?)(?<!\\)"')
782
783 def __init__(self, connection_pool: ConnectionPool):
784 self.connection_pool = connection_pool
785 self.connection: Optional[Connection] = None
786
787 async def connect(self):
788 if self.connection is None:
789 self.connection = await self.connection_pool.get_connection()
790
791 async def __aenter__(self):
792 await self.connect()
793 await self.connection.send_command("MONITOR")
794 # check that monitor returns 'OK', but don't return it to user
795 response = await self.connection.read_response()
796 if not bool_ok(response):
797 raise RedisError(f"MONITOR failed: {response}")
798 return self
799
800 async def __aexit__(self, *args):
801 await self.connection.disconnect()
802 await self.connection_pool.release(self.connection)
803
804 async def next_command(self) -> MonitorCommandInfo:
805 """Parse the response from a monitor command"""
806 await self.connect()
807 response = await self.connection.read_response()
808 if isinstance(response, bytes):
809 response = self.connection.encoder.decode(response, force=True)
810 command_time, command_data = response.split(" ", 1)
811 m = self.monitor_re.match(command_data)
812 db_id, client_info, command = m.groups()
813 command = " ".join(self.command_re.findall(command))
814 # Redis escapes double quotes because each piece of the command
815 # string is surrounded by double quotes. We don't have that
816 # requirement so remove the escaping and leave the quote.
817 command = command.replace('\\"', '"')
818
819 if client_info == "lua":
820 client_address = "lua"
821 client_port = ""
822 client_type = "lua"
823 elif client_info.startswith("unix"):
824 client_address = "unix"
825 client_port = client_info[5:]
826 client_type = "unix"
827 else:
828 # use rsplit as ipv6 addresses contain colons
829 client_address, client_port = client_info.rsplit(":", 1)
830 client_type = "tcp"
831 return {
832 "time": float(command_time),
833 "db": int(db_id),
834 "client_address": client_address,
835 "client_port": client_port,
836 "client_type": client_type,
837 "command": command,
838 }
839
840 async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
841 """Listen for commands coming to the server."""
842 while True:
843 yield await self.next_command()
844
845
846class PubSub:
847 """
848 PubSub provides publish, subscribe and listen support to Redis channels.
849
850 After subscribing to one or more channels, the listen() method will block
851 until a message arrives on one of the subscribed channels. That message
852 will be returned and it's safe to start listening again.
853 """
854
855 PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
856 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
857 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
858
859 def __init__(
860 self,
861 connection_pool: ConnectionPool,
862 shard_hint: Optional[str] = None,
863 ignore_subscribe_messages: bool = False,
864 encoder=None,
865 push_handler_func: Optional[Callable] = None,
866 event_dispatcher: Optional["EventDispatcher"] = None,
867 ):
868 if event_dispatcher is None:
869 self._event_dispatcher = EventDispatcher()
870 else:
871 self._event_dispatcher = event_dispatcher
872 self.connection_pool = connection_pool
873 self.shard_hint = shard_hint
874 self.ignore_subscribe_messages = ignore_subscribe_messages
875 self.connection = None
876 # we need to know the encoding options for this connection in order
877 # to lookup channel and pattern names for callback handlers.
878 self.encoder = encoder
879 self.push_handler_func = push_handler_func
880 if self.encoder is None:
881 self.encoder = self.connection_pool.get_encoder()
882 if self.encoder.decode_responses:
883 self.health_check_response = [
884 ["pong", self.HEALTH_CHECK_MESSAGE],
885 self.HEALTH_CHECK_MESSAGE,
886 ]
887 else:
888 self.health_check_response = [
889 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
890 self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
891 ]
892 if self.push_handler_func is None:
893 _set_info_logger()
894 self.channels = {}
895 self.pending_unsubscribe_channels = set()
896 self.patterns = {}
897 self.pending_unsubscribe_patterns = set()
898 self._lock = asyncio.Lock()
899
900 async def __aenter__(self):
901 return self
902
903 async def __aexit__(self, exc_type, exc_value, traceback):
904 await self.aclose()
905
906 def __del__(self):
907 if self.connection:
908 self.connection.deregister_connect_callback(self.on_connect)
909
910 async def aclose(self):
911 # In case a connection property does not yet exist
912 # (due to a crash earlier in the Redis() constructor), return
913 # immediately as there is nothing to clean-up.
914 if not hasattr(self, "connection"):
915 return
916 async with self._lock:
917 if self.connection:
918 await self.connection.disconnect()
919 self.connection.deregister_connect_callback(self.on_connect)
920 await self.connection_pool.release(self.connection)
921 self.connection = None
922 self.channels = {}
923 self.pending_unsubscribe_channels = set()
924 self.patterns = {}
925 self.pending_unsubscribe_patterns = set()
926
927 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
928 async def close(self) -> None:
929 """Alias for aclose(), for backwards compatibility"""
930 await self.aclose()
931
932 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
933 async def reset(self) -> None:
934 """Alias for aclose(), for backwards compatibility"""
935 await self.aclose()
936
937 async def on_connect(self, connection: Connection):
938 """Re-subscribe to any channels and patterns previously subscribed to"""
939 # NOTE: for python3, we can't pass bytestrings as keyword arguments
940 # so we need to decode channel/pattern names back to unicode strings
941 # before passing them to [p]subscribe.
942 self.pending_unsubscribe_channels.clear()
943 self.pending_unsubscribe_patterns.clear()
944 if self.channels:
945 channels = {}
946 for k, v in self.channels.items():
947 channels[self.encoder.decode(k, force=True)] = v
948 await self.subscribe(**channels)
949 if self.patterns:
950 patterns = {}
951 for k, v in self.patterns.items():
952 patterns[self.encoder.decode(k, force=True)] = v
953 await self.psubscribe(**patterns)
954
955 @property
956 def subscribed(self):
957 """Indicates if there are subscriptions to any channels or patterns"""
958 return bool(self.channels or self.patterns)
959
960 async def execute_command(self, *args: EncodableT):
961 """Execute a publish/subscribe command"""
962
963 # NOTE: don't parse the response in this function -- it could pull a
964 # legitimate message off the stack if the connection is already
965 # subscribed to one or more channels
966
967 await self.connect()
968 connection = self.connection
969 kwargs = {"check_health": not self.subscribed}
970 await self._execute(connection, connection.send_command, *args, **kwargs)
971
972 async def connect(self):
973 """
974 Ensure that the PubSub is connected
975 """
976 if self.connection is None:
977 self.connection = await self.connection_pool.get_connection()
978 # register a callback that re-subscribes to any channels we
979 # were listening to when we were disconnected
980 self.connection.register_connect_callback(self.on_connect)
981 else:
982 await self.connection.connect()
983 if self.push_handler_func is not None:
984 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
985
986 self._event_dispatcher.dispatch(
987 AfterPubSubConnectionInstantiationEvent(
988 self.connection, self.connection_pool, ClientType.ASYNC, self._lock
989 )
990 )
991
992 async def _reconnect(self, conn):
993 """
994 Try to reconnect
995 """
996 await conn.disconnect()
997 await conn.connect()
998
999 async def _execute(self, conn, command, *args, **kwargs):
1000 """
1001 Connect manually upon disconnection. If the Redis server is down,
1002 this will fail and raise a ConnectionError as desired.
1003 After reconnection, the ``on_connect`` callback should have been
1004 called by the # connection to resubscribe us to any channels and
1005 patterns we were previously listening to
1006 """
1007 return await conn.retry.call_with_retry(
1008 lambda: command(*args, **kwargs),
1009 lambda _: self._reconnect(conn),
1010 )
1011
1012 async def parse_response(self, block: bool = True, timeout: float = 0):
1013 """Parse the response from a publish/subscribe command"""
1014 conn = self.connection
1015 if conn is None:
1016 raise RuntimeError(
1017 "pubsub connection not set: "
1018 "did you forget to call subscribe() or psubscribe()?"
1019 )
1020
1021 await self.check_health()
1022
1023 if not conn.is_connected:
1024 await conn.connect()
1025
1026 read_timeout = None if block else timeout
1027 response = await self._execute(
1028 conn,
1029 conn.read_response,
1030 timeout=read_timeout,
1031 disconnect_on_error=False,
1032 push_request=True,
1033 )
1034
1035 if conn.health_check_interval and response in self.health_check_response:
1036 # ignore the health check message as user might not expect it
1037 return None
1038 return response
1039
1040 async def check_health(self):
1041 conn = self.connection
1042 if conn is None:
1043 raise RuntimeError(
1044 "pubsub connection not set: "
1045 "did you forget to call subscribe() or psubscribe()?"
1046 )
1047
1048 if (
1049 conn.health_check_interval
1050 and asyncio.get_running_loop().time() > conn.next_health_check
1051 ):
1052 await conn.send_command(
1053 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
1054 )
1055
1056 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
1057 """
1058 normalize channel/pattern names to be either bytes or strings
1059 based on whether responses are automatically decoded. this saves us
1060 from coercing the value for each message coming in.
1061 """
1062 encode = self.encoder.encode
1063 decode = self.encoder.decode
1064 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
1065
1066 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
1067 """
1068 Subscribe to channel patterns. Patterns supplied as keyword arguments
1069 expect a pattern name as the key and a callable as the value. A
1070 pattern's callable will be invoked automatically when a message is
1071 received on that pattern rather than producing a message via
1072 ``listen()``.
1073 """
1074 parsed_args = list_or_args((args[0],), args[1:]) if args else args
1075 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
1076 # Mypy bug: https://github.com/python/mypy/issues/10970
1077 new_patterns.update(kwargs) # type: ignore[arg-type]
1078 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
1079 # update the patterns dict AFTER we send the command. we don't want to
1080 # subscribe twice to these patterns, once for the command and again
1081 # for the reconnection.
1082 new_patterns = self._normalize_keys(new_patterns)
1083 self.patterns.update(new_patterns)
1084 self.pending_unsubscribe_patterns.difference_update(new_patterns)
1085 return ret_val
1086
1087 def punsubscribe(self, *args: ChannelT) -> Awaitable:
1088 """
1089 Unsubscribe from the supplied patterns. If empty, unsubscribe from
1090 all patterns.
1091 """
1092 patterns: Iterable[ChannelT]
1093 if args:
1094 parsed_args = list_or_args((args[0],), args[1:])
1095 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
1096 else:
1097 parsed_args = []
1098 patterns = self.patterns
1099 self.pending_unsubscribe_patterns.update(patterns)
1100 return self.execute_command("PUNSUBSCRIBE", *parsed_args)
1101
1102 async def subscribe(self, *args: ChannelT, **kwargs: Callable):
1103 """
1104 Subscribe to channels. Channels supplied as keyword arguments expect
1105 a channel name as the key and a callable as the value. A channel's
1106 callable will be invoked automatically when a message is received on
1107 that channel rather than producing a message via ``listen()`` or
1108 ``get_message()``.
1109 """
1110 parsed_args = list_or_args((args[0],), args[1:]) if args else ()
1111 new_channels = dict.fromkeys(parsed_args)
1112 # Mypy bug: https://github.com/python/mypy/issues/10970
1113 new_channels.update(kwargs) # type: ignore[arg-type]
1114 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
1115 # update the channels dict AFTER we send the command. we don't want to
1116 # subscribe twice to these channels, once for the command and again
1117 # for the reconnection.
1118 new_channels = self._normalize_keys(new_channels)
1119 self.channels.update(new_channels)
1120 self.pending_unsubscribe_channels.difference_update(new_channels)
1121 return ret_val
1122
1123 def unsubscribe(self, *args) -> Awaitable:
1124 """
1125 Unsubscribe from the supplied channels. If empty, unsubscribe from
1126 all channels
1127 """
1128 if args:
1129 parsed_args = list_or_args(args[0], args[1:])
1130 channels = self._normalize_keys(dict.fromkeys(parsed_args))
1131 else:
1132 parsed_args = []
1133 channels = self.channels
1134 self.pending_unsubscribe_channels.update(channels)
1135 return self.execute_command("UNSUBSCRIBE", *parsed_args)
1136
1137 async def listen(self) -> AsyncIterator:
1138 """Listen for messages on channels this client has been subscribed to"""
1139 while self.subscribed:
1140 response = await self.handle_message(await self.parse_response(block=True))
1141 if response is not None:
1142 yield response
1143
1144 async def get_message(
1145 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
1146 ):
1147 """
1148 Get the next message if one is available, otherwise None.
1149
1150 If timeout is specified, the system will wait for `timeout` seconds
1151 before returning. Timeout should be specified as a floating point
1152 number or None to wait indefinitely.
1153 """
1154 response = await self.parse_response(block=(timeout is None), timeout=timeout)
1155 if response:
1156 return await self.handle_message(response, ignore_subscribe_messages)
1157 return None
1158
1159 def ping(self, message=None) -> Awaitable:
1160 """
1161 Ping the Redis server
1162 """
1163 args = ["PING", message] if message is not None else ["PING"]
1164 return self.execute_command(*args)
1165
1166 async def handle_message(self, response, ignore_subscribe_messages=False):
1167 """
1168 Parses a pub/sub message. If the channel or pattern was subscribed to
1169 with a message handler, the handler is invoked instead of a parsed
1170 message being returned.
1171 """
1172 if response is None:
1173 return None
1174 if isinstance(response, bytes):
1175 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1176 message_type = str_if_bytes(response[0])
1177 if message_type == "pmessage":
1178 message = {
1179 "type": message_type,
1180 "pattern": response[1],
1181 "channel": response[2],
1182 "data": response[3],
1183 }
1184 elif message_type == "pong":
1185 message = {
1186 "type": message_type,
1187 "pattern": None,
1188 "channel": None,
1189 "data": response[1],
1190 }
1191 else:
1192 message = {
1193 "type": message_type,
1194 "pattern": None,
1195 "channel": response[1],
1196 "data": response[2],
1197 }
1198
1199 # if this is an unsubscribe message, remove it from memory
1200 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1201 if message_type == "punsubscribe":
1202 pattern = response[1]
1203 if pattern in self.pending_unsubscribe_patterns:
1204 self.pending_unsubscribe_patterns.remove(pattern)
1205 self.patterns.pop(pattern, None)
1206 else:
1207 channel = response[1]
1208 if channel in self.pending_unsubscribe_channels:
1209 self.pending_unsubscribe_channels.remove(channel)
1210 self.channels.pop(channel, None)
1211
1212 if message_type in self.PUBLISH_MESSAGE_TYPES:
1213 # if there's a message handler, invoke it
1214 if message_type == "pmessage":
1215 handler = self.patterns.get(message["pattern"], None)
1216 else:
1217 handler = self.channels.get(message["channel"], None)
1218 if handler:
1219 if inspect.iscoroutinefunction(handler):
1220 await handler(message)
1221 else:
1222 handler(message)
1223 return None
1224 elif message_type != "pong":
1225 # this is a subscribe/unsubscribe message. ignore if we don't
1226 # want them
1227 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1228 return None
1229
1230 return message
1231
1232 async def run(
1233 self,
1234 *,
1235 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1236 poll_timeout: float = 1.0,
1237 ) -> None:
1238 """Process pub/sub messages using registered callbacks.
1239
1240 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
1241 redis-py, but it is a coroutine. To launch it as a separate task, use
1242 ``asyncio.create_task``:
1243
1244 >>> task = asyncio.create_task(pubsub.run())
1245
1246 To shut it down, use asyncio cancellation:
1247
1248 >>> task.cancel()
1249 >>> await task
1250 """
1251 for channel, handler in self.channels.items():
1252 if handler is None:
1253 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1254 for pattern, handler in self.patterns.items():
1255 if handler is None:
1256 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1257
1258 await self.connect()
1259 while True:
1260 try:
1261 await self.get_message(
1262 ignore_subscribe_messages=True, timeout=poll_timeout
1263 )
1264 except asyncio.CancelledError:
1265 raise
1266 except BaseException as e:
1267 if exception_handler is None:
1268 raise
1269 res = exception_handler(e, self)
1270 if inspect.isawaitable(res):
1271 await res
1272 # Ensure that other tasks on the event loop get a chance to run
1273 # if we didn't have to block for I/O anywhere.
1274 await asyncio.sleep(0)
1275
1276
1277class PubsubWorkerExceptionHandler(Protocol):
1278 def __call__(self, e: BaseException, pubsub: PubSub): ...
1279
1280
1281class AsyncPubsubWorkerExceptionHandler(Protocol):
1282 async def __call__(self, e: BaseException, pubsub: PubSub): ...
1283
1284
1285PSWorkerThreadExcHandlerT = Union[
1286 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
1287]
1288
1289
1290CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
1291CommandStackT = List[CommandT]
1292
1293
1294class Pipeline(Redis): # lgtm [py/init-calls-subclass]
1295 """
1296 Pipelines provide a way to transmit multiple commands to the Redis server
1297 in one transmission. This is convenient for batch processing, such as
1298 saving all the values in a list to Redis.
1299
1300 All commands executed within a pipeline(when running in transactional mode,
1301 which is the default behavior) are wrapped with MULTI and EXEC
1302 calls. This guarantees all commands executed in the pipeline will be
1303 executed atomically.
1304
1305 Any command raising an exception does *not* halt the execution of
1306 subsequent commands in the pipeline. Instead, the exception is caught
1307 and its instance is placed into the response list returned by execute().
1308 Code iterating over the response list should be able to deal with an
1309 instance of an exception as a potential value. In general, these will be
1310 ResponseError exceptions, such as those raised when issuing a command
1311 on a key of a different datatype.
1312 """
1313
1314 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1315
1316 def __init__(
1317 self,
1318 connection_pool: ConnectionPool,
1319 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
1320 transaction: bool,
1321 shard_hint: Optional[str],
1322 ):
1323 self.connection_pool = connection_pool
1324 self.connection = None
1325 self.response_callbacks = response_callbacks
1326 self.is_transaction = transaction
1327 self.shard_hint = shard_hint
1328 self.watching = False
1329 self.command_stack: CommandStackT = []
1330 self.scripts: Set[Script] = set()
1331 self.explicit_transaction = False
1332
1333 async def __aenter__(self: _RedisT) -> _RedisT:
1334 return self
1335
1336 async def __aexit__(self, exc_type, exc_value, traceback):
1337 await self.reset()
1338
1339 def __await__(self):
1340 return self._async_self().__await__()
1341
1342 _DEL_MESSAGE = "Unclosed Pipeline client"
1343
1344 def __len__(self):
1345 return len(self.command_stack)
1346
1347 def __bool__(self):
1348 """Pipeline instances should always evaluate to True"""
1349 return True
1350
1351 async def _async_self(self):
1352 return self
1353
1354 async def reset(self):
1355 self.command_stack = []
1356 self.scripts = set()
1357 # make sure to reset the connection state in the event that we were
1358 # watching something
1359 if self.watching and self.connection:
1360 try:
1361 # call this manually since our unwatch or
1362 # immediate_execute_command methods can call reset()
1363 await self.connection.send_command("UNWATCH")
1364 await self.connection.read_response()
1365 except ConnectionError:
1366 # disconnect will also remove any previous WATCHes
1367 if self.connection:
1368 await self.connection.disconnect()
1369 # clean up the other instance attributes
1370 self.watching = False
1371 self.explicit_transaction = False
1372 # we can safely return the connection to the pool here since we're
1373 # sure we're no longer WATCHing anything
1374 if self.connection:
1375 await self.connection_pool.release(self.connection)
1376 self.connection = None
1377
1378 async def aclose(self) -> None:
1379 """Alias for reset(), a standard method name for cleanup"""
1380 await self.reset()
1381
1382 def multi(self):
1383 """
1384 Start a transactional block of the pipeline after WATCH commands
1385 are issued. End the transactional block with `execute`.
1386 """
1387 if self.explicit_transaction:
1388 raise RedisError("Cannot issue nested calls to MULTI")
1389 if self.command_stack:
1390 raise RedisError(
1391 "Commands without an initial WATCH have already been issued"
1392 )
1393 self.explicit_transaction = True
1394
1395 def execute_command(
1396 self, *args, **kwargs
1397 ) -> Union["Pipeline", Awaitable["Pipeline"]]:
1398 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1399 return self.immediate_execute_command(*args, **kwargs)
1400 return self.pipeline_execute_command(*args, **kwargs)
1401
1402 async def _disconnect_reset_raise_on_watching(
1403 self,
1404 conn: Connection,
1405 error: Exception,
1406 ):
1407 """
1408 Close the connection reset watching state and
1409 raise an exception if we were watching.
1410
1411 The supported exceptions are already checked in the
1412 retry object so we don't need to do it here.
1413
1414 After we disconnect the connection, it will try to reconnect and
1415 do a health check as part of the send_command logic(on connection level).
1416 """
1417 await conn.disconnect()
1418 # if we were already watching a variable, the watch is no longer
1419 # valid since this connection has died. raise a WatchError, which
1420 # indicates the user should retry this transaction.
1421 if self.watching:
1422 await self.reset()
1423 raise WatchError(
1424 f"A {type(error).__name__} occurred while watching one or more keys"
1425 )
1426
1427 async def immediate_execute_command(self, *args, **options):
1428 """
1429 Execute a command immediately, but don't auto-retry on the supported
1430 errors for retry if we're already WATCHing a variable.
1431 Used when issuing WATCH or subsequent commands retrieving their values but before
1432 MULTI is called.
1433 """
1434 command_name = args[0]
1435 conn = self.connection
1436 # if this is the first call, we need a connection
1437 if not conn:
1438 conn = await self.connection_pool.get_connection()
1439 self.connection = conn
1440
1441 return await conn.retry.call_with_retry(
1442 lambda: self._send_command_parse_response(
1443 conn, command_name, *args, **options
1444 ),
1445 lambda error: self._disconnect_reset_raise_on_watching(conn, error),
1446 )
1447
1448 def pipeline_execute_command(self, *args, **options):
1449 """
1450 Stage a command to be executed when execute() is next called
1451
1452 Returns the current Pipeline object back so commands can be
1453 chained together, such as:
1454
1455 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1456
1457 At some other point, you can then run: pipe.execute(),
1458 which will execute all commands queued in the pipe.
1459 """
1460 self.command_stack.append((args, options))
1461 return self
1462
1463 async def _execute_transaction( # noqa: C901
1464 self, connection: Connection, commands: CommandStackT, raise_on_error
1465 ):
1466 pre: CommandT = (("MULTI",), {})
1467 post: CommandT = (("EXEC",), {})
1468 cmds = (pre, *commands, post)
1469 all_cmds = connection.pack_commands(
1470 args for args, options in cmds if EMPTY_RESPONSE not in options
1471 )
1472 await connection.send_packed_command(all_cmds)
1473 errors = []
1474
1475 # parse off the response for MULTI
1476 # NOTE: we need to handle ResponseErrors here and continue
1477 # so that we read all the additional command messages from
1478 # the socket
1479 try:
1480 await self.parse_response(connection, "_")
1481 except ResponseError as err:
1482 errors.append((0, err))
1483
1484 # and all the other commands
1485 for i, command in enumerate(commands):
1486 if EMPTY_RESPONSE in command[1]:
1487 errors.append((i, command[1][EMPTY_RESPONSE]))
1488 else:
1489 try:
1490 await self.parse_response(connection, "_")
1491 except ResponseError as err:
1492 self.annotate_exception(err, i + 1, command[0])
1493 errors.append((i, err))
1494
1495 # parse the EXEC.
1496 try:
1497 response = await self.parse_response(connection, "_")
1498 except ExecAbortError as err:
1499 if errors:
1500 raise errors[0][1] from err
1501 raise
1502
1503 # EXEC clears any watched keys
1504 self.watching = False
1505
1506 if response is None:
1507 raise WatchError("Watched variable changed.") from None
1508
1509 # put any parse errors into the response
1510 for i, e in errors:
1511 response.insert(i, e)
1512
1513 if len(response) != len(commands):
1514 if self.connection:
1515 await self.connection.disconnect()
1516 raise ResponseError(
1517 "Wrong number of response items from pipeline execution"
1518 ) from None
1519
1520 # find any errors in the response and raise if necessary
1521 if raise_on_error:
1522 self.raise_first_error(commands, response)
1523
1524 # We have to run response callbacks manually
1525 data = []
1526 for r, cmd in zip(response, commands):
1527 if not isinstance(r, Exception):
1528 args, options = cmd
1529 command_name = args[0]
1530
1531 # Remove keys entry, it needs only for cache.
1532 options.pop("keys", None)
1533
1534 if command_name in self.response_callbacks:
1535 r = self.response_callbacks[command_name](r, **options)
1536 if inspect.isawaitable(r):
1537 r = await r
1538 data.append(r)
1539 return data
1540
1541 async def _execute_pipeline(
1542 self, connection: Connection, commands: CommandStackT, raise_on_error: bool
1543 ):
1544 # build up all commands into a single request to increase network perf
1545 all_cmds = connection.pack_commands([args for args, _ in commands])
1546 await connection.send_packed_command(all_cmds)
1547
1548 response = []
1549 for args, options in commands:
1550 try:
1551 response.append(
1552 await self.parse_response(connection, args[0], **options)
1553 )
1554 except ResponseError as e:
1555 response.append(e)
1556
1557 if raise_on_error:
1558 self.raise_first_error(commands, response)
1559 return response
1560
1561 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
1562 for i, r in enumerate(response):
1563 if isinstance(r, ResponseError):
1564 self.annotate_exception(r, i + 1, commands[i][0])
1565 raise r
1566
1567 def annotate_exception(
1568 self, exception: Exception, number: int, command: Iterable[object]
1569 ) -> None:
1570 cmd = " ".join(map(safe_str, command))
1571 msg = (
1572 f"Command # {number} ({truncate_text(cmd)}) "
1573 "of pipeline caused error: {exception.args}"
1574 )
1575 exception.args = (msg,) + exception.args[1:]
1576
1577 async def parse_response(
1578 self, connection: Connection, command_name: Union[str, bytes], **options
1579 ):
1580 result = await super().parse_response(connection, command_name, **options)
1581 if command_name in self.UNWATCH_COMMANDS:
1582 self.watching = False
1583 elif command_name == "WATCH":
1584 self.watching = True
1585 return result
1586
1587 async def load_scripts(self):
1588 # make sure all scripts that are about to be run on this pipeline exist
1589 scripts = list(self.scripts)
1590 immediate = self.immediate_execute_command
1591 shas = [s.sha for s in scripts]
1592 # we can't use the normal script_* methods because they would just
1593 # get buffered in the pipeline.
1594 exists = await immediate("SCRIPT EXISTS", *shas)
1595 if not all(exists):
1596 for s, exist in zip(scripts, exists):
1597 if not exist:
1598 s.sha = await immediate("SCRIPT LOAD", s.script)
1599
1600 async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
1601 """
1602 Close the connection, raise an exception if we were watching.
1603
1604 The supported exceptions are already checked in the
1605 retry object so we don't need to do it here.
1606
1607 After we disconnect the connection, it will try to reconnect and
1608 do a health check as part of the send_command logic(on connection level).
1609 """
1610 await conn.disconnect()
1611 # if we were watching a variable, the watch is no longer valid
1612 # since this connection has died. raise a WatchError, which
1613 # indicates the user should retry this transaction.
1614 if self.watching:
1615 raise WatchError(
1616 f"A {type(error).__name__} occurred while watching one or more keys"
1617 )
1618
1619 async def execute(self, raise_on_error: bool = True) -> List[Any]:
1620 """Execute all the commands in the current pipeline"""
1621 stack = self.command_stack
1622 if not stack and not self.watching:
1623 return []
1624 if self.scripts:
1625 await self.load_scripts()
1626 if self.is_transaction or self.explicit_transaction:
1627 execute = self._execute_transaction
1628 else:
1629 execute = self._execute_pipeline
1630
1631 conn = self.connection
1632 if not conn:
1633 conn = await self.connection_pool.get_connection()
1634 # assign to self.connection so reset() releases the connection
1635 # back to the pool after we're done
1636 self.connection = conn
1637 conn = cast(Connection, conn)
1638
1639 try:
1640 return await conn.retry.call_with_retry(
1641 lambda: execute(conn, stack, raise_on_error),
1642 lambda error: self._disconnect_raise_on_watching(conn, error),
1643 )
1644 finally:
1645 await self.reset()
1646
1647 async def discard(self):
1648 """Flushes all previously queued commands
1649 See: https://redis.io/commands/DISCARD
1650 """
1651 await self.execute_command("DISCARD")
1652
1653 async def watch(self, *names: KeyT):
1654 """Watches the values at keys ``names``"""
1655 if self.explicit_transaction:
1656 raise RedisError("Cannot issue a WATCH after a MULTI")
1657 return await self.execute_command("WATCH", *names)
1658
1659 async def unwatch(self):
1660 """Unwatches all previously specified keys"""
1661 return self.watching and await self.execute_command("UNWATCH") or True