Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32)
34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
35from .auth.token import TokenInterface
36from .backoff import NoBackoff
37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
38from .event import AfterConnectionReleasedEvent, EventDispatcher
39from .exceptions import (
40 AuthenticationError,
41 AuthenticationWrongNumberOfArgsError,
42 ChildDeadlockedError,
43 ConnectionError,
44 DataError,
45 MaxConnectionsError,
46 RedisError,
47 ResponseError,
48 TimeoutError,
49)
50from .maint_notifications import (
51 MaintenanceState,
52 MaintNotificationsConfig,
53 MaintNotificationsConnectionHandler,
54 MaintNotificationsPoolHandler,
55)
56from .retry import Retry
57from .utils import (
58 CRYPTOGRAPHY_AVAILABLE,
59 HIREDIS_AVAILABLE,
60 SSL_AVAILABLE,
61 compare_versions,
62 deprecated_args,
63 ensure_string,
64 format_error_message,
65 get_lib_version,
66 str_if_bytes,
67)
69if SSL_AVAILABLE:
70 import ssl
71 from ssl import VerifyFlags
72else:
73 ssl = None
74 VerifyFlags = None
76if HIREDIS_AVAILABLE:
77 import hiredis
79SYM_STAR = b"*"
80SYM_DOLLAR = b"$"
81SYM_CRLF = b"\r\n"
82SYM_EMPTY = b""
84DEFAULT_RESP_VERSION = 2
86SENTINEL = object()
88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
89if HIREDIS_AVAILABLE:
90 DefaultParser = _HiredisParser
91else:
92 DefaultParser = _RESP2Parser
95class HiredisRespSerializer:
96 def pack(self, *args: List):
97 """Pack a series of arguments into the Redis protocol"""
98 output = []
100 if isinstance(args[0], str):
101 args = tuple(args[0].encode().split()) + args[1:]
102 elif b" " in args[0]:
103 args = tuple(args[0].split()) + args[1:]
104 try:
105 output.append(hiredis.pack_command(args))
106 except TypeError:
107 _, value, traceback = sys.exc_info()
108 raise DataError(value).with_traceback(traceback)
110 return output
113class PythonRespSerializer:
114 def __init__(self, buffer_cutoff, encode) -> None:
115 self._buffer_cutoff = buffer_cutoff
116 self.encode = encode
118 def pack(self, *args):
119 """Pack a series of arguments into the Redis protocol"""
120 output = []
121 # the client might have included 1 or more literal arguments in
122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
123 # arguments to be sent separately, so split the first argument
124 # manually. These arguments should be bytestrings so that they are
125 # not encoded.
126 if isinstance(args[0], str):
127 args = tuple(args[0].encode().split()) + args[1:]
128 elif b" " in args[0]:
129 args = tuple(args[0].split()) + args[1:]
131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
133 buffer_cutoff = self._buffer_cutoff
134 for arg in map(self.encode, args):
135 # to avoid large string mallocs, chunk the command into the
136 # output list if we're sending large values or memoryviews
137 arg_length = len(arg)
138 if (
139 len(buff) > buffer_cutoff
140 or arg_length > buffer_cutoff
141 or isinstance(arg, memoryview)
142 ):
143 buff = SYM_EMPTY.join(
144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
145 )
146 output.append(buff)
147 output.append(arg)
148 buff = SYM_CRLF
149 else:
150 buff = SYM_EMPTY.join(
151 (
152 buff,
153 SYM_DOLLAR,
154 str(arg_length).encode(),
155 SYM_CRLF,
156 arg,
157 SYM_CRLF,
158 )
159 )
160 output.append(buff)
161 return output
164class ConnectionInterface:
165 @abstractmethod
166 def repr_pieces(self):
167 pass
169 @abstractmethod
170 def register_connect_callback(self, callback):
171 pass
173 @abstractmethod
174 def deregister_connect_callback(self, callback):
175 pass
177 @abstractmethod
178 def set_parser(self, parser_class):
179 pass
181 @abstractmethod
182 def get_protocol(self):
183 pass
185 @abstractmethod
186 def connect(self):
187 pass
189 @abstractmethod
190 def on_connect(self):
191 pass
193 @abstractmethod
194 def disconnect(self, *args):
195 pass
197 @abstractmethod
198 def check_health(self):
199 pass
201 @abstractmethod
202 def send_packed_command(self, command, check_health=True):
203 pass
205 @abstractmethod
206 def send_command(self, *args, **kwargs):
207 pass
209 @abstractmethod
210 def can_read(self, timeout=0):
211 pass
213 @abstractmethod
214 def read_response(
215 self,
216 disable_decoding=False,
217 *,
218 disconnect_on_error=True,
219 push_request=False,
220 ):
221 pass
223 @abstractmethod
224 def pack_command(self, *args):
225 pass
227 @abstractmethod
228 def pack_commands(self, commands):
229 pass
231 @property
232 @abstractmethod
233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
234 pass
236 @abstractmethod
237 def set_re_auth_token(self, token: TokenInterface):
238 pass
240 @abstractmethod
241 def re_auth(self):
242 pass
244 @abstractmethod
245 def mark_for_reconnect(self):
246 """
247 Mark the connection to be reconnected on the next command.
248 This is useful when a connection is moved to a different node.
249 """
250 pass
252 @abstractmethod
253 def should_reconnect(self):
254 """
255 Returns True if the connection should be reconnected.
256 """
257 pass
259 @abstractmethod
260 def reset_should_reconnect(self):
261 """
262 Reset the internal flag to False.
263 """
264 pass
267class MaintNotificationsAbstractConnection:
268 """
269 Abstract class for handling maintenance notifications logic.
270 This class is expected to be used as base class together with ConnectionInterface.
272 This class is intended to be used with multiple inheritance!
274 All logic related to maintenance notifications is encapsulated in this class.
275 """
277 def __init__(
278 self,
279 maint_notifications_config: Optional[MaintNotificationsConfig],
280 maint_notifications_pool_handler: Optional[
281 MaintNotificationsPoolHandler
282 ] = None,
283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
284 maintenance_notification_hash: Optional[int] = None,
285 orig_host_address: Optional[str] = None,
286 orig_socket_timeout: Optional[float] = None,
287 orig_socket_connect_timeout: Optional[float] = None,
288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
289 ):
290 """
291 Initialize the maintenance notifications for the connection.
293 Args:
294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
296 maintenance_state (MaintenanceState): The current maintenance state of the connection.
297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
298 orig_host_address (Optional[str]): The original host address of the connection.
299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
302 If not provided, the parser from the connection is used.
303 This is useful when the parser is created after this object.
304 """
305 self.maint_notifications_config = maint_notifications_config
306 self.maintenance_state = maintenance_state
307 self.maintenance_notification_hash = maintenance_notification_hash
308 self._configure_maintenance_notifications(
309 maint_notifications_pool_handler,
310 orig_host_address,
311 orig_socket_timeout,
312 orig_socket_connect_timeout,
313 parser,
314 )
316 @abstractmethod
317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
318 pass
320 @abstractmethod
321 def _get_socket(self) -> Optional[socket.socket]:
322 pass
324 @abstractmethod
325 def get_protocol(self) -> Union[int, str]:
326 """
327 Returns:
328 The RESP protocol version, or ``None`` if the protocol is not specified,
329 in which case the server default will be used.
330 """
331 pass
333 @property
334 @abstractmethod
335 def host(self) -> str:
336 pass
338 @host.setter
339 @abstractmethod
340 def host(self, value: str):
341 pass
343 @property
344 @abstractmethod
345 def socket_timeout(self) -> Optional[Union[float, int]]:
346 pass
348 @socket_timeout.setter
349 @abstractmethod
350 def socket_timeout(self, value: Optional[Union[float, int]]):
351 pass
353 @property
354 @abstractmethod
355 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
356 pass
358 @socket_connect_timeout.setter
359 @abstractmethod
360 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
361 pass
363 @abstractmethod
364 def send_command(self, *args, **kwargs):
365 pass
367 @abstractmethod
368 def read_response(
369 self,
370 disable_decoding=False,
371 *,
372 disconnect_on_error=True,
373 push_request=False,
374 ):
375 pass
377 @abstractmethod
378 def disconnect(self, *args):
379 pass
381 def _configure_maintenance_notifications(
382 self,
383 maint_notifications_pool_handler: Optional[
384 MaintNotificationsPoolHandler
385 ] = None,
386 orig_host_address=None,
387 orig_socket_timeout=None,
388 orig_socket_connect_timeout=None,
389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
390 ):
391 """
392 Enable maintenance notifications by setting up
393 handlers and storing original connection parameters.
395 Should be used ONLY with parsers that support push notifications.
396 """
397 if (
398 not self.maint_notifications_config
399 or not self.maint_notifications_config.enabled
400 ):
401 self._maint_notifications_pool_handler = None
402 self._maint_notifications_connection_handler = None
403 return
405 if not parser:
406 raise RedisError(
407 "To configure maintenance notifications, a parser must be provided!"
408 )
410 if not isinstance(parser, _HiredisParser) and not isinstance(
411 parser, _RESP3Parser
412 ):
413 raise RedisError(
414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
415 )
417 if maint_notifications_pool_handler:
418 # Extract a reference to a new pool handler that copies all properties
419 # of the original one and has a different connection reference
420 # This is needed because when we attach the handler to the parser
421 # we need to make sure that the handler has a reference to the
422 # connection that the parser is attached to.
423 self._maint_notifications_pool_handler = (
424 maint_notifications_pool_handler.get_handler_for_connection()
425 )
426 self._maint_notifications_pool_handler.set_connection(self)
427 else:
428 self._maint_notifications_pool_handler = None
430 self._maint_notifications_connection_handler = (
431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
432 )
434 # Set up pool handler if available
435 if self._maint_notifications_pool_handler:
436 parser.set_node_moving_push_handler(
437 self._maint_notifications_pool_handler.handle_notification
438 )
440 # Set up connection handler
441 parser.set_maintenance_push_handler(
442 self._maint_notifications_connection_handler.handle_notification
443 )
445 # Store original connection parameters
446 self.orig_host_address = orig_host_address if orig_host_address else self.host
447 self.orig_socket_timeout = (
448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
449 )
450 self.orig_socket_connect_timeout = (
451 orig_socket_connect_timeout
452 if orig_socket_connect_timeout
453 else self.socket_connect_timeout
454 )
456 def set_maint_notifications_pool_handler_for_connection(
457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
458 ):
459 # Deep copy the pool handler to avoid sharing the same pool handler
460 # between multiple connections, because otherwise each connection will override
461 # the connection reference and the pool handler will only hold a reference
462 # to the last connection that was set.
463 maint_notifications_pool_handler_copy = (
464 maint_notifications_pool_handler.get_handler_for_connection()
465 )
467 maint_notifications_pool_handler_copy.set_connection(self)
468 self._get_parser().set_node_moving_push_handler(
469 maint_notifications_pool_handler_copy.handle_notification
470 )
472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
474 # Update maintenance notification connection handler if it doesn't exist
475 if not self._maint_notifications_connection_handler:
476 self._maint_notifications_connection_handler = (
477 MaintNotificationsConnectionHandler(
478 self, maint_notifications_pool_handler.config
479 )
480 )
481 self._get_parser().set_maintenance_push_handler(
482 self._maint_notifications_connection_handler.handle_notification
483 )
484 else:
485 self._maint_notifications_connection_handler.config = (
486 maint_notifications_pool_handler.config
487 )
489 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
490 # Send maintenance notifications handshake if RESP3 is active
491 # and maintenance notifications are enabled
492 # and we have a host to determine the endpoint type from
493 # When the maint_notifications_config enabled mode is "auto",
494 # we just log a warning if the handshake fails
495 # When the mode is enabled=True, we raise an exception in case of failure
496 if (
497 self.get_protocol() not in [2, "2"]
498 and self.maint_notifications_config
499 and self.maint_notifications_config.enabled
500 and self._maint_notifications_connection_handler
501 and hasattr(self, "host")
502 ):
503 self._enable_maintenance_notifications(
504 maint_notifications_config=self.maint_notifications_config,
505 check_health=check_health,
506 )
508 def _enable_maintenance_notifications(
509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
510 ):
511 try:
512 host = getattr(self, "host", None)
513 if host is None:
514 raise ValueError(
515 "Cannot enable maintenance notifications for connection"
516 " object that doesn't have a host attribute."
517 )
518 else:
519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
520 self.send_command(
521 "CLIENT",
522 "MAINT_NOTIFICATIONS",
523 "ON",
524 "moving-endpoint-type",
525 endpoint_type.value,
526 check_health=check_health,
527 )
528 response = self.read_response()
529 if not response or str_if_bytes(response) != "OK":
530 raise ResponseError(
531 "The server doesn't support maintenance notifications"
532 )
533 except Exception as e:
534 if (
535 isinstance(e, ResponseError)
536 and maint_notifications_config.enabled == "auto"
537 ):
538 # Log warning but don't fail the connection
539 import logging
541 logger = logging.getLogger(__name__)
542 logger.debug(f"Failed to enable maintenance notifications: {e}")
543 else:
544 raise
546 def get_resolved_ip(self) -> Optional[str]:
547 """
548 Extract the resolved IP address from an
549 established connection or resolve it from the host.
551 First tries to get the actual IP from the socket (most accurate),
552 then falls back to DNS resolution if needed.
554 Args:
555 connection: The connection object to extract the IP from
557 Returns:
558 str: The resolved IP address, or None if it cannot be determined
559 """
561 # Method 1: Try to get the actual IP from the established socket connection
562 # This is most accurate as it shows the exact IP being used
563 try:
564 conn_socket = self._get_socket()
565 if conn_socket is not None:
566 peer_addr = conn_socket.getpeername()
567 if peer_addr and len(peer_addr) >= 1:
568 # For TCP sockets, peer_addr is typically (host, port) tuple
569 # Return just the host part
570 return peer_addr[0]
571 except (AttributeError, OSError):
572 # Socket might not be connected or getpeername() might fail
573 pass
575 # Method 2: Fallback to DNS resolution of the host
576 # This is less accurate but works when socket is not available
577 try:
578 host = getattr(self, "host", "localhost")
579 port = getattr(self, "port", 6379)
580 if host:
581 # Use getaddrinfo to resolve the hostname to IP
582 # This mimics what the connection would do during _connect()
583 addr_info = socket.getaddrinfo(
584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
585 )
586 if addr_info:
587 # Return the IP from the first result
588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
589 # sockaddr[0] is the IP address
590 return str(addr_info[0][4][0])
591 except (AttributeError, OSError, socket.gaierror):
592 # DNS resolution might fail
593 pass
595 return None
597 @property
598 def maintenance_state(self) -> MaintenanceState:
599 return self._maintenance_state
601 @maintenance_state.setter
602 def maintenance_state(self, state: "MaintenanceState"):
603 self._maintenance_state = state
605 def getpeername(self):
606 """
607 Returns the peer name of the connection.
608 """
609 conn_socket = self._get_socket()
610 if conn_socket:
611 return conn_socket.getpeername()[0]
612 return None
614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
615 conn_socket = self._get_socket()
616 if conn_socket:
617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
618 conn_socket.settimeout(timeout)
619 self.update_parser_timeout(timeout)
621 def update_parser_timeout(self, timeout: Optional[float] = None):
622 parser = self._get_parser()
623 if parser and parser._buffer:
624 if isinstance(parser, _RESP3Parser) and timeout:
625 parser._buffer.socket_timeout = timeout
626 elif isinstance(parser, _HiredisParser):
627 parser._socket_timeout = timeout
629 def set_tmp_settings(
630 self,
631 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
632 tmp_relaxed_timeout: Optional[float] = None,
633 ):
634 """
635 The value of SENTINEL is used to indicate that the property should not be updated.
636 """
637 if tmp_host_address and tmp_host_address != SENTINEL:
638 self.host = str(tmp_host_address)
639 if tmp_relaxed_timeout != -1:
640 self.socket_timeout = tmp_relaxed_timeout
641 self.socket_connect_timeout = tmp_relaxed_timeout
643 def reset_tmp_settings(
644 self,
645 reset_host_address: bool = False,
646 reset_relaxed_timeout: bool = False,
647 ):
648 if reset_host_address:
649 self.host = self.orig_host_address
650 if reset_relaxed_timeout:
651 self.socket_timeout = self.orig_socket_timeout
652 self.socket_connect_timeout = self.orig_socket_connect_timeout
655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
656 "Manages communication to and from a Redis server"
658 def __init__(
659 self,
660 db: int = 0,
661 password: Optional[str] = None,
662 socket_timeout: Optional[float] = None,
663 socket_connect_timeout: Optional[float] = None,
664 retry_on_timeout: bool = False,
665 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
666 encoding: str = "utf-8",
667 encoding_errors: str = "strict",
668 decode_responses: bool = False,
669 parser_class=DefaultParser,
670 socket_read_size: int = 65536,
671 health_check_interval: int = 0,
672 client_name: Optional[str] = None,
673 lib_name: Optional[str] = "redis-py",
674 lib_version: Optional[str] = get_lib_version(),
675 username: Optional[str] = None,
676 retry: Union[Any, None] = None,
677 redis_connect_func: Optional[Callable[[], None]] = None,
678 credential_provider: Optional[CredentialProvider] = None,
679 protocol: Optional[int] = 2,
680 command_packer: Optional[Callable[[], None]] = None,
681 event_dispatcher: Optional[EventDispatcher] = None,
682 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
683 maint_notifications_pool_handler: Optional[
684 MaintNotificationsPoolHandler
685 ] = None,
686 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
687 maintenance_notification_hash: Optional[int] = None,
688 orig_host_address: Optional[str] = None,
689 orig_socket_timeout: Optional[float] = None,
690 orig_socket_connect_timeout: Optional[float] = None,
691 ):
692 """
693 Initialize a new Connection.
694 To specify a retry policy for specific errors, first set
695 `retry_on_error` to a list of the error/s to retry on, then set
696 `retry` to a valid `Retry` object.
697 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
698 """
699 if (username or password) and credential_provider is not None:
700 raise DataError(
701 "'username' and 'password' cannot be passed along with 'credential_"
702 "provider'. Please provide only one of the following arguments: \n"
703 "1. 'password' and (optional) 'username'\n"
704 "2. 'credential_provider'"
705 )
706 if event_dispatcher is None:
707 self._event_dispatcher = EventDispatcher()
708 else:
709 self._event_dispatcher = event_dispatcher
710 self.pid = os.getpid()
711 self.db = db
712 self.client_name = client_name
713 self.lib_name = lib_name
714 self.lib_version = lib_version
715 self.credential_provider = credential_provider
716 self.password = password
717 self.username = username
718 self._socket_timeout = socket_timeout
719 if socket_connect_timeout is None:
720 socket_connect_timeout = socket_timeout
721 self._socket_connect_timeout = socket_connect_timeout
722 self.retry_on_timeout = retry_on_timeout
723 if retry_on_error is SENTINEL:
724 retry_on_errors_list = []
725 else:
726 retry_on_errors_list = list(retry_on_error)
727 if retry_on_timeout:
728 # Add TimeoutError to the errors list to retry on
729 retry_on_errors_list.append(TimeoutError)
730 self.retry_on_error = retry_on_errors_list
731 if retry or self.retry_on_error:
732 if retry is None:
733 self.retry = Retry(NoBackoff(), 1)
734 else:
735 # deep-copy the Retry object as it is mutable
736 self.retry = copy.deepcopy(retry)
737 if self.retry_on_error:
738 # Update the retry's supported errors with the specified errors
739 self.retry.update_supported_errors(self.retry_on_error)
740 else:
741 self.retry = Retry(NoBackoff(), 0)
742 self.health_check_interval = health_check_interval
743 self.next_health_check = 0
744 self.redis_connect_func = redis_connect_func
745 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
746 self.handshake_metadata = None
747 self._sock = None
748 self._socket_read_size = socket_read_size
749 self._connect_callbacks = []
750 self._buffer_cutoff = 6000
751 self._re_auth_token: Optional[TokenInterface] = None
752 try:
753 p = int(protocol)
754 except TypeError:
755 p = DEFAULT_RESP_VERSION
756 except ValueError:
757 raise ConnectionError("protocol must be an integer")
758 finally:
759 if p < 2 or p > 3:
760 raise ConnectionError("protocol must be either 2 or 3")
761 # p = DEFAULT_RESP_VERSION
762 self.protocol = p
763 if self.protocol == 3 and parser_class == _RESP2Parser:
764 # If the protocol is 3 but the parser is RESP2, change it to RESP3
765 # This is needed because the parser might be set before the protocol
766 # or might be provided as a kwarg to the constructor
767 # We need to react on discrepancy only for RESP2 and RESP3
768 # as hiredis supports both
769 parser_class = _RESP3Parser
770 self.set_parser(parser_class)
772 self._command_packer = self._construct_command_packer(command_packer)
773 self._should_reconnect = False
775 # Set up maintenance notifications
776 MaintNotificationsAbstractConnection.__init__(
777 self,
778 maint_notifications_config,
779 maint_notifications_pool_handler,
780 maintenance_state,
781 maintenance_notification_hash,
782 orig_host_address,
783 orig_socket_timeout,
784 orig_socket_connect_timeout,
785 self._parser,
786 )
788 def __repr__(self):
789 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
790 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
792 @abstractmethod
793 def repr_pieces(self):
794 pass
796 def __del__(self):
797 try:
798 self.disconnect()
799 except Exception:
800 pass
802 def _construct_command_packer(self, packer):
803 if packer is not None:
804 return packer
805 elif HIREDIS_AVAILABLE:
806 return HiredisRespSerializer()
807 else:
808 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
810 def register_connect_callback(self, callback):
811 """
812 Register a callback to be called when the connection is established either
813 initially or reconnected. This allows listeners to issue commands that
814 are ephemeral to the connection, for example pub/sub subscription or
815 key tracking. The callback must be a _method_ and will be kept as
816 a weak reference.
817 """
818 wm = weakref.WeakMethod(callback)
819 if wm not in self._connect_callbacks:
820 self._connect_callbacks.append(wm)
822 def deregister_connect_callback(self, callback):
823 """
824 De-register a previously registered callback. It will no-longer receive
825 notifications on connection events. Calling this is not required when the
826 listener goes away, since the callbacks are kept as weak methods.
827 """
828 try:
829 self._connect_callbacks.remove(weakref.WeakMethod(callback))
830 except ValueError:
831 pass
833 def set_parser(self, parser_class):
834 """
835 Creates a new instance of parser_class with socket size:
836 _socket_read_size and assigns it to the parser for the connection
837 :param parser_class: The required parser class
838 """
839 self._parser = parser_class(socket_read_size=self._socket_read_size)
841 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
842 return self._parser
844 def connect(self):
845 "Connects to the Redis server if not already connected"
846 # try once the socket connect with the handshake, retry the whole
847 # connect/handshake flow based on retry policy
848 self.retry.call_with_retry(
849 lambda: self.connect_check_health(
850 check_health=True, retry_socket_connect=False
851 ),
852 lambda error: self.disconnect(error),
853 )
855 def connect_check_health(
856 self, check_health: bool = True, retry_socket_connect: bool = True
857 ):
858 if self._sock:
859 return
860 try:
861 if retry_socket_connect:
862 sock = self.retry.call_with_retry(
863 lambda: self._connect(), lambda error: self.disconnect(error)
864 )
865 else:
866 sock = self._connect()
867 except socket.timeout:
868 raise TimeoutError("Timeout connecting to server")
869 except OSError as e:
870 raise ConnectionError(self._error_message(e))
872 self._sock = sock
873 try:
874 if self.redis_connect_func is None:
875 # Use the default on_connect function
876 self.on_connect_check_health(check_health=check_health)
877 else:
878 # Use the passed function redis_connect_func
879 self.redis_connect_func(self)
880 except RedisError:
881 # clean up after any error in on_connect
882 self.disconnect()
883 raise
885 # run any user callbacks. right now the only internal callback
886 # is for pubsub channel/pattern resubscription
887 # first, remove any dead weakrefs
888 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
889 for ref in self._connect_callbacks:
890 callback = ref()
891 if callback:
892 callback(self)
894 @abstractmethod
895 def _connect(self):
896 pass
898 @abstractmethod
899 def _host_error(self):
900 pass
902 def _error_message(self, exception):
903 return format_error_message(self._host_error(), exception)
905 def on_connect(self):
906 self.on_connect_check_health(check_health=True)
908 def on_connect_check_health(self, check_health: bool = True):
909 "Initialize the connection, authenticate and select a database"
910 self._parser.on_connect(self)
911 parser = self._parser
913 auth_args = None
914 # if credential provider or username and/or password are set, authenticate
915 if self.credential_provider or (self.username or self.password):
916 cred_provider = (
917 self.credential_provider
918 or UsernamePasswordCredentialProvider(self.username, self.password)
919 )
920 auth_args = cred_provider.get_credentials()
922 # if resp version is specified and we have auth args,
923 # we need to send them via HELLO
924 if auth_args and self.protocol not in [2, "2"]:
925 if isinstance(self._parser, _RESP2Parser):
926 self.set_parser(_RESP3Parser)
927 # update cluster exception classes
928 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
929 self._parser.on_connect(self)
930 if len(auth_args) == 1:
931 auth_args = ["default", auth_args[0]]
932 # avoid checking health here -- PING will fail if we try
933 # to check the health prior to the AUTH
934 self.send_command(
935 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
936 )
937 self.handshake_metadata = self.read_response()
938 # if response.get(b"proto") != self.protocol and response.get(
939 # "proto"
940 # ) != self.protocol:
941 # raise ConnectionError("Invalid RESP version")
942 elif auth_args:
943 # avoid checking health here -- PING will fail if we try
944 # to check the health prior to the AUTH
945 self.send_command("AUTH", *auth_args, check_health=False)
947 try:
948 auth_response = self.read_response()
949 except AuthenticationWrongNumberOfArgsError:
950 # a username and password were specified but the Redis
951 # server seems to be < 6.0.0 which expects a single password
952 # arg. retry auth with just the password.
953 # https://github.com/andymccurdy/redis-py/issues/1274
954 self.send_command("AUTH", auth_args[-1], check_health=False)
955 auth_response = self.read_response()
957 if str_if_bytes(auth_response) != "OK":
958 raise AuthenticationError("Invalid Username or Password")
960 # if resp version is specified, switch to it
961 elif self.protocol not in [2, "2"]:
962 if isinstance(self._parser, _RESP2Parser):
963 self.set_parser(_RESP3Parser)
964 # update cluster exception classes
965 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
966 self._parser.on_connect(self)
967 self.send_command("HELLO", self.protocol, check_health=check_health)
968 self.handshake_metadata = self.read_response()
969 if (
970 self.handshake_metadata.get(b"proto") != self.protocol
971 and self.handshake_metadata.get("proto") != self.protocol
972 ):
973 raise ConnectionError("Invalid RESP version")
975 # Activate maintenance notifications for this connection
976 # if enabled in the configuration
977 # This is a no-op if maintenance notifications are not enabled
978 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
980 # if a client_name is given, set it
981 if self.client_name:
982 self.send_command(
983 "CLIENT",
984 "SETNAME",
985 self.client_name,
986 check_health=check_health,
987 )
988 if str_if_bytes(self.read_response()) != "OK":
989 raise ConnectionError("Error setting client name")
991 try:
992 # set the library name and version
993 if self.lib_name:
994 self.send_command(
995 "CLIENT",
996 "SETINFO",
997 "LIB-NAME",
998 self.lib_name,
999 check_health=check_health,
1000 )
1001 self.read_response()
1002 except ResponseError:
1003 pass
1005 try:
1006 if self.lib_version:
1007 self.send_command(
1008 "CLIENT",
1009 "SETINFO",
1010 "LIB-VER",
1011 self.lib_version,
1012 check_health=check_health,
1013 )
1014 self.read_response()
1015 except ResponseError:
1016 pass
1018 # if a database is specified, switch to it
1019 if self.db:
1020 self.send_command("SELECT", self.db, check_health=check_health)
1021 if str_if_bytes(self.read_response()) != "OK":
1022 raise ConnectionError("Invalid Database")
1024 def disconnect(self, *args):
1025 "Disconnects from the Redis server"
1026 self._parser.on_disconnect()
1028 conn_sock = self._sock
1029 self._sock = None
1030 # reset the reconnect flag
1031 self.reset_should_reconnect()
1032 if conn_sock is None:
1033 return
1035 if os.getpid() == self.pid:
1036 try:
1037 conn_sock.shutdown(socket.SHUT_RDWR)
1038 except (OSError, TypeError):
1039 pass
1041 try:
1042 conn_sock.close()
1043 except OSError:
1044 pass
1046 def mark_for_reconnect(self):
1047 self._should_reconnect = True
1049 def should_reconnect(self):
1050 return self._should_reconnect
1052 def reset_should_reconnect(self):
1053 self._should_reconnect = False
1055 def _send_ping(self):
1056 """Send PING, expect PONG in return"""
1057 self.send_command("PING", check_health=False)
1058 if str_if_bytes(self.read_response()) != "PONG":
1059 raise ConnectionError("Bad response from PING health check")
1061 def _ping_failed(self, error):
1062 """Function to call when PING fails"""
1063 self.disconnect()
1065 def check_health(self):
1066 """Check the health of the connection with a PING/PONG"""
1067 if self.health_check_interval and time.monotonic() > self.next_health_check:
1068 self.retry.call_with_retry(self._send_ping, self._ping_failed)
1070 def send_packed_command(self, command, check_health=True):
1071 """Send an already packed command to the Redis server"""
1072 if not self._sock:
1073 self.connect_check_health(check_health=False)
1074 # guard against health check recursion
1075 if check_health:
1076 self.check_health()
1077 try:
1078 if isinstance(command, str):
1079 command = [command]
1080 for item in command:
1081 self._sock.sendall(item)
1082 except socket.timeout:
1083 self.disconnect()
1084 raise TimeoutError("Timeout writing to socket")
1085 except OSError as e:
1086 self.disconnect()
1087 if len(e.args) == 1:
1088 errno, errmsg = "UNKNOWN", e.args[0]
1089 else:
1090 errno = e.args[0]
1091 errmsg = e.args[1]
1092 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1093 except BaseException:
1094 # BaseExceptions can be raised when a socket send operation is not
1095 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1096 # to send un-sent data. However, the send_packed_command() API
1097 # does not support it so there is no point in keeping the connection open.
1098 self.disconnect()
1099 raise
1101 def send_command(self, *args, **kwargs):
1102 """Pack and send a command to the Redis server"""
1103 self.send_packed_command(
1104 self._command_packer.pack(*args),
1105 check_health=kwargs.get("check_health", True),
1106 )
1108 def can_read(self, timeout=0):
1109 """Poll the socket to see if there's data that can be read."""
1110 sock = self._sock
1111 if not sock:
1112 self.connect()
1114 host_error = self._host_error()
1116 try:
1117 return self._parser.can_read(timeout)
1119 except OSError as e:
1120 self.disconnect()
1121 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1123 def read_response(
1124 self,
1125 disable_decoding=False,
1126 *,
1127 disconnect_on_error=True,
1128 push_request=False,
1129 ):
1130 """Read the response from a previously sent command"""
1132 host_error = self._host_error()
1134 try:
1135 if self.protocol in ["3", 3]:
1136 response = self._parser.read_response(
1137 disable_decoding=disable_decoding, push_request=push_request
1138 )
1139 else:
1140 response = self._parser.read_response(disable_decoding=disable_decoding)
1141 except socket.timeout:
1142 if disconnect_on_error:
1143 self.disconnect()
1144 raise TimeoutError(f"Timeout reading from {host_error}")
1145 except OSError as e:
1146 if disconnect_on_error:
1147 self.disconnect()
1148 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1149 except BaseException:
1150 # Also by default close in case of BaseException. A lot of code
1151 # relies on this behaviour when doing Command/Response pairs.
1152 # See #1128.
1153 if disconnect_on_error:
1154 self.disconnect()
1155 raise
1157 if self.health_check_interval:
1158 self.next_health_check = time.monotonic() + self.health_check_interval
1160 if isinstance(response, ResponseError):
1161 try:
1162 raise response
1163 finally:
1164 del response # avoid creating ref cycles
1165 return response
1167 def pack_command(self, *args):
1168 """Pack a series of arguments into the Redis protocol"""
1169 return self._command_packer.pack(*args)
1171 def pack_commands(self, commands):
1172 """Pack multiple commands into the Redis protocol"""
1173 output = []
1174 pieces = []
1175 buffer_length = 0
1176 buffer_cutoff = self._buffer_cutoff
1178 for cmd in commands:
1179 for chunk in self._command_packer.pack(*cmd):
1180 chunklen = len(chunk)
1181 if (
1182 buffer_length > buffer_cutoff
1183 or chunklen > buffer_cutoff
1184 or isinstance(chunk, memoryview)
1185 ):
1186 if pieces:
1187 output.append(SYM_EMPTY.join(pieces))
1188 buffer_length = 0
1189 pieces = []
1191 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1192 output.append(chunk)
1193 else:
1194 pieces.append(chunk)
1195 buffer_length += chunklen
1197 if pieces:
1198 output.append(SYM_EMPTY.join(pieces))
1199 return output
1201 def get_protocol(self) -> Union[int, str]:
1202 return self.protocol
1204 @property
1205 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1206 return self._handshake_metadata
1208 @handshake_metadata.setter
1209 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1210 self._handshake_metadata = value
1212 def set_re_auth_token(self, token: TokenInterface):
1213 self._re_auth_token = token
1215 def re_auth(self):
1216 if self._re_auth_token is not None:
1217 self.send_command(
1218 "AUTH",
1219 self._re_auth_token.try_get("oid"),
1220 self._re_auth_token.get_value(),
1221 )
1222 self.read_response()
1223 self._re_auth_token = None
1225 def _get_socket(self) -> Optional[socket.socket]:
1226 return self._sock
1228 @property
1229 def socket_timeout(self) -> Optional[Union[float, int]]:
1230 return self._socket_timeout
1232 @socket_timeout.setter
1233 def socket_timeout(self, value: Optional[Union[float, int]]):
1234 self._socket_timeout = value
1236 @property
1237 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1238 return self._socket_connect_timeout
1240 @socket_connect_timeout.setter
1241 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1242 self._socket_connect_timeout = value
1245class Connection(AbstractConnection):
1246 "Manages TCP communication to and from a Redis server"
1248 def __init__(
1249 self,
1250 host="localhost",
1251 port=6379,
1252 socket_keepalive=False,
1253 socket_keepalive_options=None,
1254 socket_type=0,
1255 **kwargs,
1256 ):
1257 self._host = host
1258 self.port = int(port)
1259 self.socket_keepalive = socket_keepalive
1260 self.socket_keepalive_options = socket_keepalive_options or {}
1261 self.socket_type = socket_type
1262 super().__init__(**kwargs)
1264 def repr_pieces(self):
1265 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1266 if self.client_name:
1267 pieces.append(("client_name", self.client_name))
1268 return pieces
1270 def _connect(self):
1271 "Create a TCP socket connection"
1272 # we want to mimic what socket.create_connection does to support
1273 # ipv4/ipv6, but we want to set options prior to calling
1274 # socket.connect()
1275 err = None
1277 for res in socket.getaddrinfo(
1278 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1279 ):
1280 family, socktype, proto, canonname, socket_address = res
1281 sock = None
1282 try:
1283 sock = socket.socket(family, socktype, proto)
1284 # TCP_NODELAY
1285 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1287 # TCP_KEEPALIVE
1288 if self.socket_keepalive:
1289 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1290 for k, v in self.socket_keepalive_options.items():
1291 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1293 # set the socket_connect_timeout before we connect
1294 sock.settimeout(self.socket_connect_timeout)
1296 # connect
1297 sock.connect(socket_address)
1299 # set the socket_timeout now that we're connected
1300 sock.settimeout(self.socket_timeout)
1301 return sock
1303 except OSError as _:
1304 err = _
1305 if sock is not None:
1306 try:
1307 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1308 except OSError:
1309 pass
1310 sock.close()
1312 if err is not None:
1313 raise err
1314 raise OSError("socket.getaddrinfo returned an empty list")
1316 def _host_error(self):
1317 return f"{self.host}:{self.port}"
1319 @property
1320 def host(self) -> str:
1321 return self._host
1323 @host.setter
1324 def host(self, value: str):
1325 self._host = value
1328class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1329 DUMMY_CACHE_VALUE = b"foo"
1330 MIN_ALLOWED_VERSION = "7.4.0"
1331 DEFAULT_SERVER_NAME = "redis"
1333 def __init__(
1334 self,
1335 conn: ConnectionInterface,
1336 cache: CacheInterface,
1337 pool_lock: threading.RLock,
1338 ):
1339 self.pid = os.getpid()
1340 self._conn = conn
1341 self.retry = self._conn.retry
1342 self.host = self._conn.host
1343 self.port = self._conn.port
1344 self.credential_provider = conn.credential_provider
1345 self._pool_lock = pool_lock
1346 self._cache = cache
1347 self._cache_lock = threading.RLock()
1348 self._current_command_cache_key = None
1349 self._current_options = None
1350 self.register_connect_callback(self._enable_tracking_callback)
1352 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1353 MaintNotificationsAbstractConnection.__init__(
1354 self,
1355 self._conn.maint_notifications_config,
1356 self._conn._maint_notifications_pool_handler,
1357 self._conn.maintenance_state,
1358 self._conn.maintenance_notification_hash,
1359 self._conn.host,
1360 self._conn.socket_timeout,
1361 self._conn.socket_connect_timeout,
1362 self._conn._get_parser(),
1363 )
1365 def repr_pieces(self):
1366 return self._conn.repr_pieces()
1368 def register_connect_callback(self, callback):
1369 self._conn.register_connect_callback(callback)
1371 def deregister_connect_callback(self, callback):
1372 self._conn.deregister_connect_callback(callback)
1374 def set_parser(self, parser_class):
1375 self._conn.set_parser(parser_class)
1377 def set_maint_notifications_pool_handler_for_connection(
1378 self, maint_notifications_pool_handler
1379 ):
1380 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1381 self._conn.set_maint_notifications_pool_handler_for_connection(
1382 maint_notifications_pool_handler
1383 )
1385 def get_protocol(self):
1386 return self._conn.get_protocol()
1388 def connect(self):
1389 self._conn.connect()
1391 server_name = self._conn.handshake_metadata.get(b"server", None)
1392 if server_name is None:
1393 server_name = self._conn.handshake_metadata.get("server", None)
1394 server_ver = self._conn.handshake_metadata.get(b"version", None)
1395 if server_ver is None:
1396 server_ver = self._conn.handshake_metadata.get("version", None)
1397 if server_ver is None or server_ver is None:
1398 raise ConnectionError("Cannot retrieve information about server version")
1400 server_ver = ensure_string(server_ver)
1401 server_name = ensure_string(server_name)
1403 if (
1404 server_name != self.DEFAULT_SERVER_NAME
1405 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1406 ):
1407 raise ConnectionError(
1408 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1409 )
1411 def on_connect(self):
1412 self._conn.on_connect()
1414 def disconnect(self, *args):
1415 with self._cache_lock:
1416 self._cache.flush()
1417 self._conn.disconnect(*args)
1419 def check_health(self):
1420 self._conn.check_health()
1422 def send_packed_command(self, command, check_health=True):
1423 # TODO: Investigate if it's possible to unpack command
1424 # or extract keys from packed command
1425 self._conn.send_packed_command(command)
1427 def send_command(self, *args, **kwargs):
1428 self._process_pending_invalidations()
1430 with self._cache_lock:
1431 # Command is write command or not allowed
1432 # to be cached.
1433 if not self._cache.is_cachable(
1434 CacheKey(command=args[0], redis_keys=(), redis_args=())
1435 ):
1436 self._current_command_cache_key = None
1437 self._conn.send_command(*args, **kwargs)
1438 return
1440 if kwargs.get("keys") is None:
1441 raise ValueError("Cannot create cache key.")
1443 # Creates cache key.
1444 self._current_command_cache_key = CacheKey(
1445 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1446 )
1448 with self._cache_lock:
1449 # We have to trigger invalidation processing in case if
1450 # it was cached by another connection to avoid
1451 # queueing invalidations in stale connections.
1452 if self._cache.get(self._current_command_cache_key):
1453 entry = self._cache.get(self._current_command_cache_key)
1455 if entry.connection_ref != self._conn:
1456 with self._pool_lock:
1457 while entry.connection_ref.can_read():
1458 entry.connection_ref.read_response(push_request=True)
1460 return
1462 # Set temporary entry value to prevent
1463 # race condition from another connection.
1464 self._cache.set(
1465 CacheEntry(
1466 cache_key=self._current_command_cache_key,
1467 cache_value=self.DUMMY_CACHE_VALUE,
1468 status=CacheEntryStatus.IN_PROGRESS,
1469 connection_ref=self._conn,
1470 )
1471 )
1473 # Send command over socket only if it's allowed
1474 # read-only command that not yet cached.
1475 self._conn.send_command(*args, **kwargs)
1477 def can_read(self, timeout=0):
1478 return self._conn.can_read(timeout)
1480 def read_response(
1481 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1482 ):
1483 with self._cache_lock:
1484 # Check if command response exists in a cache and it's not in progress.
1485 if (
1486 self._current_command_cache_key is not None
1487 and self._cache.get(self._current_command_cache_key) is not None
1488 and self._cache.get(self._current_command_cache_key).status
1489 != CacheEntryStatus.IN_PROGRESS
1490 ):
1491 res = copy.deepcopy(
1492 self._cache.get(self._current_command_cache_key).cache_value
1493 )
1494 self._current_command_cache_key = None
1495 return res
1497 response = self._conn.read_response(
1498 disable_decoding=disable_decoding,
1499 disconnect_on_error=disconnect_on_error,
1500 push_request=push_request,
1501 )
1503 with self._cache_lock:
1504 # Prevent not-allowed command from caching.
1505 if self._current_command_cache_key is None:
1506 return response
1507 # If response is None prevent from caching.
1508 if response is None:
1509 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1510 return response
1512 cache_entry = self._cache.get(self._current_command_cache_key)
1514 # Cache only responses that still valid
1515 # and wasn't invalidated by another connection in meantime.
1516 if cache_entry is not None:
1517 cache_entry.status = CacheEntryStatus.VALID
1518 cache_entry.cache_value = response
1519 self._cache.set(cache_entry)
1521 self._current_command_cache_key = None
1523 return response
1525 def pack_command(self, *args):
1526 return self._conn.pack_command(*args)
1528 def pack_commands(self, commands):
1529 return self._conn.pack_commands(commands)
1531 @property
1532 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1533 return self._conn.handshake_metadata
1535 def set_re_auth_token(self, token: TokenInterface):
1536 self._conn.set_re_auth_token(token)
1538 def re_auth(self):
1539 self._conn.re_auth()
1541 def mark_for_reconnect(self):
1542 self._conn.mark_for_reconnect()
1544 def should_reconnect(self):
1545 return self._conn.should_reconnect()
1547 def reset_should_reconnect(self):
1548 self._conn.reset_should_reconnect()
1550 @property
1551 def host(self) -> str:
1552 return self._conn.host
1554 @host.setter
1555 def host(self, value: str):
1556 self._conn.host = value
1558 @property
1559 def socket_timeout(self) -> Optional[Union[float, int]]:
1560 return self._conn.socket_timeout
1562 @socket_timeout.setter
1563 def socket_timeout(self, value: Optional[Union[float, int]]):
1564 self._conn.socket_timeout = value
1566 @property
1567 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1568 return self._conn.socket_connect_timeout
1570 @socket_connect_timeout.setter
1571 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1572 self._conn.socket_connect_timeout = value
1574 def _get_socket(self) -> Optional[socket.socket]:
1575 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1576 return self._conn._get_socket()
1577 else:
1578 raise NotImplementedError(
1579 "Maintenance notifications are not supported by this connection type"
1580 )
1582 def _get_maint_notifications_connection_instance(
1583 self, connection
1584 ) -> MaintNotificationsAbstractConnection:
1585 """
1586 Validate that connection instance supports maintenance notifications.
1587 With this helper method we ensure that we are working
1588 with the correct connection type.
1589 After twe validate that connection instance supports maintenance notifications
1590 we can safely return the connection instance
1591 as MaintNotificationsAbstractConnection.
1592 """
1593 if not isinstance(connection, MaintNotificationsAbstractConnection):
1594 raise NotImplementedError(
1595 "Maintenance notifications are not supported by this connection type"
1596 )
1597 else:
1598 return connection
1600 @property
1601 def maintenance_state(self) -> MaintenanceState:
1602 con = self._get_maint_notifications_connection_instance(self._conn)
1603 return con.maintenance_state
1605 @maintenance_state.setter
1606 def maintenance_state(self, state: MaintenanceState):
1607 con = self._get_maint_notifications_connection_instance(self._conn)
1608 con.maintenance_state = state
1610 def getpeername(self):
1611 con = self._get_maint_notifications_connection_instance(self._conn)
1612 return con.getpeername()
1614 def get_resolved_ip(self):
1615 con = self._get_maint_notifications_connection_instance(self._conn)
1616 return con.get_resolved_ip()
1618 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1619 con = self._get_maint_notifications_connection_instance(self._conn)
1620 con.update_current_socket_timeout(relaxed_timeout)
1622 def set_tmp_settings(
1623 self,
1624 tmp_host_address: Optional[str] = None,
1625 tmp_relaxed_timeout: Optional[float] = None,
1626 ):
1627 con = self._get_maint_notifications_connection_instance(self._conn)
1628 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1630 def reset_tmp_settings(
1631 self,
1632 reset_host_address: bool = False,
1633 reset_relaxed_timeout: bool = False,
1634 ):
1635 con = self._get_maint_notifications_connection_instance(self._conn)
1636 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1638 def _connect(self):
1639 self._conn._connect()
1641 def _host_error(self):
1642 self._conn._host_error()
1644 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1645 conn.send_command("CLIENT", "TRACKING", "ON")
1646 conn.read_response()
1647 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1649 def _process_pending_invalidations(self):
1650 while self.can_read():
1651 self._conn.read_response(push_request=True)
1653 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1654 with self._cache_lock:
1655 # Flush cache when DB flushed on server-side
1656 if data[1] is None:
1657 self._cache.flush()
1658 else:
1659 self._cache.delete_by_redis_keys(data[1])
1662class SSLConnection(Connection):
1663 """Manages SSL connections to and from the Redis server(s).
1664 This class extends the Connection class, adding SSL functionality, and making
1665 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1666 """ # noqa
1668 def __init__(
1669 self,
1670 ssl_keyfile=None,
1671 ssl_certfile=None,
1672 ssl_cert_reqs="required",
1673 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1674 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1675 ssl_ca_certs=None,
1676 ssl_ca_data=None,
1677 ssl_check_hostname=True,
1678 ssl_ca_path=None,
1679 ssl_password=None,
1680 ssl_validate_ocsp=False,
1681 ssl_validate_ocsp_stapled=False,
1682 ssl_ocsp_context=None,
1683 ssl_ocsp_expected_cert=None,
1684 ssl_min_version=None,
1685 ssl_ciphers=None,
1686 **kwargs,
1687 ):
1688 """Constructor
1690 Args:
1691 ssl_keyfile: Path to an ssl private key. Defaults to None.
1692 ssl_certfile: Path to an ssl certificate. Defaults to None.
1693 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1694 or an ssl.VerifyMode. Defaults to "required".
1695 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1696 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1697 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1698 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1699 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1700 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1701 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1703 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1704 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1705 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1706 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1707 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1708 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1710 Raises:
1711 RedisError
1712 """ # noqa
1713 if not SSL_AVAILABLE:
1714 raise RedisError("Python wasn't built with SSL support")
1716 self.keyfile = ssl_keyfile
1717 self.certfile = ssl_certfile
1718 if ssl_cert_reqs is None:
1719 ssl_cert_reqs = ssl.CERT_NONE
1720 elif isinstance(ssl_cert_reqs, str):
1721 CERT_REQS = { # noqa: N806
1722 "none": ssl.CERT_NONE,
1723 "optional": ssl.CERT_OPTIONAL,
1724 "required": ssl.CERT_REQUIRED,
1725 }
1726 if ssl_cert_reqs not in CERT_REQS:
1727 raise RedisError(
1728 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1729 )
1730 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1731 self.cert_reqs = ssl_cert_reqs
1732 self.ssl_include_verify_flags = ssl_include_verify_flags
1733 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1734 self.ca_certs = ssl_ca_certs
1735 self.ca_data = ssl_ca_data
1736 self.ca_path = ssl_ca_path
1737 self.check_hostname = (
1738 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1739 )
1740 self.certificate_password = ssl_password
1741 self.ssl_validate_ocsp = ssl_validate_ocsp
1742 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1743 self.ssl_ocsp_context = ssl_ocsp_context
1744 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1745 self.ssl_min_version = ssl_min_version
1746 self.ssl_ciphers = ssl_ciphers
1747 super().__init__(**kwargs)
1749 def _connect(self):
1750 """
1751 Wrap the socket with SSL support, handling potential errors.
1752 """
1753 sock = super()._connect()
1754 try:
1755 return self._wrap_socket_with_ssl(sock)
1756 except (OSError, RedisError):
1757 sock.close()
1758 raise
1760 def _wrap_socket_with_ssl(self, sock):
1761 """
1762 Wraps the socket with SSL support.
1764 Args:
1765 sock: The plain socket to wrap with SSL.
1767 Returns:
1768 An SSL wrapped socket.
1769 """
1770 context = ssl.create_default_context()
1771 context.check_hostname = self.check_hostname
1772 context.verify_mode = self.cert_reqs
1773 if self.ssl_include_verify_flags:
1774 for flag in self.ssl_include_verify_flags:
1775 context.verify_flags |= flag
1776 if self.ssl_exclude_verify_flags:
1777 for flag in self.ssl_exclude_verify_flags:
1778 context.verify_flags &= ~flag
1779 if self.certfile or self.keyfile:
1780 context.load_cert_chain(
1781 certfile=self.certfile,
1782 keyfile=self.keyfile,
1783 password=self.certificate_password,
1784 )
1785 if (
1786 self.ca_certs is not None
1787 or self.ca_path is not None
1788 or self.ca_data is not None
1789 ):
1790 context.load_verify_locations(
1791 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1792 )
1793 if self.ssl_min_version is not None:
1794 context.minimum_version = self.ssl_min_version
1795 if self.ssl_ciphers:
1796 context.set_ciphers(self.ssl_ciphers)
1797 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1798 raise RedisError("cryptography is not installed.")
1800 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1801 raise RedisError(
1802 "Either an OCSP staple or pure OCSP connection must be validated "
1803 "- not both."
1804 )
1806 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1808 # validation for the stapled case
1809 if self.ssl_validate_ocsp_stapled:
1810 import OpenSSL
1812 from .ocsp import ocsp_staple_verifier
1814 # if a context is provided use it - otherwise, a basic context
1815 if self.ssl_ocsp_context is None:
1816 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1817 staple_ctx.use_certificate_file(self.certfile)
1818 staple_ctx.use_privatekey_file(self.keyfile)
1819 else:
1820 staple_ctx = self.ssl_ocsp_context
1822 staple_ctx.set_ocsp_client_callback(
1823 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1824 )
1826 # need another socket
1827 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1828 con.request_ocsp()
1829 con.connect((self.host, self.port))
1830 con.do_handshake()
1831 con.shutdown()
1832 return sslsock
1834 # pure ocsp validation
1835 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1836 from .ocsp import OCSPVerifier
1838 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1839 if o.is_valid():
1840 return sslsock
1841 else:
1842 raise ConnectionError("ocsp validation error")
1843 return sslsock
1846class UnixDomainSocketConnection(AbstractConnection):
1847 "Manages UDS communication to and from a Redis server"
1849 def __init__(self, path="", socket_timeout=None, **kwargs):
1850 super().__init__(**kwargs)
1851 self.path = path
1852 self.socket_timeout = socket_timeout
1854 def repr_pieces(self):
1855 pieces = [("path", self.path), ("db", self.db)]
1856 if self.client_name:
1857 pieces.append(("client_name", self.client_name))
1858 return pieces
1860 def _connect(self):
1861 "Create a Unix domain socket connection"
1862 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1863 sock.settimeout(self.socket_connect_timeout)
1864 try:
1865 sock.connect(self.path)
1866 except OSError:
1867 # Prevent ResourceWarnings for unclosed sockets.
1868 try:
1869 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1870 except OSError:
1871 pass
1872 sock.close()
1873 raise
1874 sock.settimeout(self.socket_timeout)
1875 return sock
1877 def _host_error(self):
1878 return self.path
1881FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1884def to_bool(value):
1885 if value is None or value == "":
1886 return None
1887 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1888 return False
1889 return bool(value)
1892def parse_ssl_verify_flags(value):
1893 # flags are passed in as a string representation of a list,
1894 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1895 verify_flags_str = value.replace("[", "").replace("]", "")
1897 verify_flags = []
1898 for flag in verify_flags_str.split(","):
1899 flag = flag.strip()
1900 if not hasattr(VerifyFlags, flag):
1901 raise ValueError(f"Invalid ssl verify flag: {flag}")
1902 verify_flags.append(getattr(VerifyFlags, flag))
1903 return verify_flags
1906URL_QUERY_ARGUMENT_PARSERS = {
1907 "db": int,
1908 "socket_timeout": float,
1909 "socket_connect_timeout": float,
1910 "socket_keepalive": to_bool,
1911 "retry_on_timeout": to_bool,
1912 "retry_on_error": list,
1913 "max_connections": int,
1914 "health_check_interval": int,
1915 "ssl_check_hostname": to_bool,
1916 "ssl_include_verify_flags": parse_ssl_verify_flags,
1917 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1918 "timeout": float,
1919}
1922def parse_url(url):
1923 if not (
1924 url.startswith("redis://")
1925 or url.startswith("rediss://")
1926 or url.startswith("unix://")
1927 ):
1928 raise ValueError(
1929 "Redis URL must specify one of the following "
1930 "schemes (redis://, rediss://, unix://)"
1931 )
1933 url = urlparse(url)
1934 kwargs = {}
1936 for name, value in parse_qs(url.query).items():
1937 if value and len(value) > 0:
1938 value = unquote(value[0])
1939 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1940 if parser:
1941 try:
1942 kwargs[name] = parser(value)
1943 except (TypeError, ValueError):
1944 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1945 else:
1946 kwargs[name] = value
1948 if url.username:
1949 kwargs["username"] = unquote(url.username)
1950 if url.password:
1951 kwargs["password"] = unquote(url.password)
1953 # We only support redis://, rediss:// and unix:// schemes.
1954 if url.scheme == "unix":
1955 if url.path:
1956 kwargs["path"] = unquote(url.path)
1957 kwargs["connection_class"] = UnixDomainSocketConnection
1959 else: # implied: url.scheme in ("redis", "rediss"):
1960 if url.hostname:
1961 kwargs["host"] = unquote(url.hostname)
1962 if url.port:
1963 kwargs["port"] = int(url.port)
1965 # If there's a path argument, use it as the db argument if a
1966 # querystring value wasn't specified
1967 if url.path and "db" not in kwargs:
1968 try:
1969 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1970 except (AttributeError, ValueError):
1971 pass
1973 if url.scheme == "rediss":
1974 kwargs["connection_class"] = SSLConnection
1976 return kwargs
1979_CP = TypeVar("_CP", bound="ConnectionPool")
1982class ConnectionPoolInterface(ABC):
1983 @abstractmethod
1984 def get_protocol(self):
1985 pass
1987 @abstractmethod
1988 def reset(self):
1989 pass
1991 @abstractmethod
1992 @deprecated_args(
1993 args_to_warn=["*"],
1994 reason="Use get_connection() without args instead",
1995 version="5.3.0",
1996 )
1997 def get_connection(
1998 self, command_name: Optional[str], *keys, **options
1999 ) -> ConnectionInterface:
2000 pass
2002 @abstractmethod
2003 def get_encoder(self):
2004 pass
2006 @abstractmethod
2007 def release(self, connection: ConnectionInterface):
2008 pass
2010 @abstractmethod
2011 def disconnect(self, inuse_connections: bool = True):
2012 pass
2014 @abstractmethod
2015 def close(self):
2016 pass
2018 @abstractmethod
2019 def set_retry(self, retry: Retry):
2020 pass
2022 @abstractmethod
2023 def re_auth_callback(self, token: TokenInterface):
2024 pass
2027class MaintNotificationsAbstractConnectionPool:
2028 """
2029 Abstract class for handling maintenance notifications logic.
2030 This class is mixed into the ConnectionPool classes.
2032 This class is not intended to be used directly!
2034 All logic related to maintenance notifications and
2035 connection pool handling is encapsulated in this class.
2036 """
2038 def __init__(
2039 self,
2040 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2041 **kwargs,
2042 ):
2043 # Initialize maintenance notifications
2044 is_protocol_supported = kwargs.get("protocol") in [3, "3"]
2045 if maint_notifications_config is None and is_protocol_supported:
2046 maint_notifications_config = MaintNotificationsConfig()
2048 if maint_notifications_config and maint_notifications_config.enabled:
2049 if not is_protocol_supported:
2050 raise RedisError(
2051 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2052 )
2054 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2055 self, maint_notifications_config
2056 )
2058 self._update_connection_kwargs_for_maint_notifications(
2059 self._maint_notifications_pool_handler
2060 )
2061 else:
2062 self._maint_notifications_pool_handler = None
2064 @property
2065 @abstractmethod
2066 def connection_kwargs(self) -> Dict[str, Any]:
2067 pass
2069 @connection_kwargs.setter
2070 @abstractmethod
2071 def connection_kwargs(self, value: Dict[str, Any]):
2072 pass
2074 @abstractmethod
2075 def _get_pool_lock(self) -> threading.RLock:
2076 pass
2078 @abstractmethod
2079 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2080 pass
2082 @abstractmethod
2083 def _get_in_use_connections(
2084 self,
2085 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2086 pass
2088 def maint_notifications_enabled(self):
2089 """
2090 Returns:
2091 True if the maintenance notifications are enabled, False otherwise.
2092 The maintenance notifications config is stored in the pool handler.
2093 If the pool handler is not set, the maintenance notifications are not enabled.
2094 """
2095 maint_notifications_config = (
2096 self._maint_notifications_pool_handler.config
2097 if self._maint_notifications_pool_handler
2098 else None
2099 )
2101 return maint_notifications_config and maint_notifications_config.enabled
2103 def update_maint_notifications_config(
2104 self, maint_notifications_config: MaintNotificationsConfig
2105 ):
2106 """
2107 Updates the maintenance notifications configuration.
2108 This method should be called only if the pool was created
2109 without enabling the maintenance notifications and
2110 in a later point in time maintenance notifications
2111 are requested to be enabled.
2112 """
2113 if (
2114 self.maint_notifications_enabled()
2115 and not maint_notifications_config.enabled
2116 ):
2117 raise ValueError(
2118 "Cannot disable maintenance notifications after enabling them"
2119 )
2120 # first update pool settings
2121 if not self._maint_notifications_pool_handler:
2122 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2123 self, maint_notifications_config
2124 )
2125 else:
2126 self._maint_notifications_pool_handler.config = maint_notifications_config
2128 # then update connection kwargs and existing connections
2129 self._update_connection_kwargs_for_maint_notifications(
2130 self._maint_notifications_pool_handler
2131 )
2132 self._update_maint_notifications_configs_for_connections(
2133 self._maint_notifications_pool_handler
2134 )
2136 def _update_connection_kwargs_for_maint_notifications(
2137 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2138 ):
2139 """
2140 Update the connection kwargs for all future connections.
2141 """
2142 if not self.maint_notifications_enabled():
2143 return
2145 self.connection_kwargs.update(
2146 {
2147 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2148 "maint_notifications_config": maint_notifications_pool_handler.config,
2149 }
2150 )
2152 # Store original connection parameters for maintenance notifications.
2153 if self.connection_kwargs.get("orig_host_address", None) is None:
2154 # If orig_host_address is None it means we haven't
2155 # configured the original values yet
2156 self.connection_kwargs.update(
2157 {
2158 "orig_host_address": self.connection_kwargs.get("host"),
2159 "orig_socket_timeout": self.connection_kwargs.get(
2160 "socket_timeout", None
2161 ),
2162 "orig_socket_connect_timeout": self.connection_kwargs.get(
2163 "socket_connect_timeout", None
2164 ),
2165 }
2166 )
2168 def _update_maint_notifications_configs_for_connections(
2169 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2170 ):
2171 """Update the maintenance notifications config for all connections in the pool."""
2172 with self._get_pool_lock():
2173 for conn in self._get_free_connections():
2174 conn.set_maint_notifications_pool_handler_for_connection(
2175 maint_notifications_pool_handler
2176 )
2177 conn.maint_notifications_config = (
2178 maint_notifications_pool_handler.config
2179 )
2180 conn.disconnect()
2181 for conn in self._get_in_use_connections():
2182 conn.set_maint_notifications_pool_handler_for_connection(
2183 maint_notifications_pool_handler
2184 )
2185 conn.maint_notifications_config = (
2186 maint_notifications_pool_handler.config
2187 )
2188 conn.mark_for_reconnect()
2190 def _should_update_connection(
2191 self,
2192 conn: "MaintNotificationsAbstractConnection",
2193 matching_pattern: Literal[
2194 "connected_address", "configured_address", "notification_hash"
2195 ] = "connected_address",
2196 matching_address: Optional[str] = None,
2197 matching_notification_hash: Optional[int] = None,
2198 ) -> bool:
2199 """
2200 Check if the connection should be updated based on the matching criteria.
2201 """
2202 if matching_pattern == "connected_address":
2203 if matching_address and conn.getpeername() != matching_address:
2204 return False
2205 elif matching_pattern == "configured_address":
2206 if matching_address and conn.host != matching_address:
2207 return False
2208 elif matching_pattern == "notification_hash":
2209 if (
2210 matching_notification_hash
2211 and conn.maintenance_notification_hash != matching_notification_hash
2212 ):
2213 return False
2214 return True
2216 def update_connection_settings(
2217 self,
2218 conn: "MaintNotificationsAbstractConnection",
2219 state: Optional["MaintenanceState"] = None,
2220 maintenance_notification_hash: Optional[int] = None,
2221 host_address: Optional[str] = None,
2222 relaxed_timeout: Optional[float] = None,
2223 update_notification_hash: bool = False,
2224 reset_host_address: bool = False,
2225 reset_relaxed_timeout: bool = False,
2226 ):
2227 """
2228 Update the settings for a single connection.
2229 """
2230 if state:
2231 conn.maintenance_state = state
2233 if update_notification_hash:
2234 # update the notification hash only if requested
2235 conn.maintenance_notification_hash = maintenance_notification_hash
2237 if host_address is not None:
2238 conn.set_tmp_settings(tmp_host_address=host_address)
2240 if relaxed_timeout is not None:
2241 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2243 if reset_relaxed_timeout or reset_host_address:
2244 conn.reset_tmp_settings(
2245 reset_host_address=reset_host_address,
2246 reset_relaxed_timeout=reset_relaxed_timeout,
2247 )
2249 conn.update_current_socket_timeout(relaxed_timeout)
2251 def update_connections_settings(
2252 self,
2253 state: Optional["MaintenanceState"] = None,
2254 maintenance_notification_hash: Optional[int] = None,
2255 host_address: Optional[str] = None,
2256 relaxed_timeout: Optional[float] = None,
2257 matching_address: Optional[str] = None,
2258 matching_notification_hash: Optional[int] = None,
2259 matching_pattern: Literal[
2260 "connected_address", "configured_address", "notification_hash"
2261 ] = "connected_address",
2262 update_notification_hash: bool = False,
2263 reset_host_address: bool = False,
2264 reset_relaxed_timeout: bool = False,
2265 include_free_connections: bool = True,
2266 ):
2267 """
2268 Update the settings for all matching connections in the pool.
2270 This method does not create new connections.
2271 This method does not affect the connection kwargs.
2273 :param state: The maintenance state to set for the connection.
2274 :param maintenance_notification_hash: The hash of the maintenance notification
2275 to set for the connection.
2276 :param host_address: The host address to set for the connection.
2277 :param relaxed_timeout: The relaxed timeout to set for the connection.
2278 :param matching_address: The address to match for the connection.
2279 :param matching_notification_hash: The notification hash to match for the connection.
2280 :param matching_pattern: The pattern to match for the connection.
2281 :param update_notification_hash: Whether to update the notification hash for the connection.
2282 :param reset_host_address: Whether to reset the host address to the original address.
2283 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2284 :param include_free_connections: Whether to include free/available connections.
2285 """
2286 with self._get_pool_lock():
2287 for conn in self._get_in_use_connections():
2288 if self._should_update_connection(
2289 conn,
2290 matching_pattern,
2291 matching_address,
2292 matching_notification_hash,
2293 ):
2294 self.update_connection_settings(
2295 conn,
2296 state=state,
2297 maintenance_notification_hash=maintenance_notification_hash,
2298 host_address=host_address,
2299 relaxed_timeout=relaxed_timeout,
2300 update_notification_hash=update_notification_hash,
2301 reset_host_address=reset_host_address,
2302 reset_relaxed_timeout=reset_relaxed_timeout,
2303 )
2305 if include_free_connections:
2306 for conn in self._get_free_connections():
2307 if self._should_update_connection(
2308 conn,
2309 matching_pattern,
2310 matching_address,
2311 matching_notification_hash,
2312 ):
2313 self.update_connection_settings(
2314 conn,
2315 state=state,
2316 maintenance_notification_hash=maintenance_notification_hash,
2317 host_address=host_address,
2318 relaxed_timeout=relaxed_timeout,
2319 update_notification_hash=update_notification_hash,
2320 reset_host_address=reset_host_address,
2321 reset_relaxed_timeout=reset_relaxed_timeout,
2322 )
2324 def update_connection_kwargs(
2325 self,
2326 **kwargs,
2327 ):
2328 """
2329 Update the connection kwargs for all future connections.
2331 This method updates the connection kwargs for all future connections created by the pool.
2332 Existing connections are not affected.
2333 """
2334 self.connection_kwargs.update(kwargs)
2336 def update_active_connections_for_reconnect(
2337 self,
2338 moving_address_src: Optional[str] = None,
2339 ):
2340 """
2341 Mark all active connections for reconnect.
2342 This is used when a cluster node is migrated to a different address.
2344 :param moving_address_src: The address of the node that is being moved.
2345 """
2346 with self._get_pool_lock():
2347 for conn in self._get_in_use_connections():
2348 if self._should_update_connection(
2349 conn, "connected_address", moving_address_src
2350 ):
2351 conn.mark_for_reconnect()
2353 def disconnect_free_connections(
2354 self,
2355 moving_address_src: Optional[str] = None,
2356 ):
2357 """
2358 Disconnect all free/available connections.
2359 This is used when a cluster node is migrated to a different address.
2361 :param moving_address_src: The address of the node that is being moved.
2362 """
2363 with self._get_pool_lock():
2364 for conn in self._get_free_connections():
2365 if self._should_update_connection(
2366 conn, "connected_address", moving_address_src
2367 ):
2368 conn.disconnect()
2371class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2372 """
2373 Create a connection pool. ``If max_connections`` is set, then this
2374 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2375 limit is reached.
2377 By default, TCP connections are created unless ``connection_class``
2378 is specified. Use class:`.UnixDomainSocketConnection` for
2379 unix sockets.
2380 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2382 If ``maint_notifications_config`` is provided, the connection pool will support
2383 maintenance notifications.
2384 Maintenance notifications are supported only with RESP3.
2385 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2386 the maintenance notifications will be enabled by default.
2388 Any additional keyword arguments are passed to the constructor of
2389 ``connection_class``.
2390 """
2392 @classmethod
2393 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2394 """
2395 Return a connection pool configured from the given URL.
2397 For example::
2399 redis://[[username]:[password]]@localhost:6379/0
2400 rediss://[[username]:[password]]@localhost:6379/0
2401 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2403 Three URL schemes are supported:
2405 - `redis://` creates a TCP socket connection. See more at:
2406 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2407 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2408 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2409 - ``unix://``: creates a Unix Domain Socket connection.
2411 The username, password, hostname, path and all querystring values
2412 are passed through urllib.parse.unquote in order to replace any
2413 percent-encoded values with their corresponding characters.
2415 There are several ways to specify a database number. The first value
2416 found will be used:
2418 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2419 2. If using the redis:// or rediss:// schemes, the path argument
2420 of the url, e.g. redis://localhost/0
2421 3. A ``db`` keyword argument to this function.
2423 If none of these options are specified, the default db=0 is used.
2425 All querystring options are cast to their appropriate Python types.
2426 Boolean arguments can be specified with string values "True"/"False"
2427 or "Yes"/"No". Values that cannot be properly cast cause a
2428 ``ValueError`` to be raised. Once parsed, the querystring arguments
2429 and keyword arguments are passed to the ``ConnectionPool``'s
2430 class initializer. In the case of conflicting arguments, querystring
2431 arguments always win.
2432 """
2433 url_options = parse_url(url)
2435 if "connection_class" in kwargs:
2436 url_options["connection_class"] = kwargs["connection_class"]
2438 kwargs.update(url_options)
2439 return cls(**kwargs)
2441 def __init__(
2442 self,
2443 connection_class=Connection,
2444 max_connections: Optional[int] = None,
2445 cache_factory: Optional[CacheFactoryInterface] = None,
2446 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2447 **connection_kwargs,
2448 ):
2449 max_connections = max_connections or 2**31
2450 if not isinstance(max_connections, int) or max_connections < 0:
2451 raise ValueError('"max_connections" must be a positive integer')
2453 self.connection_class = connection_class
2454 self._connection_kwargs = connection_kwargs
2455 self.max_connections = max_connections
2456 self.cache = None
2457 self._cache_factory = cache_factory
2459 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2460 if self._connection_kwargs.get("protocol") not in [3, "3"]:
2461 raise RedisError("Client caching is only supported with RESP version 3")
2463 cache = self._connection_kwargs.get("cache")
2465 if cache is not None:
2466 if not isinstance(cache, CacheInterface):
2467 raise ValueError("Cache must implement CacheInterface")
2469 self.cache = cache
2470 else:
2471 if self._cache_factory is not None:
2472 self.cache = self._cache_factory.get_cache()
2473 else:
2474 self.cache = CacheFactory(
2475 self._connection_kwargs.get("cache_config")
2476 ).get_cache()
2478 connection_kwargs.pop("cache", None)
2479 connection_kwargs.pop("cache_config", None)
2481 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2482 if self._event_dispatcher is None:
2483 self._event_dispatcher = EventDispatcher()
2485 # a lock to protect the critical section in _checkpid().
2486 # this lock is acquired when the process id changes, such as
2487 # after a fork. during this time, multiple threads in the child
2488 # process could attempt to acquire this lock. the first thread
2489 # to acquire the lock will reset the data structures and lock
2490 # object of this pool. subsequent threads acquiring this lock
2491 # will notice the first thread already did the work and simply
2492 # release the lock.
2494 self._fork_lock = threading.RLock()
2495 self._lock = threading.RLock()
2497 MaintNotificationsAbstractConnectionPool.__init__(
2498 self,
2499 maint_notifications_config=maint_notifications_config,
2500 **connection_kwargs,
2501 )
2503 self.reset()
2505 def __repr__(self) -> str:
2506 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
2507 return (
2508 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2509 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2510 f"({conn_kwargs})>)>"
2511 )
2513 @property
2514 def connection_kwargs(self) -> Dict[str, Any]:
2515 return self._connection_kwargs
2517 @connection_kwargs.setter
2518 def connection_kwargs(self, value: Dict[str, Any]):
2519 self._connection_kwargs = value
2521 def get_protocol(self):
2522 """
2523 Returns:
2524 The RESP protocol version, or ``None`` if the protocol is not specified,
2525 in which case the server default will be used.
2526 """
2527 return self.connection_kwargs.get("protocol", None)
2529 def reset(self) -> None:
2530 self._created_connections = 0
2531 self._available_connections = []
2532 self._in_use_connections = set()
2534 # this must be the last operation in this method. while reset() is
2535 # called when holding _fork_lock, other threads in this process
2536 # can call _checkpid() which compares self.pid and os.getpid() without
2537 # holding any lock (for performance reasons). keeping this assignment
2538 # as the last operation ensures that those other threads will also
2539 # notice a pid difference and block waiting for the first thread to
2540 # release _fork_lock. when each of these threads eventually acquire
2541 # _fork_lock, they will notice that another thread already called
2542 # reset() and they will immediately release _fork_lock and continue on.
2543 self.pid = os.getpid()
2545 def _checkpid(self) -> None:
2546 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
2547 # systems. this is called by all ConnectionPool methods that
2548 # manipulate the pool's state such as get_connection() and release().
2549 #
2550 # _checkpid() determines whether the process has forked by comparing
2551 # the current process id to the process id saved on the ConnectionPool
2552 # instance. if these values are the same, _checkpid() simply returns.
2553 #
2554 # when the process ids differ, _checkpid() assumes that the process
2555 # has forked and that we're now running in the child process. the child
2556 # process cannot use the parent's file descriptors (e.g., sockets).
2557 # therefore, when _checkpid() sees the process id change, it calls
2558 # reset() in order to reinitialize the child's ConnectionPool. this
2559 # will cause the child to make all new connection objects.
2560 #
2561 # _checkpid() is protected by self._fork_lock to ensure that multiple
2562 # threads in the child process do not call reset() multiple times.
2563 #
2564 # there is an extremely small chance this could fail in the following
2565 # scenario:
2566 # 1. process A calls _checkpid() for the first time and acquires
2567 # self._fork_lock.
2568 # 2. while holding self._fork_lock, process A forks (the fork()
2569 # could happen in a different thread owned by process A)
2570 # 3. process B (the forked child process) inherits the
2571 # ConnectionPool's state from the parent. that state includes
2572 # a locked _fork_lock. process B will not be notified when
2573 # process A releases the _fork_lock and will thus never be
2574 # able to acquire the _fork_lock.
2575 #
2576 # to mitigate this possible deadlock, _checkpid() will only wait 5
2577 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
2578 # that time it is assumed that the child is deadlocked and a
2579 # redis.ChildDeadlockedError error is raised.
2580 if self.pid != os.getpid():
2581 acquired = self._fork_lock.acquire(timeout=5)
2582 if not acquired:
2583 raise ChildDeadlockedError
2584 # reset() the instance for the new process if another thread
2585 # hasn't already done so
2586 try:
2587 if self.pid != os.getpid():
2588 self.reset()
2589 finally:
2590 self._fork_lock.release()
2592 @deprecated_args(
2593 args_to_warn=["*"],
2594 reason="Use get_connection() without args instead",
2595 version="5.3.0",
2596 )
2597 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
2598 "Get a connection from the pool"
2600 self._checkpid()
2601 with self._lock:
2602 try:
2603 connection = self._available_connections.pop()
2604 except IndexError:
2605 connection = self.make_connection()
2606 self._in_use_connections.add(connection)
2608 try:
2609 # ensure this connection is connected to Redis
2610 connection.connect()
2611 # connections that the pool provides should be ready to send
2612 # a command. if not, the connection was either returned to the
2613 # pool before all data has been read or the socket has been
2614 # closed. either way, reconnect and verify everything is good.
2615 try:
2616 if (
2617 connection.can_read()
2618 and self.cache is None
2619 and not self.maint_notifications_enabled()
2620 ):
2621 raise ConnectionError("Connection has data")
2622 except (ConnectionError, TimeoutError, OSError):
2623 connection.disconnect()
2624 connection.connect()
2625 if connection.can_read():
2626 raise ConnectionError("Connection not ready")
2627 except BaseException:
2628 # release the connection back to the pool so that we don't
2629 # leak it
2630 self.release(connection)
2631 raise
2632 return connection
2634 def get_encoder(self) -> Encoder:
2635 "Return an encoder based on encoding settings"
2636 kwargs = self.connection_kwargs
2637 return Encoder(
2638 encoding=kwargs.get("encoding", "utf-8"),
2639 encoding_errors=kwargs.get("encoding_errors", "strict"),
2640 decode_responses=kwargs.get("decode_responses", False),
2641 )
2643 def make_connection(self) -> "ConnectionInterface":
2644 "Create a new connection"
2645 if self._created_connections >= self.max_connections:
2646 raise MaxConnectionsError("Too many connections")
2647 self._created_connections += 1
2649 kwargs = dict(self.connection_kwargs)
2651 if self.cache is not None:
2652 return CacheProxyConnection(
2653 self.connection_class(**kwargs), self.cache, self._lock
2654 )
2655 return self.connection_class(**kwargs)
2657 def release(self, connection: "Connection") -> None:
2658 "Releases the connection back to the pool"
2659 self._checkpid()
2660 with self._lock:
2661 try:
2662 self._in_use_connections.remove(connection)
2663 except KeyError:
2664 # Gracefully fail when a connection is returned to this pool
2665 # that the pool doesn't actually own
2666 return
2668 if self.owns_connection(connection):
2669 if connection.should_reconnect():
2670 connection.disconnect()
2671 self._available_connections.append(connection)
2672 self._event_dispatcher.dispatch(
2673 AfterConnectionReleasedEvent(connection)
2674 )
2675 else:
2676 # Pool doesn't own this connection, do not add it back
2677 # to the pool.
2678 # The created connections count should not be changed,
2679 # because the connection was not created by the pool.
2680 connection.disconnect()
2681 return
2683 def owns_connection(self, connection: "Connection") -> int:
2684 return connection.pid == self.pid
2686 def disconnect(self, inuse_connections: bool = True) -> None:
2687 """
2688 Disconnects connections in the pool
2690 If ``inuse_connections`` is True, disconnect connections that are
2691 currently in use, potentially by other threads. Otherwise only disconnect
2692 connections that are idle in the pool.
2693 """
2694 self._checkpid()
2695 with self._lock:
2696 if inuse_connections:
2697 connections = chain(
2698 self._available_connections, self._in_use_connections
2699 )
2700 else:
2701 connections = self._available_connections
2703 for connection in connections:
2704 connection.disconnect()
2706 def close(self) -> None:
2707 """Close the pool, disconnecting all connections"""
2708 self.disconnect()
2710 def set_retry(self, retry: Retry) -> None:
2711 self.connection_kwargs.update({"retry": retry})
2712 for conn in self._available_connections:
2713 conn.retry = retry
2714 for conn in self._in_use_connections:
2715 conn.retry = retry
2717 def re_auth_callback(self, token: TokenInterface):
2718 with self._lock:
2719 for conn in self._available_connections:
2720 conn.retry.call_with_retry(
2721 lambda: conn.send_command(
2722 "AUTH", token.try_get("oid"), token.get_value()
2723 ),
2724 lambda error: self._mock(error),
2725 )
2726 conn.retry.call_with_retry(
2727 lambda: conn.read_response(), lambda error: self._mock(error)
2728 )
2729 for conn in self._in_use_connections:
2730 conn.set_re_auth_token(token)
2732 def _get_pool_lock(self):
2733 return self._lock
2735 def _get_free_connections(self):
2736 with self._lock:
2737 return self._available_connections
2739 def _get_in_use_connections(self):
2740 with self._lock:
2741 return self._in_use_connections
2743 async def _mock(self, error: RedisError):
2744 """
2745 Dummy functions, needs to be passed as error callback to retry object.
2746 :param error:
2747 :return:
2748 """
2749 pass
2752class BlockingConnectionPool(ConnectionPool):
2753 """
2754 Thread-safe blocking connection pool::
2756 >>> from redis.client import Redis
2757 >>> client = Redis(connection_pool=BlockingConnectionPool())
2759 It performs the same function as the default
2760 :py:class:`~redis.ConnectionPool` implementation, in that,
2761 it maintains a pool of reusable connections that can be shared by
2762 multiple redis clients (safely across threads if required).
2764 The difference is that, in the event that a client tries to get a
2765 connection from the pool when all of connections are in use, rather than
2766 raising a :py:class:`~redis.ConnectionError` (as the default
2767 :py:class:`~redis.ConnectionPool` implementation does), it
2768 makes the client wait ("blocks") for a specified number of seconds until
2769 a connection becomes available.
2771 Use ``max_connections`` to increase / decrease the pool size::
2773 >>> pool = BlockingConnectionPool(max_connections=10)
2775 Use ``timeout`` to tell it either how many seconds to wait for a connection
2776 to become available, or to block forever:
2778 >>> # Block forever.
2779 >>> pool = BlockingConnectionPool(timeout=None)
2781 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
2782 >>> # not available.
2783 >>> pool = BlockingConnectionPool(timeout=5)
2784 """
2786 def __init__(
2787 self,
2788 max_connections=50,
2789 timeout=20,
2790 connection_class=Connection,
2791 queue_class=LifoQueue,
2792 **connection_kwargs,
2793 ):
2794 self.queue_class = queue_class
2795 self.timeout = timeout
2796 self._in_maintenance = False
2797 self._locked = False
2798 super().__init__(
2799 connection_class=connection_class,
2800 max_connections=max_connections,
2801 **connection_kwargs,
2802 )
2804 def reset(self):
2805 # Create and fill up a thread safe queue with ``None`` values.
2806 try:
2807 if self._in_maintenance:
2808 self._lock.acquire()
2809 self._locked = True
2810 self.pool = self.queue_class(self.max_connections)
2811 while True:
2812 try:
2813 self.pool.put_nowait(None)
2814 except Full:
2815 break
2817 # Keep a list of actual connection instances so that we can
2818 # disconnect them later.
2819 self._connections = []
2820 finally:
2821 if self._locked:
2822 try:
2823 self._lock.release()
2824 except Exception:
2825 pass
2826 self._locked = False
2828 # this must be the last operation in this method. while reset() is
2829 # called when holding _fork_lock, other threads in this process
2830 # can call _checkpid() which compares self.pid and os.getpid() without
2831 # holding any lock (for performance reasons). keeping this assignment
2832 # as the last operation ensures that those other threads will also
2833 # notice a pid difference and block waiting for the first thread to
2834 # release _fork_lock. when each of these threads eventually acquire
2835 # _fork_lock, they will notice that another thread already called
2836 # reset() and they will immediately release _fork_lock and continue on.
2837 self.pid = os.getpid()
2839 def make_connection(self):
2840 "Make a fresh connection."
2841 try:
2842 if self._in_maintenance:
2843 self._lock.acquire()
2844 self._locked = True
2846 if self.cache is not None:
2847 connection = CacheProxyConnection(
2848 self.connection_class(**self.connection_kwargs),
2849 self.cache,
2850 self._lock,
2851 )
2852 else:
2853 connection = self.connection_class(**self.connection_kwargs)
2854 self._connections.append(connection)
2855 return connection
2856 finally:
2857 if self._locked:
2858 try:
2859 self._lock.release()
2860 except Exception:
2861 pass
2862 self._locked = False
2864 @deprecated_args(
2865 args_to_warn=["*"],
2866 reason="Use get_connection() without args instead",
2867 version="5.3.0",
2868 )
2869 def get_connection(self, command_name=None, *keys, **options):
2870 """
2871 Get a connection, blocking for ``self.timeout`` until a connection
2872 is available from the pool.
2874 If the connection returned is ``None`` then creates a new connection.
2875 Because we use a last-in first-out queue, the existing connections
2876 (having been returned to the pool after the initial ``None`` values
2877 were added) will be returned before ``None`` values. This means we only
2878 create new connections when we need to, i.e.: the actual number of
2879 connections will only increase in response to demand.
2880 """
2881 # Make sure we haven't changed process.
2882 self._checkpid()
2884 # Try and get a connection from the pool. If one isn't available within
2885 # self.timeout then raise a ``ConnectionError``.
2886 connection = None
2887 try:
2888 if self._in_maintenance:
2889 self._lock.acquire()
2890 self._locked = True
2891 try:
2892 connection = self.pool.get(block=True, timeout=self.timeout)
2893 except Empty:
2894 # Note that this is not caught by the redis client and will be
2895 # raised unless handled by application code. If you want never to
2896 raise ConnectionError("No connection available.")
2898 # If the ``connection`` is actually ``None`` then that's a cue to make
2899 # a new connection to add to the pool.
2900 if connection is None:
2901 connection = self.make_connection()
2902 finally:
2903 if self._locked:
2904 try:
2905 self._lock.release()
2906 except Exception:
2907 pass
2908 self._locked = False
2910 try:
2911 # ensure this connection is connected to Redis
2912 connection.connect()
2913 # connections that the pool provides should be ready to send
2914 # a command. if not, the connection was either returned to the
2915 # pool before all data has been read or the socket has been
2916 # closed. either way, reconnect and verify everything is good.
2917 try:
2918 if connection.can_read():
2919 raise ConnectionError("Connection has data")
2920 except (ConnectionError, TimeoutError, OSError):
2921 connection.disconnect()
2922 connection.connect()
2923 if connection.can_read():
2924 raise ConnectionError("Connection not ready")
2925 except BaseException:
2926 # release the connection back to the pool so that we don't leak it
2927 self.release(connection)
2928 raise
2930 return connection
2932 def release(self, connection):
2933 "Releases the connection back to the pool."
2934 # Make sure we haven't changed process.
2935 self._checkpid()
2937 try:
2938 if self._in_maintenance:
2939 self._lock.acquire()
2940 self._locked = True
2941 if not self.owns_connection(connection):
2942 # pool doesn't own this connection. do not add it back
2943 # to the pool. instead add a None value which is a placeholder
2944 # that will cause the pool to recreate the connection if
2945 # its needed.
2946 connection.disconnect()
2947 self.pool.put_nowait(None)
2948 return
2949 if connection.should_reconnect():
2950 connection.disconnect()
2951 # Put the connection back into the pool.
2952 try:
2953 self.pool.put_nowait(connection)
2954 except Full:
2955 # perhaps the pool has been reset() after a fork? regardless,
2956 # we don't want this connection
2957 pass
2958 finally:
2959 if self._locked:
2960 try:
2961 self._lock.release()
2962 except Exception:
2963 pass
2964 self._locked = False
2966 def disconnect(self, inuse_connections: bool = True):
2967 "Disconnects either all connections in the pool or just the free connections."
2968 self._checkpid()
2969 try:
2970 if self._in_maintenance:
2971 self._lock.acquire()
2972 self._locked = True
2973 if inuse_connections:
2974 connections = self._connections
2975 else:
2976 connections = self._get_free_connections()
2977 for connection in connections:
2978 connection.disconnect()
2979 finally:
2980 if self._locked:
2981 try:
2982 self._lock.release()
2983 except Exception:
2984 pass
2985 self._locked = False
2987 def _get_free_connections(self):
2988 with self._lock:
2989 return {conn for conn in self.pool.queue if conn}
2991 def _get_in_use_connections(self):
2992 with self._lock:
2993 # free connections
2994 connections_in_queue = {conn for conn in self.pool.queue if conn}
2995 # in self._connections we keep all created connections
2996 # so the ones that are not in the queue are the in use ones
2997 return {
2998 conn for conn in self._connections if conn not in connections_in_queue
2999 }
3001 def set_in_maintenance(self, in_maintenance: bool):
3002 """
3003 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
3005 This is used to prevent new connections from being created while we are in maintenance mode.
3006 The pool will be in maintenance mode only when we are processing a MOVING notification.
3007 """
3008 self._in_maintenance = in_maintenance