Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32)
34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
35from .auth.token import TokenInterface
36from .backoff import NoBackoff
37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
38from .event import AfterConnectionReleasedEvent, EventDispatcher
39from .exceptions import (
40 AuthenticationError,
41 AuthenticationWrongNumberOfArgsError,
42 ChildDeadlockedError,
43 ConnectionError,
44 DataError,
45 MaxConnectionsError,
46 RedisError,
47 ResponseError,
48 TimeoutError,
49)
50from .maint_notifications import (
51 MaintenanceState,
52 MaintNotificationsConfig,
53 MaintNotificationsConnectionHandler,
54 MaintNotificationsPoolHandler,
55)
56from .retry import Retry
57from .utils import (
58 CRYPTOGRAPHY_AVAILABLE,
59 HIREDIS_AVAILABLE,
60 SSL_AVAILABLE,
61 compare_versions,
62 deprecated_args,
63 ensure_string,
64 format_error_message,
65 get_lib_version,
66 str_if_bytes,
67)
69if SSL_AVAILABLE:
70 import ssl
71 from ssl import VerifyFlags
72else:
73 ssl = None
74 VerifyFlags = None
76if HIREDIS_AVAILABLE:
77 import hiredis
79SYM_STAR = b"*"
80SYM_DOLLAR = b"$"
81SYM_CRLF = b"\r\n"
82SYM_EMPTY = b""
84DEFAULT_RESP_VERSION = 2
86SENTINEL = object()
88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
89if HIREDIS_AVAILABLE:
90 DefaultParser = _HiredisParser
91else:
92 DefaultParser = _RESP2Parser
95class HiredisRespSerializer:
96 def pack(self, *args: List):
97 """Pack a series of arguments into the Redis protocol"""
98 output = []
100 if isinstance(args[0], str):
101 args = tuple(args[0].encode().split()) + args[1:]
102 elif b" " in args[0]:
103 args = tuple(args[0].split()) + args[1:]
104 try:
105 output.append(hiredis.pack_command(args))
106 except TypeError:
107 _, value, traceback = sys.exc_info()
108 raise DataError(value).with_traceback(traceback)
110 return output
113class PythonRespSerializer:
114 def __init__(self, buffer_cutoff, encode) -> None:
115 self._buffer_cutoff = buffer_cutoff
116 self.encode = encode
118 def pack(self, *args):
119 """Pack a series of arguments into the Redis protocol"""
120 output = []
121 # the client might have included 1 or more literal arguments in
122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
123 # arguments to be sent separately, so split the first argument
124 # manually. These arguments should be bytestrings so that they are
125 # not encoded.
126 if isinstance(args[0], str):
127 args = tuple(args[0].encode().split()) + args[1:]
128 elif b" " in args[0]:
129 args = tuple(args[0].split()) + args[1:]
131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
133 buffer_cutoff = self._buffer_cutoff
134 for arg in map(self.encode, args):
135 # to avoid large string mallocs, chunk the command into the
136 # output list if we're sending large values or memoryviews
137 arg_length = len(arg)
138 if (
139 len(buff) > buffer_cutoff
140 or arg_length > buffer_cutoff
141 or isinstance(arg, memoryview)
142 ):
143 buff = SYM_EMPTY.join(
144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
145 )
146 output.append(buff)
147 output.append(arg)
148 buff = SYM_CRLF
149 else:
150 buff = SYM_EMPTY.join(
151 (
152 buff,
153 SYM_DOLLAR,
154 str(arg_length).encode(),
155 SYM_CRLF,
156 arg,
157 SYM_CRLF,
158 )
159 )
160 output.append(buff)
161 return output
164class ConnectionInterface:
165 @abstractmethod
166 def repr_pieces(self):
167 pass
169 @abstractmethod
170 def register_connect_callback(self, callback):
171 pass
173 @abstractmethod
174 def deregister_connect_callback(self, callback):
175 pass
177 @abstractmethod
178 def set_parser(self, parser_class):
179 pass
181 @abstractmethod
182 def get_protocol(self):
183 pass
185 @abstractmethod
186 def connect(self):
187 pass
189 @abstractmethod
190 def on_connect(self):
191 pass
193 @abstractmethod
194 def disconnect(self, *args):
195 pass
197 @abstractmethod
198 def check_health(self):
199 pass
201 @abstractmethod
202 def send_packed_command(self, command, check_health=True):
203 pass
205 @abstractmethod
206 def send_command(self, *args, **kwargs):
207 pass
209 @abstractmethod
210 def can_read(self, timeout=0):
211 pass
213 @abstractmethod
214 def read_response(
215 self,
216 disable_decoding=False,
217 *,
218 disconnect_on_error=True,
219 push_request=False,
220 ):
221 pass
223 @abstractmethod
224 def pack_command(self, *args):
225 pass
227 @abstractmethod
228 def pack_commands(self, commands):
229 pass
231 @property
232 @abstractmethod
233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
234 pass
236 @abstractmethod
237 def set_re_auth_token(self, token: TokenInterface):
238 pass
240 @abstractmethod
241 def re_auth(self):
242 pass
244 @abstractmethod
245 def mark_for_reconnect(self):
246 """
247 Mark the connection to be reconnected on the next command.
248 This is useful when a connection is moved to a different node.
249 """
250 pass
252 @abstractmethod
253 def should_reconnect(self):
254 """
255 Returns True if the connection should be reconnected.
256 """
257 pass
259 @abstractmethod
260 def reset_should_reconnect(self):
261 """
262 Reset the internal flag to False.
263 """
264 pass
267class MaintNotificationsAbstractConnection:
268 """
269 Abstract class for handling maintenance notifications logic.
270 This class is expected to be used as base class together with ConnectionInterface.
272 This class is intended to be used with multiple inheritance!
274 All logic related to maintenance notifications is encapsulated in this class.
275 """
277 def __init__(
278 self,
279 maint_notifications_config: Optional[MaintNotificationsConfig],
280 maint_notifications_pool_handler: Optional[
281 MaintNotificationsPoolHandler
282 ] = None,
283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
284 maintenance_notification_hash: Optional[int] = None,
285 orig_host_address: Optional[str] = None,
286 orig_socket_timeout: Optional[float] = None,
287 orig_socket_connect_timeout: Optional[float] = None,
288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
289 ):
290 """
291 Initialize the maintenance notifications for the connection.
293 Args:
294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
296 maintenance_state (MaintenanceState): The current maintenance state of the connection.
297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
298 orig_host_address (Optional[str]): The original host address of the connection.
299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
302 If not provided, the parser from the connection is used.
303 This is useful when the parser is created after this object.
304 """
305 self.maint_notifications_config = maint_notifications_config
306 self.maintenance_state = maintenance_state
307 self.maintenance_notification_hash = maintenance_notification_hash
308 self._configure_maintenance_notifications(
309 maint_notifications_pool_handler,
310 orig_host_address,
311 orig_socket_timeout,
312 orig_socket_connect_timeout,
313 parser,
314 )
316 @abstractmethod
317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
318 pass
320 @abstractmethod
321 def _get_socket(self) -> Optional[socket.socket]:
322 pass
324 @abstractmethod
325 def get_protocol(self) -> Union[int, str]:
326 """
327 Returns:
328 The RESP protocol version, or ``None`` if the protocol is not specified,
329 in which case the server default will be used.
330 """
331 pass
333 @property
334 @abstractmethod
335 def host(self) -> str:
336 pass
338 @host.setter
339 @abstractmethod
340 def host(self, value: str):
341 pass
343 @property
344 @abstractmethod
345 def socket_timeout(self) -> Optional[Union[float, int]]:
346 pass
348 @socket_timeout.setter
349 @abstractmethod
350 def socket_timeout(self, value: Optional[Union[float, int]]):
351 pass
353 @property
354 @abstractmethod
355 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
356 pass
358 @socket_connect_timeout.setter
359 @abstractmethod
360 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
361 pass
363 @abstractmethod
364 def send_command(self, *args, **kwargs):
365 pass
367 @abstractmethod
368 def read_response(
369 self,
370 disable_decoding=False,
371 *,
372 disconnect_on_error=True,
373 push_request=False,
374 ):
375 pass
377 @abstractmethod
378 def disconnect(self, *args):
379 pass
381 def _configure_maintenance_notifications(
382 self,
383 maint_notifications_pool_handler: Optional[
384 MaintNotificationsPoolHandler
385 ] = None,
386 orig_host_address=None,
387 orig_socket_timeout=None,
388 orig_socket_connect_timeout=None,
389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
390 ):
391 """
392 Enable maintenance notifications by setting up
393 handlers and storing original connection parameters.
395 Should be used ONLY with parsers that support push notifications.
396 """
397 if (
398 not self.maint_notifications_config
399 or not self.maint_notifications_config.enabled
400 ):
401 self._maint_notifications_pool_handler = None
402 self._maint_notifications_connection_handler = None
403 return
405 if not parser:
406 raise RedisError(
407 "To configure maintenance notifications, a parser must be provided!"
408 )
410 if not isinstance(parser, _HiredisParser) and not isinstance(
411 parser, _RESP3Parser
412 ):
413 raise RedisError(
414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
415 )
417 if maint_notifications_pool_handler:
418 # Extract a reference to a new pool handler that copies all properties
419 # of the original one and has a different connection reference
420 # This is needed because when we attach the handler to the parser
421 # we need to make sure that the handler has a reference to the
422 # connection that the parser is attached to.
423 self._maint_notifications_pool_handler = (
424 maint_notifications_pool_handler.get_handler_for_connection()
425 )
426 self._maint_notifications_pool_handler.set_connection(self)
427 else:
428 self._maint_notifications_pool_handler = None
430 self._maint_notifications_connection_handler = (
431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
432 )
434 # Set up pool handler if available
435 if self._maint_notifications_pool_handler:
436 parser.set_node_moving_push_handler(
437 self._maint_notifications_pool_handler.handle_notification
438 )
440 # Set up connection handler
441 parser.set_maintenance_push_handler(
442 self._maint_notifications_connection_handler.handle_notification
443 )
445 # Store original connection parameters
446 self.orig_host_address = orig_host_address if orig_host_address else self.host
447 self.orig_socket_timeout = (
448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
449 )
450 self.orig_socket_connect_timeout = (
451 orig_socket_connect_timeout
452 if orig_socket_connect_timeout
453 else self.socket_connect_timeout
454 )
456 def set_maint_notifications_pool_handler_for_connection(
457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
458 ):
459 # Deep copy the pool handler to avoid sharing the same pool handler
460 # between multiple connections, because otherwise each connection will override
461 # the connection reference and the pool handler will only hold a reference
462 # to the last connection that was set.
463 maint_notifications_pool_handler_copy = (
464 maint_notifications_pool_handler.get_handler_for_connection()
465 )
467 maint_notifications_pool_handler_copy.set_connection(self)
468 self._get_parser().set_node_moving_push_handler(
469 maint_notifications_pool_handler_copy.handle_notification
470 )
472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
474 # Update maintenance notification connection handler if it doesn't exist
475 if not self._maint_notifications_connection_handler:
476 self._maint_notifications_connection_handler = (
477 MaintNotificationsConnectionHandler(
478 self, maint_notifications_pool_handler.config
479 )
480 )
481 self._get_parser().set_maintenance_push_handler(
482 self._maint_notifications_connection_handler.handle_notification
483 )
484 else:
485 self._maint_notifications_connection_handler.config = (
486 maint_notifications_pool_handler.config
487 )
489 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
490 # Send maintenance notifications handshake if RESP3 is active
491 # and maintenance notifications are enabled
492 # and we have a host to determine the endpoint type from
493 # When the maint_notifications_config enabled mode is "auto",
494 # we just log a warning if the handshake fails
495 # When the mode is enabled=True, we raise an exception in case of failure
496 if (
497 self.get_protocol() not in [2, "2"]
498 and self.maint_notifications_config
499 and self.maint_notifications_config.enabled
500 and self._maint_notifications_connection_handler
501 and hasattr(self, "host")
502 ):
503 self._enable_maintenance_notifications(
504 maint_notifications_config=self.maint_notifications_config,
505 check_health=check_health,
506 )
508 def _enable_maintenance_notifications(
509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
510 ):
511 try:
512 host = getattr(self, "host", None)
513 if host is None:
514 raise ValueError(
515 "Cannot enable maintenance notifications for connection"
516 " object that doesn't have a host attribute."
517 )
518 else:
519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
520 self.send_command(
521 "CLIENT",
522 "MAINT_NOTIFICATIONS",
523 "ON",
524 "moving-endpoint-type",
525 endpoint_type.value,
526 check_health=check_health,
527 )
528 response = self.read_response()
529 if not response or str_if_bytes(response) != "OK":
530 raise ResponseError(
531 "The server doesn't support maintenance notifications"
532 )
533 except Exception as e:
534 if (
535 isinstance(e, ResponseError)
536 and maint_notifications_config.enabled == "auto"
537 ):
538 # Log warning but don't fail the connection
539 import logging
541 logger = logging.getLogger(__name__)
542 logger.warning(f"Failed to enable maintenance notifications: {e}")
543 else:
544 raise
546 def get_resolved_ip(self) -> Optional[str]:
547 """
548 Extract the resolved IP address from an
549 established connection or resolve it from the host.
551 First tries to get the actual IP from the socket (most accurate),
552 then falls back to DNS resolution if needed.
554 Args:
555 connection: The connection object to extract the IP from
557 Returns:
558 str: The resolved IP address, or None if it cannot be determined
559 """
561 # Method 1: Try to get the actual IP from the established socket connection
562 # This is most accurate as it shows the exact IP being used
563 try:
564 conn_socket = self._get_socket()
565 if conn_socket is not None:
566 peer_addr = conn_socket.getpeername()
567 if peer_addr and len(peer_addr) >= 1:
568 # For TCP sockets, peer_addr is typically (host, port) tuple
569 # Return just the host part
570 return peer_addr[0]
571 except (AttributeError, OSError):
572 # Socket might not be connected or getpeername() might fail
573 pass
575 # Method 2: Fallback to DNS resolution of the host
576 # This is less accurate but works when socket is not available
577 try:
578 host = getattr(self, "host", "localhost")
579 port = getattr(self, "port", 6379)
580 if host:
581 # Use getaddrinfo to resolve the hostname to IP
582 # This mimics what the connection would do during _connect()
583 addr_info = socket.getaddrinfo(
584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
585 )
586 if addr_info:
587 # Return the IP from the first result
588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
589 # sockaddr[0] is the IP address
590 return str(addr_info[0][4][0])
591 except (AttributeError, OSError, socket.gaierror):
592 # DNS resolution might fail
593 pass
595 return None
597 @property
598 def maintenance_state(self) -> MaintenanceState:
599 return self._maintenance_state
601 @maintenance_state.setter
602 def maintenance_state(self, state: "MaintenanceState"):
603 self._maintenance_state = state
605 def getpeername(self):
606 """
607 Returns the peer name of the connection.
608 """
609 conn_socket = self._get_socket()
610 if conn_socket:
611 return conn_socket.getpeername()[0]
612 return None
614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
615 conn_socket = self._get_socket()
616 if conn_socket:
617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
618 conn_socket.settimeout(timeout)
619 self.update_parser_timeout(timeout)
621 def update_parser_timeout(self, timeout: Optional[float] = None):
622 parser = self._get_parser()
623 if parser and parser._buffer:
624 if isinstance(parser, _RESP3Parser) and timeout:
625 parser._buffer.socket_timeout = timeout
626 elif isinstance(parser, _HiredisParser):
627 parser._socket_timeout = timeout
629 def set_tmp_settings(
630 self,
631 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
632 tmp_relaxed_timeout: Optional[float] = None,
633 ):
634 """
635 The value of SENTINEL is used to indicate that the property should not be updated.
636 """
637 if tmp_host_address and tmp_host_address != SENTINEL:
638 self.host = str(tmp_host_address)
639 if tmp_relaxed_timeout != -1:
640 self.socket_timeout = tmp_relaxed_timeout
641 self.socket_connect_timeout = tmp_relaxed_timeout
643 def reset_tmp_settings(
644 self,
645 reset_host_address: bool = False,
646 reset_relaxed_timeout: bool = False,
647 ):
648 if reset_host_address:
649 self.host = self.orig_host_address
650 if reset_relaxed_timeout:
651 self.socket_timeout = self.orig_socket_timeout
652 self.socket_connect_timeout = self.orig_socket_connect_timeout
655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
656 "Manages communication to and from a Redis server"
658 def __init__(
659 self,
660 db: int = 0,
661 password: Optional[str] = None,
662 socket_timeout: Optional[float] = None,
663 socket_connect_timeout: Optional[float] = None,
664 retry_on_timeout: bool = False,
665 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
666 encoding: str = "utf-8",
667 encoding_errors: str = "strict",
668 decode_responses: bool = False,
669 parser_class=DefaultParser,
670 socket_read_size: int = 65536,
671 health_check_interval: int = 0,
672 client_name: Optional[str] = None,
673 lib_name: Optional[str] = "redis-py",
674 lib_version: Optional[str] = get_lib_version(),
675 username: Optional[str] = None,
676 retry: Union[Any, None] = None,
677 redis_connect_func: Optional[Callable[[], None]] = None,
678 credential_provider: Optional[CredentialProvider] = None,
679 protocol: Optional[int] = 2,
680 command_packer: Optional[Callable[[], None]] = None,
681 event_dispatcher: Optional[EventDispatcher] = None,
682 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
683 maint_notifications_pool_handler: Optional[
684 MaintNotificationsPoolHandler
685 ] = None,
686 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
687 maintenance_notification_hash: Optional[int] = None,
688 orig_host_address: Optional[str] = None,
689 orig_socket_timeout: Optional[float] = None,
690 orig_socket_connect_timeout: Optional[float] = None,
691 ):
692 """
693 Initialize a new Connection.
694 To specify a retry policy for specific errors, first set
695 `retry_on_error` to a list of the error/s to retry on, then set
696 `retry` to a valid `Retry` object.
697 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
698 """
699 if (username or password) and credential_provider is not None:
700 raise DataError(
701 "'username' and 'password' cannot be passed along with 'credential_"
702 "provider'. Please provide only one of the following arguments: \n"
703 "1. 'password' and (optional) 'username'\n"
704 "2. 'credential_provider'"
705 )
706 if event_dispatcher is None:
707 self._event_dispatcher = EventDispatcher()
708 else:
709 self._event_dispatcher = event_dispatcher
710 self.pid = os.getpid()
711 self.db = db
712 self.client_name = client_name
713 self.lib_name = lib_name
714 self.lib_version = lib_version
715 self.credential_provider = credential_provider
716 self.password = password
717 self.username = username
718 self._socket_timeout = socket_timeout
719 if socket_connect_timeout is None:
720 socket_connect_timeout = socket_timeout
721 self._socket_connect_timeout = socket_connect_timeout
722 self.retry_on_timeout = retry_on_timeout
723 if retry_on_error is SENTINEL:
724 retry_on_errors_list = []
725 else:
726 retry_on_errors_list = list(retry_on_error)
727 if retry_on_timeout:
728 # Add TimeoutError to the errors list to retry on
729 retry_on_errors_list.append(TimeoutError)
730 self.retry_on_error = retry_on_errors_list
731 if retry or self.retry_on_error:
732 if retry is None:
733 self.retry = Retry(NoBackoff(), 1)
734 else:
735 # deep-copy the Retry object as it is mutable
736 self.retry = copy.deepcopy(retry)
737 if self.retry_on_error:
738 # Update the retry's supported errors with the specified errors
739 self.retry.update_supported_errors(self.retry_on_error)
740 else:
741 self.retry = Retry(NoBackoff(), 0)
742 self.health_check_interval = health_check_interval
743 self.next_health_check = 0
744 self.redis_connect_func = redis_connect_func
745 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
746 self.handshake_metadata = None
747 self._sock = None
748 self._socket_read_size = socket_read_size
749 self._connect_callbacks = []
750 self._buffer_cutoff = 6000
751 self._re_auth_token: Optional[TokenInterface] = None
752 try:
753 p = int(protocol)
754 except TypeError:
755 p = DEFAULT_RESP_VERSION
756 except ValueError:
757 raise ConnectionError("protocol must be an integer")
758 finally:
759 if p < 2 or p > 3:
760 raise ConnectionError("protocol must be either 2 or 3")
761 # p = DEFAULT_RESP_VERSION
762 self.protocol = p
763 if self.protocol == 3 and parser_class == _RESP2Parser:
764 # If the protocol is 3 but the parser is RESP2, change it to RESP3
765 # This is needed because the parser might be set before the protocol
766 # or might be provided as a kwarg to the constructor
767 # We need to react on discrepancy only for RESP2 and RESP3
768 # as hiredis supports both
769 parser_class = _RESP3Parser
770 self.set_parser(parser_class)
772 self._command_packer = self._construct_command_packer(command_packer)
773 self._should_reconnect = False
775 # Set up maintenance notifications
776 MaintNotificationsAbstractConnection.__init__(
777 self,
778 maint_notifications_config,
779 maint_notifications_pool_handler,
780 maintenance_state,
781 maintenance_notification_hash,
782 orig_host_address,
783 orig_socket_timeout,
784 orig_socket_connect_timeout,
785 self._parser,
786 )
788 def __repr__(self):
789 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
790 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
792 @abstractmethod
793 def repr_pieces(self):
794 pass
796 def __del__(self):
797 try:
798 self.disconnect()
799 except Exception:
800 pass
802 def _construct_command_packer(self, packer):
803 if packer is not None:
804 return packer
805 elif HIREDIS_AVAILABLE:
806 return HiredisRespSerializer()
807 else:
808 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
810 def register_connect_callback(self, callback):
811 """
812 Register a callback to be called when the connection is established either
813 initially or reconnected. This allows listeners to issue commands that
814 are ephemeral to the connection, for example pub/sub subscription or
815 key tracking. The callback must be a _method_ and will be kept as
816 a weak reference.
817 """
818 wm = weakref.WeakMethod(callback)
819 if wm not in self._connect_callbacks:
820 self._connect_callbacks.append(wm)
822 def deregister_connect_callback(self, callback):
823 """
824 De-register a previously registered callback. It will no-longer receive
825 notifications on connection events. Calling this is not required when the
826 listener goes away, since the callbacks are kept as weak methods.
827 """
828 try:
829 self._connect_callbacks.remove(weakref.WeakMethod(callback))
830 except ValueError:
831 pass
833 def set_parser(self, parser_class):
834 """
835 Creates a new instance of parser_class with socket size:
836 _socket_read_size and assigns it to the parser for the connection
837 :param parser_class: The required parser class
838 """
839 self._parser = parser_class(socket_read_size=self._socket_read_size)
841 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
842 return self._parser
844 def connect(self):
845 "Connects to the Redis server if not already connected"
846 self.connect_check_health(check_health=True)
848 def connect_check_health(
849 self, check_health: bool = True, retry_socket_connect: bool = True
850 ):
851 if self._sock:
852 return
853 try:
854 if retry_socket_connect:
855 sock = self.retry.call_with_retry(
856 lambda: self._connect(), lambda error: self.disconnect(error)
857 )
858 else:
859 sock = self._connect()
860 except socket.timeout:
861 raise TimeoutError("Timeout connecting to server")
862 except OSError as e:
863 raise ConnectionError(self._error_message(e))
865 self._sock = sock
866 try:
867 if self.redis_connect_func is None:
868 # Use the default on_connect function
869 self.on_connect_check_health(check_health=check_health)
870 else:
871 # Use the passed function redis_connect_func
872 self.redis_connect_func(self)
873 except RedisError:
874 # clean up after any error in on_connect
875 self.disconnect()
876 raise
878 # run any user callbacks. right now the only internal callback
879 # is for pubsub channel/pattern resubscription
880 # first, remove any dead weakrefs
881 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
882 for ref in self._connect_callbacks:
883 callback = ref()
884 if callback:
885 callback(self)
887 @abstractmethod
888 def _connect(self):
889 pass
891 @abstractmethod
892 def _host_error(self):
893 pass
895 def _error_message(self, exception):
896 return format_error_message(self._host_error(), exception)
898 def on_connect(self):
899 self.on_connect_check_health(check_health=True)
901 def on_connect_check_health(self, check_health: bool = True):
902 "Initialize the connection, authenticate and select a database"
903 self._parser.on_connect(self)
904 parser = self._parser
906 auth_args = None
907 # if credential provider or username and/or password are set, authenticate
908 if self.credential_provider or (self.username or self.password):
909 cred_provider = (
910 self.credential_provider
911 or UsernamePasswordCredentialProvider(self.username, self.password)
912 )
913 auth_args = cred_provider.get_credentials()
915 # if resp version is specified and we have auth args,
916 # we need to send them via HELLO
917 if auth_args and self.protocol not in [2, "2"]:
918 if isinstance(self._parser, _RESP2Parser):
919 self.set_parser(_RESP3Parser)
920 # update cluster exception classes
921 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
922 self._parser.on_connect(self)
923 if len(auth_args) == 1:
924 auth_args = ["default", auth_args[0]]
925 # avoid checking health here -- PING will fail if we try
926 # to check the health prior to the AUTH
927 self.send_command(
928 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
929 )
930 self.handshake_metadata = self.read_response()
931 # if response.get(b"proto") != self.protocol and response.get(
932 # "proto"
933 # ) != self.protocol:
934 # raise ConnectionError("Invalid RESP version")
935 elif auth_args:
936 # avoid checking health here -- PING will fail if we try
937 # to check the health prior to the AUTH
938 self.send_command("AUTH", *auth_args, check_health=False)
940 try:
941 auth_response = self.read_response()
942 except AuthenticationWrongNumberOfArgsError:
943 # a username and password were specified but the Redis
944 # server seems to be < 6.0.0 which expects a single password
945 # arg. retry auth with just the password.
946 # https://github.com/andymccurdy/redis-py/issues/1274
947 self.send_command("AUTH", auth_args[-1], check_health=False)
948 auth_response = self.read_response()
950 if str_if_bytes(auth_response) != "OK":
951 raise AuthenticationError("Invalid Username or Password")
953 # if resp version is specified, switch to it
954 elif self.protocol not in [2, "2"]:
955 if isinstance(self._parser, _RESP2Parser):
956 self.set_parser(_RESP3Parser)
957 # update cluster exception classes
958 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
959 self._parser.on_connect(self)
960 self.send_command("HELLO", self.protocol, check_health=check_health)
961 self.handshake_metadata = self.read_response()
962 if (
963 self.handshake_metadata.get(b"proto") != self.protocol
964 and self.handshake_metadata.get("proto") != self.protocol
965 ):
966 raise ConnectionError("Invalid RESP version")
968 # Activate maintenance notifications for this connection
969 # if enabled in the configuration
970 # This is a no-op if maintenance notifications are not enabled
971 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
973 # if a client_name is given, set it
974 if self.client_name:
975 self.send_command(
976 "CLIENT",
977 "SETNAME",
978 self.client_name,
979 check_health=check_health,
980 )
981 if str_if_bytes(self.read_response()) != "OK":
982 raise ConnectionError("Error setting client name")
984 try:
985 # set the library name and version
986 if self.lib_name:
987 self.send_command(
988 "CLIENT",
989 "SETINFO",
990 "LIB-NAME",
991 self.lib_name,
992 check_health=check_health,
993 )
994 self.read_response()
995 except ResponseError:
996 pass
998 try:
999 if self.lib_version:
1000 self.send_command(
1001 "CLIENT",
1002 "SETINFO",
1003 "LIB-VER",
1004 self.lib_version,
1005 check_health=check_health,
1006 )
1007 self.read_response()
1008 except ResponseError:
1009 pass
1011 # if a database is specified, switch to it
1012 if self.db:
1013 self.send_command("SELECT", self.db, check_health=check_health)
1014 if str_if_bytes(self.read_response()) != "OK":
1015 raise ConnectionError("Invalid Database")
1017 def disconnect(self, *args):
1018 "Disconnects from the Redis server"
1019 self._parser.on_disconnect()
1021 conn_sock = self._sock
1022 self._sock = None
1023 # reset the reconnect flag
1024 self.reset_should_reconnect()
1025 if conn_sock is None:
1026 return
1028 if os.getpid() == self.pid:
1029 try:
1030 conn_sock.shutdown(socket.SHUT_RDWR)
1031 except (OSError, TypeError):
1032 pass
1034 try:
1035 conn_sock.close()
1036 except OSError:
1037 pass
1039 def mark_for_reconnect(self):
1040 self._should_reconnect = True
1042 def should_reconnect(self):
1043 return self._should_reconnect
1045 def reset_should_reconnect(self):
1046 self._should_reconnect = False
1048 def _send_ping(self):
1049 """Send PING, expect PONG in return"""
1050 self.send_command("PING", check_health=False)
1051 if str_if_bytes(self.read_response()) != "PONG":
1052 raise ConnectionError("Bad response from PING health check")
1054 def _ping_failed(self, error):
1055 """Function to call when PING fails"""
1056 self.disconnect()
1058 def check_health(self):
1059 """Check the health of the connection with a PING/PONG"""
1060 if self.health_check_interval and time.monotonic() > self.next_health_check:
1061 self.retry.call_with_retry(self._send_ping, self._ping_failed)
1063 def send_packed_command(self, command, check_health=True):
1064 """Send an already packed command to the Redis server"""
1065 if not self._sock:
1066 self.connect_check_health(check_health=False)
1067 # guard against health check recursion
1068 if check_health:
1069 self.check_health()
1070 try:
1071 if isinstance(command, str):
1072 command = [command]
1073 for item in command:
1074 self._sock.sendall(item)
1075 except socket.timeout:
1076 self.disconnect()
1077 raise TimeoutError("Timeout writing to socket")
1078 except OSError as e:
1079 self.disconnect()
1080 if len(e.args) == 1:
1081 errno, errmsg = "UNKNOWN", e.args[0]
1082 else:
1083 errno = e.args[0]
1084 errmsg = e.args[1]
1085 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1086 except BaseException:
1087 # BaseExceptions can be raised when a socket send operation is not
1088 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1089 # to send un-sent data. However, the send_packed_command() API
1090 # does not support it so there is no point in keeping the connection open.
1091 self.disconnect()
1092 raise
1094 def send_command(self, *args, **kwargs):
1095 """Pack and send a command to the Redis server"""
1096 self.send_packed_command(
1097 self._command_packer.pack(*args),
1098 check_health=kwargs.get("check_health", True),
1099 )
1101 def can_read(self, timeout=0):
1102 """Poll the socket to see if there's data that can be read."""
1103 sock = self._sock
1104 if not sock:
1105 self.connect()
1107 host_error = self._host_error()
1109 try:
1110 return self._parser.can_read(timeout)
1112 except OSError as e:
1113 self.disconnect()
1114 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1116 def read_response(
1117 self,
1118 disable_decoding=False,
1119 *,
1120 disconnect_on_error=True,
1121 push_request=False,
1122 ):
1123 """Read the response from a previously sent command"""
1125 host_error = self._host_error()
1127 try:
1128 if self.protocol in ["3", 3]:
1129 response = self._parser.read_response(
1130 disable_decoding=disable_decoding, push_request=push_request
1131 )
1132 else:
1133 response = self._parser.read_response(disable_decoding=disable_decoding)
1134 except socket.timeout:
1135 if disconnect_on_error:
1136 self.disconnect()
1137 raise TimeoutError(f"Timeout reading from {host_error}")
1138 except OSError as e:
1139 if disconnect_on_error:
1140 self.disconnect()
1141 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1142 except BaseException:
1143 # Also by default close in case of BaseException. A lot of code
1144 # relies on this behaviour when doing Command/Response pairs.
1145 # See #1128.
1146 if disconnect_on_error:
1147 self.disconnect()
1148 raise
1150 if self.health_check_interval:
1151 self.next_health_check = time.monotonic() + self.health_check_interval
1153 if isinstance(response, ResponseError):
1154 try:
1155 raise response
1156 finally:
1157 del response # avoid creating ref cycles
1158 return response
1160 def pack_command(self, *args):
1161 """Pack a series of arguments into the Redis protocol"""
1162 return self._command_packer.pack(*args)
1164 def pack_commands(self, commands):
1165 """Pack multiple commands into the Redis protocol"""
1166 output = []
1167 pieces = []
1168 buffer_length = 0
1169 buffer_cutoff = self._buffer_cutoff
1171 for cmd in commands:
1172 for chunk in self._command_packer.pack(*cmd):
1173 chunklen = len(chunk)
1174 if (
1175 buffer_length > buffer_cutoff
1176 or chunklen > buffer_cutoff
1177 or isinstance(chunk, memoryview)
1178 ):
1179 if pieces:
1180 output.append(SYM_EMPTY.join(pieces))
1181 buffer_length = 0
1182 pieces = []
1184 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1185 output.append(chunk)
1186 else:
1187 pieces.append(chunk)
1188 buffer_length += chunklen
1190 if pieces:
1191 output.append(SYM_EMPTY.join(pieces))
1192 return output
1194 def get_protocol(self) -> Union[int, str]:
1195 return self.protocol
1197 @property
1198 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1199 return self._handshake_metadata
1201 @handshake_metadata.setter
1202 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1203 self._handshake_metadata = value
1205 def set_re_auth_token(self, token: TokenInterface):
1206 self._re_auth_token = token
1208 def re_auth(self):
1209 if self._re_auth_token is not None:
1210 self.send_command(
1211 "AUTH",
1212 self._re_auth_token.try_get("oid"),
1213 self._re_auth_token.get_value(),
1214 )
1215 self.read_response()
1216 self._re_auth_token = None
1218 def _get_socket(self) -> Optional[socket.socket]:
1219 return self._sock
1221 @property
1222 def socket_timeout(self) -> Optional[Union[float, int]]:
1223 return self._socket_timeout
1225 @socket_timeout.setter
1226 def socket_timeout(self, value: Optional[Union[float, int]]):
1227 self._socket_timeout = value
1229 @property
1230 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1231 return self._socket_connect_timeout
1233 @socket_connect_timeout.setter
1234 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1235 self._socket_connect_timeout = value
1238class Connection(AbstractConnection):
1239 "Manages TCP communication to and from a Redis server"
1241 def __init__(
1242 self,
1243 host="localhost",
1244 port=6379,
1245 socket_keepalive=False,
1246 socket_keepalive_options=None,
1247 socket_type=0,
1248 **kwargs,
1249 ):
1250 self._host = host
1251 self.port = int(port)
1252 self.socket_keepalive = socket_keepalive
1253 self.socket_keepalive_options = socket_keepalive_options or {}
1254 self.socket_type = socket_type
1255 super().__init__(**kwargs)
1257 def repr_pieces(self):
1258 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1259 if self.client_name:
1260 pieces.append(("client_name", self.client_name))
1261 return pieces
1263 def _connect(self):
1264 "Create a TCP socket connection"
1265 # we want to mimic what socket.create_connection does to support
1266 # ipv4/ipv6, but we want to set options prior to calling
1267 # socket.connect()
1268 err = None
1270 for res in socket.getaddrinfo(
1271 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1272 ):
1273 family, socktype, proto, canonname, socket_address = res
1274 sock = None
1275 try:
1276 sock = socket.socket(family, socktype, proto)
1277 # TCP_NODELAY
1278 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1280 # TCP_KEEPALIVE
1281 if self.socket_keepalive:
1282 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1283 for k, v in self.socket_keepalive_options.items():
1284 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1286 # set the socket_connect_timeout before we connect
1287 sock.settimeout(self.socket_connect_timeout)
1289 # connect
1290 sock.connect(socket_address)
1292 # set the socket_timeout now that we're connected
1293 sock.settimeout(self.socket_timeout)
1294 return sock
1296 except OSError as _:
1297 err = _
1298 if sock is not None:
1299 try:
1300 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1301 except OSError:
1302 pass
1303 sock.close()
1305 if err is not None:
1306 raise err
1307 raise OSError("socket.getaddrinfo returned an empty list")
1309 def _host_error(self):
1310 return f"{self.host}:{self.port}"
1312 @property
1313 def host(self) -> str:
1314 return self._host
1316 @host.setter
1317 def host(self, value: str):
1318 self._host = value
1321class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1322 DUMMY_CACHE_VALUE = b"foo"
1323 MIN_ALLOWED_VERSION = "7.4.0"
1324 DEFAULT_SERVER_NAME = "redis"
1326 def __init__(
1327 self,
1328 conn: ConnectionInterface,
1329 cache: CacheInterface,
1330 pool_lock: threading.RLock,
1331 ):
1332 self.pid = os.getpid()
1333 self._conn = conn
1334 self.retry = self._conn.retry
1335 self.host = self._conn.host
1336 self.port = self._conn.port
1337 self.credential_provider = conn.credential_provider
1338 self._pool_lock = pool_lock
1339 self._cache = cache
1340 self._cache_lock = threading.RLock()
1341 self._current_command_cache_key = None
1342 self._current_options = None
1343 self.register_connect_callback(self._enable_tracking_callback)
1345 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1346 MaintNotificationsAbstractConnection.__init__(
1347 self,
1348 self._conn.maint_notifications_config,
1349 self._conn._maint_notifications_pool_handler,
1350 self._conn.maintenance_state,
1351 self._conn.maintenance_notification_hash,
1352 self._conn.host,
1353 self._conn.socket_timeout,
1354 self._conn.socket_connect_timeout,
1355 self._conn._get_parser(),
1356 )
1358 def repr_pieces(self):
1359 return self._conn.repr_pieces()
1361 def register_connect_callback(self, callback):
1362 self._conn.register_connect_callback(callback)
1364 def deregister_connect_callback(self, callback):
1365 self._conn.deregister_connect_callback(callback)
1367 def set_parser(self, parser_class):
1368 self._conn.set_parser(parser_class)
1370 def set_maint_notifications_pool_handler_for_connection(
1371 self, maint_notifications_pool_handler
1372 ):
1373 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1374 self._conn.set_maint_notifications_pool_handler_for_connection(
1375 maint_notifications_pool_handler
1376 )
1378 def get_protocol(self):
1379 return self._conn.get_protocol()
1381 def connect(self):
1382 self._conn.connect()
1384 server_name = self._conn.handshake_metadata.get(b"server", None)
1385 if server_name is None:
1386 server_name = self._conn.handshake_metadata.get("server", None)
1387 server_ver = self._conn.handshake_metadata.get(b"version", None)
1388 if server_ver is None:
1389 server_ver = self._conn.handshake_metadata.get("version", None)
1390 if server_ver is None or server_ver is None:
1391 raise ConnectionError("Cannot retrieve information about server version")
1393 server_ver = ensure_string(server_ver)
1394 server_name = ensure_string(server_name)
1396 if (
1397 server_name != self.DEFAULT_SERVER_NAME
1398 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1399 ):
1400 raise ConnectionError(
1401 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1402 )
1404 def on_connect(self):
1405 self._conn.on_connect()
1407 def disconnect(self, *args):
1408 with self._cache_lock:
1409 self._cache.flush()
1410 self._conn.disconnect(*args)
1412 def check_health(self):
1413 self._conn.check_health()
1415 def send_packed_command(self, command, check_health=True):
1416 # TODO: Investigate if it's possible to unpack command
1417 # or extract keys from packed command
1418 self._conn.send_packed_command(command)
1420 def send_command(self, *args, **kwargs):
1421 self._process_pending_invalidations()
1423 with self._cache_lock:
1424 # Command is write command or not allowed
1425 # to be cached.
1426 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())):
1427 self._current_command_cache_key = None
1428 self._conn.send_command(*args, **kwargs)
1429 return
1431 if kwargs.get("keys") is None:
1432 raise ValueError("Cannot create cache key.")
1434 # Creates cache key.
1435 self._current_command_cache_key = CacheKey(
1436 command=args[0], redis_keys=tuple(kwargs.get("keys"))
1437 )
1439 with self._cache_lock:
1440 # We have to trigger invalidation processing in case if
1441 # it was cached by another connection to avoid
1442 # queueing invalidations in stale connections.
1443 if self._cache.get(self._current_command_cache_key):
1444 entry = self._cache.get(self._current_command_cache_key)
1446 if entry.connection_ref != self._conn:
1447 with self._pool_lock:
1448 while entry.connection_ref.can_read():
1449 entry.connection_ref.read_response(push_request=True)
1451 return
1453 # Set temporary entry value to prevent
1454 # race condition from another connection.
1455 self._cache.set(
1456 CacheEntry(
1457 cache_key=self._current_command_cache_key,
1458 cache_value=self.DUMMY_CACHE_VALUE,
1459 status=CacheEntryStatus.IN_PROGRESS,
1460 connection_ref=self._conn,
1461 )
1462 )
1464 # Send command over socket only if it's allowed
1465 # read-only command that not yet cached.
1466 self._conn.send_command(*args, **kwargs)
1468 def can_read(self, timeout=0):
1469 return self._conn.can_read(timeout)
1471 def read_response(
1472 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1473 ):
1474 with self._cache_lock:
1475 # Check if command response exists in a cache and it's not in progress.
1476 if (
1477 self._current_command_cache_key is not None
1478 and self._cache.get(self._current_command_cache_key) is not None
1479 and self._cache.get(self._current_command_cache_key).status
1480 != CacheEntryStatus.IN_PROGRESS
1481 ):
1482 res = copy.deepcopy(
1483 self._cache.get(self._current_command_cache_key).cache_value
1484 )
1485 self._current_command_cache_key = None
1486 return res
1488 response = self._conn.read_response(
1489 disable_decoding=disable_decoding,
1490 disconnect_on_error=disconnect_on_error,
1491 push_request=push_request,
1492 )
1494 with self._cache_lock:
1495 # Prevent not-allowed command from caching.
1496 if self._current_command_cache_key is None:
1497 return response
1498 # If response is None prevent from caching.
1499 if response is None:
1500 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1501 return response
1503 cache_entry = self._cache.get(self._current_command_cache_key)
1505 # Cache only responses that still valid
1506 # and wasn't invalidated by another connection in meantime.
1507 if cache_entry is not None:
1508 cache_entry.status = CacheEntryStatus.VALID
1509 cache_entry.cache_value = response
1510 self._cache.set(cache_entry)
1512 self._current_command_cache_key = None
1514 return response
1516 def pack_command(self, *args):
1517 return self._conn.pack_command(*args)
1519 def pack_commands(self, commands):
1520 return self._conn.pack_commands(commands)
1522 @property
1523 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1524 return self._conn.handshake_metadata
1526 def set_re_auth_token(self, token: TokenInterface):
1527 self._conn.set_re_auth_token(token)
1529 def re_auth(self):
1530 self._conn.re_auth()
1532 def mark_for_reconnect(self):
1533 self._conn.mark_for_reconnect()
1535 def should_reconnect(self):
1536 return self._conn.should_reconnect()
1538 def reset_should_reconnect(self):
1539 self._conn.reset_should_reconnect()
1541 @property
1542 def host(self) -> str:
1543 return self._conn.host
1545 @host.setter
1546 def host(self, value: str):
1547 self._conn.host = value
1549 @property
1550 def socket_timeout(self) -> Optional[Union[float, int]]:
1551 return self._conn.socket_timeout
1553 @socket_timeout.setter
1554 def socket_timeout(self, value: Optional[Union[float, int]]):
1555 self._conn.socket_timeout = value
1557 @property
1558 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1559 return self._conn.socket_connect_timeout
1561 @socket_connect_timeout.setter
1562 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1563 self._conn.socket_connect_timeout = value
1565 def _get_socket(self) -> Optional[socket.socket]:
1566 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1567 return self._conn._get_socket()
1568 else:
1569 raise NotImplementedError(
1570 "Maintenance notifications are not supported by this connection type"
1571 )
1573 def _get_maint_notifications_connection_instance(
1574 self, connection
1575 ) -> MaintNotificationsAbstractConnection:
1576 """
1577 Validate that connection instance supports maintenance notifications.
1578 With this helper method we ensure that we are working
1579 with the correct connection type.
1580 After twe validate that connection instance supports maintenance notifications
1581 we can safely return the connection instance
1582 as MaintNotificationsAbstractConnection.
1583 """
1584 if not isinstance(connection, MaintNotificationsAbstractConnection):
1585 raise NotImplementedError(
1586 "Maintenance notifications are not supported by this connection type"
1587 )
1588 else:
1589 return connection
1591 @property
1592 def maintenance_state(self) -> MaintenanceState:
1593 con = self._get_maint_notifications_connection_instance(self._conn)
1594 return con.maintenance_state
1596 @maintenance_state.setter
1597 def maintenance_state(self, state: MaintenanceState):
1598 con = self._get_maint_notifications_connection_instance(self._conn)
1599 con.maintenance_state = state
1601 def getpeername(self):
1602 con = self._get_maint_notifications_connection_instance(self._conn)
1603 return con.getpeername()
1605 def get_resolved_ip(self):
1606 con = self._get_maint_notifications_connection_instance(self._conn)
1607 return con.get_resolved_ip()
1609 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1610 con = self._get_maint_notifications_connection_instance(self._conn)
1611 con.update_current_socket_timeout(relaxed_timeout)
1613 def set_tmp_settings(
1614 self,
1615 tmp_host_address: Optional[str] = None,
1616 tmp_relaxed_timeout: Optional[float] = None,
1617 ):
1618 con = self._get_maint_notifications_connection_instance(self._conn)
1619 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1621 def reset_tmp_settings(
1622 self,
1623 reset_host_address: bool = False,
1624 reset_relaxed_timeout: bool = False,
1625 ):
1626 con = self._get_maint_notifications_connection_instance(self._conn)
1627 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1629 def _connect(self):
1630 self._conn._connect()
1632 def _host_error(self):
1633 self._conn._host_error()
1635 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1636 conn.send_command("CLIENT", "TRACKING", "ON")
1637 conn.read_response()
1638 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1640 def _process_pending_invalidations(self):
1641 while self.can_read():
1642 self._conn.read_response(push_request=True)
1644 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1645 with self._cache_lock:
1646 # Flush cache when DB flushed on server-side
1647 if data[1] is None:
1648 self._cache.flush()
1649 else:
1650 self._cache.delete_by_redis_keys(data[1])
1653class SSLConnection(Connection):
1654 """Manages SSL connections to and from the Redis server(s).
1655 This class extends the Connection class, adding SSL functionality, and making
1656 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1657 """ # noqa
1659 def __init__(
1660 self,
1661 ssl_keyfile=None,
1662 ssl_certfile=None,
1663 ssl_cert_reqs="required",
1664 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1665 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1666 ssl_ca_certs=None,
1667 ssl_ca_data=None,
1668 ssl_check_hostname=True,
1669 ssl_ca_path=None,
1670 ssl_password=None,
1671 ssl_validate_ocsp=False,
1672 ssl_validate_ocsp_stapled=False,
1673 ssl_ocsp_context=None,
1674 ssl_ocsp_expected_cert=None,
1675 ssl_min_version=None,
1676 ssl_ciphers=None,
1677 **kwargs,
1678 ):
1679 """Constructor
1681 Args:
1682 ssl_keyfile: Path to an ssl private key. Defaults to None.
1683 ssl_certfile: Path to an ssl certificate. Defaults to None.
1684 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1685 or an ssl.VerifyMode. Defaults to "required".
1686 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1687 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1688 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1689 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1690 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1691 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1692 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1694 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1695 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1696 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1697 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1698 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1699 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1701 Raises:
1702 RedisError
1703 """ # noqa
1704 if not SSL_AVAILABLE:
1705 raise RedisError("Python wasn't built with SSL support")
1707 self.keyfile = ssl_keyfile
1708 self.certfile = ssl_certfile
1709 if ssl_cert_reqs is None:
1710 ssl_cert_reqs = ssl.CERT_NONE
1711 elif isinstance(ssl_cert_reqs, str):
1712 CERT_REQS = { # noqa: N806
1713 "none": ssl.CERT_NONE,
1714 "optional": ssl.CERT_OPTIONAL,
1715 "required": ssl.CERT_REQUIRED,
1716 }
1717 if ssl_cert_reqs not in CERT_REQS:
1718 raise RedisError(
1719 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1720 )
1721 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1722 self.cert_reqs = ssl_cert_reqs
1723 self.ssl_include_verify_flags = ssl_include_verify_flags
1724 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1725 self.ca_certs = ssl_ca_certs
1726 self.ca_data = ssl_ca_data
1727 self.ca_path = ssl_ca_path
1728 self.check_hostname = (
1729 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1730 )
1731 self.certificate_password = ssl_password
1732 self.ssl_validate_ocsp = ssl_validate_ocsp
1733 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1734 self.ssl_ocsp_context = ssl_ocsp_context
1735 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1736 self.ssl_min_version = ssl_min_version
1737 self.ssl_ciphers = ssl_ciphers
1738 super().__init__(**kwargs)
1740 def _connect(self):
1741 """
1742 Wrap the socket with SSL support, handling potential errors.
1743 """
1744 sock = super()._connect()
1745 try:
1746 return self._wrap_socket_with_ssl(sock)
1747 except (OSError, RedisError):
1748 sock.close()
1749 raise
1751 def _wrap_socket_with_ssl(self, sock):
1752 """
1753 Wraps the socket with SSL support.
1755 Args:
1756 sock: The plain socket to wrap with SSL.
1758 Returns:
1759 An SSL wrapped socket.
1760 """
1761 context = ssl.create_default_context()
1762 context.check_hostname = self.check_hostname
1763 context.verify_mode = self.cert_reqs
1764 if self.ssl_include_verify_flags:
1765 for flag in self.ssl_include_verify_flags:
1766 context.verify_flags |= flag
1767 if self.ssl_exclude_verify_flags:
1768 for flag in self.ssl_exclude_verify_flags:
1769 context.verify_flags &= ~flag
1770 if self.certfile or self.keyfile:
1771 context.load_cert_chain(
1772 certfile=self.certfile,
1773 keyfile=self.keyfile,
1774 password=self.certificate_password,
1775 )
1776 if (
1777 self.ca_certs is not None
1778 or self.ca_path is not None
1779 or self.ca_data is not None
1780 ):
1781 context.load_verify_locations(
1782 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1783 )
1784 if self.ssl_min_version is not None:
1785 context.minimum_version = self.ssl_min_version
1786 if self.ssl_ciphers:
1787 context.set_ciphers(self.ssl_ciphers)
1788 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1789 raise RedisError("cryptography is not installed.")
1791 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1792 raise RedisError(
1793 "Either an OCSP staple or pure OCSP connection must be validated "
1794 "- not both."
1795 )
1797 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1799 # validation for the stapled case
1800 if self.ssl_validate_ocsp_stapled:
1801 import OpenSSL
1803 from .ocsp import ocsp_staple_verifier
1805 # if a context is provided use it - otherwise, a basic context
1806 if self.ssl_ocsp_context is None:
1807 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1808 staple_ctx.use_certificate_file(self.certfile)
1809 staple_ctx.use_privatekey_file(self.keyfile)
1810 else:
1811 staple_ctx = self.ssl_ocsp_context
1813 staple_ctx.set_ocsp_client_callback(
1814 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1815 )
1817 # need another socket
1818 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1819 con.request_ocsp()
1820 con.connect((self.host, self.port))
1821 con.do_handshake()
1822 con.shutdown()
1823 return sslsock
1825 # pure ocsp validation
1826 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1827 from .ocsp import OCSPVerifier
1829 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1830 if o.is_valid():
1831 return sslsock
1832 else:
1833 raise ConnectionError("ocsp validation error")
1834 return sslsock
1837class UnixDomainSocketConnection(AbstractConnection):
1838 "Manages UDS communication to and from a Redis server"
1840 def __init__(self, path="", socket_timeout=None, **kwargs):
1841 super().__init__(**kwargs)
1842 self.path = path
1843 self.socket_timeout = socket_timeout
1845 def repr_pieces(self):
1846 pieces = [("path", self.path), ("db", self.db)]
1847 if self.client_name:
1848 pieces.append(("client_name", self.client_name))
1849 return pieces
1851 def _connect(self):
1852 "Create a Unix domain socket connection"
1853 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1854 sock.settimeout(self.socket_connect_timeout)
1855 try:
1856 sock.connect(self.path)
1857 except OSError:
1858 # Prevent ResourceWarnings for unclosed sockets.
1859 try:
1860 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1861 except OSError:
1862 pass
1863 sock.close()
1864 raise
1865 sock.settimeout(self.socket_timeout)
1866 return sock
1868 def _host_error(self):
1869 return self.path
1872FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1875def to_bool(value):
1876 if value is None or value == "":
1877 return None
1878 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1879 return False
1880 return bool(value)
1883def parse_ssl_verify_flags(value):
1884 # flags are passed in as a string representation of a list,
1885 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1886 verify_flags_str = value.replace("[", "").replace("]", "")
1888 verify_flags = []
1889 for flag in verify_flags_str.split(","):
1890 flag = flag.strip()
1891 if not hasattr(VerifyFlags, flag):
1892 raise ValueError(f"Invalid ssl verify flag: {flag}")
1893 verify_flags.append(getattr(VerifyFlags, flag))
1894 return verify_flags
1897URL_QUERY_ARGUMENT_PARSERS = {
1898 "db": int,
1899 "socket_timeout": float,
1900 "socket_connect_timeout": float,
1901 "socket_keepalive": to_bool,
1902 "retry_on_timeout": to_bool,
1903 "retry_on_error": list,
1904 "max_connections": int,
1905 "health_check_interval": int,
1906 "ssl_check_hostname": to_bool,
1907 "ssl_include_verify_flags": parse_ssl_verify_flags,
1908 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1909 "timeout": float,
1910}
1913def parse_url(url):
1914 if not (
1915 url.startswith("redis://")
1916 or url.startswith("rediss://")
1917 or url.startswith("unix://")
1918 ):
1919 raise ValueError(
1920 "Redis URL must specify one of the following "
1921 "schemes (redis://, rediss://, unix://)"
1922 )
1924 url = urlparse(url)
1925 kwargs = {}
1927 for name, value in parse_qs(url.query).items():
1928 if value and len(value) > 0:
1929 value = unquote(value[0])
1930 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1931 if parser:
1932 try:
1933 kwargs[name] = parser(value)
1934 except (TypeError, ValueError):
1935 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1936 else:
1937 kwargs[name] = value
1939 if url.username:
1940 kwargs["username"] = unquote(url.username)
1941 if url.password:
1942 kwargs["password"] = unquote(url.password)
1944 # We only support redis://, rediss:// and unix:// schemes.
1945 if url.scheme == "unix":
1946 if url.path:
1947 kwargs["path"] = unquote(url.path)
1948 kwargs["connection_class"] = UnixDomainSocketConnection
1950 else: # implied: url.scheme in ("redis", "rediss"):
1951 if url.hostname:
1952 kwargs["host"] = unquote(url.hostname)
1953 if url.port:
1954 kwargs["port"] = int(url.port)
1956 # If there's a path argument, use it as the db argument if a
1957 # querystring value wasn't specified
1958 if url.path and "db" not in kwargs:
1959 try:
1960 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1961 except (AttributeError, ValueError):
1962 pass
1964 if url.scheme == "rediss":
1965 kwargs["connection_class"] = SSLConnection
1967 return kwargs
1970_CP = TypeVar("_CP", bound="ConnectionPool")
1973class ConnectionPoolInterface(ABC):
1974 @abstractmethod
1975 def get_protocol(self):
1976 pass
1978 @abstractmethod
1979 def reset(self):
1980 pass
1982 @abstractmethod
1983 @deprecated_args(
1984 args_to_warn=["*"],
1985 reason="Use get_connection() without args instead",
1986 version="5.3.0",
1987 )
1988 def get_connection(
1989 self, command_name: Optional[str], *keys, **options
1990 ) -> ConnectionInterface:
1991 pass
1993 @abstractmethod
1994 def get_encoder(self):
1995 pass
1997 @abstractmethod
1998 def release(self, connection: ConnectionInterface):
1999 pass
2001 @abstractmethod
2002 def disconnect(self, inuse_connections: bool = True):
2003 pass
2005 @abstractmethod
2006 def close(self):
2007 pass
2009 @abstractmethod
2010 def set_retry(self, retry: Retry):
2011 pass
2013 @abstractmethod
2014 def re_auth_callback(self, token: TokenInterface):
2015 pass
2018class MaintNotificationsAbstractConnectionPool:
2019 """
2020 Abstract class for handling maintenance notifications logic.
2021 This class is mixed into the ConnectionPool classes.
2023 This class is not intended to be used directly!
2025 All logic related to maintenance notifications and
2026 connection pool handling is encapsulated in this class.
2027 """
2029 def __init__(
2030 self,
2031 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2032 **kwargs,
2033 ):
2034 # Initialize maintenance notifications
2035 is_protocol_supported = kwargs.get("protocol") in [3, "3"]
2036 if maint_notifications_config is None and is_protocol_supported:
2037 maint_notifications_config = MaintNotificationsConfig()
2039 if maint_notifications_config and maint_notifications_config.enabled:
2040 if not is_protocol_supported:
2041 raise RedisError(
2042 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2043 )
2045 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2046 self, maint_notifications_config
2047 )
2049 self._update_connection_kwargs_for_maint_notifications(
2050 self._maint_notifications_pool_handler
2051 )
2052 else:
2053 self._maint_notifications_pool_handler = None
2055 @property
2056 @abstractmethod
2057 def connection_kwargs(self) -> Dict[str, Any]:
2058 pass
2060 @connection_kwargs.setter
2061 @abstractmethod
2062 def connection_kwargs(self, value: Dict[str, Any]):
2063 pass
2065 @abstractmethod
2066 def _get_pool_lock(self) -> threading.RLock:
2067 pass
2069 @abstractmethod
2070 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2071 pass
2073 @abstractmethod
2074 def _get_in_use_connections(
2075 self,
2076 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2077 pass
2079 def maint_notifications_enabled(self):
2080 """
2081 Returns:
2082 True if the maintenance notifications are enabled, False otherwise.
2083 The maintenance notifications config is stored in the pool handler.
2084 If the pool handler is not set, the maintenance notifications are not enabled.
2085 """
2086 maint_notifications_config = (
2087 self._maint_notifications_pool_handler.config
2088 if self._maint_notifications_pool_handler
2089 else None
2090 )
2092 return maint_notifications_config and maint_notifications_config.enabled
2094 def update_maint_notifications_config(
2095 self, maint_notifications_config: MaintNotificationsConfig
2096 ):
2097 """
2098 Updates the maintenance notifications configuration.
2099 This method should be called only if the pool was created
2100 without enabling the maintenance notifications and
2101 in a later point in time maintenance notifications
2102 are requested to be enabled.
2103 """
2104 if (
2105 self.maint_notifications_enabled()
2106 and not maint_notifications_config.enabled
2107 ):
2108 raise ValueError(
2109 "Cannot disable maintenance notifications after enabling them"
2110 )
2111 # first update pool settings
2112 if not self._maint_notifications_pool_handler:
2113 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2114 self, maint_notifications_config
2115 )
2116 else:
2117 self._maint_notifications_pool_handler.config = maint_notifications_config
2119 # then update connection kwargs and existing connections
2120 self._update_connection_kwargs_for_maint_notifications(
2121 self._maint_notifications_pool_handler
2122 )
2123 self._update_maint_notifications_configs_for_connections(
2124 self._maint_notifications_pool_handler
2125 )
2127 def _update_connection_kwargs_for_maint_notifications(
2128 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2129 ):
2130 """
2131 Update the connection kwargs for all future connections.
2132 """
2133 if not self.maint_notifications_enabled():
2134 return
2136 self.connection_kwargs.update(
2137 {
2138 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2139 "maint_notifications_config": maint_notifications_pool_handler.config,
2140 }
2141 )
2143 # Store original connection parameters for maintenance notifications.
2144 if self.connection_kwargs.get("orig_host_address", None) is None:
2145 # If orig_host_address is None it means we haven't
2146 # configured the original values yet
2147 self.connection_kwargs.update(
2148 {
2149 "orig_host_address": self.connection_kwargs.get("host"),
2150 "orig_socket_timeout": self.connection_kwargs.get(
2151 "socket_timeout", None
2152 ),
2153 "orig_socket_connect_timeout": self.connection_kwargs.get(
2154 "socket_connect_timeout", None
2155 ),
2156 }
2157 )
2159 def _update_maint_notifications_configs_for_connections(
2160 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2161 ):
2162 """Update the maintenance notifications config for all connections in the pool."""
2163 with self._get_pool_lock():
2164 for conn in self._get_free_connections():
2165 conn.set_maint_notifications_pool_handler_for_connection(
2166 maint_notifications_pool_handler
2167 )
2168 conn.maint_notifications_config = (
2169 maint_notifications_pool_handler.config
2170 )
2171 conn.disconnect()
2172 for conn in self._get_in_use_connections():
2173 conn.set_maint_notifications_pool_handler_for_connection(
2174 maint_notifications_pool_handler
2175 )
2176 conn.maint_notifications_config = (
2177 maint_notifications_pool_handler.config
2178 )
2179 conn.mark_for_reconnect()
2181 def _should_update_connection(
2182 self,
2183 conn: "MaintNotificationsAbstractConnection",
2184 matching_pattern: Literal[
2185 "connected_address", "configured_address", "notification_hash"
2186 ] = "connected_address",
2187 matching_address: Optional[str] = None,
2188 matching_notification_hash: Optional[int] = None,
2189 ) -> bool:
2190 """
2191 Check if the connection should be updated based on the matching criteria.
2192 """
2193 if matching_pattern == "connected_address":
2194 if matching_address and conn.getpeername() != matching_address:
2195 return False
2196 elif matching_pattern == "configured_address":
2197 if matching_address and conn.host != matching_address:
2198 return False
2199 elif matching_pattern == "notification_hash":
2200 if (
2201 matching_notification_hash
2202 and conn.maintenance_notification_hash != matching_notification_hash
2203 ):
2204 return False
2205 return True
2207 def update_connection_settings(
2208 self,
2209 conn: "MaintNotificationsAbstractConnection",
2210 state: Optional["MaintenanceState"] = None,
2211 maintenance_notification_hash: Optional[int] = None,
2212 host_address: Optional[str] = None,
2213 relaxed_timeout: Optional[float] = None,
2214 update_notification_hash: bool = False,
2215 reset_host_address: bool = False,
2216 reset_relaxed_timeout: bool = False,
2217 ):
2218 """
2219 Update the settings for a single connection.
2220 """
2221 if state:
2222 conn.maintenance_state = state
2224 if update_notification_hash:
2225 # update the notification hash only if requested
2226 conn.maintenance_notification_hash = maintenance_notification_hash
2228 if host_address is not None:
2229 conn.set_tmp_settings(tmp_host_address=host_address)
2231 if relaxed_timeout is not None:
2232 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2234 if reset_relaxed_timeout or reset_host_address:
2235 conn.reset_tmp_settings(
2236 reset_host_address=reset_host_address,
2237 reset_relaxed_timeout=reset_relaxed_timeout,
2238 )
2240 conn.update_current_socket_timeout(relaxed_timeout)
2242 def update_connections_settings(
2243 self,
2244 state: Optional["MaintenanceState"] = None,
2245 maintenance_notification_hash: Optional[int] = None,
2246 host_address: Optional[str] = None,
2247 relaxed_timeout: Optional[float] = None,
2248 matching_address: Optional[str] = None,
2249 matching_notification_hash: Optional[int] = None,
2250 matching_pattern: Literal[
2251 "connected_address", "configured_address", "notification_hash"
2252 ] = "connected_address",
2253 update_notification_hash: bool = False,
2254 reset_host_address: bool = False,
2255 reset_relaxed_timeout: bool = False,
2256 include_free_connections: bool = True,
2257 ):
2258 """
2259 Update the settings for all matching connections in the pool.
2261 This method does not create new connections.
2262 This method does not affect the connection kwargs.
2264 :param state: The maintenance state to set for the connection.
2265 :param maintenance_notification_hash: The hash of the maintenance notification
2266 to set for the connection.
2267 :param host_address: The host address to set for the connection.
2268 :param relaxed_timeout: The relaxed timeout to set for the connection.
2269 :param matching_address: The address to match for the connection.
2270 :param matching_notification_hash: The notification hash to match for the connection.
2271 :param matching_pattern: The pattern to match for the connection.
2272 :param update_notification_hash: Whether to update the notification hash for the connection.
2273 :param reset_host_address: Whether to reset the host address to the original address.
2274 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2275 :param include_free_connections: Whether to include free/available connections.
2276 """
2277 with self._get_pool_lock():
2278 for conn in self._get_in_use_connections():
2279 if self._should_update_connection(
2280 conn,
2281 matching_pattern,
2282 matching_address,
2283 matching_notification_hash,
2284 ):
2285 self.update_connection_settings(
2286 conn,
2287 state=state,
2288 maintenance_notification_hash=maintenance_notification_hash,
2289 host_address=host_address,
2290 relaxed_timeout=relaxed_timeout,
2291 update_notification_hash=update_notification_hash,
2292 reset_host_address=reset_host_address,
2293 reset_relaxed_timeout=reset_relaxed_timeout,
2294 )
2296 if include_free_connections:
2297 for conn in self._get_free_connections():
2298 if self._should_update_connection(
2299 conn,
2300 matching_pattern,
2301 matching_address,
2302 matching_notification_hash,
2303 ):
2304 self.update_connection_settings(
2305 conn,
2306 state=state,
2307 maintenance_notification_hash=maintenance_notification_hash,
2308 host_address=host_address,
2309 relaxed_timeout=relaxed_timeout,
2310 update_notification_hash=update_notification_hash,
2311 reset_host_address=reset_host_address,
2312 reset_relaxed_timeout=reset_relaxed_timeout,
2313 )
2315 def update_connection_kwargs(
2316 self,
2317 **kwargs,
2318 ):
2319 """
2320 Update the connection kwargs for all future connections.
2322 This method updates the connection kwargs for all future connections created by the pool.
2323 Existing connections are not affected.
2324 """
2325 self.connection_kwargs.update(kwargs)
2327 def update_active_connections_for_reconnect(
2328 self,
2329 moving_address_src: Optional[str] = None,
2330 ):
2331 """
2332 Mark all active connections for reconnect.
2333 This is used when a cluster node is migrated to a different address.
2335 :param moving_address_src: The address of the node that is being moved.
2336 """
2337 with self._get_pool_lock():
2338 for conn in self._get_in_use_connections():
2339 if self._should_update_connection(
2340 conn, "connected_address", moving_address_src
2341 ):
2342 conn.mark_for_reconnect()
2344 def disconnect_free_connections(
2345 self,
2346 moving_address_src: Optional[str] = None,
2347 ):
2348 """
2349 Disconnect all free/available connections.
2350 This is used when a cluster node is migrated to a different address.
2352 :param moving_address_src: The address of the node that is being moved.
2353 """
2354 with self._get_pool_lock():
2355 for conn in self._get_free_connections():
2356 if self._should_update_connection(
2357 conn, "connected_address", moving_address_src
2358 ):
2359 conn.disconnect()
2362class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2363 """
2364 Create a connection pool. ``If max_connections`` is set, then this
2365 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2366 limit is reached.
2368 By default, TCP connections are created unless ``connection_class``
2369 is specified. Use class:`.UnixDomainSocketConnection` for
2370 unix sockets.
2371 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2373 If ``maint_notifications_config`` is provided, the connection pool will support
2374 maintenance notifications.
2375 Maintenance notifications are supported only with RESP3.
2376 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2377 the maintenance notifications will be enabled by default.
2379 Any additional keyword arguments are passed to the constructor of
2380 ``connection_class``.
2381 """
2383 @classmethod
2384 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2385 """
2386 Return a connection pool configured from the given URL.
2388 For example::
2390 redis://[[username]:[password]]@localhost:6379/0
2391 rediss://[[username]:[password]]@localhost:6379/0
2392 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2394 Three URL schemes are supported:
2396 - `redis://` creates a TCP socket connection. See more at:
2397 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2398 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2399 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2400 - ``unix://``: creates a Unix Domain Socket connection.
2402 The username, password, hostname, path and all querystring values
2403 are passed through urllib.parse.unquote in order to replace any
2404 percent-encoded values with their corresponding characters.
2406 There are several ways to specify a database number. The first value
2407 found will be used:
2409 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2410 2. If using the redis:// or rediss:// schemes, the path argument
2411 of the url, e.g. redis://localhost/0
2412 3. A ``db`` keyword argument to this function.
2414 If none of these options are specified, the default db=0 is used.
2416 All querystring options are cast to their appropriate Python types.
2417 Boolean arguments can be specified with string values "True"/"False"
2418 or "Yes"/"No". Values that cannot be properly cast cause a
2419 ``ValueError`` to be raised. Once parsed, the querystring arguments
2420 and keyword arguments are passed to the ``ConnectionPool``'s
2421 class initializer. In the case of conflicting arguments, querystring
2422 arguments always win.
2423 """
2424 url_options = parse_url(url)
2426 if "connection_class" in kwargs:
2427 url_options["connection_class"] = kwargs["connection_class"]
2429 kwargs.update(url_options)
2430 return cls(**kwargs)
2432 def __init__(
2433 self,
2434 connection_class=Connection,
2435 max_connections: Optional[int] = None,
2436 cache_factory: Optional[CacheFactoryInterface] = None,
2437 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2438 **connection_kwargs,
2439 ):
2440 max_connections = max_connections or 2**31
2441 if not isinstance(max_connections, int) or max_connections < 0:
2442 raise ValueError('"max_connections" must be a positive integer')
2444 self.connection_class = connection_class
2445 self._connection_kwargs = connection_kwargs
2446 self.max_connections = max_connections
2447 self.cache = None
2448 self._cache_factory = cache_factory
2450 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2451 if self._connection_kwargs.get("protocol") not in [3, "3"]:
2452 raise RedisError("Client caching is only supported with RESP version 3")
2454 cache = self._connection_kwargs.get("cache")
2456 if cache is not None:
2457 if not isinstance(cache, CacheInterface):
2458 raise ValueError("Cache must implement CacheInterface")
2460 self.cache = cache
2461 else:
2462 if self._cache_factory is not None:
2463 self.cache = self._cache_factory.get_cache()
2464 else:
2465 self.cache = CacheFactory(
2466 self._connection_kwargs.get("cache_config")
2467 ).get_cache()
2469 connection_kwargs.pop("cache", None)
2470 connection_kwargs.pop("cache_config", None)
2472 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2473 if self._event_dispatcher is None:
2474 self._event_dispatcher = EventDispatcher()
2476 # a lock to protect the critical section in _checkpid().
2477 # this lock is acquired when the process id changes, such as
2478 # after a fork. during this time, multiple threads in the child
2479 # process could attempt to acquire this lock. the first thread
2480 # to acquire the lock will reset the data structures and lock
2481 # object of this pool. subsequent threads acquiring this lock
2482 # will notice the first thread already did the work and simply
2483 # release the lock.
2485 self._fork_lock = threading.RLock()
2486 self._lock = threading.RLock()
2488 MaintNotificationsAbstractConnectionPool.__init__(
2489 self,
2490 maint_notifications_config=maint_notifications_config,
2491 **connection_kwargs,
2492 )
2494 self.reset()
2496 def __repr__(self) -> str:
2497 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
2498 return (
2499 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2500 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2501 f"({conn_kwargs})>)>"
2502 )
2504 @property
2505 def connection_kwargs(self) -> Dict[str, Any]:
2506 return self._connection_kwargs
2508 @connection_kwargs.setter
2509 def connection_kwargs(self, value: Dict[str, Any]):
2510 self._connection_kwargs = value
2512 def get_protocol(self):
2513 """
2514 Returns:
2515 The RESP protocol version, or ``None`` if the protocol is not specified,
2516 in which case the server default will be used.
2517 """
2518 return self.connection_kwargs.get("protocol", None)
2520 def reset(self) -> None:
2521 self._created_connections = 0
2522 self._available_connections = []
2523 self._in_use_connections = set()
2525 # this must be the last operation in this method. while reset() is
2526 # called when holding _fork_lock, other threads in this process
2527 # can call _checkpid() which compares self.pid and os.getpid() without
2528 # holding any lock (for performance reasons). keeping this assignment
2529 # as the last operation ensures that those other threads will also
2530 # notice a pid difference and block waiting for the first thread to
2531 # release _fork_lock. when each of these threads eventually acquire
2532 # _fork_lock, they will notice that another thread already called
2533 # reset() and they will immediately release _fork_lock and continue on.
2534 self.pid = os.getpid()
2536 def _checkpid(self) -> None:
2537 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
2538 # systems. this is called by all ConnectionPool methods that
2539 # manipulate the pool's state such as get_connection() and release().
2540 #
2541 # _checkpid() determines whether the process has forked by comparing
2542 # the current process id to the process id saved on the ConnectionPool
2543 # instance. if these values are the same, _checkpid() simply returns.
2544 #
2545 # when the process ids differ, _checkpid() assumes that the process
2546 # has forked and that we're now running in the child process. the child
2547 # process cannot use the parent's file descriptors (e.g., sockets).
2548 # therefore, when _checkpid() sees the process id change, it calls
2549 # reset() in order to reinitialize the child's ConnectionPool. this
2550 # will cause the child to make all new connection objects.
2551 #
2552 # _checkpid() is protected by self._fork_lock to ensure that multiple
2553 # threads in the child process do not call reset() multiple times.
2554 #
2555 # there is an extremely small chance this could fail in the following
2556 # scenario:
2557 # 1. process A calls _checkpid() for the first time and acquires
2558 # self._fork_lock.
2559 # 2. while holding self._fork_lock, process A forks (the fork()
2560 # could happen in a different thread owned by process A)
2561 # 3. process B (the forked child process) inherits the
2562 # ConnectionPool's state from the parent. that state includes
2563 # a locked _fork_lock. process B will not be notified when
2564 # process A releases the _fork_lock and will thus never be
2565 # able to acquire the _fork_lock.
2566 #
2567 # to mitigate this possible deadlock, _checkpid() will only wait 5
2568 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
2569 # that time it is assumed that the child is deadlocked and a
2570 # redis.ChildDeadlockedError error is raised.
2571 if self.pid != os.getpid():
2572 acquired = self._fork_lock.acquire(timeout=5)
2573 if not acquired:
2574 raise ChildDeadlockedError
2575 # reset() the instance for the new process if another thread
2576 # hasn't already done so
2577 try:
2578 if self.pid != os.getpid():
2579 self.reset()
2580 finally:
2581 self._fork_lock.release()
2583 @deprecated_args(
2584 args_to_warn=["*"],
2585 reason="Use get_connection() without args instead",
2586 version="5.3.0",
2587 )
2588 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
2589 "Get a connection from the pool"
2591 self._checkpid()
2592 with self._lock:
2593 try:
2594 connection = self._available_connections.pop()
2595 except IndexError:
2596 connection = self.make_connection()
2597 self._in_use_connections.add(connection)
2599 try:
2600 # ensure this connection is connected to Redis
2601 connection.connect()
2602 # connections that the pool provides should be ready to send
2603 # a command. if not, the connection was either returned to the
2604 # pool before all data has been read or the socket has been
2605 # closed. either way, reconnect and verify everything is good.
2606 try:
2607 if (
2608 connection.can_read()
2609 and self.cache is None
2610 and not self.maint_notifications_enabled()
2611 ):
2612 raise ConnectionError("Connection has data")
2613 except (ConnectionError, TimeoutError, OSError):
2614 connection.disconnect()
2615 connection.connect()
2616 if connection.can_read():
2617 raise ConnectionError("Connection not ready")
2618 except BaseException:
2619 # release the connection back to the pool so that we don't
2620 # leak it
2621 self.release(connection)
2622 raise
2623 return connection
2625 def get_encoder(self) -> Encoder:
2626 "Return an encoder based on encoding settings"
2627 kwargs = self.connection_kwargs
2628 return Encoder(
2629 encoding=kwargs.get("encoding", "utf-8"),
2630 encoding_errors=kwargs.get("encoding_errors", "strict"),
2631 decode_responses=kwargs.get("decode_responses", False),
2632 )
2634 def make_connection(self) -> "ConnectionInterface":
2635 "Create a new connection"
2636 if self._created_connections >= self.max_connections:
2637 raise MaxConnectionsError("Too many connections")
2638 self._created_connections += 1
2640 kwargs = dict(self.connection_kwargs)
2642 if self.cache is not None:
2643 return CacheProxyConnection(
2644 self.connection_class(**kwargs), self.cache, self._lock
2645 )
2646 return self.connection_class(**kwargs)
2648 def release(self, connection: "Connection") -> None:
2649 "Releases the connection back to the pool"
2650 self._checkpid()
2651 with self._lock:
2652 try:
2653 self._in_use_connections.remove(connection)
2654 except KeyError:
2655 # Gracefully fail when a connection is returned to this pool
2656 # that the pool doesn't actually own
2657 return
2659 if self.owns_connection(connection):
2660 if connection.should_reconnect():
2661 connection.disconnect()
2662 self._available_connections.append(connection)
2663 self._event_dispatcher.dispatch(
2664 AfterConnectionReleasedEvent(connection)
2665 )
2666 else:
2667 # Pool doesn't own this connection, do not add it back
2668 # to the pool.
2669 # The created connections count should not be changed,
2670 # because the connection was not created by the pool.
2671 connection.disconnect()
2672 return
2674 def owns_connection(self, connection: "Connection") -> int:
2675 return connection.pid == self.pid
2677 def disconnect(self, inuse_connections: bool = True) -> None:
2678 """
2679 Disconnects connections in the pool
2681 If ``inuse_connections`` is True, disconnect connections that are
2682 currently in use, potentially by other threads. Otherwise only disconnect
2683 connections that are idle in the pool.
2684 """
2685 self._checkpid()
2686 with self._lock:
2687 if inuse_connections:
2688 connections = chain(
2689 self._available_connections, self._in_use_connections
2690 )
2691 else:
2692 connections = self._available_connections
2694 for connection in connections:
2695 connection.disconnect()
2697 def close(self) -> None:
2698 """Close the pool, disconnecting all connections"""
2699 self.disconnect()
2701 def set_retry(self, retry: Retry) -> None:
2702 self.connection_kwargs.update({"retry": retry})
2703 for conn in self._available_connections:
2704 conn.retry = retry
2705 for conn in self._in_use_connections:
2706 conn.retry = retry
2708 def re_auth_callback(self, token: TokenInterface):
2709 with self._lock:
2710 for conn in self._available_connections:
2711 conn.retry.call_with_retry(
2712 lambda: conn.send_command(
2713 "AUTH", token.try_get("oid"), token.get_value()
2714 ),
2715 lambda error: self._mock(error),
2716 )
2717 conn.retry.call_with_retry(
2718 lambda: conn.read_response(), lambda error: self._mock(error)
2719 )
2720 for conn in self._in_use_connections:
2721 conn.set_re_auth_token(token)
2723 def _get_pool_lock(self):
2724 return self._lock
2726 def _get_free_connections(self):
2727 with self._lock:
2728 return self._available_connections
2730 def _get_in_use_connections(self):
2731 with self._lock:
2732 return self._in_use_connections
2734 async def _mock(self, error: RedisError):
2735 """
2736 Dummy functions, needs to be passed as error callback to retry object.
2737 :param error:
2738 :return:
2739 """
2740 pass
2743class BlockingConnectionPool(ConnectionPool):
2744 """
2745 Thread-safe blocking connection pool::
2747 >>> from redis.client import Redis
2748 >>> client = Redis(connection_pool=BlockingConnectionPool())
2750 It performs the same function as the default
2751 :py:class:`~redis.ConnectionPool` implementation, in that,
2752 it maintains a pool of reusable connections that can be shared by
2753 multiple redis clients (safely across threads if required).
2755 The difference is that, in the event that a client tries to get a
2756 connection from the pool when all of connections are in use, rather than
2757 raising a :py:class:`~redis.ConnectionError` (as the default
2758 :py:class:`~redis.ConnectionPool` implementation does), it
2759 makes the client wait ("blocks") for a specified number of seconds until
2760 a connection becomes available.
2762 Use ``max_connections`` to increase / decrease the pool size::
2764 >>> pool = BlockingConnectionPool(max_connections=10)
2766 Use ``timeout`` to tell it either how many seconds to wait for a connection
2767 to become available, or to block forever:
2769 >>> # Block forever.
2770 >>> pool = BlockingConnectionPool(timeout=None)
2772 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
2773 >>> # not available.
2774 >>> pool = BlockingConnectionPool(timeout=5)
2775 """
2777 def __init__(
2778 self,
2779 max_connections=50,
2780 timeout=20,
2781 connection_class=Connection,
2782 queue_class=LifoQueue,
2783 **connection_kwargs,
2784 ):
2785 self.queue_class = queue_class
2786 self.timeout = timeout
2787 self._in_maintenance = False
2788 self._locked = False
2789 super().__init__(
2790 connection_class=connection_class,
2791 max_connections=max_connections,
2792 **connection_kwargs,
2793 )
2795 def reset(self):
2796 # Create and fill up a thread safe queue with ``None`` values.
2797 try:
2798 if self._in_maintenance:
2799 self._lock.acquire()
2800 self._locked = True
2801 self.pool = self.queue_class(self.max_connections)
2802 while True:
2803 try:
2804 self.pool.put_nowait(None)
2805 except Full:
2806 break
2808 # Keep a list of actual connection instances so that we can
2809 # disconnect them later.
2810 self._connections = []
2811 finally:
2812 if self._locked:
2813 try:
2814 self._lock.release()
2815 except Exception:
2816 pass
2817 self._locked = False
2819 # this must be the last operation in this method. while reset() is
2820 # called when holding _fork_lock, other threads in this process
2821 # can call _checkpid() which compares self.pid and os.getpid() without
2822 # holding any lock (for performance reasons). keeping this assignment
2823 # as the last operation ensures that those other threads will also
2824 # notice a pid difference and block waiting for the first thread to
2825 # release _fork_lock. when each of these threads eventually acquire
2826 # _fork_lock, they will notice that another thread already called
2827 # reset() and they will immediately release _fork_lock and continue on.
2828 self.pid = os.getpid()
2830 def make_connection(self):
2831 "Make a fresh connection."
2832 try:
2833 if self._in_maintenance:
2834 self._lock.acquire()
2835 self._locked = True
2837 if self.cache is not None:
2838 connection = CacheProxyConnection(
2839 self.connection_class(**self.connection_kwargs),
2840 self.cache,
2841 self._lock,
2842 )
2843 else:
2844 connection = self.connection_class(**self.connection_kwargs)
2845 self._connections.append(connection)
2846 return connection
2847 finally:
2848 if self._locked:
2849 try:
2850 self._lock.release()
2851 except Exception:
2852 pass
2853 self._locked = False
2855 @deprecated_args(
2856 args_to_warn=["*"],
2857 reason="Use get_connection() without args instead",
2858 version="5.3.0",
2859 )
2860 def get_connection(self, command_name=None, *keys, **options):
2861 """
2862 Get a connection, blocking for ``self.timeout`` until a connection
2863 is available from the pool.
2865 If the connection returned is ``None`` then creates a new connection.
2866 Because we use a last-in first-out queue, the existing connections
2867 (having been returned to the pool after the initial ``None`` values
2868 were added) will be returned before ``None`` values. This means we only
2869 create new connections when we need to, i.e.: the actual number of
2870 connections will only increase in response to demand.
2871 """
2872 # Make sure we haven't changed process.
2873 self._checkpid()
2875 # Try and get a connection from the pool. If one isn't available within
2876 # self.timeout then raise a ``ConnectionError``.
2877 connection = None
2878 try:
2879 if self._in_maintenance:
2880 self._lock.acquire()
2881 self._locked = True
2882 try:
2883 connection = self.pool.get(block=True, timeout=self.timeout)
2884 except Empty:
2885 # Note that this is not caught by the redis client and will be
2886 # raised unless handled by application code. If you want never to
2887 raise ConnectionError("No connection available.")
2889 # If the ``connection`` is actually ``None`` then that's a cue to make
2890 # a new connection to add to the pool.
2891 if connection is None:
2892 connection = self.make_connection()
2893 finally:
2894 if self._locked:
2895 try:
2896 self._lock.release()
2897 except Exception:
2898 pass
2899 self._locked = False
2901 try:
2902 # ensure this connection is connected to Redis
2903 connection.connect()
2904 # connections that the pool provides should be ready to send
2905 # a command. if not, the connection was either returned to the
2906 # pool before all data has been read or the socket has been
2907 # closed. either way, reconnect and verify everything is good.
2908 try:
2909 if connection.can_read():
2910 raise ConnectionError("Connection has data")
2911 except (ConnectionError, TimeoutError, OSError):
2912 connection.disconnect()
2913 connection.connect()
2914 if connection.can_read():
2915 raise ConnectionError("Connection not ready")
2916 except BaseException:
2917 # release the connection back to the pool so that we don't leak it
2918 self.release(connection)
2919 raise
2921 return connection
2923 def release(self, connection):
2924 "Releases the connection back to the pool."
2925 # Make sure we haven't changed process.
2926 self._checkpid()
2928 try:
2929 if self._in_maintenance:
2930 self._lock.acquire()
2931 self._locked = True
2932 if not self.owns_connection(connection):
2933 # pool doesn't own this connection. do not add it back
2934 # to the pool. instead add a None value which is a placeholder
2935 # that will cause the pool to recreate the connection if
2936 # its needed.
2937 connection.disconnect()
2938 self.pool.put_nowait(None)
2939 return
2940 if connection.should_reconnect():
2941 connection.disconnect()
2942 # Put the connection back into the pool.
2943 try:
2944 self.pool.put_nowait(connection)
2945 except Full:
2946 # perhaps the pool has been reset() after a fork? regardless,
2947 # we don't want this connection
2948 pass
2949 finally:
2950 if self._locked:
2951 try:
2952 self._lock.release()
2953 except Exception:
2954 pass
2955 self._locked = False
2957 def disconnect(self, inuse_connections: bool = True):
2958 "Disconnects either all connections in the pool or just the free connections."
2959 self._checkpid()
2960 try:
2961 if self._in_maintenance:
2962 self._lock.acquire()
2963 self._locked = True
2964 if inuse_connections:
2965 connections = self._connections
2966 else:
2967 connections = self._get_free_connections()
2968 for connection in connections:
2969 connection.disconnect()
2970 finally:
2971 if self._locked:
2972 try:
2973 self._lock.release()
2974 except Exception:
2975 pass
2976 self._locked = False
2978 def _get_free_connections(self):
2979 with self._lock:
2980 return {conn for conn in self.pool.queue if conn}
2982 def _get_in_use_connections(self):
2983 with self._lock:
2984 # free connections
2985 connections_in_queue = {conn for conn in self.pool.queue if conn}
2986 # in self._connections we keep all created connections
2987 # so the ones that are not in the queue are the in use ones
2988 return {
2989 conn for conn in self._connections if conn not in connections_in_queue
2990 }
2992 def set_in_maintenance(self, in_maintenance: bool):
2993 """
2994 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
2996 This is used to prevent new connections from being created while we are in maintenance mode.
2997 The pool will be in maintenance mode only when we are processing a MOVING notification.
2998 """
2999 self._in_maintenance = in_maintenance