Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32)
34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
35from .auth.token import TokenInterface
36from .backoff import NoBackoff
37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
38from .event import AfterConnectionReleasedEvent, EventDispatcher
39from .exceptions import (
40 AuthenticationError,
41 AuthenticationWrongNumberOfArgsError,
42 ChildDeadlockedError,
43 ConnectionError,
44 DataError,
45 MaxConnectionsError,
46 RedisError,
47 ResponseError,
48 TimeoutError,
49)
50from .maint_notifications import (
51 MaintenanceState,
52 MaintNotificationsConfig,
53 MaintNotificationsConnectionHandler,
54 MaintNotificationsPoolHandler,
55)
56from .retry import Retry
57from .utils import (
58 CRYPTOGRAPHY_AVAILABLE,
59 HIREDIS_AVAILABLE,
60 SSL_AVAILABLE,
61 compare_versions,
62 deprecated_args,
63 ensure_string,
64 format_error_message,
65 get_lib_version,
66 str_if_bytes,
67)
69if SSL_AVAILABLE:
70 import ssl
71 from ssl import VerifyFlags
72else:
73 ssl = None
74 VerifyFlags = None
76if HIREDIS_AVAILABLE:
77 import hiredis
79SYM_STAR = b"*"
80SYM_DOLLAR = b"$"
81SYM_CRLF = b"\r\n"
82SYM_EMPTY = b""
84DEFAULT_RESP_VERSION = 2
86SENTINEL = object()
88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
89if HIREDIS_AVAILABLE:
90 DefaultParser = _HiredisParser
91else:
92 DefaultParser = _RESP2Parser
95class HiredisRespSerializer:
96 def pack(self, *args: List):
97 """Pack a series of arguments into the Redis protocol"""
98 output = []
100 if isinstance(args[0], str):
101 args = tuple(args[0].encode().split()) + args[1:]
102 elif b" " in args[0]:
103 args = tuple(args[0].split()) + args[1:]
104 try:
105 output.append(hiredis.pack_command(args))
106 except TypeError:
107 _, value, traceback = sys.exc_info()
108 raise DataError(value).with_traceback(traceback)
110 return output
113class PythonRespSerializer:
114 def __init__(self, buffer_cutoff, encode) -> None:
115 self._buffer_cutoff = buffer_cutoff
116 self.encode = encode
118 def pack(self, *args):
119 """Pack a series of arguments into the Redis protocol"""
120 output = []
121 # the client might have included 1 or more literal arguments in
122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
123 # arguments to be sent separately, so split the first argument
124 # manually. These arguments should be bytestrings so that they are
125 # not encoded.
126 if isinstance(args[0], str):
127 args = tuple(args[0].encode().split()) + args[1:]
128 elif b" " in args[0]:
129 args = tuple(args[0].split()) + args[1:]
131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
133 buffer_cutoff = self._buffer_cutoff
134 for arg in map(self.encode, args):
135 # to avoid large string mallocs, chunk the command into the
136 # output list if we're sending large values or memoryviews
137 arg_length = len(arg)
138 if (
139 len(buff) > buffer_cutoff
140 or arg_length > buffer_cutoff
141 or isinstance(arg, memoryview)
142 ):
143 buff = SYM_EMPTY.join(
144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
145 )
146 output.append(buff)
147 output.append(arg)
148 buff = SYM_CRLF
149 else:
150 buff = SYM_EMPTY.join(
151 (
152 buff,
153 SYM_DOLLAR,
154 str(arg_length).encode(),
155 SYM_CRLF,
156 arg,
157 SYM_CRLF,
158 )
159 )
160 output.append(buff)
161 return output
164class ConnectionInterface:
165 @abstractmethod
166 def repr_pieces(self):
167 pass
169 @abstractmethod
170 def register_connect_callback(self, callback):
171 pass
173 @abstractmethod
174 def deregister_connect_callback(self, callback):
175 pass
177 @abstractmethod
178 def set_parser(self, parser_class):
179 pass
181 @abstractmethod
182 def get_protocol(self):
183 pass
185 @abstractmethod
186 def connect(self):
187 pass
189 @abstractmethod
190 def on_connect(self):
191 pass
193 @abstractmethod
194 def disconnect(self, *args):
195 pass
197 @abstractmethod
198 def check_health(self):
199 pass
201 @abstractmethod
202 def send_packed_command(self, command, check_health=True):
203 pass
205 @abstractmethod
206 def send_command(self, *args, **kwargs):
207 pass
209 @abstractmethod
210 def can_read(self, timeout=0):
211 pass
213 @abstractmethod
214 def read_response(
215 self,
216 disable_decoding=False,
217 *,
218 disconnect_on_error=True,
219 push_request=False,
220 ):
221 pass
223 @abstractmethod
224 def pack_command(self, *args):
225 pass
227 @abstractmethod
228 def pack_commands(self, commands):
229 pass
231 @property
232 @abstractmethod
233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
234 pass
236 @abstractmethod
237 def set_re_auth_token(self, token: TokenInterface):
238 pass
240 @abstractmethod
241 def re_auth(self):
242 pass
244 @abstractmethod
245 def mark_for_reconnect(self):
246 """
247 Mark the connection to be reconnected on the next command.
248 This is useful when a connection is moved to a different node.
249 """
250 pass
252 @abstractmethod
253 def should_reconnect(self):
254 """
255 Returns True if the connection should be reconnected.
256 """
257 pass
259 @abstractmethod
260 def reset_should_reconnect(self):
261 """
262 Reset the internal flag to False.
263 """
264 pass
267class MaintNotificationsAbstractConnection:
268 """
269 Abstract class for handling maintenance notifications logic.
270 This class is expected to be used as base class together with ConnectionInterface.
272 This class is intended to be used with multiple inheritance!
274 All logic related to maintenance notifications is encapsulated in this class.
275 """
277 def __init__(
278 self,
279 maint_notifications_config: Optional[MaintNotificationsConfig],
280 maint_notifications_pool_handler: Optional[
281 MaintNotificationsPoolHandler
282 ] = None,
283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
284 maintenance_notification_hash: Optional[int] = None,
285 orig_host_address: Optional[str] = None,
286 orig_socket_timeout: Optional[float] = None,
287 orig_socket_connect_timeout: Optional[float] = None,
288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
289 ):
290 """
291 Initialize the maintenance notifications for the connection.
293 Args:
294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
296 maintenance_state (MaintenanceState): The current maintenance state of the connection.
297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
298 orig_host_address (Optional[str]): The original host address of the connection.
299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
302 If not provided, the parser from the connection is used.
303 This is useful when the parser is created after this object.
304 """
305 self.maint_notifications_config = maint_notifications_config
306 self.maintenance_state = maintenance_state
307 self.maintenance_notification_hash = maintenance_notification_hash
308 self._configure_maintenance_notifications(
309 maint_notifications_pool_handler,
310 orig_host_address,
311 orig_socket_timeout,
312 orig_socket_connect_timeout,
313 parser,
314 )
316 @abstractmethod
317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
318 pass
320 @abstractmethod
321 def _get_socket(self) -> Optional[socket.socket]:
322 pass
324 @abstractmethod
325 def get_protocol(self) -> Union[int, str]:
326 """
327 Returns:
328 The RESP protocol version, or ``None`` if the protocol is not specified,
329 in which case the server default will be used.
330 """
331 pass
333 @property
334 @abstractmethod
335 def host(self) -> str:
336 pass
338 @host.setter
339 @abstractmethod
340 def host(self, value: str):
341 pass
343 @property
344 @abstractmethod
345 def socket_timeout(self) -> Optional[Union[float, int]]:
346 pass
348 @socket_timeout.setter
349 @abstractmethod
350 def socket_timeout(self, value: Optional[Union[float, int]]):
351 pass
353 @property
354 @abstractmethod
355 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
356 pass
358 @socket_connect_timeout.setter
359 @abstractmethod
360 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
361 pass
363 @abstractmethod
364 def send_command(self, *args, **kwargs):
365 pass
367 @abstractmethod
368 def read_response(
369 self,
370 disable_decoding=False,
371 *,
372 disconnect_on_error=True,
373 push_request=False,
374 ):
375 pass
377 @abstractmethod
378 def disconnect(self, *args):
379 pass
381 def _configure_maintenance_notifications(
382 self,
383 maint_notifications_pool_handler: Optional[
384 MaintNotificationsPoolHandler
385 ] = None,
386 orig_host_address=None,
387 orig_socket_timeout=None,
388 orig_socket_connect_timeout=None,
389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
390 ):
391 """
392 Enable maintenance notifications by setting up
393 handlers and storing original connection parameters.
395 Should be used ONLY with parsers that support push notifications.
396 """
397 if (
398 not self.maint_notifications_config
399 or not self.maint_notifications_config.enabled
400 ):
401 self._maint_notifications_pool_handler = None
402 self._maint_notifications_connection_handler = None
403 return
405 if not parser:
406 raise RedisError(
407 "To configure maintenance notifications, a parser must be provided!"
408 )
410 if not isinstance(parser, _HiredisParser) and not isinstance(
411 parser, _RESP3Parser
412 ):
413 raise RedisError(
414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
415 )
417 if maint_notifications_pool_handler:
418 # Extract a reference to a new pool handler that copies all properties
419 # of the original one and has a different connection reference
420 # This is needed because when we attach the handler to the parser
421 # we need to make sure that the handler has a reference to the
422 # connection that the parser is attached to.
423 self._maint_notifications_pool_handler = (
424 maint_notifications_pool_handler.get_handler_for_connection()
425 )
426 self._maint_notifications_pool_handler.set_connection(self)
427 else:
428 self._maint_notifications_pool_handler = None
430 self._maint_notifications_connection_handler = (
431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
432 )
434 # Set up pool handler if available
435 if self._maint_notifications_pool_handler:
436 parser.set_node_moving_push_handler(
437 self._maint_notifications_pool_handler.handle_notification
438 )
440 # Set up connection handler
441 parser.set_maintenance_push_handler(
442 self._maint_notifications_connection_handler.handle_notification
443 )
445 # Store original connection parameters
446 self.orig_host_address = orig_host_address if orig_host_address else self.host
447 self.orig_socket_timeout = (
448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
449 )
450 self.orig_socket_connect_timeout = (
451 orig_socket_connect_timeout
452 if orig_socket_connect_timeout
453 else self.socket_connect_timeout
454 )
456 def set_maint_notifications_pool_handler_for_connection(
457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
458 ):
459 # Deep copy the pool handler to avoid sharing the same pool handler
460 # between multiple connections, because otherwise each connection will override
461 # the connection reference and the pool handler will only hold a reference
462 # to the last connection that was set.
463 maint_notifications_pool_handler_copy = (
464 maint_notifications_pool_handler.get_handler_for_connection()
465 )
467 maint_notifications_pool_handler_copy.set_connection(self)
468 self._get_parser().set_node_moving_push_handler(
469 maint_notifications_pool_handler_copy.handle_notification
470 )
472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
474 # Update maintenance notification connection handler if it doesn't exist
475 if not self._maint_notifications_connection_handler:
476 self._maint_notifications_connection_handler = (
477 MaintNotificationsConnectionHandler(
478 self, maint_notifications_pool_handler.config
479 )
480 )
481 self._get_parser().set_maintenance_push_handler(
482 self._maint_notifications_connection_handler.handle_notification
483 )
484 else:
485 self._maint_notifications_connection_handler.config = (
486 maint_notifications_pool_handler.config
487 )
489 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
490 # Send maintenance notifications handshake if RESP3 is active
491 # and maintenance notifications are enabled
492 # and we have a host to determine the endpoint type from
493 # When the maint_notifications_config enabled mode is "auto",
494 # we just log a warning if the handshake fails
495 # When the mode is enabled=True, we raise an exception in case of failure
496 if (
497 self.get_protocol() not in [2, "2"]
498 and self.maint_notifications_config
499 and self.maint_notifications_config.enabled
500 and self._maint_notifications_connection_handler
501 and hasattr(self, "host")
502 ):
503 self._enable_maintenance_notifications(
504 maint_notifications_config=self.maint_notifications_config,
505 check_health=check_health,
506 )
508 def _enable_maintenance_notifications(
509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
510 ):
511 try:
512 host = getattr(self, "host", None)
513 if host is None:
514 raise ValueError(
515 "Cannot enable maintenance notifications for connection"
516 " object that doesn't have a host attribute."
517 )
518 else:
519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
520 self.send_command(
521 "CLIENT",
522 "MAINT_NOTIFICATIONS",
523 "ON",
524 "moving-endpoint-type",
525 endpoint_type.value,
526 check_health=check_health,
527 )
528 response = self.read_response()
529 if not response or str_if_bytes(response) != "OK":
530 raise ResponseError(
531 "The server doesn't support maintenance notifications"
532 )
533 except Exception as e:
534 if (
535 isinstance(e, ResponseError)
536 and maint_notifications_config.enabled == "auto"
537 ):
538 # Log warning but don't fail the connection
539 import logging
541 logger = logging.getLogger(__name__)
542 logger.debug(f"Failed to enable maintenance notifications: {e}")
543 else:
544 raise
546 def get_resolved_ip(self) -> Optional[str]:
547 """
548 Extract the resolved IP address from an
549 established connection or resolve it from the host.
551 First tries to get the actual IP from the socket (most accurate),
552 then falls back to DNS resolution if needed.
554 Args:
555 connection: The connection object to extract the IP from
557 Returns:
558 str: The resolved IP address, or None if it cannot be determined
559 """
561 # Method 1: Try to get the actual IP from the established socket connection
562 # This is most accurate as it shows the exact IP being used
563 try:
564 conn_socket = self._get_socket()
565 if conn_socket is not None:
566 peer_addr = conn_socket.getpeername()
567 if peer_addr and len(peer_addr) >= 1:
568 # For TCP sockets, peer_addr is typically (host, port) tuple
569 # Return just the host part
570 return peer_addr[0]
571 except (AttributeError, OSError):
572 # Socket might not be connected or getpeername() might fail
573 pass
575 # Method 2: Fallback to DNS resolution of the host
576 # This is less accurate but works when socket is not available
577 try:
578 host = getattr(self, "host", "localhost")
579 port = getattr(self, "port", 6379)
580 if host:
581 # Use getaddrinfo to resolve the hostname to IP
582 # This mimics what the connection would do during _connect()
583 addr_info = socket.getaddrinfo(
584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
585 )
586 if addr_info:
587 # Return the IP from the first result
588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
589 # sockaddr[0] is the IP address
590 return str(addr_info[0][4][0])
591 except (AttributeError, OSError, socket.gaierror):
592 # DNS resolution might fail
593 pass
595 return None
597 @property
598 def maintenance_state(self) -> MaintenanceState:
599 return self._maintenance_state
601 @maintenance_state.setter
602 def maintenance_state(self, state: "MaintenanceState"):
603 self._maintenance_state = state
605 def getpeername(self):
606 """
607 Returns the peer name of the connection.
608 """
609 conn_socket = self._get_socket()
610 if conn_socket:
611 return conn_socket.getpeername()[0]
612 return None
614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
615 conn_socket = self._get_socket()
616 if conn_socket:
617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
618 conn_socket.settimeout(timeout)
619 self.update_parser_timeout(timeout)
621 def update_parser_timeout(self, timeout: Optional[float] = None):
622 parser = self._get_parser()
623 if parser and parser._buffer:
624 if isinstance(parser, _RESP3Parser) and timeout:
625 parser._buffer.socket_timeout = timeout
626 elif isinstance(parser, _HiredisParser):
627 parser._socket_timeout = timeout
629 def set_tmp_settings(
630 self,
631 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
632 tmp_relaxed_timeout: Optional[float] = None,
633 ):
634 """
635 The value of SENTINEL is used to indicate that the property should not be updated.
636 """
637 if tmp_host_address and tmp_host_address != SENTINEL:
638 self.host = str(tmp_host_address)
639 if tmp_relaxed_timeout != -1:
640 self.socket_timeout = tmp_relaxed_timeout
641 self.socket_connect_timeout = tmp_relaxed_timeout
643 def reset_tmp_settings(
644 self,
645 reset_host_address: bool = False,
646 reset_relaxed_timeout: bool = False,
647 ):
648 if reset_host_address:
649 self.host = self.orig_host_address
650 if reset_relaxed_timeout:
651 self.socket_timeout = self.orig_socket_timeout
652 self.socket_connect_timeout = self.orig_socket_connect_timeout
655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
656 "Manages communication to and from a Redis server"
658 def __init__(
659 self,
660 db: int = 0,
661 password: Optional[str] = None,
662 socket_timeout: Optional[float] = None,
663 socket_connect_timeout: Optional[float] = None,
664 retry_on_timeout: bool = False,
665 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
666 encoding: str = "utf-8",
667 encoding_errors: str = "strict",
668 decode_responses: bool = False,
669 parser_class=DefaultParser,
670 socket_read_size: int = 65536,
671 health_check_interval: int = 0,
672 client_name: Optional[str] = None,
673 lib_name: Optional[str] = "redis-py",
674 lib_version: Optional[str] = get_lib_version(),
675 username: Optional[str] = None,
676 retry: Union[Any, None] = None,
677 redis_connect_func: Optional[Callable[[], None]] = None,
678 credential_provider: Optional[CredentialProvider] = None,
679 protocol: Optional[int] = 2,
680 command_packer: Optional[Callable[[], None]] = None,
681 event_dispatcher: Optional[EventDispatcher] = None,
682 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
683 maint_notifications_pool_handler: Optional[
684 MaintNotificationsPoolHandler
685 ] = None,
686 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
687 maintenance_notification_hash: Optional[int] = None,
688 orig_host_address: Optional[str] = None,
689 orig_socket_timeout: Optional[float] = None,
690 orig_socket_connect_timeout: Optional[float] = None,
691 ):
692 """
693 Initialize a new Connection.
694 To specify a retry policy for specific errors, first set
695 `retry_on_error` to a list of the error/s to retry on, then set
696 `retry` to a valid `Retry` object.
697 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
698 """
699 if (username or password) and credential_provider is not None:
700 raise DataError(
701 "'username' and 'password' cannot be passed along with 'credential_"
702 "provider'. Please provide only one of the following arguments: \n"
703 "1. 'password' and (optional) 'username'\n"
704 "2. 'credential_provider'"
705 )
706 if event_dispatcher is None:
707 self._event_dispatcher = EventDispatcher()
708 else:
709 self._event_dispatcher = event_dispatcher
710 self.pid = os.getpid()
711 self.db = db
712 self.client_name = client_name
713 self.lib_name = lib_name
714 self.lib_version = lib_version
715 self.credential_provider = credential_provider
716 self.password = password
717 self.username = username
718 self._socket_timeout = socket_timeout
719 if socket_connect_timeout is None:
720 socket_connect_timeout = socket_timeout
721 self._socket_connect_timeout = socket_connect_timeout
722 self.retry_on_timeout = retry_on_timeout
723 if retry_on_error is SENTINEL:
724 retry_on_errors_list = []
725 else:
726 retry_on_errors_list = list(retry_on_error)
727 if retry_on_timeout:
728 # Add TimeoutError to the errors list to retry on
729 retry_on_errors_list.append(TimeoutError)
730 self.retry_on_error = retry_on_errors_list
731 if retry or self.retry_on_error:
732 if retry is None:
733 self.retry = Retry(NoBackoff(), 1)
734 else:
735 # deep-copy the Retry object as it is mutable
736 self.retry = copy.deepcopy(retry)
737 if self.retry_on_error:
738 # Update the retry's supported errors with the specified errors
739 self.retry.update_supported_errors(self.retry_on_error)
740 else:
741 self.retry = Retry(NoBackoff(), 0)
742 self.health_check_interval = health_check_interval
743 self.next_health_check = 0
744 self.redis_connect_func = redis_connect_func
745 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
746 self.handshake_metadata = None
747 self._sock = None
748 self._socket_read_size = socket_read_size
749 self._connect_callbacks = []
750 self._buffer_cutoff = 6000
751 self._re_auth_token: Optional[TokenInterface] = None
752 try:
753 p = int(protocol)
754 except TypeError:
755 p = DEFAULT_RESP_VERSION
756 except ValueError:
757 raise ConnectionError("protocol must be an integer")
758 finally:
759 if p < 2 or p > 3:
760 raise ConnectionError("protocol must be either 2 or 3")
761 # p = DEFAULT_RESP_VERSION
762 self.protocol = p
763 if self.protocol == 3 and parser_class == _RESP2Parser:
764 # If the protocol is 3 but the parser is RESP2, change it to RESP3
765 # This is needed because the parser might be set before the protocol
766 # or might be provided as a kwarg to the constructor
767 # We need to react on discrepancy only for RESP2 and RESP3
768 # as hiredis supports both
769 parser_class = _RESP3Parser
770 self.set_parser(parser_class)
772 self._command_packer = self._construct_command_packer(command_packer)
773 self._should_reconnect = False
775 # Set up maintenance notifications
776 MaintNotificationsAbstractConnection.__init__(
777 self,
778 maint_notifications_config,
779 maint_notifications_pool_handler,
780 maintenance_state,
781 maintenance_notification_hash,
782 orig_host_address,
783 orig_socket_timeout,
784 orig_socket_connect_timeout,
785 self._parser,
786 )
788 def __repr__(self):
789 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
790 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
792 @abstractmethod
793 def repr_pieces(self):
794 pass
796 def __del__(self):
797 try:
798 self.disconnect()
799 except Exception:
800 pass
802 def _construct_command_packer(self, packer):
803 if packer is not None:
804 return packer
805 elif HIREDIS_AVAILABLE:
806 return HiredisRespSerializer()
807 else:
808 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
810 def register_connect_callback(self, callback):
811 """
812 Register a callback to be called when the connection is established either
813 initially or reconnected. This allows listeners to issue commands that
814 are ephemeral to the connection, for example pub/sub subscription or
815 key tracking. The callback must be a _method_ and will be kept as
816 a weak reference.
817 """
818 wm = weakref.WeakMethod(callback)
819 if wm not in self._connect_callbacks:
820 self._connect_callbacks.append(wm)
822 def deregister_connect_callback(self, callback):
823 """
824 De-register a previously registered callback. It will no-longer receive
825 notifications on connection events. Calling this is not required when the
826 listener goes away, since the callbacks are kept as weak methods.
827 """
828 try:
829 self._connect_callbacks.remove(weakref.WeakMethod(callback))
830 except ValueError:
831 pass
833 def set_parser(self, parser_class):
834 """
835 Creates a new instance of parser_class with socket size:
836 _socket_read_size and assigns it to the parser for the connection
837 :param parser_class: The required parser class
838 """
839 self._parser = parser_class(socket_read_size=self._socket_read_size)
841 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
842 return self._parser
844 def connect(self):
845 "Connects to the Redis server if not already connected"
846 self.connect_check_health(check_health=True)
848 def connect_check_health(
849 self, check_health: bool = True, retry_socket_connect: bool = True
850 ):
851 if self._sock:
852 return
853 try:
854 if retry_socket_connect:
855 sock = self.retry.call_with_retry(
856 lambda: self._connect(), lambda error: self.disconnect(error)
857 )
858 else:
859 sock = self._connect()
860 except socket.timeout:
861 raise TimeoutError("Timeout connecting to server")
862 except OSError as e:
863 raise ConnectionError(self._error_message(e))
865 self._sock = sock
866 try:
867 if self.redis_connect_func is None:
868 # Use the default on_connect function
869 self.on_connect_check_health(check_health=check_health)
870 else:
871 # Use the passed function redis_connect_func
872 self.redis_connect_func(self)
873 except RedisError:
874 # clean up after any error in on_connect
875 self.disconnect()
876 raise
878 # run any user callbacks. right now the only internal callback
879 # is for pubsub channel/pattern resubscription
880 # first, remove any dead weakrefs
881 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
882 for ref in self._connect_callbacks:
883 callback = ref()
884 if callback:
885 callback(self)
887 @abstractmethod
888 def _connect(self):
889 pass
891 @abstractmethod
892 def _host_error(self):
893 pass
895 def _error_message(self, exception):
896 return format_error_message(self._host_error(), exception)
898 def on_connect(self):
899 self.on_connect_check_health(check_health=True)
901 def on_connect_check_health(self, check_health: bool = True):
902 "Initialize the connection, authenticate and select a database"
903 self._parser.on_connect(self)
904 parser = self._parser
906 auth_args = None
907 # if credential provider or username and/or password are set, authenticate
908 if self.credential_provider or (self.username or self.password):
909 cred_provider = (
910 self.credential_provider
911 or UsernamePasswordCredentialProvider(self.username, self.password)
912 )
913 auth_args = cred_provider.get_credentials()
915 # if resp version is specified and we have auth args,
916 # we need to send them via HELLO
917 if auth_args and self.protocol not in [2, "2"]:
918 if isinstance(self._parser, _RESP2Parser):
919 self.set_parser(_RESP3Parser)
920 # update cluster exception classes
921 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
922 self._parser.on_connect(self)
923 if len(auth_args) == 1:
924 auth_args = ["default", auth_args[0]]
925 # avoid checking health here -- PING will fail if we try
926 # to check the health prior to the AUTH
927 self.send_command(
928 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
929 )
930 self.handshake_metadata = self.read_response()
931 # if response.get(b"proto") != self.protocol and response.get(
932 # "proto"
933 # ) != self.protocol:
934 # raise ConnectionError("Invalid RESP version")
935 elif auth_args:
936 # avoid checking health here -- PING will fail if we try
937 # to check the health prior to the AUTH
938 self.send_command("AUTH", *auth_args, check_health=False)
940 try:
941 auth_response = self.read_response()
942 except AuthenticationWrongNumberOfArgsError:
943 # a username and password were specified but the Redis
944 # server seems to be < 6.0.0 which expects a single password
945 # arg. retry auth with just the password.
946 # https://github.com/andymccurdy/redis-py/issues/1274
947 self.send_command("AUTH", auth_args[-1], check_health=False)
948 auth_response = self.read_response()
950 if str_if_bytes(auth_response) != "OK":
951 raise AuthenticationError("Invalid Username or Password")
953 # if resp version is specified, switch to it
954 elif self.protocol not in [2, "2"]:
955 if isinstance(self._parser, _RESP2Parser):
956 self.set_parser(_RESP3Parser)
957 # update cluster exception classes
958 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
959 self._parser.on_connect(self)
960 self.send_command("HELLO", self.protocol, check_health=check_health)
961 self.handshake_metadata = self.read_response()
962 if (
963 self.handshake_metadata.get(b"proto") != self.protocol
964 and self.handshake_metadata.get("proto") != self.protocol
965 ):
966 raise ConnectionError("Invalid RESP version")
968 # Activate maintenance notifications for this connection
969 # if enabled in the configuration
970 # This is a no-op if maintenance notifications are not enabled
971 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
973 # if a client_name is given, set it
974 if self.client_name:
975 self.send_command(
976 "CLIENT",
977 "SETNAME",
978 self.client_name,
979 check_health=check_health,
980 )
981 if str_if_bytes(self.read_response()) != "OK":
982 raise ConnectionError("Error setting client name")
984 try:
985 # set the library name and version
986 if self.lib_name:
987 self.send_command(
988 "CLIENT",
989 "SETINFO",
990 "LIB-NAME",
991 self.lib_name,
992 check_health=check_health,
993 )
994 self.read_response()
995 except ResponseError:
996 pass
998 try:
999 if self.lib_version:
1000 self.send_command(
1001 "CLIENT",
1002 "SETINFO",
1003 "LIB-VER",
1004 self.lib_version,
1005 check_health=check_health,
1006 )
1007 self.read_response()
1008 except ResponseError:
1009 pass
1011 # if a database is specified, switch to it
1012 if self.db:
1013 self.send_command("SELECT", self.db, check_health=check_health)
1014 if str_if_bytes(self.read_response()) != "OK":
1015 raise ConnectionError("Invalid Database")
1017 def disconnect(self, *args):
1018 "Disconnects from the Redis server"
1019 self._parser.on_disconnect()
1021 conn_sock = self._sock
1022 self._sock = None
1023 # reset the reconnect flag
1024 self.reset_should_reconnect()
1025 if conn_sock is None:
1026 return
1028 if os.getpid() == self.pid:
1029 try:
1030 conn_sock.shutdown(socket.SHUT_RDWR)
1031 except (OSError, TypeError):
1032 pass
1034 try:
1035 conn_sock.close()
1036 except OSError:
1037 pass
1039 def mark_for_reconnect(self):
1040 self._should_reconnect = True
1042 def should_reconnect(self):
1043 return self._should_reconnect
1045 def reset_should_reconnect(self):
1046 self._should_reconnect = False
1048 def _send_ping(self):
1049 """Send PING, expect PONG in return"""
1050 self.send_command("PING", check_health=False)
1051 if str_if_bytes(self.read_response()) != "PONG":
1052 raise ConnectionError("Bad response from PING health check")
1054 def _ping_failed(self, error):
1055 """Function to call when PING fails"""
1056 self.disconnect()
1058 def check_health(self):
1059 """Check the health of the connection with a PING/PONG"""
1060 if self.health_check_interval and time.monotonic() > self.next_health_check:
1061 self.retry.call_with_retry(self._send_ping, self._ping_failed)
1063 def send_packed_command(self, command, check_health=True):
1064 """Send an already packed command to the Redis server"""
1065 if not self._sock:
1066 self.connect_check_health(check_health=False)
1067 # guard against health check recursion
1068 if check_health:
1069 self.check_health()
1070 try:
1071 if isinstance(command, str):
1072 command = [command]
1073 for item in command:
1074 self._sock.sendall(item)
1075 except socket.timeout:
1076 self.disconnect()
1077 raise TimeoutError("Timeout writing to socket")
1078 except OSError as e:
1079 self.disconnect()
1080 if len(e.args) == 1:
1081 errno, errmsg = "UNKNOWN", e.args[0]
1082 else:
1083 errno = e.args[0]
1084 errmsg = e.args[1]
1085 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1086 except BaseException:
1087 # BaseExceptions can be raised when a socket send operation is not
1088 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1089 # to send un-sent data. However, the send_packed_command() API
1090 # does not support it so there is no point in keeping the connection open.
1091 self.disconnect()
1092 raise
1094 def send_command(self, *args, **kwargs):
1095 """Pack and send a command to the Redis server"""
1096 self.send_packed_command(
1097 self._command_packer.pack(*args),
1098 check_health=kwargs.get("check_health", True),
1099 )
1101 def can_read(self, timeout=0):
1102 """Poll the socket to see if there's data that can be read."""
1103 sock = self._sock
1104 if not sock:
1105 self.connect()
1107 host_error = self._host_error()
1109 try:
1110 return self._parser.can_read(timeout)
1112 except OSError as e:
1113 self.disconnect()
1114 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1116 def read_response(
1117 self,
1118 disable_decoding=False,
1119 *,
1120 disconnect_on_error=True,
1121 push_request=False,
1122 ):
1123 """Read the response from a previously sent command"""
1125 host_error = self._host_error()
1127 try:
1128 if self.protocol in ["3", 3]:
1129 response = self._parser.read_response(
1130 disable_decoding=disable_decoding, push_request=push_request
1131 )
1132 else:
1133 response = self._parser.read_response(disable_decoding=disable_decoding)
1134 except socket.timeout:
1135 if disconnect_on_error:
1136 self.disconnect()
1137 raise TimeoutError(f"Timeout reading from {host_error}")
1138 except OSError as e:
1139 if disconnect_on_error:
1140 self.disconnect()
1141 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1142 except BaseException:
1143 # Also by default close in case of BaseException. A lot of code
1144 # relies on this behaviour when doing Command/Response pairs.
1145 # See #1128.
1146 if disconnect_on_error:
1147 self.disconnect()
1148 raise
1150 if self.health_check_interval:
1151 self.next_health_check = time.monotonic() + self.health_check_interval
1153 if isinstance(response, ResponseError):
1154 try:
1155 raise response
1156 finally:
1157 del response # avoid creating ref cycles
1158 return response
1160 def pack_command(self, *args):
1161 """Pack a series of arguments into the Redis protocol"""
1162 return self._command_packer.pack(*args)
1164 def pack_commands(self, commands):
1165 """Pack multiple commands into the Redis protocol"""
1166 output = []
1167 pieces = []
1168 buffer_length = 0
1169 buffer_cutoff = self._buffer_cutoff
1171 for cmd in commands:
1172 for chunk in self._command_packer.pack(*cmd):
1173 chunklen = len(chunk)
1174 if (
1175 buffer_length > buffer_cutoff
1176 or chunklen > buffer_cutoff
1177 or isinstance(chunk, memoryview)
1178 ):
1179 if pieces:
1180 output.append(SYM_EMPTY.join(pieces))
1181 buffer_length = 0
1182 pieces = []
1184 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1185 output.append(chunk)
1186 else:
1187 pieces.append(chunk)
1188 buffer_length += chunklen
1190 if pieces:
1191 output.append(SYM_EMPTY.join(pieces))
1192 return output
1194 def get_protocol(self) -> Union[int, str]:
1195 return self.protocol
1197 @property
1198 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1199 return self._handshake_metadata
1201 @handshake_metadata.setter
1202 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1203 self._handshake_metadata = value
1205 def set_re_auth_token(self, token: TokenInterface):
1206 self._re_auth_token = token
1208 def re_auth(self):
1209 if self._re_auth_token is not None:
1210 self.send_command(
1211 "AUTH",
1212 self._re_auth_token.try_get("oid"),
1213 self._re_auth_token.get_value(),
1214 )
1215 self.read_response()
1216 self._re_auth_token = None
1218 def _get_socket(self) -> Optional[socket.socket]:
1219 return self._sock
1221 @property
1222 def socket_timeout(self) -> Optional[Union[float, int]]:
1223 return self._socket_timeout
1225 @socket_timeout.setter
1226 def socket_timeout(self, value: Optional[Union[float, int]]):
1227 self._socket_timeout = value
1229 @property
1230 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1231 return self._socket_connect_timeout
1233 @socket_connect_timeout.setter
1234 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1235 self._socket_connect_timeout = value
1238class Connection(AbstractConnection):
1239 "Manages TCP communication to and from a Redis server"
1241 def __init__(
1242 self,
1243 host="localhost",
1244 port=6379,
1245 socket_keepalive=False,
1246 socket_keepalive_options=None,
1247 socket_type=0,
1248 **kwargs,
1249 ):
1250 self._host = host
1251 self.port = int(port)
1252 self.socket_keepalive = socket_keepalive
1253 self.socket_keepalive_options = socket_keepalive_options or {}
1254 self.socket_type = socket_type
1255 super().__init__(**kwargs)
1257 def repr_pieces(self):
1258 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1259 if self.client_name:
1260 pieces.append(("client_name", self.client_name))
1261 return pieces
1263 def _connect(self):
1264 "Create a TCP socket connection"
1265 # we want to mimic what socket.create_connection does to support
1266 # ipv4/ipv6, but we want to set options prior to calling
1267 # socket.connect()
1268 err = None
1270 for res in socket.getaddrinfo(
1271 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1272 ):
1273 family, socktype, proto, canonname, socket_address = res
1274 sock = None
1275 try:
1276 sock = socket.socket(family, socktype, proto)
1277 # TCP_NODELAY
1278 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1280 # TCP_KEEPALIVE
1281 if self.socket_keepalive:
1282 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1283 for k, v in self.socket_keepalive_options.items():
1284 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1286 # set the socket_connect_timeout before we connect
1287 sock.settimeout(self.socket_connect_timeout)
1289 # connect
1290 sock.connect(socket_address)
1292 # set the socket_timeout now that we're connected
1293 sock.settimeout(self.socket_timeout)
1294 return sock
1296 except OSError as _:
1297 err = _
1298 if sock is not None:
1299 try:
1300 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1301 except OSError:
1302 pass
1303 sock.close()
1305 if err is not None:
1306 raise err
1307 raise OSError("socket.getaddrinfo returned an empty list")
1309 def _host_error(self):
1310 return f"{self.host}:{self.port}"
1312 @property
1313 def host(self) -> str:
1314 return self._host
1316 @host.setter
1317 def host(self, value: str):
1318 self._host = value
1321class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1322 DUMMY_CACHE_VALUE = b"foo"
1323 MIN_ALLOWED_VERSION = "7.4.0"
1324 DEFAULT_SERVER_NAME = "redis"
1326 def __init__(
1327 self,
1328 conn: ConnectionInterface,
1329 cache: CacheInterface,
1330 pool_lock: threading.RLock,
1331 ):
1332 self.pid = os.getpid()
1333 self._conn = conn
1334 self.retry = self._conn.retry
1335 self.host = self._conn.host
1336 self.port = self._conn.port
1337 self.credential_provider = conn.credential_provider
1338 self._pool_lock = pool_lock
1339 self._cache = cache
1340 self._cache_lock = threading.RLock()
1341 self._current_command_cache_key = None
1342 self._current_options = None
1343 self.register_connect_callback(self._enable_tracking_callback)
1345 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1346 MaintNotificationsAbstractConnection.__init__(
1347 self,
1348 self._conn.maint_notifications_config,
1349 self._conn._maint_notifications_pool_handler,
1350 self._conn.maintenance_state,
1351 self._conn.maintenance_notification_hash,
1352 self._conn.host,
1353 self._conn.socket_timeout,
1354 self._conn.socket_connect_timeout,
1355 self._conn._get_parser(),
1356 )
1358 def repr_pieces(self):
1359 return self._conn.repr_pieces()
1361 def register_connect_callback(self, callback):
1362 self._conn.register_connect_callback(callback)
1364 def deregister_connect_callback(self, callback):
1365 self._conn.deregister_connect_callback(callback)
1367 def set_parser(self, parser_class):
1368 self._conn.set_parser(parser_class)
1370 def set_maint_notifications_pool_handler_for_connection(
1371 self, maint_notifications_pool_handler
1372 ):
1373 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1374 self._conn.set_maint_notifications_pool_handler_for_connection(
1375 maint_notifications_pool_handler
1376 )
1378 def get_protocol(self):
1379 return self._conn.get_protocol()
1381 def connect(self):
1382 self._conn.connect()
1384 server_name = self._conn.handshake_metadata.get(b"server", None)
1385 if server_name is None:
1386 server_name = self._conn.handshake_metadata.get("server", None)
1387 server_ver = self._conn.handshake_metadata.get(b"version", None)
1388 if server_ver is None:
1389 server_ver = self._conn.handshake_metadata.get("version", None)
1390 if server_ver is None or server_ver is None:
1391 raise ConnectionError("Cannot retrieve information about server version")
1393 server_ver = ensure_string(server_ver)
1394 server_name = ensure_string(server_name)
1396 if (
1397 server_name != self.DEFAULT_SERVER_NAME
1398 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1399 ):
1400 raise ConnectionError(
1401 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1402 )
1404 def on_connect(self):
1405 self._conn.on_connect()
1407 def disconnect(self, *args):
1408 with self._cache_lock:
1409 self._cache.flush()
1410 self._conn.disconnect(*args)
1412 def check_health(self):
1413 self._conn.check_health()
1415 def send_packed_command(self, command, check_health=True):
1416 # TODO: Investigate if it's possible to unpack command
1417 # or extract keys from packed command
1418 self._conn.send_packed_command(command)
1420 def send_command(self, *args, **kwargs):
1421 self._process_pending_invalidations()
1423 with self._cache_lock:
1424 # Command is write command or not allowed
1425 # to be cached.
1426 if not self._cache.is_cachable(
1427 CacheKey(command=args[0], redis_keys=(), redis_args=())
1428 ):
1429 self._current_command_cache_key = None
1430 self._conn.send_command(*args, **kwargs)
1431 return
1433 if kwargs.get("keys") is None:
1434 raise ValueError("Cannot create cache key.")
1436 # Creates cache key.
1437 self._current_command_cache_key = CacheKey(
1438 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1439 )
1441 with self._cache_lock:
1442 # We have to trigger invalidation processing in case if
1443 # it was cached by another connection to avoid
1444 # queueing invalidations in stale connections.
1445 if self._cache.get(self._current_command_cache_key):
1446 entry = self._cache.get(self._current_command_cache_key)
1448 if entry.connection_ref != self._conn:
1449 with self._pool_lock:
1450 while entry.connection_ref.can_read():
1451 entry.connection_ref.read_response(push_request=True)
1453 return
1455 # Set temporary entry value to prevent
1456 # race condition from another connection.
1457 self._cache.set(
1458 CacheEntry(
1459 cache_key=self._current_command_cache_key,
1460 cache_value=self.DUMMY_CACHE_VALUE,
1461 status=CacheEntryStatus.IN_PROGRESS,
1462 connection_ref=self._conn,
1463 )
1464 )
1466 # Send command over socket only if it's allowed
1467 # read-only command that not yet cached.
1468 self._conn.send_command(*args, **kwargs)
1470 def can_read(self, timeout=0):
1471 return self._conn.can_read(timeout)
1473 def read_response(
1474 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1475 ):
1476 with self._cache_lock:
1477 # Check if command response exists in a cache and it's not in progress.
1478 if (
1479 self._current_command_cache_key is not None
1480 and self._cache.get(self._current_command_cache_key) is not None
1481 and self._cache.get(self._current_command_cache_key).status
1482 != CacheEntryStatus.IN_PROGRESS
1483 ):
1484 res = copy.deepcopy(
1485 self._cache.get(self._current_command_cache_key).cache_value
1486 )
1487 self._current_command_cache_key = None
1488 return res
1490 response = self._conn.read_response(
1491 disable_decoding=disable_decoding,
1492 disconnect_on_error=disconnect_on_error,
1493 push_request=push_request,
1494 )
1496 with self._cache_lock:
1497 # Prevent not-allowed command from caching.
1498 if self._current_command_cache_key is None:
1499 return response
1500 # If response is None prevent from caching.
1501 if response is None:
1502 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1503 return response
1505 cache_entry = self._cache.get(self._current_command_cache_key)
1507 # Cache only responses that still valid
1508 # and wasn't invalidated by another connection in meantime.
1509 if cache_entry is not None:
1510 cache_entry.status = CacheEntryStatus.VALID
1511 cache_entry.cache_value = response
1512 self._cache.set(cache_entry)
1514 self._current_command_cache_key = None
1516 return response
1518 def pack_command(self, *args):
1519 return self._conn.pack_command(*args)
1521 def pack_commands(self, commands):
1522 return self._conn.pack_commands(commands)
1524 @property
1525 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1526 return self._conn.handshake_metadata
1528 def set_re_auth_token(self, token: TokenInterface):
1529 self._conn.set_re_auth_token(token)
1531 def re_auth(self):
1532 self._conn.re_auth()
1534 def mark_for_reconnect(self):
1535 self._conn.mark_for_reconnect()
1537 def should_reconnect(self):
1538 return self._conn.should_reconnect()
1540 def reset_should_reconnect(self):
1541 self._conn.reset_should_reconnect()
1543 @property
1544 def host(self) -> str:
1545 return self._conn.host
1547 @host.setter
1548 def host(self, value: str):
1549 self._conn.host = value
1551 @property
1552 def socket_timeout(self) -> Optional[Union[float, int]]:
1553 return self._conn.socket_timeout
1555 @socket_timeout.setter
1556 def socket_timeout(self, value: Optional[Union[float, int]]):
1557 self._conn.socket_timeout = value
1559 @property
1560 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1561 return self._conn.socket_connect_timeout
1563 @socket_connect_timeout.setter
1564 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1565 self._conn.socket_connect_timeout = value
1567 def _get_socket(self) -> Optional[socket.socket]:
1568 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1569 return self._conn._get_socket()
1570 else:
1571 raise NotImplementedError(
1572 "Maintenance notifications are not supported by this connection type"
1573 )
1575 def _get_maint_notifications_connection_instance(
1576 self, connection
1577 ) -> MaintNotificationsAbstractConnection:
1578 """
1579 Validate that connection instance supports maintenance notifications.
1580 With this helper method we ensure that we are working
1581 with the correct connection type.
1582 After twe validate that connection instance supports maintenance notifications
1583 we can safely return the connection instance
1584 as MaintNotificationsAbstractConnection.
1585 """
1586 if not isinstance(connection, MaintNotificationsAbstractConnection):
1587 raise NotImplementedError(
1588 "Maintenance notifications are not supported by this connection type"
1589 )
1590 else:
1591 return connection
1593 @property
1594 def maintenance_state(self) -> MaintenanceState:
1595 con = self._get_maint_notifications_connection_instance(self._conn)
1596 return con.maintenance_state
1598 @maintenance_state.setter
1599 def maintenance_state(self, state: MaintenanceState):
1600 con = self._get_maint_notifications_connection_instance(self._conn)
1601 con.maintenance_state = state
1603 def getpeername(self):
1604 con = self._get_maint_notifications_connection_instance(self._conn)
1605 return con.getpeername()
1607 def get_resolved_ip(self):
1608 con = self._get_maint_notifications_connection_instance(self._conn)
1609 return con.get_resolved_ip()
1611 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1612 con = self._get_maint_notifications_connection_instance(self._conn)
1613 con.update_current_socket_timeout(relaxed_timeout)
1615 def set_tmp_settings(
1616 self,
1617 tmp_host_address: Optional[str] = None,
1618 tmp_relaxed_timeout: Optional[float] = None,
1619 ):
1620 con = self._get_maint_notifications_connection_instance(self._conn)
1621 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1623 def reset_tmp_settings(
1624 self,
1625 reset_host_address: bool = False,
1626 reset_relaxed_timeout: bool = False,
1627 ):
1628 con = self._get_maint_notifications_connection_instance(self._conn)
1629 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1631 def _connect(self):
1632 self._conn._connect()
1634 def _host_error(self):
1635 self._conn._host_error()
1637 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1638 conn.send_command("CLIENT", "TRACKING", "ON")
1639 conn.read_response()
1640 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1642 def _process_pending_invalidations(self):
1643 while self.can_read():
1644 self._conn.read_response(push_request=True)
1646 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1647 with self._cache_lock:
1648 # Flush cache when DB flushed on server-side
1649 if data[1] is None:
1650 self._cache.flush()
1651 else:
1652 self._cache.delete_by_redis_keys(data[1])
1655class SSLConnection(Connection):
1656 """Manages SSL connections to and from the Redis server(s).
1657 This class extends the Connection class, adding SSL functionality, and making
1658 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1659 """ # noqa
1661 def __init__(
1662 self,
1663 ssl_keyfile=None,
1664 ssl_certfile=None,
1665 ssl_cert_reqs="required",
1666 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1667 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1668 ssl_ca_certs=None,
1669 ssl_ca_data=None,
1670 ssl_check_hostname=True,
1671 ssl_ca_path=None,
1672 ssl_password=None,
1673 ssl_validate_ocsp=False,
1674 ssl_validate_ocsp_stapled=False,
1675 ssl_ocsp_context=None,
1676 ssl_ocsp_expected_cert=None,
1677 ssl_min_version=None,
1678 ssl_ciphers=None,
1679 **kwargs,
1680 ):
1681 """Constructor
1683 Args:
1684 ssl_keyfile: Path to an ssl private key. Defaults to None.
1685 ssl_certfile: Path to an ssl certificate. Defaults to None.
1686 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1687 or an ssl.VerifyMode. Defaults to "required".
1688 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1689 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1690 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1691 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1692 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1693 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1694 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1696 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1697 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1698 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1699 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1700 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1701 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1703 Raises:
1704 RedisError
1705 """ # noqa
1706 if not SSL_AVAILABLE:
1707 raise RedisError("Python wasn't built with SSL support")
1709 self.keyfile = ssl_keyfile
1710 self.certfile = ssl_certfile
1711 if ssl_cert_reqs is None:
1712 ssl_cert_reqs = ssl.CERT_NONE
1713 elif isinstance(ssl_cert_reqs, str):
1714 CERT_REQS = { # noqa: N806
1715 "none": ssl.CERT_NONE,
1716 "optional": ssl.CERT_OPTIONAL,
1717 "required": ssl.CERT_REQUIRED,
1718 }
1719 if ssl_cert_reqs not in CERT_REQS:
1720 raise RedisError(
1721 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1722 )
1723 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1724 self.cert_reqs = ssl_cert_reqs
1725 self.ssl_include_verify_flags = ssl_include_verify_flags
1726 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1727 self.ca_certs = ssl_ca_certs
1728 self.ca_data = ssl_ca_data
1729 self.ca_path = ssl_ca_path
1730 self.check_hostname = (
1731 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1732 )
1733 self.certificate_password = ssl_password
1734 self.ssl_validate_ocsp = ssl_validate_ocsp
1735 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1736 self.ssl_ocsp_context = ssl_ocsp_context
1737 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1738 self.ssl_min_version = ssl_min_version
1739 self.ssl_ciphers = ssl_ciphers
1740 super().__init__(**kwargs)
1742 def _connect(self):
1743 """
1744 Wrap the socket with SSL support, handling potential errors.
1745 """
1746 sock = super()._connect()
1747 try:
1748 return self._wrap_socket_with_ssl(sock)
1749 except (OSError, RedisError):
1750 sock.close()
1751 raise
1753 def _wrap_socket_with_ssl(self, sock):
1754 """
1755 Wraps the socket with SSL support.
1757 Args:
1758 sock: The plain socket to wrap with SSL.
1760 Returns:
1761 An SSL wrapped socket.
1762 """
1763 context = ssl.create_default_context()
1764 context.check_hostname = self.check_hostname
1765 context.verify_mode = self.cert_reqs
1766 if self.ssl_include_verify_flags:
1767 for flag in self.ssl_include_verify_flags:
1768 context.verify_flags |= flag
1769 if self.ssl_exclude_verify_flags:
1770 for flag in self.ssl_exclude_verify_flags:
1771 context.verify_flags &= ~flag
1772 if self.certfile or self.keyfile:
1773 context.load_cert_chain(
1774 certfile=self.certfile,
1775 keyfile=self.keyfile,
1776 password=self.certificate_password,
1777 )
1778 if (
1779 self.ca_certs is not None
1780 or self.ca_path is not None
1781 or self.ca_data is not None
1782 ):
1783 context.load_verify_locations(
1784 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1785 )
1786 if self.ssl_min_version is not None:
1787 context.minimum_version = self.ssl_min_version
1788 if self.ssl_ciphers:
1789 context.set_ciphers(self.ssl_ciphers)
1790 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1791 raise RedisError("cryptography is not installed.")
1793 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1794 raise RedisError(
1795 "Either an OCSP staple or pure OCSP connection must be validated "
1796 "- not both."
1797 )
1799 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1801 # validation for the stapled case
1802 if self.ssl_validate_ocsp_stapled:
1803 import OpenSSL
1805 from .ocsp import ocsp_staple_verifier
1807 # if a context is provided use it - otherwise, a basic context
1808 if self.ssl_ocsp_context is None:
1809 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1810 staple_ctx.use_certificate_file(self.certfile)
1811 staple_ctx.use_privatekey_file(self.keyfile)
1812 else:
1813 staple_ctx = self.ssl_ocsp_context
1815 staple_ctx.set_ocsp_client_callback(
1816 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1817 )
1819 # need another socket
1820 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1821 con.request_ocsp()
1822 con.connect((self.host, self.port))
1823 con.do_handshake()
1824 con.shutdown()
1825 return sslsock
1827 # pure ocsp validation
1828 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1829 from .ocsp import OCSPVerifier
1831 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1832 if o.is_valid():
1833 return sslsock
1834 else:
1835 raise ConnectionError("ocsp validation error")
1836 return sslsock
1839class UnixDomainSocketConnection(AbstractConnection):
1840 "Manages UDS communication to and from a Redis server"
1842 def __init__(self, path="", socket_timeout=None, **kwargs):
1843 super().__init__(**kwargs)
1844 self.path = path
1845 self.socket_timeout = socket_timeout
1847 def repr_pieces(self):
1848 pieces = [("path", self.path), ("db", self.db)]
1849 if self.client_name:
1850 pieces.append(("client_name", self.client_name))
1851 return pieces
1853 def _connect(self):
1854 "Create a Unix domain socket connection"
1855 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1856 sock.settimeout(self.socket_connect_timeout)
1857 try:
1858 sock.connect(self.path)
1859 except OSError:
1860 # Prevent ResourceWarnings for unclosed sockets.
1861 try:
1862 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1863 except OSError:
1864 pass
1865 sock.close()
1866 raise
1867 sock.settimeout(self.socket_timeout)
1868 return sock
1870 def _host_error(self):
1871 return self.path
1874FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1877def to_bool(value):
1878 if value is None or value == "":
1879 return None
1880 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1881 return False
1882 return bool(value)
1885def parse_ssl_verify_flags(value):
1886 # flags are passed in as a string representation of a list,
1887 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1888 verify_flags_str = value.replace("[", "").replace("]", "")
1890 verify_flags = []
1891 for flag in verify_flags_str.split(","):
1892 flag = flag.strip()
1893 if not hasattr(VerifyFlags, flag):
1894 raise ValueError(f"Invalid ssl verify flag: {flag}")
1895 verify_flags.append(getattr(VerifyFlags, flag))
1896 return verify_flags
1899URL_QUERY_ARGUMENT_PARSERS = {
1900 "db": int,
1901 "socket_timeout": float,
1902 "socket_connect_timeout": float,
1903 "socket_keepalive": to_bool,
1904 "retry_on_timeout": to_bool,
1905 "retry_on_error": list,
1906 "max_connections": int,
1907 "health_check_interval": int,
1908 "ssl_check_hostname": to_bool,
1909 "ssl_include_verify_flags": parse_ssl_verify_flags,
1910 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1911 "timeout": float,
1912}
1915def parse_url(url):
1916 if not (
1917 url.startswith("redis://")
1918 or url.startswith("rediss://")
1919 or url.startswith("unix://")
1920 ):
1921 raise ValueError(
1922 "Redis URL must specify one of the following "
1923 "schemes (redis://, rediss://, unix://)"
1924 )
1926 url = urlparse(url)
1927 kwargs = {}
1929 for name, value in parse_qs(url.query).items():
1930 if value and len(value) > 0:
1931 value = unquote(value[0])
1932 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1933 if parser:
1934 try:
1935 kwargs[name] = parser(value)
1936 except (TypeError, ValueError):
1937 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1938 else:
1939 kwargs[name] = value
1941 if url.username:
1942 kwargs["username"] = unquote(url.username)
1943 if url.password:
1944 kwargs["password"] = unquote(url.password)
1946 # We only support redis://, rediss:// and unix:// schemes.
1947 if url.scheme == "unix":
1948 if url.path:
1949 kwargs["path"] = unquote(url.path)
1950 kwargs["connection_class"] = UnixDomainSocketConnection
1952 else: # implied: url.scheme in ("redis", "rediss"):
1953 if url.hostname:
1954 kwargs["host"] = unquote(url.hostname)
1955 if url.port:
1956 kwargs["port"] = int(url.port)
1958 # If there's a path argument, use it as the db argument if a
1959 # querystring value wasn't specified
1960 if url.path and "db" not in kwargs:
1961 try:
1962 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1963 except (AttributeError, ValueError):
1964 pass
1966 if url.scheme == "rediss":
1967 kwargs["connection_class"] = SSLConnection
1969 return kwargs
1972_CP = TypeVar("_CP", bound="ConnectionPool")
1975class ConnectionPoolInterface(ABC):
1976 @abstractmethod
1977 def get_protocol(self):
1978 pass
1980 @abstractmethod
1981 def reset(self):
1982 pass
1984 @abstractmethod
1985 @deprecated_args(
1986 args_to_warn=["*"],
1987 reason="Use get_connection() without args instead",
1988 version="5.3.0",
1989 )
1990 def get_connection(
1991 self, command_name: Optional[str], *keys, **options
1992 ) -> ConnectionInterface:
1993 pass
1995 @abstractmethod
1996 def get_encoder(self):
1997 pass
1999 @abstractmethod
2000 def release(self, connection: ConnectionInterface):
2001 pass
2003 @abstractmethod
2004 def disconnect(self, inuse_connections: bool = True):
2005 pass
2007 @abstractmethod
2008 def close(self):
2009 pass
2011 @abstractmethod
2012 def set_retry(self, retry: Retry):
2013 pass
2015 @abstractmethod
2016 def re_auth_callback(self, token: TokenInterface):
2017 pass
2020class MaintNotificationsAbstractConnectionPool:
2021 """
2022 Abstract class for handling maintenance notifications logic.
2023 This class is mixed into the ConnectionPool classes.
2025 This class is not intended to be used directly!
2027 All logic related to maintenance notifications and
2028 connection pool handling is encapsulated in this class.
2029 """
2031 def __init__(
2032 self,
2033 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2034 **kwargs,
2035 ):
2036 # Initialize maintenance notifications
2037 is_protocol_supported = kwargs.get("protocol") in [3, "3"]
2038 if maint_notifications_config is None and is_protocol_supported:
2039 maint_notifications_config = MaintNotificationsConfig()
2041 if maint_notifications_config and maint_notifications_config.enabled:
2042 if not is_protocol_supported:
2043 raise RedisError(
2044 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2045 )
2047 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2048 self, maint_notifications_config
2049 )
2051 self._update_connection_kwargs_for_maint_notifications(
2052 self._maint_notifications_pool_handler
2053 )
2054 else:
2055 self._maint_notifications_pool_handler = None
2057 @property
2058 @abstractmethod
2059 def connection_kwargs(self) -> Dict[str, Any]:
2060 pass
2062 @connection_kwargs.setter
2063 @abstractmethod
2064 def connection_kwargs(self, value: Dict[str, Any]):
2065 pass
2067 @abstractmethod
2068 def _get_pool_lock(self) -> threading.RLock:
2069 pass
2071 @abstractmethod
2072 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2073 pass
2075 @abstractmethod
2076 def _get_in_use_connections(
2077 self,
2078 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2079 pass
2081 def maint_notifications_enabled(self):
2082 """
2083 Returns:
2084 True if the maintenance notifications are enabled, False otherwise.
2085 The maintenance notifications config is stored in the pool handler.
2086 If the pool handler is not set, the maintenance notifications are not enabled.
2087 """
2088 maint_notifications_config = (
2089 self._maint_notifications_pool_handler.config
2090 if self._maint_notifications_pool_handler
2091 else None
2092 )
2094 return maint_notifications_config and maint_notifications_config.enabled
2096 def update_maint_notifications_config(
2097 self, maint_notifications_config: MaintNotificationsConfig
2098 ):
2099 """
2100 Updates the maintenance notifications configuration.
2101 This method should be called only if the pool was created
2102 without enabling the maintenance notifications and
2103 in a later point in time maintenance notifications
2104 are requested to be enabled.
2105 """
2106 if (
2107 self.maint_notifications_enabled()
2108 and not maint_notifications_config.enabled
2109 ):
2110 raise ValueError(
2111 "Cannot disable maintenance notifications after enabling them"
2112 )
2113 # first update pool settings
2114 if not self._maint_notifications_pool_handler:
2115 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2116 self, maint_notifications_config
2117 )
2118 else:
2119 self._maint_notifications_pool_handler.config = maint_notifications_config
2121 # then update connection kwargs and existing connections
2122 self._update_connection_kwargs_for_maint_notifications(
2123 self._maint_notifications_pool_handler
2124 )
2125 self._update_maint_notifications_configs_for_connections(
2126 self._maint_notifications_pool_handler
2127 )
2129 def _update_connection_kwargs_for_maint_notifications(
2130 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2131 ):
2132 """
2133 Update the connection kwargs for all future connections.
2134 """
2135 if not self.maint_notifications_enabled():
2136 return
2138 self.connection_kwargs.update(
2139 {
2140 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2141 "maint_notifications_config": maint_notifications_pool_handler.config,
2142 }
2143 )
2145 # Store original connection parameters for maintenance notifications.
2146 if self.connection_kwargs.get("orig_host_address", None) is None:
2147 # If orig_host_address is None it means we haven't
2148 # configured the original values yet
2149 self.connection_kwargs.update(
2150 {
2151 "orig_host_address": self.connection_kwargs.get("host"),
2152 "orig_socket_timeout": self.connection_kwargs.get(
2153 "socket_timeout", None
2154 ),
2155 "orig_socket_connect_timeout": self.connection_kwargs.get(
2156 "socket_connect_timeout", None
2157 ),
2158 }
2159 )
2161 def _update_maint_notifications_configs_for_connections(
2162 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2163 ):
2164 """Update the maintenance notifications config for all connections in the pool."""
2165 with self._get_pool_lock():
2166 for conn in self._get_free_connections():
2167 conn.set_maint_notifications_pool_handler_for_connection(
2168 maint_notifications_pool_handler
2169 )
2170 conn.maint_notifications_config = (
2171 maint_notifications_pool_handler.config
2172 )
2173 conn.disconnect()
2174 for conn in self._get_in_use_connections():
2175 conn.set_maint_notifications_pool_handler_for_connection(
2176 maint_notifications_pool_handler
2177 )
2178 conn.maint_notifications_config = (
2179 maint_notifications_pool_handler.config
2180 )
2181 conn.mark_for_reconnect()
2183 def _should_update_connection(
2184 self,
2185 conn: "MaintNotificationsAbstractConnection",
2186 matching_pattern: Literal[
2187 "connected_address", "configured_address", "notification_hash"
2188 ] = "connected_address",
2189 matching_address: Optional[str] = None,
2190 matching_notification_hash: Optional[int] = None,
2191 ) -> bool:
2192 """
2193 Check if the connection should be updated based on the matching criteria.
2194 """
2195 if matching_pattern == "connected_address":
2196 if matching_address and conn.getpeername() != matching_address:
2197 return False
2198 elif matching_pattern == "configured_address":
2199 if matching_address and conn.host != matching_address:
2200 return False
2201 elif matching_pattern == "notification_hash":
2202 if (
2203 matching_notification_hash
2204 and conn.maintenance_notification_hash != matching_notification_hash
2205 ):
2206 return False
2207 return True
2209 def update_connection_settings(
2210 self,
2211 conn: "MaintNotificationsAbstractConnection",
2212 state: Optional["MaintenanceState"] = None,
2213 maintenance_notification_hash: Optional[int] = None,
2214 host_address: Optional[str] = None,
2215 relaxed_timeout: Optional[float] = None,
2216 update_notification_hash: bool = False,
2217 reset_host_address: bool = False,
2218 reset_relaxed_timeout: bool = False,
2219 ):
2220 """
2221 Update the settings for a single connection.
2222 """
2223 if state:
2224 conn.maintenance_state = state
2226 if update_notification_hash:
2227 # update the notification hash only if requested
2228 conn.maintenance_notification_hash = maintenance_notification_hash
2230 if host_address is not None:
2231 conn.set_tmp_settings(tmp_host_address=host_address)
2233 if relaxed_timeout is not None:
2234 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2236 if reset_relaxed_timeout or reset_host_address:
2237 conn.reset_tmp_settings(
2238 reset_host_address=reset_host_address,
2239 reset_relaxed_timeout=reset_relaxed_timeout,
2240 )
2242 conn.update_current_socket_timeout(relaxed_timeout)
2244 def update_connections_settings(
2245 self,
2246 state: Optional["MaintenanceState"] = None,
2247 maintenance_notification_hash: Optional[int] = None,
2248 host_address: Optional[str] = None,
2249 relaxed_timeout: Optional[float] = None,
2250 matching_address: Optional[str] = None,
2251 matching_notification_hash: Optional[int] = None,
2252 matching_pattern: Literal[
2253 "connected_address", "configured_address", "notification_hash"
2254 ] = "connected_address",
2255 update_notification_hash: bool = False,
2256 reset_host_address: bool = False,
2257 reset_relaxed_timeout: bool = False,
2258 include_free_connections: bool = True,
2259 ):
2260 """
2261 Update the settings for all matching connections in the pool.
2263 This method does not create new connections.
2264 This method does not affect the connection kwargs.
2266 :param state: The maintenance state to set for the connection.
2267 :param maintenance_notification_hash: The hash of the maintenance notification
2268 to set for the connection.
2269 :param host_address: The host address to set for the connection.
2270 :param relaxed_timeout: The relaxed timeout to set for the connection.
2271 :param matching_address: The address to match for the connection.
2272 :param matching_notification_hash: The notification hash to match for the connection.
2273 :param matching_pattern: The pattern to match for the connection.
2274 :param update_notification_hash: Whether to update the notification hash for the connection.
2275 :param reset_host_address: Whether to reset the host address to the original address.
2276 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2277 :param include_free_connections: Whether to include free/available connections.
2278 """
2279 with self._get_pool_lock():
2280 for conn in self._get_in_use_connections():
2281 if self._should_update_connection(
2282 conn,
2283 matching_pattern,
2284 matching_address,
2285 matching_notification_hash,
2286 ):
2287 self.update_connection_settings(
2288 conn,
2289 state=state,
2290 maintenance_notification_hash=maintenance_notification_hash,
2291 host_address=host_address,
2292 relaxed_timeout=relaxed_timeout,
2293 update_notification_hash=update_notification_hash,
2294 reset_host_address=reset_host_address,
2295 reset_relaxed_timeout=reset_relaxed_timeout,
2296 )
2298 if include_free_connections:
2299 for conn in self._get_free_connections():
2300 if self._should_update_connection(
2301 conn,
2302 matching_pattern,
2303 matching_address,
2304 matching_notification_hash,
2305 ):
2306 self.update_connection_settings(
2307 conn,
2308 state=state,
2309 maintenance_notification_hash=maintenance_notification_hash,
2310 host_address=host_address,
2311 relaxed_timeout=relaxed_timeout,
2312 update_notification_hash=update_notification_hash,
2313 reset_host_address=reset_host_address,
2314 reset_relaxed_timeout=reset_relaxed_timeout,
2315 )
2317 def update_connection_kwargs(
2318 self,
2319 **kwargs,
2320 ):
2321 """
2322 Update the connection kwargs for all future connections.
2324 This method updates the connection kwargs for all future connections created by the pool.
2325 Existing connections are not affected.
2326 """
2327 self.connection_kwargs.update(kwargs)
2329 def update_active_connections_for_reconnect(
2330 self,
2331 moving_address_src: Optional[str] = None,
2332 ):
2333 """
2334 Mark all active connections for reconnect.
2335 This is used when a cluster node is migrated to a different address.
2337 :param moving_address_src: The address of the node that is being moved.
2338 """
2339 with self._get_pool_lock():
2340 for conn in self._get_in_use_connections():
2341 if self._should_update_connection(
2342 conn, "connected_address", moving_address_src
2343 ):
2344 conn.mark_for_reconnect()
2346 def disconnect_free_connections(
2347 self,
2348 moving_address_src: Optional[str] = None,
2349 ):
2350 """
2351 Disconnect all free/available connections.
2352 This is used when a cluster node is migrated to a different address.
2354 :param moving_address_src: The address of the node that is being moved.
2355 """
2356 with self._get_pool_lock():
2357 for conn in self._get_free_connections():
2358 if self._should_update_connection(
2359 conn, "connected_address", moving_address_src
2360 ):
2361 conn.disconnect()
2364class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2365 """
2366 Create a connection pool. ``If max_connections`` is set, then this
2367 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2368 limit is reached.
2370 By default, TCP connections are created unless ``connection_class``
2371 is specified. Use class:`.UnixDomainSocketConnection` for
2372 unix sockets.
2373 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2375 If ``maint_notifications_config`` is provided, the connection pool will support
2376 maintenance notifications.
2377 Maintenance notifications are supported only with RESP3.
2378 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2379 the maintenance notifications will be enabled by default.
2381 Any additional keyword arguments are passed to the constructor of
2382 ``connection_class``.
2383 """
2385 @classmethod
2386 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2387 """
2388 Return a connection pool configured from the given URL.
2390 For example::
2392 redis://[[username]:[password]]@localhost:6379/0
2393 rediss://[[username]:[password]]@localhost:6379/0
2394 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2396 Three URL schemes are supported:
2398 - `redis://` creates a TCP socket connection. See more at:
2399 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2400 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2401 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2402 - ``unix://``: creates a Unix Domain Socket connection.
2404 The username, password, hostname, path and all querystring values
2405 are passed through urllib.parse.unquote in order to replace any
2406 percent-encoded values with their corresponding characters.
2408 There are several ways to specify a database number. The first value
2409 found will be used:
2411 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2412 2. If using the redis:// or rediss:// schemes, the path argument
2413 of the url, e.g. redis://localhost/0
2414 3. A ``db`` keyword argument to this function.
2416 If none of these options are specified, the default db=0 is used.
2418 All querystring options are cast to their appropriate Python types.
2419 Boolean arguments can be specified with string values "True"/"False"
2420 or "Yes"/"No". Values that cannot be properly cast cause a
2421 ``ValueError`` to be raised. Once parsed, the querystring arguments
2422 and keyword arguments are passed to the ``ConnectionPool``'s
2423 class initializer. In the case of conflicting arguments, querystring
2424 arguments always win.
2425 """
2426 url_options = parse_url(url)
2428 if "connection_class" in kwargs:
2429 url_options["connection_class"] = kwargs["connection_class"]
2431 kwargs.update(url_options)
2432 return cls(**kwargs)
2434 def __init__(
2435 self,
2436 connection_class=Connection,
2437 max_connections: Optional[int] = None,
2438 cache_factory: Optional[CacheFactoryInterface] = None,
2439 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2440 **connection_kwargs,
2441 ):
2442 max_connections = max_connections or 2**31
2443 if not isinstance(max_connections, int) or max_connections < 0:
2444 raise ValueError('"max_connections" must be a positive integer')
2446 self.connection_class = connection_class
2447 self._connection_kwargs = connection_kwargs
2448 self.max_connections = max_connections
2449 self.cache = None
2450 self._cache_factory = cache_factory
2452 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2453 if self._connection_kwargs.get("protocol") not in [3, "3"]:
2454 raise RedisError("Client caching is only supported with RESP version 3")
2456 cache = self._connection_kwargs.get("cache")
2458 if cache is not None:
2459 if not isinstance(cache, CacheInterface):
2460 raise ValueError("Cache must implement CacheInterface")
2462 self.cache = cache
2463 else:
2464 if self._cache_factory is not None:
2465 self.cache = self._cache_factory.get_cache()
2466 else:
2467 self.cache = CacheFactory(
2468 self._connection_kwargs.get("cache_config")
2469 ).get_cache()
2471 connection_kwargs.pop("cache", None)
2472 connection_kwargs.pop("cache_config", None)
2474 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2475 if self._event_dispatcher is None:
2476 self._event_dispatcher = EventDispatcher()
2478 # a lock to protect the critical section in _checkpid().
2479 # this lock is acquired when the process id changes, such as
2480 # after a fork. during this time, multiple threads in the child
2481 # process could attempt to acquire this lock. the first thread
2482 # to acquire the lock will reset the data structures and lock
2483 # object of this pool. subsequent threads acquiring this lock
2484 # will notice the first thread already did the work and simply
2485 # release the lock.
2487 self._fork_lock = threading.RLock()
2488 self._lock = threading.RLock()
2490 MaintNotificationsAbstractConnectionPool.__init__(
2491 self,
2492 maint_notifications_config=maint_notifications_config,
2493 **connection_kwargs,
2494 )
2496 self.reset()
2498 def __repr__(self) -> str:
2499 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
2500 return (
2501 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2502 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2503 f"({conn_kwargs})>)>"
2504 )
2506 @property
2507 def connection_kwargs(self) -> Dict[str, Any]:
2508 return self._connection_kwargs
2510 @connection_kwargs.setter
2511 def connection_kwargs(self, value: Dict[str, Any]):
2512 self._connection_kwargs = value
2514 def get_protocol(self):
2515 """
2516 Returns:
2517 The RESP protocol version, or ``None`` if the protocol is not specified,
2518 in which case the server default will be used.
2519 """
2520 return self.connection_kwargs.get("protocol", None)
2522 def reset(self) -> None:
2523 self._created_connections = 0
2524 self._available_connections = []
2525 self._in_use_connections = set()
2527 # this must be the last operation in this method. while reset() is
2528 # called when holding _fork_lock, other threads in this process
2529 # can call _checkpid() which compares self.pid and os.getpid() without
2530 # holding any lock (for performance reasons). keeping this assignment
2531 # as the last operation ensures that those other threads will also
2532 # notice a pid difference and block waiting for the first thread to
2533 # release _fork_lock. when each of these threads eventually acquire
2534 # _fork_lock, they will notice that another thread already called
2535 # reset() and they will immediately release _fork_lock and continue on.
2536 self.pid = os.getpid()
2538 def _checkpid(self) -> None:
2539 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
2540 # systems. this is called by all ConnectionPool methods that
2541 # manipulate the pool's state such as get_connection() and release().
2542 #
2543 # _checkpid() determines whether the process has forked by comparing
2544 # the current process id to the process id saved on the ConnectionPool
2545 # instance. if these values are the same, _checkpid() simply returns.
2546 #
2547 # when the process ids differ, _checkpid() assumes that the process
2548 # has forked and that we're now running in the child process. the child
2549 # process cannot use the parent's file descriptors (e.g., sockets).
2550 # therefore, when _checkpid() sees the process id change, it calls
2551 # reset() in order to reinitialize the child's ConnectionPool. this
2552 # will cause the child to make all new connection objects.
2553 #
2554 # _checkpid() is protected by self._fork_lock to ensure that multiple
2555 # threads in the child process do not call reset() multiple times.
2556 #
2557 # there is an extremely small chance this could fail in the following
2558 # scenario:
2559 # 1. process A calls _checkpid() for the first time and acquires
2560 # self._fork_lock.
2561 # 2. while holding self._fork_lock, process A forks (the fork()
2562 # could happen in a different thread owned by process A)
2563 # 3. process B (the forked child process) inherits the
2564 # ConnectionPool's state from the parent. that state includes
2565 # a locked _fork_lock. process B will not be notified when
2566 # process A releases the _fork_lock and will thus never be
2567 # able to acquire the _fork_lock.
2568 #
2569 # to mitigate this possible deadlock, _checkpid() will only wait 5
2570 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
2571 # that time it is assumed that the child is deadlocked and a
2572 # redis.ChildDeadlockedError error is raised.
2573 if self.pid != os.getpid():
2574 acquired = self._fork_lock.acquire(timeout=5)
2575 if not acquired:
2576 raise ChildDeadlockedError
2577 # reset() the instance for the new process if another thread
2578 # hasn't already done so
2579 try:
2580 if self.pid != os.getpid():
2581 self.reset()
2582 finally:
2583 self._fork_lock.release()
2585 @deprecated_args(
2586 args_to_warn=["*"],
2587 reason="Use get_connection() without args instead",
2588 version="5.3.0",
2589 )
2590 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
2591 "Get a connection from the pool"
2593 self._checkpid()
2594 with self._lock:
2595 try:
2596 connection = self._available_connections.pop()
2597 except IndexError:
2598 connection = self.make_connection()
2599 self._in_use_connections.add(connection)
2601 try:
2602 # ensure this connection is connected to Redis
2603 connection.connect()
2604 # connections that the pool provides should be ready to send
2605 # a command. if not, the connection was either returned to the
2606 # pool before all data has been read or the socket has been
2607 # closed. either way, reconnect and verify everything is good.
2608 try:
2609 if (
2610 connection.can_read()
2611 and self.cache is None
2612 and not self.maint_notifications_enabled()
2613 ):
2614 raise ConnectionError("Connection has data")
2615 except (ConnectionError, TimeoutError, OSError):
2616 connection.disconnect()
2617 connection.connect()
2618 if connection.can_read():
2619 raise ConnectionError("Connection not ready")
2620 except BaseException:
2621 # release the connection back to the pool so that we don't
2622 # leak it
2623 self.release(connection)
2624 raise
2625 return connection
2627 def get_encoder(self) -> Encoder:
2628 "Return an encoder based on encoding settings"
2629 kwargs = self.connection_kwargs
2630 return Encoder(
2631 encoding=kwargs.get("encoding", "utf-8"),
2632 encoding_errors=kwargs.get("encoding_errors", "strict"),
2633 decode_responses=kwargs.get("decode_responses", False),
2634 )
2636 def make_connection(self) -> "ConnectionInterface":
2637 "Create a new connection"
2638 if self._created_connections >= self.max_connections:
2639 raise MaxConnectionsError("Too many connections")
2640 self._created_connections += 1
2642 kwargs = dict(self.connection_kwargs)
2644 if self.cache is not None:
2645 return CacheProxyConnection(
2646 self.connection_class(**kwargs), self.cache, self._lock
2647 )
2648 return self.connection_class(**kwargs)
2650 def release(self, connection: "Connection") -> None:
2651 "Releases the connection back to the pool"
2652 self._checkpid()
2653 with self._lock:
2654 try:
2655 self._in_use_connections.remove(connection)
2656 except KeyError:
2657 # Gracefully fail when a connection is returned to this pool
2658 # that the pool doesn't actually own
2659 return
2661 if self.owns_connection(connection):
2662 if connection.should_reconnect():
2663 connection.disconnect()
2664 self._available_connections.append(connection)
2665 self._event_dispatcher.dispatch(
2666 AfterConnectionReleasedEvent(connection)
2667 )
2668 else:
2669 # Pool doesn't own this connection, do not add it back
2670 # to the pool.
2671 # The created connections count should not be changed,
2672 # because the connection was not created by the pool.
2673 connection.disconnect()
2674 return
2676 def owns_connection(self, connection: "Connection") -> int:
2677 return connection.pid == self.pid
2679 def disconnect(self, inuse_connections: bool = True) -> None:
2680 """
2681 Disconnects connections in the pool
2683 If ``inuse_connections`` is True, disconnect connections that are
2684 currently in use, potentially by other threads. Otherwise only disconnect
2685 connections that are idle in the pool.
2686 """
2687 self._checkpid()
2688 with self._lock:
2689 if inuse_connections:
2690 connections = chain(
2691 self._available_connections, self._in_use_connections
2692 )
2693 else:
2694 connections = self._available_connections
2696 for connection in connections:
2697 connection.disconnect()
2699 def close(self) -> None:
2700 """Close the pool, disconnecting all connections"""
2701 self.disconnect()
2703 def set_retry(self, retry: Retry) -> None:
2704 self.connection_kwargs.update({"retry": retry})
2705 for conn in self._available_connections:
2706 conn.retry = retry
2707 for conn in self._in_use_connections:
2708 conn.retry = retry
2710 def re_auth_callback(self, token: TokenInterface):
2711 with self._lock:
2712 for conn in self._available_connections:
2713 conn.retry.call_with_retry(
2714 lambda: conn.send_command(
2715 "AUTH", token.try_get("oid"), token.get_value()
2716 ),
2717 lambda error: self._mock(error),
2718 )
2719 conn.retry.call_with_retry(
2720 lambda: conn.read_response(), lambda error: self._mock(error)
2721 )
2722 for conn in self._in_use_connections:
2723 conn.set_re_auth_token(token)
2725 def _get_pool_lock(self):
2726 return self._lock
2728 def _get_free_connections(self):
2729 with self._lock:
2730 return self._available_connections
2732 def _get_in_use_connections(self):
2733 with self._lock:
2734 return self._in_use_connections
2736 async def _mock(self, error: RedisError):
2737 """
2738 Dummy functions, needs to be passed as error callback to retry object.
2739 :param error:
2740 :return:
2741 """
2742 pass
2745class BlockingConnectionPool(ConnectionPool):
2746 """
2747 Thread-safe blocking connection pool::
2749 >>> from redis.client import Redis
2750 >>> client = Redis(connection_pool=BlockingConnectionPool())
2752 It performs the same function as the default
2753 :py:class:`~redis.ConnectionPool` implementation, in that,
2754 it maintains a pool of reusable connections that can be shared by
2755 multiple redis clients (safely across threads if required).
2757 The difference is that, in the event that a client tries to get a
2758 connection from the pool when all of connections are in use, rather than
2759 raising a :py:class:`~redis.ConnectionError` (as the default
2760 :py:class:`~redis.ConnectionPool` implementation does), it
2761 makes the client wait ("blocks") for a specified number of seconds until
2762 a connection becomes available.
2764 Use ``max_connections`` to increase / decrease the pool size::
2766 >>> pool = BlockingConnectionPool(max_connections=10)
2768 Use ``timeout`` to tell it either how many seconds to wait for a connection
2769 to become available, or to block forever:
2771 >>> # Block forever.
2772 >>> pool = BlockingConnectionPool(timeout=None)
2774 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
2775 >>> # not available.
2776 >>> pool = BlockingConnectionPool(timeout=5)
2777 """
2779 def __init__(
2780 self,
2781 max_connections=50,
2782 timeout=20,
2783 connection_class=Connection,
2784 queue_class=LifoQueue,
2785 **connection_kwargs,
2786 ):
2787 self.queue_class = queue_class
2788 self.timeout = timeout
2789 self._in_maintenance = False
2790 self._locked = False
2791 super().__init__(
2792 connection_class=connection_class,
2793 max_connections=max_connections,
2794 **connection_kwargs,
2795 )
2797 def reset(self):
2798 # Create and fill up a thread safe queue with ``None`` values.
2799 try:
2800 if self._in_maintenance:
2801 self._lock.acquire()
2802 self._locked = True
2803 self.pool = self.queue_class(self.max_connections)
2804 while True:
2805 try:
2806 self.pool.put_nowait(None)
2807 except Full:
2808 break
2810 # Keep a list of actual connection instances so that we can
2811 # disconnect them later.
2812 self._connections = []
2813 finally:
2814 if self._locked:
2815 try:
2816 self._lock.release()
2817 except Exception:
2818 pass
2819 self._locked = False
2821 # this must be the last operation in this method. while reset() is
2822 # called when holding _fork_lock, other threads in this process
2823 # can call _checkpid() which compares self.pid and os.getpid() without
2824 # holding any lock (for performance reasons). keeping this assignment
2825 # as the last operation ensures that those other threads will also
2826 # notice a pid difference and block waiting for the first thread to
2827 # release _fork_lock. when each of these threads eventually acquire
2828 # _fork_lock, they will notice that another thread already called
2829 # reset() and they will immediately release _fork_lock and continue on.
2830 self.pid = os.getpid()
2832 def make_connection(self):
2833 "Make a fresh connection."
2834 try:
2835 if self._in_maintenance:
2836 self._lock.acquire()
2837 self._locked = True
2839 if self.cache is not None:
2840 connection = CacheProxyConnection(
2841 self.connection_class(**self.connection_kwargs),
2842 self.cache,
2843 self._lock,
2844 )
2845 else:
2846 connection = self.connection_class(**self.connection_kwargs)
2847 self._connections.append(connection)
2848 return connection
2849 finally:
2850 if self._locked:
2851 try:
2852 self._lock.release()
2853 except Exception:
2854 pass
2855 self._locked = False
2857 @deprecated_args(
2858 args_to_warn=["*"],
2859 reason="Use get_connection() without args instead",
2860 version="5.3.0",
2861 )
2862 def get_connection(self, command_name=None, *keys, **options):
2863 """
2864 Get a connection, blocking for ``self.timeout`` until a connection
2865 is available from the pool.
2867 If the connection returned is ``None`` then creates a new connection.
2868 Because we use a last-in first-out queue, the existing connections
2869 (having been returned to the pool after the initial ``None`` values
2870 were added) will be returned before ``None`` values. This means we only
2871 create new connections when we need to, i.e.: the actual number of
2872 connections will only increase in response to demand.
2873 """
2874 # Make sure we haven't changed process.
2875 self._checkpid()
2877 # Try and get a connection from the pool. If one isn't available within
2878 # self.timeout then raise a ``ConnectionError``.
2879 connection = None
2880 try:
2881 if self._in_maintenance:
2882 self._lock.acquire()
2883 self._locked = True
2884 try:
2885 connection = self.pool.get(block=True, timeout=self.timeout)
2886 except Empty:
2887 # Note that this is not caught by the redis client and will be
2888 # raised unless handled by application code. If you want never to
2889 raise ConnectionError("No connection available.")
2891 # If the ``connection`` is actually ``None`` then that's a cue to make
2892 # a new connection to add to the pool.
2893 if connection is None:
2894 connection = self.make_connection()
2895 finally:
2896 if self._locked:
2897 try:
2898 self._lock.release()
2899 except Exception:
2900 pass
2901 self._locked = False
2903 try:
2904 # ensure this connection is connected to Redis
2905 connection.connect()
2906 # connections that the pool provides should be ready to send
2907 # a command. if not, the connection was either returned to the
2908 # pool before all data has been read or the socket has been
2909 # closed. either way, reconnect and verify everything is good.
2910 try:
2911 if connection.can_read():
2912 raise ConnectionError("Connection has data")
2913 except (ConnectionError, TimeoutError, OSError):
2914 connection.disconnect()
2915 connection.connect()
2916 if connection.can_read():
2917 raise ConnectionError("Connection not ready")
2918 except BaseException:
2919 # release the connection back to the pool so that we don't leak it
2920 self.release(connection)
2921 raise
2923 return connection
2925 def release(self, connection):
2926 "Releases the connection back to the pool."
2927 # Make sure we haven't changed process.
2928 self._checkpid()
2930 try:
2931 if self._in_maintenance:
2932 self._lock.acquire()
2933 self._locked = True
2934 if not self.owns_connection(connection):
2935 # pool doesn't own this connection. do not add it back
2936 # to the pool. instead add a None value which is a placeholder
2937 # that will cause the pool to recreate the connection if
2938 # its needed.
2939 connection.disconnect()
2940 self.pool.put_nowait(None)
2941 return
2942 if connection.should_reconnect():
2943 connection.disconnect()
2944 # Put the connection back into the pool.
2945 try:
2946 self.pool.put_nowait(connection)
2947 except Full:
2948 # perhaps the pool has been reset() after a fork? regardless,
2949 # we don't want this connection
2950 pass
2951 finally:
2952 if self._locked:
2953 try:
2954 self._lock.release()
2955 except Exception:
2956 pass
2957 self._locked = False
2959 def disconnect(self, inuse_connections: bool = True):
2960 "Disconnects either all connections in the pool or just the free connections."
2961 self._checkpid()
2962 try:
2963 if self._in_maintenance:
2964 self._lock.acquire()
2965 self._locked = True
2966 if inuse_connections:
2967 connections = self._connections
2968 else:
2969 connections = self._get_free_connections()
2970 for connection in connections:
2971 connection.disconnect()
2972 finally:
2973 if self._locked:
2974 try:
2975 self._lock.release()
2976 except Exception:
2977 pass
2978 self._locked = False
2980 def _get_free_connections(self):
2981 with self._lock:
2982 return {conn for conn in self.pool.queue if conn}
2984 def _get_in_use_connections(self):
2985 with self._lock:
2986 # free connections
2987 connections_in_queue = {conn for conn in self.pool.queue if conn}
2988 # in self._connections we keep all created connections
2989 # so the ones that are not in the queue are the in use ones
2990 return {
2991 conn for conn in self._connections if conn not in connections_in_queue
2992 }
2994 def set_in_maintenance(self, in_maintenance: bool):
2995 """
2996 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
2998 This is used to prevent new connections from being created while we are in maintenance mode.
2999 The pool will be in maintenance mode only when we are processing a MOVING notification.
3000 """
3001 self._in_maintenance = in_maintenance