Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/client.py: 19%
688 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 06:16 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 06:16 +0000
1import copy
2import re
3import threading
4import time
5import warnings
6from itertools import chain
7from typing import Any, Callable, Dict, List, Optional, Type, Union
9from redis._cache import (
10 DEFAULT_BLACKLIST,
11 DEFAULT_EVICTION_POLICY,
12 DEFAULT_WHITELIST,
13 AbstractCache,
14)
15from redis._parsers.encoders import Encoder
16from redis._parsers.helpers import (
17 _RedisCallbacks,
18 _RedisCallbacksRESP2,
19 _RedisCallbacksRESP3,
20 bool_ok,
21)
22from redis.commands import (
23 CoreCommands,
24 RedisModuleCommands,
25 SentinelCommands,
26 list_or_args,
27)
28from redis.connection import (
29 AbstractConnection,
30 ConnectionPool,
31 SSLConnection,
32 UnixDomainSocketConnection,
33)
34from redis.credentials import CredentialProvider
35from redis.exceptions import (
36 ConnectionError,
37 ExecAbortError,
38 PubSubError,
39 RedisError,
40 ResponseError,
41 TimeoutError,
42 WatchError,
43)
44from redis.lock import Lock
45from redis.retry import Retry
46from redis.utils import (
47 HIREDIS_AVAILABLE,
48 _set_info_logger,
49 get_lib_version,
50 safe_str,
51 str_if_bytes,
52)
54SYM_EMPTY = b""
55EMPTY_RESPONSE = "EMPTY_RESPONSE"
57# some responses (ie. dump) are binary, and just meant to never be decoded
58NEVER_DECODE = "NEVER_DECODE"
61class CaseInsensitiveDict(dict):
62 "Case insensitive dict implementation. Assumes string keys only."
64 def __init__(self, data: Dict[str, str]) -> None:
65 for k, v in data.items():
66 self[k.upper()] = v
68 def __contains__(self, k):
69 return super().__contains__(k.upper())
71 def __delitem__(self, k):
72 super().__delitem__(k.upper())
74 def __getitem__(self, k):
75 return super().__getitem__(k.upper())
77 def get(self, k, default=None):
78 return super().get(k.upper(), default)
80 def __setitem__(self, k, v):
81 super().__setitem__(k.upper(), v)
83 def update(self, data):
84 data = CaseInsensitiveDict(data)
85 super().update(data)
88class AbstractRedis:
89 pass
92class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
93 """
94 Implementation of the Redis protocol.
96 This abstract class provides a Python interface to all Redis commands
97 and an implementation of the Redis protocol.
99 Pipelines derive from this, implementing how
100 the commands are sent and received to the Redis server. Based on
101 configuration, an instance will either use a ConnectionPool, or
102 Connection object to talk to redis.
104 It is not safe to pass PubSub or Pipeline objects between threads.
105 """
107 @classmethod
108 def from_url(cls, url: str, **kwargs) -> "Redis":
109 """
110 Return a Redis client object configured from the given URL
112 For example::
114 redis://[[username]:[password]]@localhost:6379/0
115 rediss://[[username]:[password]]@localhost:6379/0
116 unix://[username@]/path/to/socket.sock?db=0[&password=password]
118 Three URL schemes are supported:
120 - `redis://` creates a TCP socket connection. See more at:
121 <https://www.iana.org/assignments/uri-schemes/prov/redis>
122 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
123 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
124 - ``unix://``: creates a Unix Domain Socket connection.
126 The username, password, hostname, path and all querystring values
127 are passed through urllib.parse.unquote in order to replace any
128 percent-encoded values with their corresponding characters.
130 There are several ways to specify a database number. The first value
131 found will be used:
133 1. A ``db`` querystring option, e.g. redis://localhost?db=0
134 2. If using the redis:// or rediss:// schemes, the path argument
135 of the url, e.g. redis://localhost/0
136 3. A ``db`` keyword argument to this function.
138 If none of these options are specified, the default db=0 is used.
140 All querystring options are cast to their appropriate Python types.
141 Boolean arguments can be specified with string values "True"/"False"
142 or "Yes"/"No". Values that cannot be properly cast cause a
143 ``ValueError`` to be raised. Once parsed, the querystring arguments
144 and keyword arguments are passed to the ``ConnectionPool``'s
145 class initializer. In the case of conflicting arguments, querystring
146 arguments always win.
148 """
149 single_connection_client = kwargs.pop("single_connection_client", False)
150 connection_pool = ConnectionPool.from_url(url, **kwargs)
151 client = cls(
152 connection_pool=connection_pool,
153 single_connection_client=single_connection_client,
154 )
155 client.auto_close_connection_pool = True
156 return client
158 @classmethod
159 def from_pool(
160 cls: Type["Redis"],
161 connection_pool: ConnectionPool,
162 ) -> "Redis":
163 """
164 Return a Redis client from the given connection pool.
165 The Redis client will take ownership of the connection pool and
166 close it when the Redis client is closed.
167 """
168 client = cls(
169 connection_pool=connection_pool,
170 )
171 client.auto_close_connection_pool = True
172 return client
174 def __init__(
175 self,
176 host="localhost",
177 port=6379,
178 db=0,
179 password=None,
180 socket_timeout=None,
181 socket_connect_timeout=None,
182 socket_keepalive=None,
183 socket_keepalive_options=None,
184 connection_pool=None,
185 unix_socket_path=None,
186 encoding="utf-8",
187 encoding_errors="strict",
188 charset=None,
189 errors=None,
190 decode_responses=False,
191 retry_on_timeout=False,
192 retry_on_error=None,
193 ssl=False,
194 ssl_keyfile=None,
195 ssl_certfile=None,
196 ssl_cert_reqs="required",
197 ssl_ca_certs=None,
198 ssl_ca_path=None,
199 ssl_ca_data=None,
200 ssl_check_hostname=False,
201 ssl_password=None,
202 ssl_validate_ocsp=False,
203 ssl_validate_ocsp_stapled=False,
204 ssl_ocsp_context=None,
205 ssl_ocsp_expected_cert=None,
206 ssl_min_version=None,
207 max_connections=None,
208 single_connection_client=False,
209 health_check_interval=0,
210 client_name=None,
211 lib_name="redis-py",
212 lib_version=get_lib_version(),
213 username=None,
214 retry=None,
215 redis_connect_func=None,
216 credential_provider: Optional[CredentialProvider] = None,
217 protocol: Optional[int] = 2,
218 cache_enabled: bool = False,
219 client_cache: Optional[AbstractCache] = None,
220 cache_max_size: int = 10000,
221 cache_ttl: int = 0,
222 cache_policy: str = DEFAULT_EVICTION_POLICY,
223 cache_blacklist: List[str] = DEFAULT_BLACKLIST,
224 cache_whitelist: List[str] = DEFAULT_WHITELIST,
225 ) -> None:
226 """
227 Initialize a new Redis client.
228 To specify a retry policy for specific errors, first set
229 `retry_on_error` to a list of the error/s to retry on, then set
230 `retry` to a valid `Retry` object.
231 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
233 Args:
235 single_connection_client:
236 if `True`, connection pool is not used. In that case `Redis`
237 instance use is not thread safe.
238 """
239 if not connection_pool:
240 if charset is not None:
241 warnings.warn(
242 DeprecationWarning(
243 '"charset" is deprecated. Use "encoding" instead'
244 )
245 )
246 encoding = charset
247 if errors is not None:
248 warnings.warn(
249 DeprecationWarning(
250 '"errors" is deprecated. Use "encoding_errors" instead'
251 )
252 )
253 encoding_errors = errors
254 if not retry_on_error:
255 retry_on_error = []
256 if retry_on_timeout is True:
257 retry_on_error.append(TimeoutError)
258 kwargs = {
259 "db": db,
260 "username": username,
261 "password": password,
262 "socket_timeout": socket_timeout,
263 "encoding": encoding,
264 "encoding_errors": encoding_errors,
265 "decode_responses": decode_responses,
266 "retry_on_error": retry_on_error,
267 "retry": copy.deepcopy(retry),
268 "max_connections": max_connections,
269 "health_check_interval": health_check_interval,
270 "client_name": client_name,
271 "lib_name": lib_name,
272 "lib_version": lib_version,
273 "redis_connect_func": redis_connect_func,
274 "credential_provider": credential_provider,
275 "protocol": protocol,
276 "cache_enabled": cache_enabled,
277 "client_cache": client_cache,
278 "cache_max_size": cache_max_size,
279 "cache_ttl": cache_ttl,
280 "cache_policy": cache_policy,
281 "cache_blacklist": cache_blacklist,
282 "cache_whitelist": cache_whitelist,
283 }
284 # based on input, setup appropriate connection args
285 if unix_socket_path is not None:
286 kwargs.update(
287 {
288 "path": unix_socket_path,
289 "connection_class": UnixDomainSocketConnection,
290 }
291 )
292 else:
293 # TCP specific options
294 kwargs.update(
295 {
296 "host": host,
297 "port": port,
298 "socket_connect_timeout": socket_connect_timeout,
299 "socket_keepalive": socket_keepalive,
300 "socket_keepalive_options": socket_keepalive_options,
301 }
302 )
304 if ssl:
305 kwargs.update(
306 {
307 "connection_class": SSLConnection,
308 "ssl_keyfile": ssl_keyfile,
309 "ssl_certfile": ssl_certfile,
310 "ssl_cert_reqs": ssl_cert_reqs,
311 "ssl_ca_certs": ssl_ca_certs,
312 "ssl_ca_data": ssl_ca_data,
313 "ssl_check_hostname": ssl_check_hostname,
314 "ssl_password": ssl_password,
315 "ssl_ca_path": ssl_ca_path,
316 "ssl_validate_ocsp_stapled": ssl_validate_ocsp_stapled,
317 "ssl_validate_ocsp": ssl_validate_ocsp,
318 "ssl_ocsp_context": ssl_ocsp_context,
319 "ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
320 "ssl_min_version": ssl_min_version,
321 }
322 )
323 connection_pool = ConnectionPool(**kwargs)
324 self.auto_close_connection_pool = True
325 else:
326 self.auto_close_connection_pool = False
328 self.connection_pool = connection_pool
329 self.connection = None
330 if single_connection_client:
331 self.connection = self.connection_pool.get_connection("_")
333 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
335 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]:
336 self.response_callbacks.update(_RedisCallbacksRESP3)
337 else:
338 self.response_callbacks.update(_RedisCallbacksRESP2)
340 def __repr__(self) -> str:
341 return (
342 f"<{type(self).__module__}.{type(self).__name__}"
343 f"({repr(self.connection_pool)})>"
344 )
346 def get_encoder(self) -> "Encoder":
347 """Get the connection pool's encoder"""
348 return self.connection_pool.get_encoder()
350 def get_connection_kwargs(self) -> Dict:
351 """Get the connection's key-word arguments"""
352 return self.connection_pool.connection_kwargs
354 def get_retry(self) -> Optional["Retry"]:
355 return self.get_connection_kwargs().get("retry")
357 def set_retry(self, retry: "Retry") -> None:
358 self.get_connection_kwargs().update({"retry": retry})
359 self.connection_pool.set_retry(retry)
361 def set_response_callback(self, command: str, callback: Callable) -> None:
362 """Set a custom Response Callback"""
363 self.response_callbacks[command] = callback
365 def load_external_module(self, funcname, func) -> None:
366 """
367 This function can be used to add externally defined redis modules,
368 and their namespaces to the redis client.
370 funcname - A string containing the name of the function to create
371 func - The function, being added to this class.
373 ex: Assume that one has a custom redis module named foomod that
374 creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
375 To load function functions into this namespace:
377 from redis import Redis
378 from foomodule import F
379 r = Redis()
380 r.load_external_module("foo", F)
381 r.foo().dothing('your', 'arguments')
383 For a concrete example see the reimport of the redisjson module in
384 tests/test_connection.py::test_loading_external_modules
385 """
386 setattr(self, funcname, func)
388 def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline":
389 """
390 Return a new pipeline object that can queue multiple commands for
391 later execution. ``transaction`` indicates whether all commands
392 should be executed atomically. Apart from making a group of operations
393 atomic, pipelines are useful for reducing the back-and-forth overhead
394 between the client and server.
395 """
396 return Pipeline(
397 self.connection_pool, self.response_callbacks, transaction, shard_hint
398 )
400 def transaction(
401 self, func: Callable[["Pipeline"], None], *watches, **kwargs
402 ) -> None:
403 """
404 Convenience method for executing the callable `func` as a transaction
405 while watching all keys specified in `watches`. The 'func' callable
406 should expect a single argument which is a Pipeline object.
407 """
408 shard_hint = kwargs.pop("shard_hint", None)
409 value_from_callable = kwargs.pop("value_from_callable", False)
410 watch_delay = kwargs.pop("watch_delay", None)
411 with self.pipeline(True, shard_hint) as pipe:
412 while True:
413 try:
414 if watches:
415 pipe.watch(*watches)
416 func_value = func(pipe)
417 exec_value = pipe.execute()
418 return func_value if value_from_callable else exec_value
419 except WatchError:
420 if watch_delay is not None and watch_delay > 0:
421 time.sleep(watch_delay)
422 continue
424 def lock(
425 self,
426 name: str,
427 timeout: Optional[float] = None,
428 sleep: float = 0.1,
429 blocking: bool = True,
430 blocking_timeout: Optional[float] = None,
431 lock_class: Union[None, Any] = None,
432 thread_local: bool = True,
433 ):
434 """
435 Return a new Lock object using key ``name`` that mimics
436 the behavior of threading.Lock.
438 If specified, ``timeout`` indicates a maximum life for the lock.
439 By default, it will remain locked until release() is called.
441 ``sleep`` indicates the amount of time to sleep per loop iteration
442 when the lock is in blocking mode and another client is currently
443 holding the lock.
445 ``blocking`` indicates whether calling ``acquire`` should block until
446 the lock has been acquired or to fail immediately, causing ``acquire``
447 to return False and the lock not being acquired. Defaults to True.
448 Note this value can be overridden by passing a ``blocking``
449 argument to ``acquire``.
451 ``blocking_timeout`` indicates the maximum amount of time in seconds to
452 spend trying to acquire the lock. A value of ``None`` indicates
453 continue trying forever. ``blocking_timeout`` can be specified as a
454 float or integer, both representing the number of seconds to wait.
456 ``lock_class`` forces the specified lock implementation. Note that as
457 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
458 a Lua-based lock). So, it's unlikely you'll need this parameter, unless
459 you have created your own custom lock class.
461 ``thread_local`` indicates whether the lock token is placed in
462 thread-local storage. By default, the token is placed in thread local
463 storage so that a thread only sees its token, not a token set by
464 another thread. Consider the following timeline:
466 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
467 thread-1 sets the token to "abc"
468 time: 1, thread-2 blocks trying to acquire `my-lock` using the
469 Lock instance.
470 time: 5, thread-1 has not yet completed. redis expires the lock
471 key.
472 time: 5, thread-2 acquired `my-lock` now that it's available.
473 thread-2 sets the token to "xyz"
474 time: 6, thread-1 finishes its work and calls release(). if the
475 token is *not* stored in thread local storage, then
476 thread-1 would see the token value as "xyz" and would be
477 able to successfully release the thread-2's lock.
479 In some use cases it's necessary to disable thread local storage. For
480 example, if you have code where one thread acquires a lock and passes
481 that lock instance to a worker thread to release later. If thread
482 local storage isn't disabled in this case, the worker thread won't see
483 the token set by the thread that acquired the lock. Our assumption
484 is that these cases aren't common and as such default to using
485 thread local storage."""
486 if lock_class is None:
487 lock_class = Lock
488 return lock_class(
489 self,
490 name,
491 timeout=timeout,
492 sleep=sleep,
493 blocking=blocking,
494 blocking_timeout=blocking_timeout,
495 thread_local=thread_local,
496 )
498 def pubsub(self, **kwargs):
499 """
500 Return a Publish/Subscribe object. With this object, you can
501 subscribe to channels and listen for messages that get published to
502 them.
503 """
504 return PubSub(self.connection_pool, **kwargs)
506 def monitor(self):
507 return Monitor(self.connection_pool)
509 def client(self):
510 return self.__class__(
511 connection_pool=self.connection_pool, single_connection_client=True
512 )
514 def __enter__(self):
515 return self
517 def __exit__(self, exc_type, exc_value, traceback):
518 self.close()
520 def __del__(self):
521 self.close()
523 def close(self):
524 # In case a connection property does not yet exist
525 # (due to a crash earlier in the Redis() constructor), return
526 # immediately as there is nothing to clean-up.
527 if not hasattr(self, "connection"):
528 return
530 conn = self.connection
531 if conn:
532 self.connection = None
533 self.connection_pool.release(conn)
535 if self.auto_close_connection_pool:
536 self.connection_pool.disconnect()
538 def _send_command_parse_response(self, conn, command_name, *args, **options):
539 """
540 Send a command and parse the response
541 """
542 conn.send_command(*args)
543 return self.parse_response(conn, command_name, **options)
545 def _disconnect_raise(self, conn, error):
546 """
547 Close the connection and raise an exception
548 if retry_on_error is not set or the error
549 is not one of the specified error types
550 """
551 conn.disconnect()
552 if (
553 conn.retry_on_error is None
554 or isinstance(error, tuple(conn.retry_on_error)) is False
555 ):
556 raise error
558 # COMMAND EXECUTION AND PROTOCOL PARSING
559 def execute_command(self, *args, **options):
560 """Execute a command and return a parsed response"""
561 command_name = args[0]
562 keys = options.pop("keys", None)
563 pool = self.connection_pool
564 conn = self.connection or pool.get_connection(command_name, **options)
565 response_from_cache = conn._get_from_local_cache(args)
566 try:
567 if response_from_cache is not None:
568 return response_from_cache
569 else:
570 response = conn.retry.call_with_retry(
571 lambda: self._send_command_parse_response(
572 conn, command_name, *args, **options
573 ),
574 lambda error: self._disconnect_raise(conn, error),
575 )
576 conn._add_to_local_cache(args, response, keys)
577 return response
578 finally:
579 if not self.connection:
580 pool.release(conn)
582 def parse_response(self, connection, command_name, **options):
583 """Parses a response from the Redis server"""
584 try:
585 if NEVER_DECODE in options:
586 response = connection.read_response(disable_decoding=True)
587 options.pop(NEVER_DECODE)
588 else:
589 response = connection.read_response()
590 except ResponseError:
591 if EMPTY_RESPONSE in options:
592 return options[EMPTY_RESPONSE]
593 raise
595 if EMPTY_RESPONSE in options:
596 options.pop(EMPTY_RESPONSE)
598 if command_name in self.response_callbacks:
599 return self.response_callbacks[command_name](response, **options)
600 return response
602 def flush_cache(self):
603 try:
604 if self.connection:
605 self.connection.client_cache.flush()
606 else:
607 self.connection_pool.flush_cache()
608 except AttributeError:
609 pass
611 def delete_command_from_cache(self, command):
612 try:
613 if self.connection:
614 self.connection.client_cache.delete_command(command)
615 else:
616 self.connection_pool.delete_command_from_cache(command)
617 except AttributeError:
618 pass
620 def invalidate_key_from_cache(self, key):
621 try:
622 if self.connection:
623 self.connection.client_cache.invalidate_key(key)
624 else:
625 self.connection_pool.invalidate_key_from_cache(key)
626 except AttributeError:
627 pass
630StrictRedis = Redis
633class Monitor:
634 """
635 Monitor is useful for handling the MONITOR command to the redis server.
636 next_command() method returns one command from monitor
637 listen() method yields commands from monitor.
638 """
640 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)")
641 command_re = re.compile(r'"(.*?)(?<!\\)"')
643 def __init__(self, connection_pool):
644 self.connection_pool = connection_pool
645 self.connection = self.connection_pool.get_connection("MONITOR")
647 def __enter__(self):
648 self.connection.send_command("MONITOR")
649 # check that monitor returns 'OK', but don't return it to user
650 response = self.connection.read_response()
651 if not bool_ok(response):
652 raise RedisError(f"MONITOR failed: {response}")
653 return self
655 def __exit__(self, *args):
656 self.connection.disconnect()
657 self.connection_pool.release(self.connection)
659 def next_command(self):
660 """Parse the response from a monitor command"""
661 response = self.connection.read_response()
662 if isinstance(response, bytes):
663 response = self.connection.encoder.decode(response, force=True)
664 command_time, command_data = response.split(" ", 1)
665 m = self.monitor_re.match(command_data)
666 db_id, client_info, command = m.groups()
667 command = " ".join(self.command_re.findall(command))
668 # Redis escapes double quotes because each piece of the command
669 # string is surrounded by double quotes. We don't have that
670 # requirement so remove the escaping and leave the quote.
671 command = command.replace('\\"', '"')
673 if client_info == "lua":
674 client_address = "lua"
675 client_port = ""
676 client_type = "lua"
677 elif client_info.startswith("unix"):
678 client_address = "unix"
679 client_port = client_info[5:]
680 client_type = "unix"
681 else:
682 # use rsplit as ipv6 addresses contain colons
683 client_address, client_port = client_info.rsplit(":", 1)
684 client_type = "tcp"
685 return {
686 "time": float(command_time),
687 "db": int(db_id),
688 "client_address": client_address,
689 "client_port": client_port,
690 "client_type": client_type,
691 "command": command,
692 }
694 def listen(self):
695 """Listen for commands coming to the server."""
696 while True:
697 yield self.next_command()
700class PubSub:
701 """
702 PubSub provides publish, subscribe and listen support to Redis channels.
704 After subscribing to one or more channels, the listen() method will block
705 until a message arrives on one of the subscribed channels. That message
706 will be returned and it's safe to start listening again.
707 """
709 PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
710 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
711 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
713 def __init__(
714 self,
715 connection_pool,
716 shard_hint=None,
717 ignore_subscribe_messages: bool = False,
718 encoder: Optional["Encoder"] = None,
719 push_handler_func: Union[None, Callable[[str], None]] = None,
720 ):
721 self.connection_pool = connection_pool
722 self.shard_hint = shard_hint
723 self.ignore_subscribe_messages = ignore_subscribe_messages
724 self.connection = None
725 self.subscribed_event = threading.Event()
726 # we need to know the encoding options for this connection in order
727 # to lookup channel and pattern names for callback handlers.
728 self.encoder = encoder
729 self.push_handler_func = push_handler_func
730 if self.encoder is None:
731 self.encoder = self.connection_pool.get_encoder()
732 self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
733 if self.encoder.decode_responses:
734 self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
735 else:
736 self.health_check_response = [b"pong", self.health_check_response_b]
737 if self.push_handler_func is None:
738 _set_info_logger()
739 self.reset()
741 def __enter__(self) -> "PubSub":
742 return self
744 def __exit__(self, exc_type, exc_value, traceback) -> None:
745 self.reset()
747 def __del__(self) -> None:
748 try:
749 # if this object went out of scope prior to shutting down
750 # subscriptions, close the connection manually before
751 # returning it to the connection pool
752 self.reset()
753 except Exception:
754 pass
756 def reset(self) -> None:
757 if self.connection:
758 self.connection.disconnect()
759 self.connection.deregister_connect_callback(self.on_connect)
760 self.connection_pool.release(self.connection)
761 self.connection = None
762 self.health_check_response_counter = 0
763 self.channels = {}
764 self.pending_unsubscribe_channels = set()
765 self.shard_channels = {}
766 self.pending_unsubscribe_shard_channels = set()
767 self.patterns = {}
768 self.pending_unsubscribe_patterns = set()
769 self.subscribed_event.clear()
771 def close(self) -> None:
772 self.reset()
774 def on_connect(self, connection) -> None:
775 "Re-subscribe to any channels and patterns previously subscribed to"
776 # NOTE: for python3, we can't pass bytestrings as keyword arguments
777 # so we need to decode channel/pattern names back to unicode strings
778 # before passing them to [p]subscribe.
779 self.pending_unsubscribe_channels.clear()
780 self.pending_unsubscribe_patterns.clear()
781 self.pending_unsubscribe_shard_channels.clear()
782 if self.channels:
783 channels = {
784 self.encoder.decode(k, force=True): v for k, v in self.channels.items()
785 }
786 self.subscribe(**channels)
787 if self.patterns:
788 patterns = {
789 self.encoder.decode(k, force=True): v for k, v in self.patterns.items()
790 }
791 self.psubscribe(**patterns)
792 if self.shard_channels:
793 shard_channels = {
794 self.encoder.decode(k, force=True): v
795 for k, v in self.shard_channels.items()
796 }
797 self.ssubscribe(**shard_channels)
799 @property
800 def subscribed(self) -> bool:
801 """Indicates if there are subscriptions to any channels or patterns"""
802 return self.subscribed_event.is_set()
804 def execute_command(self, *args):
805 """Execute a publish/subscribe command"""
807 # NOTE: don't parse the response in this function -- it could pull a
808 # legitimate message off the stack if the connection is already
809 # subscribed to one or more channels
811 if self.connection is None:
812 self.connection = self.connection_pool.get_connection(
813 "pubsub", self.shard_hint
814 )
815 # register a callback that re-subscribes to any channels we
816 # were listening to when we were disconnected
817 self.connection.register_connect_callback(self.on_connect)
818 if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
819 self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
820 connection = self.connection
821 kwargs = {"check_health": not self.subscribed}
822 if not self.subscribed:
823 self.clean_health_check_responses()
824 self._execute(connection, connection.send_command, *args, **kwargs)
826 def clean_health_check_responses(self) -> None:
827 """
828 If any health check responses are present, clean them
829 """
830 ttl = 10
831 conn = self.connection
832 while self.health_check_response_counter > 0 and ttl > 0:
833 if self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
834 response = self._execute(conn, conn.read_response)
835 if self.is_health_check_response(response):
836 self.health_check_response_counter -= 1
837 else:
838 raise PubSubError(
839 "A non health check response was cleaned by "
840 "execute_command: {0}".format(response)
841 )
842 ttl -= 1
844 def _disconnect_raise_connect(self, conn, error) -> None:
845 """
846 Close the connection and raise an exception
847 if retry_on_error is not set or the error is not one
848 of the specified error types. Otherwise, try to
849 reconnect
850 """
851 conn.disconnect()
852 if (
853 conn.retry_on_error is None
854 or isinstance(error, tuple(conn.retry_on_error)) is False
855 ):
856 raise error
857 conn.connect()
859 def _execute(self, conn, command, *args, **kwargs):
860 """
861 Connect manually upon disconnection. If the Redis server is down,
862 this will fail and raise a ConnectionError as desired.
863 After reconnection, the ``on_connect`` callback should have been
864 called by the # connection to resubscribe us to any channels and
865 patterns we were previously listening to
866 """
867 return conn.retry.call_with_retry(
868 lambda: command(*args, **kwargs),
869 lambda error: self._disconnect_raise_connect(conn, error),
870 )
872 def parse_response(self, block=True, timeout=0):
873 """Parse the response from a publish/subscribe command"""
874 conn = self.connection
875 if conn is None:
876 raise RuntimeError(
877 "pubsub connection not set: "
878 "did you forget to call subscribe() or psubscribe()?"
879 )
881 self.check_health()
883 def try_read():
884 if not block:
885 if not conn.can_read(timeout=timeout):
886 return None
887 else:
888 conn.connect()
889 return conn.read_response(disconnect_on_error=False, push_request=True)
891 response = self._execute(conn, try_read)
893 if self.is_health_check_response(response):
894 # ignore the health check message as user might not expect it
895 self.health_check_response_counter -= 1
896 return None
897 return response
899 def is_health_check_response(self, response) -> bool:
900 """
901 Check if the response is a health check response.
902 If there are no subscriptions redis responds to PING command with a
903 bulk response, instead of a multi-bulk with "pong" and the response.
904 """
905 return response in [
906 self.health_check_response, # If there was a subscription
907 self.health_check_response_b, # If there wasn't
908 ]
910 def check_health(self) -> None:
911 conn = self.connection
912 if conn is None:
913 raise RuntimeError(
914 "pubsub connection not set: "
915 "did you forget to call subscribe() or psubscribe()?"
916 )
918 if conn.health_check_interval and time.time() > conn.next_health_check:
919 conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
920 self.health_check_response_counter += 1
922 def _normalize_keys(self, data) -> Dict:
923 """
924 normalize channel/pattern names to be either bytes or strings
925 based on whether responses are automatically decoded. this saves us
926 from coercing the value for each message coming in.
927 """
928 encode = self.encoder.encode
929 decode = self.encoder.decode
930 return {decode(encode(k)): v for k, v in data.items()}
932 def psubscribe(self, *args, **kwargs):
933 """
934 Subscribe to channel patterns. Patterns supplied as keyword arguments
935 expect a pattern name as the key and a callable as the value. A
936 pattern's callable will be invoked automatically when a message is
937 received on that pattern rather than producing a message via
938 ``listen()``.
939 """
940 if args:
941 args = list_or_args(args[0], args[1:])
942 new_patterns = dict.fromkeys(args)
943 new_patterns.update(kwargs)
944 ret_val = self.execute_command("PSUBSCRIBE", *new_patterns.keys())
945 # update the patterns dict AFTER we send the command. we don't want to
946 # subscribe twice to these patterns, once for the command and again
947 # for the reconnection.
948 new_patterns = self._normalize_keys(new_patterns)
949 self.patterns.update(new_patterns)
950 if not self.subscribed:
951 # Set the subscribed_event flag to True
952 self.subscribed_event.set()
953 # Clear the health check counter
954 self.health_check_response_counter = 0
955 self.pending_unsubscribe_patterns.difference_update(new_patterns)
956 return ret_val
958 def punsubscribe(self, *args):
959 """
960 Unsubscribe from the supplied patterns. If empty, unsubscribe from
961 all patterns.
962 """
963 if args:
964 args = list_or_args(args[0], args[1:])
965 patterns = self._normalize_keys(dict.fromkeys(args))
966 else:
967 patterns = self.patterns
968 self.pending_unsubscribe_patterns.update(patterns)
969 return self.execute_command("PUNSUBSCRIBE", *args)
971 def subscribe(self, *args, **kwargs):
972 """
973 Subscribe to channels. Channels supplied as keyword arguments expect
974 a channel name as the key and a callable as the value. A channel's
975 callable will be invoked automatically when a message is received on
976 that channel rather than producing a message via ``listen()`` or
977 ``get_message()``.
978 """
979 if args:
980 args = list_or_args(args[0], args[1:])
981 new_channels = dict.fromkeys(args)
982 new_channels.update(kwargs)
983 ret_val = self.execute_command("SUBSCRIBE", *new_channels.keys())
984 # update the channels dict AFTER we send the command. we don't want to
985 # subscribe twice to these channels, once for the command and again
986 # for the reconnection.
987 new_channels = self._normalize_keys(new_channels)
988 self.channels.update(new_channels)
989 if not self.subscribed:
990 # Set the subscribed_event flag to True
991 self.subscribed_event.set()
992 # Clear the health check counter
993 self.health_check_response_counter = 0
994 self.pending_unsubscribe_channels.difference_update(new_channels)
995 return ret_val
997 def unsubscribe(self, *args):
998 """
999 Unsubscribe from the supplied channels. If empty, unsubscribe from
1000 all channels
1001 """
1002 if args:
1003 args = list_or_args(args[0], args[1:])
1004 channels = self._normalize_keys(dict.fromkeys(args))
1005 else:
1006 channels = self.channels
1007 self.pending_unsubscribe_channels.update(channels)
1008 return self.execute_command("UNSUBSCRIBE", *args)
1010 def ssubscribe(self, *args, target_node=None, **kwargs):
1011 """
1012 Subscribes the client to the specified shard channels.
1013 Channels supplied as keyword arguments expect a channel name as the key
1014 and a callable as the value. A channel's callable will be invoked automatically
1015 when a message is received on that channel rather than producing a message via
1016 ``listen()`` or ``get_sharded_message()``.
1017 """
1018 if args:
1019 args = list_or_args(args[0], args[1:])
1020 new_s_channels = dict.fromkeys(args)
1021 new_s_channels.update(kwargs)
1022 ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1023 # update the s_channels dict AFTER we send the command. we don't want to
1024 # subscribe twice to these channels, once for the command and again
1025 # for the reconnection.
1026 new_s_channels = self._normalize_keys(new_s_channels)
1027 self.shard_channels.update(new_s_channels)
1028 if not self.subscribed:
1029 # Set the subscribed_event flag to True
1030 self.subscribed_event.set()
1031 # Clear the health check counter
1032 self.health_check_response_counter = 0
1033 self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1034 return ret_val
1036 def sunsubscribe(self, *args, target_node=None):
1037 """
1038 Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1039 all shard_channels
1040 """
1041 if args:
1042 args = list_or_args(args[0], args[1:])
1043 s_channels = self._normalize_keys(dict.fromkeys(args))
1044 else:
1045 s_channels = self.shard_channels
1046 self.pending_unsubscribe_shard_channels.update(s_channels)
1047 return self.execute_command("SUNSUBSCRIBE", *args)
1049 def listen(self):
1050 "Listen for messages on channels this client has been subscribed to"
1051 while self.subscribed:
1052 response = self.handle_message(self.parse_response(block=True))
1053 if response is not None:
1054 yield response
1056 def get_message(
1057 self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
1058 ):
1059 """
1060 Get the next message if one is available, otherwise None.
1062 If timeout is specified, the system will wait for `timeout` seconds
1063 before returning. Timeout should be specified as a floating point
1064 number, or None, to wait indefinitely.
1065 """
1066 if not self.subscribed:
1067 # Wait for subscription
1068 start_time = time.time()
1069 if self.subscribed_event.wait(timeout) is True:
1070 # The connection was subscribed during the timeout time frame.
1071 # The timeout should be adjusted based on the time spent
1072 # waiting for the subscription
1073 time_spent = time.time() - start_time
1074 timeout = max(0.0, timeout - time_spent)
1075 else:
1076 # The connection isn't subscribed to any channels or patterns,
1077 # so no messages are available
1078 return None
1080 response = self.parse_response(block=(timeout is None), timeout=timeout)
1081 if response:
1082 return self.handle_message(response, ignore_subscribe_messages)
1083 return None
1085 get_sharded_message = get_message
1087 def ping(self, message: Union[str, None] = None) -> bool:
1088 """
1089 Ping the Redis server
1090 """
1091 args = ["PING", message] if message is not None else ["PING"]
1092 return self.execute_command(*args)
1094 def handle_message(self, response, ignore_subscribe_messages=False):
1095 """
1096 Parses a pub/sub message. If the channel or pattern was subscribed to
1097 with a message handler, the handler is invoked instead of a parsed
1098 message being returned.
1099 """
1100 if response is None:
1101 return None
1102 if isinstance(response, bytes):
1103 response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
1104 message_type = str_if_bytes(response[0])
1105 if message_type == "pmessage":
1106 message = {
1107 "type": message_type,
1108 "pattern": response[1],
1109 "channel": response[2],
1110 "data": response[3],
1111 }
1112 elif message_type == "pong":
1113 message = {
1114 "type": message_type,
1115 "pattern": None,
1116 "channel": None,
1117 "data": response[1],
1118 }
1119 else:
1120 message = {
1121 "type": message_type,
1122 "pattern": None,
1123 "channel": response[1],
1124 "data": response[2],
1125 }
1127 # if this is an unsubscribe message, remove it from memory
1128 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
1129 if message_type == "punsubscribe":
1130 pattern = response[1]
1131 if pattern in self.pending_unsubscribe_patterns:
1132 self.pending_unsubscribe_patterns.remove(pattern)
1133 self.patterns.pop(pattern, None)
1134 elif message_type == "sunsubscribe":
1135 s_channel = response[1]
1136 if s_channel in self.pending_unsubscribe_shard_channels:
1137 self.pending_unsubscribe_shard_channels.remove(s_channel)
1138 self.shard_channels.pop(s_channel, None)
1139 else:
1140 channel = response[1]
1141 if channel in self.pending_unsubscribe_channels:
1142 self.pending_unsubscribe_channels.remove(channel)
1143 self.channels.pop(channel, None)
1144 if not self.channels and not self.patterns and not self.shard_channels:
1145 # There are no subscriptions anymore, set subscribed_event flag
1146 # to false
1147 self.subscribed_event.clear()
1149 if message_type in self.PUBLISH_MESSAGE_TYPES:
1150 # if there's a message handler, invoke it
1151 if message_type == "pmessage":
1152 handler = self.patterns.get(message["pattern"], None)
1153 elif message_type == "smessage":
1154 handler = self.shard_channels.get(message["channel"], None)
1155 else:
1156 handler = self.channels.get(message["channel"], None)
1157 if handler:
1158 handler(message)
1159 return None
1160 elif message_type != "pong":
1161 # this is a subscribe/unsubscribe message. ignore if we don't
1162 # want them
1163 if ignore_subscribe_messages or self.ignore_subscribe_messages:
1164 return None
1166 return message
1168 def run_in_thread(
1169 self,
1170 sleep_time: float = 0.0,
1171 daemon: bool = False,
1172 exception_handler: Optional[Callable] = None,
1173 ) -> "PubSubWorkerThread":
1174 for channel, handler in self.channels.items():
1175 if handler is None:
1176 raise PubSubError(f"Channel: '{channel}' has no handler registered")
1177 for pattern, handler in self.patterns.items():
1178 if handler is None:
1179 raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1180 for s_channel, handler in self.shard_channels.items():
1181 if handler is None:
1182 raise PubSubError(
1183 f"Shard Channel: '{s_channel}' has no handler registered"
1184 )
1186 thread = PubSubWorkerThread(
1187 self, sleep_time, daemon=daemon, exception_handler=exception_handler
1188 )
1189 thread.start()
1190 return thread
1193class PubSubWorkerThread(threading.Thread):
1194 def __init__(
1195 self,
1196 pubsub,
1197 sleep_time: float,
1198 daemon: bool = False,
1199 exception_handler: Union[
1200 Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None
1201 ] = None,
1202 ):
1203 super().__init__()
1204 self.daemon = daemon
1205 self.pubsub = pubsub
1206 self.sleep_time = sleep_time
1207 self.exception_handler = exception_handler
1208 self._running = threading.Event()
1210 def run(self) -> None:
1211 if self._running.is_set():
1212 return
1213 self._running.set()
1214 pubsub = self.pubsub
1215 sleep_time = self.sleep_time
1216 while self._running.is_set():
1217 try:
1218 pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time)
1219 except BaseException as e:
1220 if self.exception_handler is None:
1221 raise
1222 self.exception_handler(e, pubsub, self)
1223 pubsub.close()
1225 def stop(self) -> None:
1226 # trip the flag so the run loop exits. the run loop will
1227 # close the pubsub connection, which disconnects the socket
1228 # and returns the connection to the pool.
1229 self._running.clear()
1232class Pipeline(Redis):
1233 """
1234 Pipelines provide a way to transmit multiple commands to the Redis server
1235 in one transmission. This is convenient for batch processing, such as
1236 saving all the values in a list to Redis.
1238 All commands executed within a pipeline are wrapped with MULTI and EXEC
1239 calls. This guarantees all commands executed in the pipeline will be
1240 executed atomically.
1242 Any command raising an exception does *not* halt the execution of
1243 subsequent commands in the pipeline. Instead, the exception is caught
1244 and its instance is placed into the response list returned by execute().
1245 Code iterating over the response list should be able to deal with an
1246 instance of an exception as a potential value. In general, these will be
1247 ResponseError exceptions, such as those raised when issuing a command
1248 on a key of a different datatype.
1249 """
1251 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
1253 def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
1254 self.connection_pool = connection_pool
1255 self.connection = None
1256 self.response_callbacks = response_callbacks
1257 self.transaction = transaction
1258 self.shard_hint = shard_hint
1260 self.watching = False
1261 self.reset()
1263 def __enter__(self) -> "Pipeline":
1264 return self
1266 def __exit__(self, exc_type, exc_value, traceback):
1267 self.reset()
1269 def __del__(self):
1270 try:
1271 self.reset()
1272 except Exception:
1273 pass
1275 def __len__(self) -> int:
1276 return len(self.command_stack)
1278 def __bool__(self) -> bool:
1279 """Pipeline instances should always evaluate to True"""
1280 return True
1282 def reset(self) -> None:
1283 self.command_stack = []
1284 self.scripts = set()
1285 # make sure to reset the connection state in the event that we were
1286 # watching something
1287 if self.watching and self.connection:
1288 try:
1289 # call this manually since our unwatch or
1290 # immediate_execute_command methods can call reset()
1291 self.connection.send_command("UNWATCH")
1292 self.connection.read_response()
1293 except ConnectionError:
1294 # disconnect will also remove any previous WATCHes
1295 self.connection.disconnect()
1296 # clean up the other instance attributes
1297 self.watching = False
1298 self.explicit_transaction = False
1299 # we can safely return the connection to the pool here since we're
1300 # sure we're no longer WATCHing anything
1301 if self.connection:
1302 self.connection_pool.release(self.connection)
1303 self.connection = None
1305 def close(self) -> None:
1306 """Close the pipeline"""
1307 self.reset()
1309 def multi(self) -> None:
1310 """
1311 Start a transactional block of the pipeline after WATCH commands
1312 are issued. End the transactional block with `execute`.
1313 """
1314 if self.explicit_transaction:
1315 raise RedisError("Cannot issue nested calls to MULTI")
1316 if self.command_stack:
1317 raise RedisError(
1318 "Commands without an initial WATCH have already been issued"
1319 )
1320 self.explicit_transaction = True
1322 def execute_command(self, *args, **kwargs):
1323 kwargs.pop("keys", None) # the keys are used only for client side caching
1324 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
1325 return self.immediate_execute_command(*args, **kwargs)
1326 return self.pipeline_execute_command(*args, **kwargs)
1328 def _disconnect_reset_raise(self, conn, error) -> None:
1329 """
1330 Close the connection, reset watching state and
1331 raise an exception if we were watching,
1332 if retry_on_error is not set or the error is not one
1333 of the specified error types.
1334 """
1335 conn.disconnect()
1336 # if we were already watching a variable, the watch is no longer
1337 # valid since this connection has died. raise a WatchError, which
1338 # indicates the user should retry this transaction.
1339 if self.watching:
1340 self.reset()
1341 raise WatchError(
1342 "A ConnectionError occurred on while watching one or more keys"
1343 )
1344 # if retry_on_error is not set or the error is not one
1345 # of the specified error types, raise it
1346 if (
1347 conn.retry_on_error is None
1348 or isinstance(error, tuple(conn.retry_on_error)) is False
1349 ):
1350 self.reset()
1351 raise
1353 def immediate_execute_command(self, *args, **options):
1354 """
1355 Execute a command immediately, but don't auto-retry on a
1356 ConnectionError if we're already WATCHing a variable. Used when
1357 issuing WATCH or subsequent commands retrieving their values but before
1358 MULTI is called.
1359 """
1360 command_name = args[0]
1361 conn = self.connection
1362 # if this is the first call, we need a connection
1363 if not conn:
1364 conn = self.connection_pool.get_connection(command_name, self.shard_hint)
1365 self.connection = conn
1367 return conn.retry.call_with_retry(
1368 lambda: self._send_command_parse_response(
1369 conn, command_name, *args, **options
1370 ),
1371 lambda error: self._disconnect_reset_raise(conn, error),
1372 )
1374 def pipeline_execute_command(self, *args, **options) -> "Pipeline":
1375 """
1376 Stage a command to be executed when execute() is next called
1378 Returns the current Pipeline object back so commands can be
1379 chained together, such as:
1381 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
1383 At some other point, you can then run: pipe.execute(),
1384 which will execute all commands queued in the pipe.
1385 """
1386 self.command_stack.append((args, options))
1387 return self
1389 def _execute_transaction(self, connection, commands, raise_on_error) -> List:
1390 cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})])
1391 all_cmds = connection.pack_commands(
1392 [args for args, options in cmds if EMPTY_RESPONSE not in options]
1393 )
1394 connection.send_packed_command(all_cmds)
1395 errors = []
1397 # parse off the response for MULTI
1398 # NOTE: we need to handle ResponseErrors here and continue
1399 # so that we read all the additional command messages from
1400 # the socket
1401 try:
1402 self.parse_response(connection, "_")
1403 except ResponseError as e:
1404 errors.append((0, e))
1406 # and all the other commands
1407 for i, command in enumerate(commands):
1408 if EMPTY_RESPONSE in command[1]:
1409 errors.append((i, command[1][EMPTY_RESPONSE]))
1410 else:
1411 try:
1412 self.parse_response(connection, "_")
1413 except ResponseError as e:
1414 self.annotate_exception(e, i + 1, command[0])
1415 errors.append((i, e))
1417 # parse the EXEC.
1418 try:
1419 response = self.parse_response(connection, "_")
1420 except ExecAbortError:
1421 if errors:
1422 raise errors[0][1]
1423 raise
1425 # EXEC clears any watched keys
1426 self.watching = False
1428 if response is None:
1429 raise WatchError("Watched variable changed.")
1431 # put any parse errors into the response
1432 for i, e in errors:
1433 response.insert(i, e)
1435 if len(response) != len(commands):
1436 self.connection.disconnect()
1437 raise ResponseError(
1438 "Wrong number of response items from pipeline execution"
1439 )
1441 # find any errors in the response and raise if necessary
1442 if raise_on_error:
1443 self.raise_first_error(commands, response)
1445 # We have to run response callbacks manually
1446 data = []
1447 for r, cmd in zip(response, commands):
1448 if not isinstance(r, Exception):
1449 args, options = cmd
1450 command_name = args[0]
1451 if command_name in self.response_callbacks:
1452 r = self.response_callbacks[command_name](r, **options)
1453 data.append(r)
1454 return data
1456 def _execute_pipeline(self, connection, commands, raise_on_error):
1457 # build up all commands into a single request to increase network perf
1458 all_cmds = connection.pack_commands([args for args, _ in commands])
1459 connection.send_packed_command(all_cmds)
1461 response = []
1462 for args, options in commands:
1463 try:
1464 response.append(self.parse_response(connection, args[0], **options))
1465 except ResponseError as e:
1466 response.append(e)
1468 if raise_on_error:
1469 self.raise_first_error(commands, response)
1470 return response
1472 def raise_first_error(self, commands, response):
1473 for i, r in enumerate(response):
1474 if isinstance(r, ResponseError):
1475 self.annotate_exception(r, i + 1, commands[i][0])
1476 raise r
1478 def annotate_exception(self, exception, number, command):
1479 cmd = " ".join(map(safe_str, command))
1480 msg = (
1481 f"Command # {number} ({cmd}) of pipeline "
1482 f"caused error: {exception.args[0]}"
1483 )
1484 exception.args = (msg,) + exception.args[1:]
1486 def parse_response(self, connection, command_name, **options):
1487 result = Redis.parse_response(self, connection, command_name, **options)
1488 if command_name in self.UNWATCH_COMMANDS:
1489 self.watching = False
1490 elif command_name == "WATCH":
1491 self.watching = True
1492 return result
1494 def load_scripts(self):
1495 # make sure all scripts that are about to be run on this pipeline exist
1496 scripts = list(self.scripts)
1497 immediate = self.immediate_execute_command
1498 shas = [s.sha for s in scripts]
1499 # we can't use the normal script_* methods because they would just
1500 # get buffered in the pipeline.
1501 exists = immediate("SCRIPT EXISTS", *shas)
1502 if not all(exists):
1503 for s, exist in zip(scripts, exists):
1504 if not exist:
1505 s.sha = immediate("SCRIPT LOAD", s.script)
1507 def _disconnect_raise_reset(
1508 self,
1509 conn: AbstractConnection,
1510 error: Exception,
1511 ) -> None:
1512 """
1513 Close the connection, raise an exception if we were watching,
1514 and raise an exception if retry_on_error is not set or the
1515 error is not one of the specified error types.
1516 """
1517 conn.disconnect()
1518 # if we were watching a variable, the watch is no longer valid
1519 # since this connection has died. raise a WatchError, which
1520 # indicates the user should retry this transaction.
1521 if self.watching:
1522 raise WatchError(
1523 "A ConnectionError occurred on while watching one or more keys"
1524 )
1525 # if retry_on_error is not set or the error is not one
1526 # of the specified error types, raise it
1527 if (
1528 conn.retry_on_error is None
1529 or isinstance(error, tuple(conn.retry_on_error)) is False
1530 ):
1532 self.reset()
1533 raise error
1535 def execute(self, raise_on_error=True):
1536 """Execute all the commands in the current pipeline"""
1537 stack = self.command_stack
1538 if not stack and not self.watching:
1539 return []
1540 if self.scripts:
1541 self.load_scripts()
1542 if self.transaction or self.explicit_transaction:
1543 execute = self._execute_transaction
1544 else:
1545 execute = self._execute_pipeline
1547 conn = self.connection
1548 if not conn:
1549 conn = self.connection_pool.get_connection("MULTI", self.shard_hint)
1550 # assign to self.connection so reset() releases the connection
1551 # back to the pool after we're done
1552 self.connection = conn
1554 try:
1555 return conn.retry.call_with_retry(
1556 lambda: execute(conn, stack, raise_on_error),
1557 lambda error: self._disconnect_raise_reset(conn, error),
1558 )
1559 finally:
1560 self.reset()
1562 def discard(self):
1563 """
1564 Flushes all previously queued commands
1565 See: https://redis.io/commands/DISCARD
1566 """
1567 self.execute_command("DISCARD")
1569 def watch(self, *names):
1570 """Watches the values at keys ``names``"""
1571 if self.explicit_transaction:
1572 raise RedisError("Cannot issue a WATCH after a MULTI")
1573 return self.execute_command("WATCH", *names)
1575 def unwatch(self) -> bool:
1576 """Unwatches all previously specified keys"""
1577 return self.watching and self.execute_command("UNWATCH") or True