Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32 CacheProxy,
33)
35from ._defaults import (
36 DEFAULT_SOCKET_CONNECT_TIMEOUT,
37 DEFAULT_SOCKET_READ_SIZE,
38 DEFAULT_SOCKET_TIMEOUT,
39 get_default_socket_keepalive_options,
40)
41from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
42from .auth.token import TokenInterface
43from .backoff import NoBackoff
44from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
45from .driver_info import DriverInfo, resolve_driver_info
46from .event import AfterConnectionReleasedEvent, EventDispatcher
47from .exceptions import (
48 AuthenticationError,
49 AuthenticationWrongNumberOfArgsError,
50 ChildDeadlockedError,
51 ConnectionError,
52 DataError,
53 MaxConnectionsError,
54 RedisError,
55 ResponseError,
56 TimeoutError,
57)
58from .maint_notifications import (
59 MaintenanceState,
60 MaintNotificationsConfig,
61 MaintNotificationsConnectionHandler,
62 MaintNotificationsPoolHandler,
63 OSSMaintNotificationsHandler,
64)
65from .observability.attributes import (
66 DB_CLIENT_CONNECTION_POOL_NAME,
67 DB_CLIENT_CONNECTION_STATE,
68 AttributeBuilder,
69 ConnectionState,
70 CSCReason,
71 CSCResult,
72 get_pool_name,
73)
74from .observability.metrics import CloseReason
75from .observability.recorder import (
76 init_csc_items,
77 record_connection_closed,
78 record_connection_count,
79 record_connection_create_time,
80 record_connection_wait_time,
81 record_csc_eviction,
82 record_csc_network_saved,
83 record_csc_request,
84 record_error_count,
85 register_csc_items_callback,
86)
87from .retry import Retry
88from .utils import (
89 CRYPTOGRAPHY_AVAILABLE,
90 DEFAULT_RESP_VERSION,
91 HIREDIS_AVAILABLE,
92 SENTINEL,
93 SSL_AVAILABLE,
94 check_protocol_version,
95 compare_versions,
96 deprecated_args,
97 ensure_string,
98 format_error_message,
99 str_if_bytes,
100)
102if SSL_AVAILABLE:
103 import ssl
104 from ssl import VerifyFlags
105else:
106 ssl = None
107 VerifyFlags = None
109if HIREDIS_AVAILABLE:
110 import hiredis
112SYM_STAR = b"*"
113SYM_DOLLAR = b"$"
114SYM_CRLF = b"\r\n"
115SYM_EMPTY = b""
117DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
118if HIREDIS_AVAILABLE:
119 DefaultParser = _HiredisParser
120else:
121 DefaultParser = _RESP2Parser
124class HiredisRespSerializer:
125 def pack(self, *args: List):
126 """Pack a series of arguments into the Redis protocol"""
127 output = []
129 if isinstance(args[0], str):
130 args = tuple(args[0].encode().split()) + args[1:]
131 elif b" " in args[0]:
132 args = tuple(args[0].split()) + args[1:]
133 args = tuple(
134 bytes(arg) if isinstance(arg, (bytearray, memoryview)) else arg
135 for arg in args
136 )
137 try:
138 output.append(hiredis.pack_command(args))
139 except TypeError:
140 _, value, traceback = sys.exc_info()
141 raise DataError(value).with_traceback(traceback)
143 return output
146class PythonRespSerializer:
147 def __init__(self, buffer_cutoff, encode) -> None:
148 self._buffer_cutoff = buffer_cutoff
149 self.encode = encode
151 def pack(self, *args):
152 """Pack a series of arguments into the Redis protocol"""
153 output = []
154 # the client might have included 1 or more literal arguments in
155 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
156 # arguments to be sent separately, so split the first argument
157 # manually. These arguments should be bytestrings so that they are
158 # not encoded.
159 if isinstance(args[0], str):
160 args = tuple(args[0].encode().split()) + args[1:]
161 elif b" " in args[0]:
162 args = tuple(args[0].split()) + args[1:]
164 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
166 buffer_cutoff = self._buffer_cutoff
167 for arg in map(self.encode, args):
168 # to avoid large string mallocs, chunk the command into the
169 # output list if we're sending large values or memoryviews
170 arg_length = len(arg)
171 if (
172 len(buff) > buffer_cutoff
173 or arg_length > buffer_cutoff
174 or isinstance(arg, memoryview)
175 ):
176 buff = SYM_EMPTY.join(
177 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
178 )
179 output.append(buff)
180 output.append(arg)
181 buff = SYM_CRLF
182 else:
183 buff = SYM_EMPTY.join(
184 (
185 buff,
186 SYM_DOLLAR,
187 str(arg_length).encode(),
188 SYM_CRLF,
189 arg,
190 SYM_CRLF,
191 )
192 )
193 output.append(buff)
194 return output
197class ConnectionInterface:
198 @abstractmethod
199 def repr_pieces(self):
200 pass
202 @abstractmethod
203 def register_connect_callback(self, callback):
204 pass
206 @abstractmethod
207 def deregister_connect_callback(self, callback):
208 pass
210 @abstractmethod
211 def set_parser(self, parser_class):
212 pass
214 @abstractmethod
215 def get_protocol(self):
216 pass
218 @abstractmethod
219 def connect(self):
220 pass
222 @abstractmethod
223 def on_connect(self):
224 pass
226 @abstractmethod
227 def disconnect(self, *args, **kwargs):
228 pass
230 @abstractmethod
231 def check_health(self):
232 pass
234 @abstractmethod
235 def send_packed_command(self, command, check_health=True):
236 pass
238 @abstractmethod
239 def send_command(self, *args, **kwargs):
240 pass
242 @abstractmethod
243 def can_read(self, timeout: float = 0) -> bool:
244 # TODO: Rename this API; it detects pending data or dirty/closed
245 # connection state, not only whether application data can be read.
246 pass
248 @abstractmethod
249 def read_response(
250 self,
251 disable_decoding=False,
252 *,
253 timeout: Union[float, object] = SENTINEL,
254 disconnect_on_error=True,
255 push_request=False,
256 ):
257 pass
259 @abstractmethod
260 def pack_command(self, *args):
261 pass
263 @abstractmethod
264 def pack_commands(self, commands):
265 pass
267 @property
268 @abstractmethod
269 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
270 pass
272 @abstractmethod
273 def set_re_auth_token(self, token: TokenInterface):
274 pass
276 @abstractmethod
277 def re_auth(self):
278 pass
280 @abstractmethod
281 def mark_for_reconnect(self):
282 """
283 Mark the connection to be reconnected on the next command.
284 This is useful when a connection is moved to a different node.
285 """
286 pass
288 @abstractmethod
289 def should_reconnect(self):
290 """
291 Returns True if the connection should be reconnected.
292 """
293 pass
295 @abstractmethod
296 def reset_should_reconnect(self):
297 """
298 Reset the internal flag to False.
299 """
300 pass
302 @abstractmethod
303 def extract_connection_details(self) -> str:
304 pass
306 @property
307 @abstractmethod
308 def is_connected(self) -> bool:
309 """
310 Return ``True`` if the connection to the server is active.
311 """
312 pass
315class MaintNotificationsAbstractConnection:
316 """
317 Abstract class for handling maintenance notifications logic.
318 This class is expected to be used as base class together with ConnectionInterface.
320 This class is intended to be used with multiple inheritance!
322 All logic related to maintenance notifications is encapsulated in this class.
323 """
325 def __init__(
326 self,
327 maint_notifications_config: Optional[MaintNotificationsConfig],
328 maint_notifications_pool_handler: Optional[
329 MaintNotificationsPoolHandler
330 ] = None,
331 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
332 maintenance_notification_hash: Optional[int] = None,
333 orig_host_address: Optional[str] = None,
334 orig_socket_timeout: Optional[float] = None,
335 orig_socket_connect_timeout: Optional[float] = None,
336 oss_cluster_maint_notifications_handler: Optional[
337 OSSMaintNotificationsHandler
338 ] = None,
339 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
340 event_dispatcher: Optional[EventDispatcher] = None,
341 ):
342 """
343 Initialize the maintenance notifications for the connection.
345 Args:
346 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
347 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
348 maintenance_state (MaintenanceState): The current maintenance state of the connection.
349 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
350 orig_host_address (Optional[str]): The original host address of the connection.
351 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
352 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
353 oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications.
354 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
355 If not provided, the parser from the connection is used.
356 This is useful when the parser is created after this object.
357 """
358 self.maint_notifications_config = maint_notifications_config
359 self.maintenance_state = maintenance_state
360 self.maintenance_notification_hash = maintenance_notification_hash
362 if event_dispatcher is not None:
363 self.event_dispatcher = event_dispatcher
364 else:
365 self.event_dispatcher = EventDispatcher()
367 self._configure_maintenance_notifications(
368 maint_notifications_pool_handler,
369 orig_host_address,
370 orig_socket_timeout,
371 orig_socket_connect_timeout,
372 oss_cluster_maint_notifications_handler,
373 parser,
374 )
375 self._processed_start_maint_notifications = set()
376 self._skipped_end_maint_notifications = set()
378 @abstractmethod
379 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
380 pass
382 @abstractmethod
383 def _get_socket(self) -> Optional[socket.socket]:
384 pass
386 @abstractmethod
387 def get_protocol(self) -> Union[int, str]:
388 """
389 Returns:
390 The RESP protocol version, or ``None`` if the protocol is not specified,
391 in which case the server default will be used.
392 """
393 pass
395 @property
396 @abstractmethod
397 def host(self) -> str:
398 pass
400 @host.setter
401 @abstractmethod
402 def host(self, value: str):
403 pass
405 @property
406 @abstractmethod
407 def socket_timeout(self) -> Optional[Union[float, int]]:
408 pass
410 @socket_timeout.setter
411 @abstractmethod
412 def socket_timeout(self, value: Optional[Union[float, int]]):
413 pass
415 @property
416 @abstractmethod
417 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
418 pass
420 @socket_connect_timeout.setter
421 @abstractmethod
422 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
423 pass
425 @abstractmethod
426 def send_command(self, *args, **kwargs):
427 pass
429 @abstractmethod
430 def read_response(
431 self,
432 disable_decoding=False,
433 *,
434 timeout: Union[float, object] = SENTINEL,
435 disconnect_on_error=True,
436 push_request=False,
437 ):
438 pass
440 @abstractmethod
441 def disconnect(self, *args, **kwargs):
442 pass
444 def _configure_maintenance_notifications(
445 self,
446 maint_notifications_pool_handler: Optional[
447 MaintNotificationsPoolHandler
448 ] = None,
449 orig_host_address=None,
450 orig_socket_timeout=None,
451 orig_socket_connect_timeout=None,
452 oss_cluster_maint_notifications_handler: Optional[
453 OSSMaintNotificationsHandler
454 ] = None,
455 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
456 ):
457 """
458 Enable maintenance notifications by setting up
459 handlers and storing original connection parameters.
461 Should be used ONLY with parsers that support push notifications.
462 """
463 if (
464 not self.maint_notifications_config
465 or not self.maint_notifications_config.enabled
466 ):
467 self._maint_notifications_pool_handler = None
468 self._maint_notifications_connection_handler = None
469 self._oss_cluster_maint_notifications_handler = None
470 return
472 if not parser:
473 raise RedisError(
474 "To configure maintenance notifications, a parser must be provided!"
475 )
477 if not isinstance(parser, _HiredisParser) and not isinstance(
478 parser, _RESP3Parser
479 ):
480 raise RedisError(
481 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
482 )
484 if maint_notifications_pool_handler:
485 # Extract a reference to a new pool handler that copies all properties
486 # of the original one and has a different connection reference
487 # This is needed because when we attach the handler to the parser
488 # we need to make sure that the handler has a reference to the
489 # connection that the parser is attached to.
490 self._maint_notifications_pool_handler = (
491 maint_notifications_pool_handler.get_handler_for_connection()
492 )
493 self._maint_notifications_pool_handler.set_connection(self)
494 else:
495 self._maint_notifications_pool_handler = None
497 self._maint_notifications_connection_handler = (
498 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
499 )
501 if oss_cluster_maint_notifications_handler:
502 self._oss_cluster_maint_notifications_handler = (
503 oss_cluster_maint_notifications_handler
504 )
505 else:
506 self._oss_cluster_maint_notifications_handler = None
508 # Set up OSS cluster handler to parser if available
509 if self._oss_cluster_maint_notifications_handler:
510 parser.set_oss_cluster_maint_push_handler(
511 self._oss_cluster_maint_notifications_handler.handle_notification
512 )
514 # Set up pool handler to parser if available
515 if self._maint_notifications_pool_handler:
516 parser.set_node_moving_push_handler(
517 self._maint_notifications_pool_handler.handle_notification
518 )
520 # Set up connection handler
521 parser.set_maintenance_push_handler(
522 self._maint_notifications_connection_handler.handle_notification
523 )
525 # Store original connection parameters
526 self.orig_host_address = orig_host_address if orig_host_address else self.host
527 self.orig_socket_timeout = (
528 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
529 )
530 self.orig_socket_connect_timeout = (
531 orig_socket_connect_timeout
532 if orig_socket_connect_timeout
533 else self.socket_connect_timeout
534 )
536 def set_maint_notifications_pool_handler_for_connection(
537 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
538 ):
539 # Deep copy the pool handler to avoid sharing the same pool handler
540 # between multiple connections, because otherwise each connection will override
541 # the connection reference and the pool handler will only hold a reference
542 # to the last connection that was set.
543 maint_notifications_pool_handler_copy = (
544 maint_notifications_pool_handler.get_handler_for_connection()
545 )
547 maint_notifications_pool_handler_copy.set_connection(self)
548 self._get_parser().set_node_moving_push_handler(
549 maint_notifications_pool_handler_copy.handle_notification
550 )
552 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
554 # Update maintenance notification connection handler if it doesn't exist
555 if not self._maint_notifications_connection_handler:
556 self._maint_notifications_connection_handler = (
557 MaintNotificationsConnectionHandler(
558 self, maint_notifications_pool_handler.config
559 )
560 )
561 self._get_parser().set_maintenance_push_handler(
562 self._maint_notifications_connection_handler.handle_notification
563 )
564 else:
565 self._maint_notifications_connection_handler.config = (
566 maint_notifications_pool_handler.config
567 )
569 def set_maint_notifications_cluster_handler_for_connection(
570 self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
571 ):
572 self._get_parser().set_oss_cluster_maint_push_handler(
573 oss_cluster_maint_notifications_handler.handle_notification
574 )
576 self._oss_cluster_maint_notifications_handler = (
577 oss_cluster_maint_notifications_handler
578 )
580 # Update maintenance notification connection handler if it doesn't exist
581 if not self._maint_notifications_connection_handler:
582 self._maint_notifications_connection_handler = (
583 MaintNotificationsConnectionHandler(
584 self, oss_cluster_maint_notifications_handler.config
585 )
586 )
587 self._get_parser().set_maintenance_push_handler(
588 self._maint_notifications_connection_handler.handle_notification
589 )
590 else:
591 self._maint_notifications_connection_handler.config = (
592 oss_cluster_maint_notifications_handler.config
593 )
595 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
596 # Send maintenance notifications handshake if RESP3 is active
597 # and maintenance notifications are enabled
598 # and we have a host to determine the endpoint type from
599 # When the maint_notifications_config enabled mode is "auto",
600 # we just log a warning if the handshake fails
601 # When the mode is enabled=True, we raise an exception in case of failure
602 host = getattr(self, "host", None)
603 if (
604 self.get_protocol() not in [2, "2"]
605 and self.maint_notifications_config
606 and self.maint_notifications_config.enabled
607 and self._maint_notifications_connection_handler
608 and host is not None
609 ):
610 self._enable_maintenance_notifications(
611 maint_notifications_config=self.maint_notifications_config,
612 check_health=check_health,
613 )
615 def _enable_maintenance_notifications(
616 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
617 ):
618 try:
619 host = getattr(self, "host", None)
620 if host is None:
621 raise ValueError(
622 "Cannot enable maintenance notifications for connection"
623 " object that doesn't have a host attribute."
624 )
625 else:
626 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
627 self.send_command(
628 "CLIENT",
629 "MAINT_NOTIFICATIONS",
630 "ON",
631 "moving-endpoint-type",
632 endpoint_type.value,
633 check_health=check_health,
634 )
635 response = self.read_response()
636 if not response or str_if_bytes(response) != "OK":
637 raise ResponseError(
638 "The server doesn't support maintenance notifications"
639 )
640 except Exception as e:
641 if (
642 isinstance(e, ResponseError)
643 and maint_notifications_config.enabled == "auto"
644 ):
645 # Log warning but don't fail the connection
646 import logging
648 logger = logging.getLogger(__name__)
649 logger.debug(f"Failed to enable maintenance notifications: {e}")
650 else:
651 raise
653 def get_resolved_ip(self) -> Optional[str]:
654 """
655 Extract the resolved IP address from an
656 established connection or resolve it from the host.
658 First tries to get the actual IP from the socket (most accurate),
659 then falls back to DNS resolution if needed.
661 Args:
662 connection: The connection object to extract the IP from
664 Returns:
665 str: The resolved IP address, or None if it cannot be determined
666 """
668 # Method 1: Try to get the actual IP from the established socket connection
669 # This is most accurate as it shows the exact IP being used
670 try:
671 conn_socket = self._get_socket()
672 if conn_socket is not None:
673 peer_addr = conn_socket.getpeername()
674 if peer_addr and len(peer_addr) >= 1:
675 # For TCP sockets, peer_addr is typically (host, port) tuple
676 # Return just the host part
677 return peer_addr[0]
678 except (AttributeError, OSError):
679 # Socket might not be connected or getpeername() might fail
680 pass
682 # Method 2: Fallback to DNS resolution of the host
683 # This is less accurate but works when socket is not available
684 try:
685 host = getattr(self, "host", "localhost")
686 port = getattr(self, "port", 6379)
687 if host:
688 # Use getaddrinfo to resolve the hostname to IP
689 # This mimics what the connection would do during _connect()
690 addr_info = socket.getaddrinfo(
691 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
692 )
693 if addr_info:
694 # Return the IP from the first result
695 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
696 # sockaddr[0] is the IP address
697 return str(addr_info[0][4][0])
698 except (AttributeError, OSError, socket.gaierror):
699 # DNS resolution might fail
700 pass
702 return None
704 @property
705 def maintenance_state(self) -> MaintenanceState:
706 return self._maintenance_state
708 @maintenance_state.setter
709 def maintenance_state(self, state: "MaintenanceState"):
710 self._maintenance_state = state
712 def add_maint_start_notification(self, id: int):
713 self._processed_start_maint_notifications.add(id)
715 def get_processed_start_notifications(self) -> set:
716 return self._processed_start_maint_notifications
718 def add_skipped_end_notification(self, id: int):
719 self._skipped_end_maint_notifications.add(id)
721 def get_skipped_end_notifications(self) -> set:
722 return self._skipped_end_maint_notifications
724 def reset_received_notifications(self):
725 self._processed_start_maint_notifications.clear()
726 self._skipped_end_maint_notifications.clear()
728 def getpeername(self):
729 """
730 Returns the peer name of the connection.
731 """
732 conn_socket = self._get_socket()
733 if conn_socket:
734 return conn_socket.getpeername()[0]
735 return None
737 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
738 conn_socket = self._get_socket()
739 if conn_socket:
740 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
741 # if the current timeout is 0 it means we are in the middle of a can_read call
742 # in this case we don't want to change the timeout because the operation
743 # is non-blocking and should return immediately
744 # Changing the state from non-blocking to blocking in the middle of a read operation
745 # will lead to a deadlock
746 if conn_socket.gettimeout() != 0:
747 conn_socket.settimeout(timeout)
748 self.update_parser_timeout(timeout)
750 def update_parser_timeout(self, timeout: Optional[float] = None):
751 parser = self._get_parser()
752 if parser and parser._buffer:
753 if isinstance(parser, _RESP3Parser) and timeout:
754 parser._buffer.socket_timeout = timeout
755 elif isinstance(parser, _HiredisParser):
756 parser._socket_timeout = timeout
758 def set_tmp_settings(
759 self,
760 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
761 tmp_relaxed_timeout: Optional[float] = None,
762 ):
763 """
764 The value of SENTINEL is used to indicate that the property should not be updated.
765 """
766 if tmp_host_address and tmp_host_address != SENTINEL:
767 self.host = str(tmp_host_address)
768 if tmp_relaxed_timeout != -1:
769 self.socket_timeout = tmp_relaxed_timeout
770 self.socket_connect_timeout = tmp_relaxed_timeout
772 def reset_tmp_settings(
773 self,
774 reset_host_address: bool = False,
775 reset_relaxed_timeout: bool = False,
776 ):
777 if reset_host_address:
778 self.host = self.orig_host_address
779 if reset_relaxed_timeout:
780 self.socket_timeout = self.orig_socket_timeout
781 self.socket_connect_timeout = self.orig_socket_connect_timeout
784class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
785 "Manages communication to and from a Redis server"
787 @deprecated_args(
788 args_to_warn=["lib_name", "lib_version"],
789 reason="Use 'driver_info' parameter instead. "
790 "lib_name and lib_version will be removed in a future version.",
791 )
792 def __init__(
793 self,
794 db: int = 0,
795 password: Optional[str] = None,
796 socket_timeout: Optional[float] = DEFAULT_SOCKET_TIMEOUT,
797 socket_connect_timeout: Optional[float] = DEFAULT_SOCKET_CONNECT_TIMEOUT,
798 retry_on_timeout: bool = False,
799 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
800 encoding: str = "utf-8",
801 encoding_errors: str = "strict",
802 decode_responses: bool = False,
803 parser_class=DefaultParser,
804 socket_read_size: int = DEFAULT_SOCKET_READ_SIZE,
805 health_check_interval: int = 0,
806 client_name: Optional[str] = None,
807 lib_name: Union[Optional[str], object] = SENTINEL,
808 lib_version: Union[Optional[str], object] = SENTINEL,
809 driver_info: Union[Optional[DriverInfo], object] = SENTINEL,
810 username: Optional[str] = None,
811 retry: Union[Any, None] = None,
812 redis_connect_func: Optional[Callable[[], None]] = None,
813 credential_provider: Optional[CredentialProvider] = None,
814 protocol: Optional[int] = None,
815 legacy_responses: bool = True,
816 command_packer: Optional[Callable[[], None]] = None,
817 event_dispatcher: Optional[EventDispatcher] = None,
818 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
819 maint_notifications_pool_handler: Optional[
820 MaintNotificationsPoolHandler
821 ] = None,
822 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
823 maintenance_notification_hash: Optional[int] = None,
824 orig_host_address: Optional[str] = None,
825 orig_socket_timeout: Optional[float] = None,
826 orig_socket_connect_timeout: Optional[float] = None,
827 oss_cluster_maint_notifications_handler: Optional[
828 OSSMaintNotificationsHandler
829 ] = None,
830 ):
831 """
832 Initialize a new Connection.
834 To specify a retry policy for specific errors, first set
835 `retry_on_error` to a list of the error/s to retry on, then set
836 `retry` to a valid `Retry` object.
837 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
839 Parameters
840 ----------
841 driver_info : DriverInfo, optional
842 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
843 are ignored. If not provided, a DriverInfo will be created from lib_name
844 and lib_version. Explicit None disables CLIENT SETINFO.
845 lib_name : str, optional
846 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
847 lib_version : str, optional
848 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
849 """
850 if (username or password) and credential_provider is not None:
851 raise DataError(
852 "'username' and 'password' cannot be passed along with 'credential_"
853 "provider'. Please provide only one of the following arguments: \n"
854 "1. 'password' and (optional) 'username'\n"
855 "2. 'credential_provider'"
856 )
857 if event_dispatcher is None:
858 self._event_dispatcher = EventDispatcher()
859 else:
860 self._event_dispatcher = event_dispatcher
861 self.pid = os.getpid()
862 self.db = db
863 self.client_name = client_name
865 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version.
866 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
868 self.credential_provider = credential_provider
869 self.password = password
870 self.username = username
871 self._socket_timeout = socket_timeout
872 if socket_connect_timeout is None:
873 socket_connect_timeout = socket_timeout
874 self._socket_connect_timeout = socket_connect_timeout
875 self.retry_on_timeout = retry_on_timeout
876 if retry_on_error is SENTINEL:
877 retry_on_errors_list = []
878 else:
879 retry_on_errors_list = list(retry_on_error)
880 if retry_on_timeout:
881 # Add TimeoutError to the errors list to retry on
882 retry_on_errors_list.append(TimeoutError)
883 self.retry_on_error = retry_on_errors_list
884 if retry or self.retry_on_error:
885 if retry is None:
886 self.retry = Retry(NoBackoff(), 1)
887 else:
888 # deep-copy the Retry object as it is mutable
889 self.retry = copy.deepcopy(retry)
890 if self.retry_on_error:
891 # Update the retry's supported errors with the specified errors
892 self.retry.update_supported_errors(self.retry_on_error)
893 else:
894 self.retry = Retry(NoBackoff(), 0)
895 self.health_check_interval = health_check_interval
896 self.next_health_check = 0
897 self.redis_connect_func = redis_connect_func
898 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
899 self.handshake_metadata = None
900 self._sock = None
901 self._socket_read_size = socket_read_size
902 self._connect_callbacks = []
903 self._buffer_cutoff = 6000
904 self._re_auth_token: Optional[TokenInterface] = None
905 try:
906 p = int(protocol)
907 except TypeError:
908 p = DEFAULT_RESP_VERSION
909 except ValueError:
910 raise ConnectionError("protocol must be an integer")
911 else:
912 if p < 2 or p > 3:
913 raise ConnectionError("protocol must be either 2 or 3")
914 self.protocol = p
915 self.legacy_responses = legacy_responses
916 if self.protocol == 3 and parser_class == _RESP2Parser:
917 # If the protocol is 3 but the parser is RESP2, change it to RESP3
918 # This is needed because the parser might be set before the protocol
919 # or might be provided as a kwarg to the constructor
920 # We need to react on discrepancy only for RESP2 and RESP3
921 # as hiredis supports both
922 parser_class = _RESP3Parser
923 self.set_parser(parser_class)
925 self._command_packer = self._construct_command_packer(command_packer)
926 self._should_reconnect = False
928 # Set up maintenance notifications
929 MaintNotificationsAbstractConnection.__init__(
930 self,
931 maint_notifications_config,
932 maint_notifications_pool_handler,
933 maintenance_state,
934 maintenance_notification_hash,
935 orig_host_address,
936 orig_socket_timeout,
937 orig_socket_connect_timeout,
938 oss_cluster_maint_notifications_handler,
939 self._parser,
940 event_dispatcher=self._event_dispatcher,
941 )
943 def __repr__(self):
944 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
945 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
947 @abstractmethod
948 def repr_pieces(self):
949 pass
951 def __del__(self):
952 try:
953 self.disconnect()
954 except Exception:
955 pass
957 @property
958 def is_connected(self) -> bool:
959 return self._sock is not None
961 def _construct_command_packer(self, packer):
962 if packer is not None:
963 return packer
964 elif HIREDIS_AVAILABLE:
965 return HiredisRespSerializer()
966 else:
967 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
969 def register_connect_callback(self, callback):
970 """
971 Register a callback to be called when the connection is established either
972 initially or reconnected. This allows listeners to issue commands that
973 are ephemeral to the connection, for example pub/sub subscription or
974 key tracking. The callback must be a _method_ and will be kept as
975 a weak reference.
976 """
977 wm = weakref.WeakMethod(callback)
978 if wm not in self._connect_callbacks:
979 self._connect_callbacks.append(wm)
981 def deregister_connect_callback(self, callback):
982 """
983 De-register a previously registered callback. It will no-longer receive
984 notifications on connection events. Calling this is not required when the
985 listener goes away, since the callbacks are kept as weak methods.
986 """
987 try:
988 self._connect_callbacks.remove(weakref.WeakMethod(callback))
989 except ValueError:
990 pass
992 def set_parser(self, parser_class):
993 """
994 Creates a new instance of parser_class with socket size:
995 _socket_read_size and assigns it to the parser for the connection
996 :param parser_class: The required parser class
997 """
998 self._parser = parser_class(socket_read_size=self._socket_read_size)
1000 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
1001 return self._parser
1003 def connect(self):
1004 "Connects to the Redis server if not already connected"
1005 # try once the socket connect with the handshake, retry the whole
1006 # connect/handshake flow based on retry policy
1007 self.retry.call_with_retry(
1008 lambda: self.connect_check_health(
1009 check_health=True, retry_socket_connect=False
1010 ),
1011 lambda error: self.disconnect(error),
1012 )
1014 def connect_check_health(
1015 self, check_health: bool = True, retry_socket_connect: bool = True
1016 ):
1017 if self._sock:
1018 return
1019 # Track actual retry attempts for error reporting
1020 actual_retry_attempts = [0]
1022 def failure_callback(error, failure_count):
1023 actual_retry_attempts[0] = failure_count
1024 self.disconnect(error=error, failure_count=failure_count)
1026 try:
1027 if retry_socket_connect:
1028 sock = self.retry.call_with_retry(
1029 self._connect,
1030 failure_callback,
1031 with_failure_count=True,
1032 )
1033 else:
1034 sock = self._connect()
1035 except socket.timeout:
1036 e = TimeoutError("Timeout connecting to server")
1037 record_error_count(
1038 server_address=self.host,
1039 server_port=self.port,
1040 network_peer_address=self.host,
1041 network_peer_port=self.port,
1042 error_type=e,
1043 retry_attempts=actual_retry_attempts[0],
1044 )
1045 raise e
1046 except OSError as e:
1047 e = ConnectionError(self._error_message(e))
1048 record_error_count(
1049 server_address=getattr(self, "host", None),
1050 server_port=getattr(self, "port", None),
1051 network_peer_address=getattr(self, "host", None),
1052 network_peer_port=getattr(self, "port", None),
1053 error_type=e,
1054 retry_attempts=actual_retry_attempts[0],
1055 )
1056 raise e
1058 self._sock = sock
1059 try:
1060 if self.redis_connect_func is None:
1061 # Use the default on_connect function
1062 self.on_connect_check_health(check_health=check_health)
1063 else:
1064 # Use the passed function redis_connect_func
1065 self.redis_connect_func(self)
1066 except RedisError:
1067 # clean up after any error in on_connect
1068 self.disconnect()
1069 raise
1071 # run any user callbacks. right now the only internal callback
1072 # is for pubsub channel/pattern resubscription
1073 # first, remove any dead weakrefs
1074 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
1075 for ref in self._connect_callbacks:
1076 callback = ref()
1077 if callback:
1078 callback(self)
1080 @abstractmethod
1081 def _connect(self):
1082 pass
1084 @abstractmethod
1085 def _host_error(self):
1086 pass
1088 def _error_message(self, exception):
1089 return format_error_message(self._host_error(), exception)
1091 def on_connect(self):
1092 self.on_connect_check_health(check_health=True)
1094 def on_connect_check_health(self, check_health: bool = True):
1095 "Initialize the connection, authenticate and select a database"
1096 self._parser.on_connect(self)
1097 parser = self._parser
1099 auth_args = None
1100 # if credential provider or username and/or password are set, authenticate
1101 if self.credential_provider or (self.username or self.password):
1102 cred_provider = (
1103 self.credential_provider
1104 or UsernamePasswordCredentialProvider(self.username, self.password)
1105 )
1106 auth_args = cred_provider.get_credentials()
1108 # if resp version is specified and we have auth args,
1109 # we need to send them via HELLO
1110 if auth_args and self.protocol not in [2, "2"]:
1111 if isinstance(self._parser, _RESP2Parser):
1112 self.set_parser(_RESP3Parser)
1113 # update cluster exception classes
1114 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1115 self._parser.on_connect(self)
1116 if len(auth_args) == 1:
1117 auth_args = ["default", auth_args[0]]
1118 # avoid checking health here -- PING will fail if we try
1119 # to check the health prior to the AUTH
1120 self.send_command(
1121 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
1122 )
1123 self.handshake_metadata = self.read_response()
1124 # if response.get(b"proto") != self.protocol and response.get(
1125 # "proto"
1126 # ) != self.protocol:
1127 # raise ConnectionError("Invalid RESP version")
1128 elif auth_args:
1129 # avoid checking health here -- PING will fail if we try
1130 # to check the health prior to the AUTH
1131 self.send_command("AUTH", *auth_args, check_health=False)
1133 try:
1134 auth_response = self.read_response()
1135 except AuthenticationWrongNumberOfArgsError:
1136 # a username and password were specified but the Redis
1137 # server seems to be < 6.0.0 which expects a single password
1138 # arg. retry auth with just the password.
1139 # https://github.com/andymccurdy/redis-py/issues/1274
1140 self.send_command("AUTH", auth_args[-1], check_health=False)
1141 auth_response = self.read_response()
1143 if str_if_bytes(auth_response) != "OK":
1144 raise AuthenticationError("Invalid Username or Password")
1146 # if resp version is specified, switch to it
1147 elif self.protocol not in [2, "2"]:
1148 if isinstance(self._parser, _RESP2Parser):
1149 self.set_parser(_RESP3Parser)
1150 # update cluster exception classes
1151 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1152 self._parser.on_connect(self)
1153 self.send_command("HELLO", self.protocol, check_health=check_health)
1154 self.handshake_metadata = self.read_response()
1155 if (
1156 self.handshake_metadata.get(b"proto") != self.protocol
1157 and self.handshake_metadata.get("proto") != self.protocol
1158 ):
1159 raise ConnectionError("Invalid RESP version")
1161 # Activate maintenance notifications for this connection
1162 # if enabled in the configuration
1163 # This is a no-op if maintenance notifications are not enabled
1164 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
1166 # if a client_name is given, set it
1167 if self.client_name:
1168 self.send_command(
1169 "CLIENT",
1170 "SETNAME",
1171 self.client_name,
1172 check_health=check_health,
1173 )
1174 if str_if_bytes(self.read_response()) != "OK":
1175 raise ConnectionError("Error setting client name")
1177 # Set the library name and version from driver_info
1178 try:
1179 if self.driver_info and self.driver_info.formatted_name:
1180 self.send_command(
1181 "CLIENT",
1182 "SETINFO",
1183 "LIB-NAME",
1184 self.driver_info.formatted_name,
1185 check_health=check_health,
1186 )
1187 self.read_response()
1188 except ResponseError:
1189 pass
1191 try:
1192 if self.driver_info and self.driver_info.lib_version:
1193 self.send_command(
1194 "CLIENT",
1195 "SETINFO",
1196 "LIB-VER",
1197 self.driver_info.lib_version,
1198 check_health=check_health,
1199 )
1200 self.read_response()
1201 except ResponseError:
1202 pass
1204 # if a database is specified, switch to it
1205 if self.db:
1206 self.send_command("SELECT", self.db, check_health=check_health)
1207 if str_if_bytes(self.read_response()) != "OK":
1208 raise ConnectionError("Invalid Database")
1210 def disconnect(self, *args, **kwargs):
1211 "Disconnects from the Redis server"
1212 self._parser.on_disconnect()
1214 conn_sock = self._sock
1215 self._sock = None
1216 # reset the reconnect flag
1217 self.reset_should_reconnect()
1219 if conn_sock is None:
1220 return
1222 if os.getpid() == self.pid:
1223 try:
1224 conn_sock.shutdown(socket.SHUT_RDWR)
1225 except (OSError, TypeError):
1226 pass
1228 try:
1229 conn_sock.close()
1230 except OSError:
1231 pass
1233 error = kwargs.get("error")
1234 failure_count = kwargs.get("failure_count")
1235 health_check_failed = kwargs.get("health_check_failed")
1237 if error:
1238 if health_check_failed:
1239 close_reason = CloseReason.HEALTHCHECK_FAILED
1240 else:
1241 close_reason = CloseReason.ERROR
1243 if failure_count is not None and failure_count > self.retry.get_retries():
1244 record_error_count(
1245 server_address=self.host,
1246 server_port=self.port,
1247 network_peer_address=self.host,
1248 network_peer_port=self.port,
1249 error_type=error,
1250 retry_attempts=failure_count,
1251 )
1253 record_connection_closed(
1254 close_reason=close_reason,
1255 error_type=error,
1256 )
1257 else:
1258 record_connection_closed(
1259 close_reason=CloseReason.APPLICATION_CLOSE,
1260 )
1262 if self.maintenance_state == MaintenanceState.MAINTENANCE:
1263 # this block will be executed only if the connection was in maintenance state
1264 # and the connection was closed.
1265 # The state change won't be applied on connections that are in Moving state
1266 # because their state and configurations will be handled when the moving ttl expires.
1267 self.reset_tmp_settings(reset_relaxed_timeout=True)
1268 self.maintenance_state = MaintenanceState.NONE
1269 # reset the sets that keep track of received start maint
1270 # notifications and skipped end maint notifications
1271 self.reset_received_notifications()
1273 def mark_for_reconnect(self):
1274 self._should_reconnect = True
1276 def should_reconnect(self):
1277 return self._should_reconnect
1279 def reset_should_reconnect(self):
1280 self._should_reconnect = False
1282 def _send_ping(self):
1283 """Send PING, expect PONG in return"""
1284 self.send_command("PING", check_health=False)
1285 if str_if_bytes(self.read_response()) != "PONG":
1286 raise ConnectionError("Bad response from PING health check")
1288 def _ping_failed(self, error, failure_count):
1289 """Function to call when PING fails"""
1290 self.disconnect(
1291 error=error, failure_count=failure_count, health_check_failed=True
1292 )
1294 def check_health(self):
1295 """Check the health of the connection with a PING/PONG"""
1296 if self.health_check_interval and time.monotonic() > self.next_health_check:
1297 self.retry.call_with_retry(
1298 self._send_ping,
1299 self._ping_failed,
1300 with_failure_count=True,
1301 )
1303 def send_packed_command(self, command, check_health=True):
1304 """Send an already packed command to the Redis server"""
1305 if not self._sock:
1306 self.connect_check_health(check_health=False)
1307 # guard against health check recursion
1308 if check_health:
1309 self.check_health()
1310 try:
1311 if isinstance(command, str):
1312 command = [command]
1313 for item in command:
1314 self._sock.sendall(item)
1315 except socket.timeout:
1316 self.disconnect()
1317 raise TimeoutError("Timeout writing to socket")
1318 except OSError as e:
1319 self.disconnect()
1320 if len(e.args) == 1:
1321 errno, errmsg = "UNKNOWN", e.args[0]
1322 else:
1323 errno = e.args[0]
1324 errmsg = e.args[1]
1325 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1326 except BaseException:
1327 # BaseExceptions can be raised when a socket send operation is not
1328 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1329 # to send un-sent data. However, the send_packed_command() API
1330 # does not support it so there is no point in keeping the connection open.
1331 self.disconnect()
1332 raise
1334 def send_command(self, *args, **kwargs):
1335 """Pack and send a command to the Redis server"""
1336 self.send_packed_command(
1337 self._command_packer.pack(*args),
1338 check_health=kwargs.get("check_health", True),
1339 )
1341 def can_read(self, timeout: float = 0) -> bool:
1342 """Poll the socket to see if there's data that can be read."""
1343 # TODO: Rename this API; it detects pending data or dirty/closed
1344 # connection state, not only whether application data can be read.
1345 sock = self._sock
1346 if not sock:
1347 self.connect()
1349 host_error = self._host_error()
1351 try:
1352 return self._parser.can_read(timeout)
1354 except OSError as e:
1355 self.disconnect()
1356 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1358 def read_response(
1359 self,
1360 disable_decoding=False,
1361 *,
1362 timeout: Union[float, object] = SENTINEL,
1363 disconnect_on_error=True,
1364 push_request=False,
1365 ):
1366 """Read the response from a previously sent command"""
1368 host_error = self._host_error()
1370 try:
1371 if self.protocol in ["3", 3]:
1372 response = self._parser.read_response(
1373 disable_decoding=disable_decoding,
1374 push_request=push_request,
1375 timeout=timeout,
1376 )
1377 else:
1378 response = self._parser.read_response(
1379 disable_decoding=disable_decoding, timeout=timeout
1380 )
1381 except socket.timeout:
1382 if disconnect_on_error:
1383 self.disconnect()
1384 raise TimeoutError(f"Timeout reading from {host_error}")
1385 except OSError as e:
1386 if disconnect_on_error:
1387 self.disconnect()
1388 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1389 except BaseException:
1390 # Also by default close in case of BaseException. A lot of code
1391 # relies on this behaviour when doing Command/Response pairs.
1392 # See #1128.
1393 if disconnect_on_error:
1394 self.disconnect()
1395 raise
1397 if self.health_check_interval:
1398 self.next_health_check = time.monotonic() + self.health_check_interval
1400 if isinstance(response, ResponseError):
1401 try:
1402 raise response
1403 finally:
1404 del response # avoid creating ref cycles
1405 return response
1407 def pack_command(self, *args):
1408 """Pack a series of arguments into the Redis protocol"""
1409 return self._command_packer.pack(*args)
1411 def pack_commands(self, commands):
1412 """Pack multiple commands into the Redis protocol"""
1413 output = []
1414 pieces = []
1415 buffer_length = 0
1416 buffer_cutoff = self._buffer_cutoff
1418 for cmd in commands:
1419 for chunk in self._command_packer.pack(*cmd):
1420 chunklen = len(chunk)
1421 if (
1422 buffer_length > buffer_cutoff
1423 or chunklen > buffer_cutoff
1424 or isinstance(chunk, memoryview)
1425 ):
1426 if pieces:
1427 output.append(SYM_EMPTY.join(pieces))
1428 buffer_length = 0
1429 pieces = []
1431 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1432 output.append(chunk)
1433 else:
1434 pieces.append(chunk)
1435 buffer_length += chunklen
1437 if pieces:
1438 output.append(SYM_EMPTY.join(pieces))
1439 return output
1441 def get_protocol(self) -> Union[int, str]:
1442 return self.protocol
1444 @property
1445 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1446 return self._handshake_metadata
1448 @handshake_metadata.setter
1449 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1450 self._handshake_metadata = value
1452 def set_re_auth_token(self, token: TokenInterface):
1453 self._re_auth_token = token
1455 def re_auth(self):
1456 if self._re_auth_token is not None:
1457 self.send_command(
1458 "AUTH",
1459 self._re_auth_token.try_get("oid"),
1460 self._re_auth_token.get_value(),
1461 )
1462 self.read_response()
1463 self._re_auth_token = None
1465 def _get_socket(self) -> Optional[socket.socket]:
1466 return self._sock
1468 @property
1469 def socket_timeout(self) -> Optional[Union[float, int]]:
1470 return self._socket_timeout
1472 @socket_timeout.setter
1473 def socket_timeout(self, value: Optional[Union[float, int]]):
1474 self._socket_timeout = value
1476 @property
1477 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1478 return self._socket_connect_timeout
1480 @socket_connect_timeout.setter
1481 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1482 self._socket_connect_timeout = value
1484 def extract_connection_details(self) -> str:
1485 socket_address = None
1486 if self._sock is None:
1487 return "not connected"
1488 try:
1489 socket_address = self._sock.getsockname() if self._sock else None
1490 socket_address = socket_address[1] if socket_address else None
1491 except (AttributeError, OSError):
1492 pass
1494 return f"connected to ip {self.get_resolved_ip()}, local socket port: {socket_address}"
1497class Connection(AbstractConnection):
1498 "Manages TCP communication to and from a Redis server"
1500 def __init__(
1501 self,
1502 host="localhost",
1503 port=6379,
1504 socket_keepalive=True,
1505 socket_keepalive_options=SENTINEL,
1506 socket_type=0,
1507 **kwargs,
1508 ):
1509 """
1510 Initialize a TCP connection.
1512 Parameters
1513 ----------
1514 socket_keepalive : bool
1515 If `True`, TCP keepalive is enabled for TCP socket connections.
1516 socket_keepalive_options : Mapping[int, int | bytes] | object | None
1517 Mapping of TCP keepalive socket option constants to values, for
1518 example `{socket.TCP_KEEPIDLE: 30}`. If left unspecified, redis-py
1519 uses TCP keepalive defaults when `socket_keepalive` is enabled:
1520 idle 30 seconds, interval 5 seconds, and 3 probes. Platform-specific
1521 options that are not available are skipped. Pass `None` or `{}` to
1522 avoid setting additional TCP keepalive options.
1523 """
1524 self._host = host
1525 self.port = int(port)
1526 self.socket_keepalive = socket_keepalive
1527 if socket_keepalive_options is SENTINEL:
1528 socket_keepalive_options = get_default_socket_keepalive_options()
1529 self.socket_keepalive_options = socket_keepalive_options or {}
1530 self.socket_type = socket_type
1531 super().__init__(**kwargs)
1533 def repr_pieces(self):
1534 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1535 if self.client_name:
1536 pieces.append(("client_name", self.client_name))
1537 return pieces
1539 def _connect(self):
1540 "Create a TCP socket connection"
1541 # we want to mimic what socket.create_connection does to support
1542 # ipv4/ipv6, but we want to set options prior to calling
1543 # socket.connect()
1544 err = None
1546 for res in socket.getaddrinfo(
1547 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1548 ):
1549 family, socktype, proto, canonname, socket_address = res
1550 sock = None
1551 try:
1552 sock = socket.socket(family, socktype, proto)
1553 # TCP_NODELAY
1554 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1556 # TCP_KEEPALIVE
1557 if self.socket_keepalive:
1558 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1559 for k, v in self.socket_keepalive_options.items():
1560 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1562 # set the socket_connect_timeout before we connect
1563 sock.settimeout(self.socket_connect_timeout)
1565 # connect
1566 sock.connect(socket_address)
1568 # set the socket_timeout now that we're connected
1569 sock.settimeout(self.socket_timeout)
1570 return sock
1572 except OSError as _:
1573 err = _
1574 if sock is not None:
1575 try:
1576 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1577 except OSError:
1578 pass
1579 sock.close()
1581 if err is not None:
1582 raise err
1583 raise OSError("socket.getaddrinfo returned an empty list")
1585 def _host_error(self):
1586 return f"{self.host}:{self.port}"
1588 @property
1589 def host(self) -> str:
1590 return self._host
1592 @host.setter
1593 def host(self, value: str):
1594 self._host = value
1597class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1598 DUMMY_CACHE_VALUE = b"foo"
1599 MIN_ALLOWED_VERSION = "7.4.0"
1600 DEFAULT_SERVER_NAME = "redis"
1602 def __init__(
1603 self,
1604 conn: ConnectionInterface,
1605 cache: CacheInterface,
1606 pool_lock: threading.RLock,
1607 ):
1608 self.pid = os.getpid()
1609 self._conn = conn
1610 self.retry = self._conn.retry
1611 self.host = self._conn.host
1612 self.port = self._conn.port
1613 self.db = self._conn.db
1614 self._event_dispatcher = self._conn._event_dispatcher
1615 self.credential_provider = conn.credential_provider
1616 self._pool_lock = pool_lock
1617 self._cache = cache
1618 self._cache_lock = threading.RLock()
1619 self._current_command_cache_key = None
1620 self._current_options = None
1621 self.register_connect_callback(self._enable_tracking_callback)
1623 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1624 MaintNotificationsAbstractConnection.__init__(
1625 self,
1626 self._conn.maint_notifications_config,
1627 self._conn._maint_notifications_pool_handler,
1628 self._conn.maintenance_state,
1629 self._conn.maintenance_notification_hash,
1630 self._conn.host,
1631 self._conn.socket_timeout,
1632 self._conn.socket_connect_timeout,
1633 self._conn._oss_cluster_maint_notifications_handler,
1634 self._conn._get_parser(),
1635 event_dispatcher=self._conn.event_dispatcher,
1636 )
1638 def repr_pieces(self):
1639 return self._conn.repr_pieces()
1641 @property
1642 def is_connected(self) -> bool:
1643 return self._conn.is_connected
1645 def register_connect_callback(self, callback):
1646 self._conn.register_connect_callback(callback)
1648 def deregister_connect_callback(self, callback):
1649 self._conn.deregister_connect_callback(callback)
1651 def set_parser(self, parser_class):
1652 self._conn.set_parser(parser_class)
1654 def set_maint_notifications_pool_handler_for_connection(
1655 self, maint_notifications_pool_handler
1656 ):
1657 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1658 self._conn.set_maint_notifications_pool_handler_for_connection(
1659 maint_notifications_pool_handler
1660 )
1662 def set_maint_notifications_cluster_handler_for_connection(
1663 self, oss_cluster_maint_notifications_handler
1664 ):
1665 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1666 self._conn.set_maint_notifications_cluster_handler_for_connection(
1667 oss_cluster_maint_notifications_handler
1668 )
1670 def get_protocol(self):
1671 return self._conn.get_protocol()
1673 def connect(self):
1674 self._conn.connect()
1676 server_name = self._conn.handshake_metadata.get(b"server", None)
1677 if server_name is None:
1678 server_name = self._conn.handshake_metadata.get("server", None)
1679 server_ver = self._conn.handshake_metadata.get(b"version", None)
1680 if server_ver is None:
1681 server_ver = self._conn.handshake_metadata.get("version", None)
1682 if server_ver is None or server_name is None:
1683 raise ConnectionError("Cannot retrieve information about server version")
1685 server_ver = ensure_string(server_ver)
1686 server_name = ensure_string(server_name)
1688 if (
1689 server_name != self.DEFAULT_SERVER_NAME
1690 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1691 ):
1692 raise ConnectionError(
1693 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1694 )
1696 def on_connect(self):
1697 self._conn.on_connect()
1699 def disconnect(self, *args, **kwargs):
1700 with self._cache_lock:
1701 self._cache.flush()
1702 self._conn.disconnect(*args, **kwargs)
1704 def check_health(self):
1705 self._conn.check_health()
1707 def send_packed_command(self, command, check_health=True):
1708 # TODO: Investigate if it's possible to unpack command
1709 # or extract keys from packed command
1710 self._conn.send_packed_command(command)
1712 def send_command(self, *args, **kwargs):
1713 self._process_pending_invalidations()
1715 with self._cache_lock:
1716 # Command is write command or not allowed
1717 # to be cached.
1718 if not self._cache.is_cachable(
1719 CacheKey(command=args[0], redis_keys=(), redis_args=())
1720 ):
1721 self._current_command_cache_key = None
1722 self._conn.send_command(*args, **kwargs)
1723 return
1725 if kwargs.get("keys") is None:
1726 raise ValueError("Cannot create cache key.")
1728 # Creates cache key.
1729 self._current_command_cache_key = CacheKey(
1730 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1731 )
1733 with self._cache_lock:
1734 # We have to trigger invalidation processing in case if
1735 # it was cached by another connection to avoid
1736 # queueing invalidations in stale connections.
1737 if self._cache.get(self._current_command_cache_key):
1738 entry = self._cache.get(self._current_command_cache_key)
1740 with self._pool_lock:
1741 while entry.connection_ref.can_read():
1742 try:
1743 entry.connection_ref.read_response(
1744 push_request=True,
1745 timeout=0,
1746 disconnect_on_error=False,
1747 )
1748 except TimeoutError:
1749 break
1751 # Re-check: if the entry was invalidated during the drain,
1752 # fall through to send the command over the network.
1753 if self._cache.get(self._current_command_cache_key):
1754 return
1756 # Set temporary entry value to prevent
1757 # race condition from another connection.
1758 self._cache.set(
1759 CacheEntry(
1760 cache_key=self._current_command_cache_key,
1761 cache_value=self.DUMMY_CACHE_VALUE,
1762 status=CacheEntryStatus.IN_PROGRESS,
1763 connection_ref=self._conn,
1764 )
1765 )
1767 # Send command over socket only if it's allowed
1768 # read-only command that not yet cached.
1769 self._conn.send_command(*args, **kwargs)
1771 def can_read(self, timeout: float = 0) -> bool:
1772 # TODO: Rename this API; it detects pending data or dirty/closed
1773 # connection state, not only whether application data can be read.
1774 return self._conn.can_read(timeout)
1776 def read_response(
1777 self,
1778 disable_decoding=False,
1779 *,
1780 timeout: Union[float, object] = SENTINEL,
1781 disconnect_on_error=True,
1782 push_request=False,
1783 ):
1784 with self._cache_lock:
1785 # Check if command response exists in a cache and it's not in progress.
1786 if self._current_command_cache_key is not None:
1787 if (
1788 self._cache.get(self._current_command_cache_key) is not None
1789 and self._cache.get(self._current_command_cache_key).status
1790 != CacheEntryStatus.IN_PROGRESS
1791 ):
1792 res = copy.deepcopy(
1793 self._cache.get(self._current_command_cache_key).cache_value
1794 )
1795 self._current_command_cache_key = None
1796 record_csc_request(
1797 result=CSCResult.HIT,
1798 )
1799 record_csc_network_saved(
1800 bytes_saved=len(res) if hasattr(res, "__len__") else 0,
1801 )
1802 return res
1803 record_csc_request(
1804 result=CSCResult.MISS,
1805 )
1807 response = self._conn.read_response(
1808 disable_decoding=disable_decoding,
1809 timeout=timeout,
1810 disconnect_on_error=disconnect_on_error,
1811 push_request=push_request,
1812 )
1814 with self._cache_lock:
1815 # Prevent not-allowed command from caching.
1816 if self._current_command_cache_key is None:
1817 return response
1818 # If response is None prevent from caching.
1819 if response is None:
1820 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1821 return response
1823 cache_entry = self._cache.get(self._current_command_cache_key)
1825 # Cache only responses that still valid
1826 # and wasn't invalidated by another connection in meantime.
1827 if cache_entry is not None:
1828 cache_entry.status = CacheEntryStatus.VALID
1829 cache_entry.cache_value = response
1830 self._cache.set(cache_entry)
1832 self._current_command_cache_key = None
1834 return response
1836 def pack_command(self, *args):
1837 return self._conn.pack_command(*args)
1839 def pack_commands(self, commands):
1840 return self._conn.pack_commands(commands)
1842 @property
1843 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1844 return self._conn.handshake_metadata
1846 def set_re_auth_token(self, token: TokenInterface):
1847 self._conn.set_re_auth_token(token)
1849 def re_auth(self):
1850 self._conn.re_auth()
1852 def mark_for_reconnect(self):
1853 self._conn.mark_for_reconnect()
1855 def should_reconnect(self):
1856 return self._conn.should_reconnect()
1858 def reset_should_reconnect(self):
1859 self._conn.reset_should_reconnect()
1861 @property
1862 def host(self) -> str:
1863 return self._conn.host
1865 @host.setter
1866 def host(self, value: str):
1867 self._conn.host = value
1869 @property
1870 def socket_timeout(self) -> Optional[Union[float, int]]:
1871 return self._conn.socket_timeout
1873 @socket_timeout.setter
1874 def socket_timeout(self, value: Optional[Union[float, int]]):
1875 self._conn.socket_timeout = value
1877 @property
1878 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1879 return self._conn.socket_connect_timeout
1881 @socket_connect_timeout.setter
1882 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1883 self._conn.socket_connect_timeout = value
1885 @property
1886 def _maint_notifications_connection_handler(
1887 self,
1888 ) -> Optional[MaintNotificationsConnectionHandler]:
1889 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1890 return self._conn._maint_notifications_connection_handler
1892 @_maint_notifications_connection_handler.setter
1893 def _maint_notifications_connection_handler(
1894 self, value: Optional[MaintNotificationsConnectionHandler]
1895 ):
1896 self._conn._maint_notifications_connection_handler = value
1898 def _get_socket(self) -> Optional[socket.socket]:
1899 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1900 return self._conn._get_socket()
1901 else:
1902 raise NotImplementedError(
1903 "Maintenance notifications are not supported by this connection type"
1904 )
1906 def _get_maint_notifications_connection_instance(
1907 self, connection
1908 ) -> MaintNotificationsAbstractConnection:
1909 """
1910 Validate that connection instance supports maintenance notifications.
1911 With this helper method we ensure that we are working
1912 with the correct connection type.
1913 After twe validate that connection instance supports maintenance notifications
1914 we can safely return the connection instance
1915 as MaintNotificationsAbstractConnection.
1916 """
1917 if not isinstance(connection, MaintNotificationsAbstractConnection):
1918 raise NotImplementedError(
1919 "Maintenance notifications are not supported by this connection type"
1920 )
1921 else:
1922 return connection
1924 @property
1925 def maintenance_state(self) -> MaintenanceState:
1926 con = self._get_maint_notifications_connection_instance(self._conn)
1927 return con.maintenance_state
1929 @maintenance_state.setter
1930 def maintenance_state(self, state: MaintenanceState):
1931 con = self._get_maint_notifications_connection_instance(self._conn)
1932 con.maintenance_state = state
1934 def getpeername(self):
1935 con = self._get_maint_notifications_connection_instance(self._conn)
1936 return con.getpeername()
1938 def get_resolved_ip(self):
1939 con = self._get_maint_notifications_connection_instance(self._conn)
1940 return con.get_resolved_ip()
1942 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1943 con = self._get_maint_notifications_connection_instance(self._conn)
1944 con.update_current_socket_timeout(relaxed_timeout)
1946 def set_tmp_settings(
1947 self,
1948 tmp_host_address: Optional[str] = None,
1949 tmp_relaxed_timeout: Optional[float] = None,
1950 ):
1951 con = self._get_maint_notifications_connection_instance(self._conn)
1952 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1954 def reset_tmp_settings(
1955 self,
1956 reset_host_address: bool = False,
1957 reset_relaxed_timeout: bool = False,
1958 ):
1959 con = self._get_maint_notifications_connection_instance(self._conn)
1960 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1962 def _connect(self):
1963 self._conn._connect()
1965 def _host_error(self):
1966 self._conn._host_error()
1968 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1969 conn.send_command("CLIENT", "TRACKING", "ON")
1970 conn.read_response()
1971 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1973 def _process_pending_invalidations(self):
1974 while self.can_read():
1975 try:
1976 self._conn.read_response(
1977 push_request=True, timeout=0, disconnect_on_error=False
1978 )
1979 except TimeoutError:
1980 break
1982 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1983 with self._cache_lock:
1984 # Flush cache when DB flushed on server-side
1985 if data[1] is None:
1986 self._cache.flush()
1987 else:
1988 keys_deleted = self._cache.delete_by_redis_keys(data[1])
1990 if len(keys_deleted) > 0:
1991 record_csc_eviction(
1992 count=len(keys_deleted),
1993 reason=CSCReason.INVALIDATION,
1994 )
1996 def extract_connection_details(self) -> str:
1997 return self._conn.extract_connection_details()
2000class SSLConnection(Connection):
2001 """Manages SSL connections to and from the Redis server(s).
2002 This class extends the Connection class, adding SSL functionality, and making
2003 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
2004 """ # noqa
2006 def __init__(
2007 self,
2008 ssl_keyfile=None,
2009 ssl_certfile=None,
2010 ssl_cert_reqs="required",
2011 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
2012 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
2013 ssl_ca_certs=None,
2014 ssl_ca_data=None,
2015 ssl_check_hostname=True,
2016 ssl_ca_path=None,
2017 ssl_password=None,
2018 ssl_validate_ocsp=False,
2019 ssl_validate_ocsp_stapled=False,
2020 ssl_ocsp_context=None,
2021 ssl_ocsp_expected_cert=None,
2022 ssl_min_version=None,
2023 ssl_ciphers=None,
2024 **kwargs,
2025 ):
2026 """Constructor
2028 Args:
2029 ssl_keyfile: Path to an ssl private key. Defaults to None.
2030 ssl_certfile: Path to an ssl certificate. Defaults to None.
2031 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
2032 or an ssl.VerifyMode. Defaults to "required".
2033 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
2034 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
2035 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
2036 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
2037 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
2038 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
2039 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
2041 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
2042 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
2043 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
2044 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
2045 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
2046 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
2048 Raises:
2049 RedisError
2050 """ # noqa
2051 if not SSL_AVAILABLE:
2052 raise RedisError("Python wasn't built with SSL support")
2054 self.keyfile = ssl_keyfile
2055 self.certfile = ssl_certfile
2056 if ssl_cert_reqs is None:
2057 ssl_cert_reqs = ssl.CERT_NONE
2058 elif isinstance(ssl_cert_reqs, str):
2059 CERT_REQS = { # noqa: N806
2060 "none": ssl.CERT_NONE,
2061 "optional": ssl.CERT_OPTIONAL,
2062 "required": ssl.CERT_REQUIRED,
2063 }
2064 if ssl_cert_reqs not in CERT_REQS:
2065 raise RedisError(
2066 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
2067 )
2068 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
2069 self.cert_reqs = ssl_cert_reqs
2070 self.ssl_include_verify_flags = ssl_include_verify_flags
2071 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
2072 self.ca_certs = ssl_ca_certs
2073 self.ca_data = ssl_ca_data
2074 self.ca_path = ssl_ca_path
2075 self.check_hostname = (
2076 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
2077 )
2078 self.certificate_password = ssl_password
2079 self.ssl_validate_ocsp = ssl_validate_ocsp
2080 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
2081 self.ssl_ocsp_context = ssl_ocsp_context
2082 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
2083 self.ssl_min_version = ssl_min_version
2084 self.ssl_ciphers = ssl_ciphers
2085 super().__init__(**kwargs)
2087 def _connect(self):
2088 """
2089 Wrap the socket with SSL support, handling potential errors.
2090 """
2091 sock = super()._connect()
2092 try:
2093 return self._wrap_socket_with_ssl(sock)
2094 except (OSError, RedisError):
2095 sock.close()
2096 raise
2098 def _wrap_socket_with_ssl(self, sock):
2099 """
2100 Wraps the socket with SSL support.
2102 Args:
2103 sock: The plain socket to wrap with SSL.
2105 Returns:
2106 An SSL wrapped socket.
2107 """
2108 context = ssl.create_default_context()
2109 context.check_hostname = self.check_hostname
2110 context.verify_mode = self.cert_reqs
2111 if self.ssl_include_verify_flags:
2112 for flag in self.ssl_include_verify_flags:
2113 context.verify_flags |= flag
2114 if self.ssl_exclude_verify_flags:
2115 for flag in self.ssl_exclude_verify_flags:
2116 context.verify_flags &= ~flag
2117 if self.certfile or self.keyfile:
2118 context.load_cert_chain(
2119 certfile=self.certfile,
2120 keyfile=self.keyfile,
2121 password=self.certificate_password,
2122 )
2123 if (
2124 self.ca_certs is not None
2125 or self.ca_path is not None
2126 or self.ca_data is not None
2127 ):
2128 context.load_verify_locations(
2129 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
2130 )
2131 if self.ssl_min_version is not None:
2132 context.minimum_version = self.ssl_min_version
2133 if self.ssl_ciphers:
2134 context.set_ciphers(self.ssl_ciphers)
2135 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
2136 raise RedisError("cryptography is not installed.")
2138 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
2139 raise RedisError(
2140 "Either an OCSP staple or pure OCSP connection must be validated "
2141 "- not both."
2142 )
2144 sslsock = context.wrap_socket(sock, server_hostname=self.host)
2146 # validation for the stapled case
2147 if self.ssl_validate_ocsp_stapled:
2148 import OpenSSL
2150 from .ocsp import ocsp_staple_verifier
2152 # if a context is provided use it - otherwise, a basic context
2153 if self.ssl_ocsp_context is None:
2154 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
2155 staple_ctx.use_certificate_file(self.certfile)
2156 staple_ctx.use_privatekey_file(self.keyfile)
2157 else:
2158 staple_ctx = self.ssl_ocsp_context
2160 staple_ctx.set_ocsp_client_callback(
2161 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
2162 )
2164 # need another socket
2165 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
2166 con.request_ocsp()
2167 con.connect((self.host, self.port))
2168 con.do_handshake()
2169 con.shutdown()
2170 return sslsock
2172 # pure ocsp validation
2173 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
2174 from .ocsp import OCSPVerifier
2176 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
2177 if o.is_valid():
2178 return sslsock
2179 else:
2180 raise ConnectionError("ocsp validation error")
2181 return sslsock
2184class UnixDomainSocketConnection(AbstractConnection):
2185 "Manages UDS communication to and from a Redis server"
2187 def __init__(self, path="", socket_timeout=DEFAULT_SOCKET_TIMEOUT, **kwargs):
2188 super().__init__(**kwargs)
2189 self.path = path
2190 self.socket_timeout = socket_timeout
2192 def repr_pieces(self):
2193 pieces = [("path", self.path), ("db", self.db)]
2194 if self.client_name:
2195 pieces.append(("client_name", self.client_name))
2196 return pieces
2198 def _connect(self):
2199 "Create a Unix domain socket connection"
2200 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
2201 sock.settimeout(self.socket_connect_timeout)
2202 try:
2203 sock.connect(self.path)
2204 except OSError:
2205 # Prevent ResourceWarnings for unclosed sockets.
2206 try:
2207 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
2208 except OSError:
2209 pass
2210 sock.close()
2211 raise
2212 sock.settimeout(self.socket_timeout)
2213 return sock
2215 def _host_error(self):
2216 return self.path
2219FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
2222def to_bool(value):
2223 if value is None or value == "":
2224 return None
2225 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
2226 return False
2227 return bool(value)
2230def parse_ssl_verify_flags(value):
2231 # flags are passed in as a string representation of a list,
2232 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
2233 verify_flags_str = value.replace("[", "").replace("]", "")
2235 verify_flags = []
2236 for flag in verify_flags_str.split(","):
2237 flag = flag.strip()
2238 if not hasattr(VerifyFlags, flag):
2239 raise ValueError(f"Invalid ssl verify flag: {flag}")
2240 verify_flags.append(getattr(VerifyFlags, flag))
2241 return verify_flags
2244URL_QUERY_ARGUMENT_PARSERS = {
2245 "db": int,
2246 "socket_timeout": float,
2247 "socket_connect_timeout": float,
2248 "socket_read_size": int,
2249 "socket_keepalive": to_bool,
2250 "retry_on_timeout": to_bool,
2251 "retry_on_error": list,
2252 "max_connections": int,
2253 "health_check_interval": int,
2254 "ssl_check_hostname": to_bool,
2255 "ssl_include_verify_flags": parse_ssl_verify_flags,
2256 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
2257 "ssl_min_version": int,
2258 "timeout": float,
2259 "protocol": int,
2260 "legacy_responses": to_bool,
2261}
2264def parse_url(url):
2265 if not (
2266 url.startswith("redis://")
2267 or url.startswith("rediss://")
2268 or url.startswith("unix://")
2269 ):
2270 raise ValueError(
2271 "Redis URL must specify one of the following "
2272 "schemes (redis://, rediss://, unix://)"
2273 )
2275 url = urlparse(url)
2276 kwargs = {}
2278 for name, value in parse_qs(url.query).items():
2279 if value and len(value) > 0:
2280 value = unquote(value[0])
2281 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
2282 if parser:
2283 try:
2284 kwargs[name] = parser(value)
2285 except (TypeError, ValueError):
2286 raise ValueError(f"Invalid value for '{name}' in connection URL.")
2287 else:
2288 kwargs[name] = value
2290 if url.username:
2291 kwargs["username"] = unquote(url.username)
2292 if url.password:
2293 kwargs["password"] = unquote(url.password)
2295 # We only support redis://, rediss:// and unix:// schemes.
2296 if url.scheme == "unix":
2297 if url.path:
2298 kwargs["path"] = unquote(url.path)
2299 kwargs["connection_class"] = UnixDomainSocketConnection
2301 else: # implied: url.scheme in ("redis", "rediss"):
2302 if url.hostname:
2303 kwargs["host"] = unquote(url.hostname)
2304 if url.port:
2305 kwargs["port"] = int(url.port)
2307 # If there's a path argument, use it as the db argument if a
2308 # querystring value wasn't specified
2309 if url.path and "db" not in kwargs:
2310 try:
2311 kwargs["db"] = int(unquote(url.path).replace("/", ""))
2312 except (AttributeError, ValueError):
2313 pass
2315 if url.scheme == "rediss":
2316 kwargs["connection_class"] = SSLConnection
2318 return kwargs
2321_CP = TypeVar("_CP", bound="ConnectionPool")
2324class ConnectionPoolInterface(ABC):
2325 @abstractmethod
2326 def get_protocol(self):
2327 pass
2329 @abstractmethod
2330 def reset(self):
2331 pass
2333 @abstractmethod
2334 @deprecated_args(
2335 args_to_warn=["*"],
2336 reason="Use get_connection() without args instead",
2337 version="5.3.0",
2338 )
2339 def get_connection(
2340 self, command_name: Optional[str], *keys, **options
2341 ) -> ConnectionInterface:
2342 pass
2344 @abstractmethod
2345 def get_encoder(self):
2346 pass
2348 @abstractmethod
2349 def release(self, connection: ConnectionInterface):
2350 pass
2352 @abstractmethod
2353 def disconnect(self, inuse_connections: bool = True):
2354 pass
2356 @abstractmethod
2357 def close(self):
2358 pass
2360 @abstractmethod
2361 def set_retry(self, retry: Retry):
2362 pass
2364 @abstractmethod
2365 def re_auth_callback(self, token: TokenInterface):
2366 pass
2368 @abstractmethod
2369 def get_connection_count(self) -> list[tuple[int, dict]]:
2370 """
2371 Returns a connection count (both idle and in use).
2372 """
2373 pass
2376class MaintNotificationsAbstractConnectionPool:
2377 """
2378 Abstract class for handling maintenance notifications logic.
2379 This class is mixed into the ConnectionPool classes.
2381 This class is not intended to be used directly!
2383 All logic related to maintenance notifications and
2384 connection pool handling is encapsulated in this class.
2385 """
2387 def __init__(
2388 self,
2389 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2390 oss_cluster_maint_notifications_handler: Optional[
2391 OSSMaintNotificationsHandler
2392 ] = None,
2393 **kwargs,
2394 ):
2395 # Initialize maintenance notifications
2396 is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3)
2398 if maint_notifications_config is None and is_protocol_supported:
2399 maint_notifications_config = MaintNotificationsConfig()
2401 if maint_notifications_config and maint_notifications_config.enabled:
2402 if not is_protocol_supported:
2403 raise RedisError(
2404 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2405 )
2407 self._event_dispatcher = kwargs.get("event_dispatcher", None)
2408 if self._event_dispatcher is None:
2409 self._event_dispatcher = EventDispatcher()
2411 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2412 self, maint_notifications_config
2413 )
2414 if oss_cluster_maint_notifications_handler:
2415 self._oss_cluster_maint_notifications_handler = (
2416 oss_cluster_maint_notifications_handler
2417 )
2418 self._update_connection_kwargs_for_maint_notifications(
2419 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler
2420 )
2421 self._maint_notifications_pool_handler = None
2422 else:
2423 self._oss_cluster_maint_notifications_handler = None
2424 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2425 self, maint_notifications_config
2426 )
2428 self._update_connection_kwargs_for_maint_notifications(
2429 maint_notifications_pool_handler=self._maint_notifications_pool_handler
2430 )
2431 else:
2432 self._maint_notifications_pool_handler = None
2433 self._oss_cluster_maint_notifications_handler = None
2435 @property
2436 @abstractmethod
2437 def connection_kwargs(self) -> Dict[str, Any]:
2438 pass
2440 @connection_kwargs.setter
2441 @abstractmethod
2442 def connection_kwargs(self, value: Dict[str, Any]):
2443 pass
2445 @abstractmethod
2446 def _get_pool_lock(self) -> threading.RLock:
2447 pass
2449 @abstractmethod
2450 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2451 pass
2453 @abstractmethod
2454 def _get_in_use_connections(
2455 self,
2456 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2457 pass
2459 def maint_notifications_enabled(self):
2460 """
2461 Returns:
2462 True if the maintenance notifications are enabled, False otherwise.
2463 The maintenance notifications config is stored in the pool handler.
2464 If the pool handler is not set, the maintenance notifications are not enabled.
2465 """
2466 if self._oss_cluster_maint_notifications_handler:
2467 maint_notifications_config = (
2468 self._oss_cluster_maint_notifications_handler.config
2469 )
2470 else:
2471 maint_notifications_config = (
2472 self._maint_notifications_pool_handler.config
2473 if self._maint_notifications_pool_handler
2474 else None
2475 )
2477 return maint_notifications_config and maint_notifications_config.enabled
2479 def update_maint_notifications_config(
2480 self,
2481 maint_notifications_config: MaintNotificationsConfig,
2482 oss_cluster_maint_notifications_handler: Optional[
2483 OSSMaintNotificationsHandler
2484 ] = None,
2485 ):
2486 """
2487 Updates the maintenance notifications configuration.
2488 This method should be called only if the pool was created
2489 without enabling the maintenance notifications and
2490 in a later point in time maintenance notifications
2491 are requested to be enabled.
2492 """
2493 if (
2494 self.maint_notifications_enabled()
2495 and not maint_notifications_config.enabled
2496 ):
2497 raise ValueError(
2498 "Cannot disable maintenance notifications after enabling them"
2499 )
2500 if oss_cluster_maint_notifications_handler:
2501 self._oss_cluster_maint_notifications_handler = (
2502 oss_cluster_maint_notifications_handler
2503 )
2504 else:
2505 # first update pool settings
2506 if not self._maint_notifications_pool_handler:
2507 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2508 self, maint_notifications_config
2509 )
2510 else:
2511 self._maint_notifications_pool_handler.config = (
2512 maint_notifications_config
2513 )
2515 # then update connection kwargs and existing connections
2516 self._update_connection_kwargs_for_maint_notifications(
2517 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2518 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2519 )
2520 self._update_maint_notifications_configs_for_connections(
2521 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2522 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2523 )
2525 def _update_connection_kwargs_for_maint_notifications(
2526 self,
2527 maint_notifications_pool_handler: Optional[
2528 MaintNotificationsPoolHandler
2529 ] = None,
2530 oss_cluster_maint_notifications_handler: Optional[
2531 OSSMaintNotificationsHandler
2532 ] = None,
2533 ):
2534 """
2535 Update the connection kwargs for all future connections.
2536 """
2537 if not self.maint_notifications_enabled():
2538 return
2539 if maint_notifications_pool_handler:
2540 self.connection_kwargs.update(
2541 {
2542 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2543 "maint_notifications_config": maint_notifications_pool_handler.config,
2544 }
2545 )
2546 if oss_cluster_maint_notifications_handler:
2547 self.connection_kwargs.update(
2548 {
2549 "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler,
2550 "maint_notifications_config": oss_cluster_maint_notifications_handler.config,
2551 }
2552 )
2554 # Store original connection parameters for maintenance notifications.
2555 if self.connection_kwargs.get("orig_host_address", None) is None:
2556 # If orig_host_address is None it means we haven't
2557 # configured the original values yet
2558 self.connection_kwargs.update(
2559 {
2560 "orig_host_address": self.connection_kwargs.get("host"),
2561 "orig_socket_timeout": self.connection_kwargs.get(
2562 "socket_timeout", DEFAULT_SOCKET_TIMEOUT
2563 ),
2564 "orig_socket_connect_timeout": self.connection_kwargs.get(
2565 "socket_connect_timeout", DEFAULT_SOCKET_CONNECT_TIMEOUT
2566 ),
2567 }
2568 )
2570 def _update_maint_notifications_configs_for_connections(
2571 self,
2572 maint_notifications_pool_handler: Optional[
2573 MaintNotificationsPoolHandler
2574 ] = None,
2575 oss_cluster_maint_notifications_handler: Optional[
2576 OSSMaintNotificationsHandler
2577 ] = None,
2578 ):
2579 """Update the maintenance notifications config for all connections in the pool."""
2580 with self._get_pool_lock():
2581 for conn in self._get_free_connections():
2582 if oss_cluster_maint_notifications_handler:
2583 # set cluster handler for conn
2584 conn.set_maint_notifications_cluster_handler_for_connection(
2585 oss_cluster_maint_notifications_handler
2586 )
2587 conn.maint_notifications_config = (
2588 oss_cluster_maint_notifications_handler.config
2589 )
2590 elif maint_notifications_pool_handler:
2591 conn.set_maint_notifications_pool_handler_for_connection(
2592 maint_notifications_pool_handler
2593 )
2594 conn.maint_notifications_config = (
2595 maint_notifications_pool_handler.config
2596 )
2597 else:
2598 raise ValueError(
2599 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2600 )
2601 conn.disconnect()
2602 for conn in self._get_in_use_connections():
2603 if oss_cluster_maint_notifications_handler:
2604 conn.maint_notifications_config = (
2605 oss_cluster_maint_notifications_handler.config
2606 )
2607 conn._configure_maintenance_notifications(
2608 oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler
2609 )
2610 elif maint_notifications_pool_handler:
2611 conn.set_maint_notifications_pool_handler_for_connection(
2612 maint_notifications_pool_handler
2613 )
2614 conn.maint_notifications_config = (
2615 maint_notifications_pool_handler.config
2616 )
2617 else:
2618 raise ValueError(
2619 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2620 )
2621 conn.mark_for_reconnect()
2623 def _should_update_connection(
2624 self,
2625 conn: "MaintNotificationsAbstractConnection",
2626 matching_pattern: Literal[
2627 "connected_address", "configured_address", "notification_hash"
2628 ] = "connected_address",
2629 matching_address: Optional[str] = None,
2630 matching_notification_hash: Optional[int] = None,
2631 ) -> bool:
2632 """
2633 Check if the connection should be updated based on the matching criteria.
2634 """
2635 if matching_pattern == "connected_address":
2636 if matching_address and conn.getpeername() != matching_address:
2637 return False
2638 elif matching_pattern == "configured_address":
2639 if matching_address and conn.host != matching_address:
2640 return False
2641 elif matching_pattern == "notification_hash":
2642 if (
2643 matching_notification_hash
2644 and conn.maintenance_notification_hash != matching_notification_hash
2645 ):
2646 return False
2647 return True
2649 def update_connection_settings(
2650 self,
2651 conn: "MaintNotificationsAbstractConnection",
2652 state: Optional["MaintenanceState"] = None,
2653 maintenance_notification_hash: Optional[int] = None,
2654 host_address: Optional[str] = None,
2655 relaxed_timeout: Optional[float] = None,
2656 update_notification_hash: bool = False,
2657 reset_host_address: bool = False,
2658 reset_relaxed_timeout: bool = False,
2659 ):
2660 """
2661 Update the settings for a single connection.
2662 """
2663 if state:
2664 conn.maintenance_state = state
2666 if update_notification_hash:
2667 # update the notification hash only if requested
2668 conn.maintenance_notification_hash = maintenance_notification_hash
2670 if host_address is not None:
2671 conn.set_tmp_settings(tmp_host_address=host_address)
2673 if relaxed_timeout is not None:
2674 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2676 if reset_relaxed_timeout or reset_host_address:
2677 conn.reset_tmp_settings(
2678 reset_host_address=reset_host_address,
2679 reset_relaxed_timeout=reset_relaxed_timeout,
2680 )
2682 conn.update_current_socket_timeout(relaxed_timeout)
2684 def update_connections_settings(
2685 self,
2686 state: Optional["MaintenanceState"] = None,
2687 maintenance_notification_hash: Optional[int] = None,
2688 host_address: Optional[str] = None,
2689 relaxed_timeout: Optional[float] = None,
2690 matching_address: Optional[str] = None,
2691 matching_notification_hash: Optional[int] = None,
2692 matching_pattern: Literal[
2693 "connected_address", "configured_address", "notification_hash"
2694 ] = "connected_address",
2695 update_notification_hash: bool = False,
2696 reset_host_address: bool = False,
2697 reset_relaxed_timeout: bool = False,
2698 include_free_connections: bool = True,
2699 ):
2700 """
2701 Update the settings for all matching connections in the pool.
2703 This method does not create new connections.
2704 This method does not affect the connection kwargs.
2706 :param state: The maintenance state to set for the connection.
2707 :param maintenance_notification_hash: The hash of the maintenance notification
2708 to set for the connection.
2709 :param host_address: The host address to set for the connection.
2710 :param relaxed_timeout: The relaxed timeout to set for the connection.
2711 :param matching_address: The address to match for the connection.
2712 :param matching_notification_hash: The notification hash to match for the connection.
2713 :param matching_pattern: The pattern to match for the connection.
2714 :param update_notification_hash: Whether to update the notification hash for the connection.
2715 :param reset_host_address: Whether to reset the host address to the original address.
2716 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2717 :param include_free_connections: Whether to include free/available connections.
2718 """
2719 with self._get_pool_lock():
2720 for conn in self._get_in_use_connections():
2721 if self._should_update_connection(
2722 conn,
2723 matching_pattern,
2724 matching_address,
2725 matching_notification_hash,
2726 ):
2727 self.update_connection_settings(
2728 conn,
2729 state=state,
2730 maintenance_notification_hash=maintenance_notification_hash,
2731 host_address=host_address,
2732 relaxed_timeout=relaxed_timeout,
2733 update_notification_hash=update_notification_hash,
2734 reset_host_address=reset_host_address,
2735 reset_relaxed_timeout=reset_relaxed_timeout,
2736 )
2738 if include_free_connections:
2739 for conn in self._get_free_connections():
2740 if self._should_update_connection(
2741 conn,
2742 matching_pattern,
2743 matching_address,
2744 matching_notification_hash,
2745 ):
2746 self.update_connection_settings(
2747 conn,
2748 state=state,
2749 maintenance_notification_hash=maintenance_notification_hash,
2750 host_address=host_address,
2751 relaxed_timeout=relaxed_timeout,
2752 update_notification_hash=update_notification_hash,
2753 reset_host_address=reset_host_address,
2754 reset_relaxed_timeout=reset_relaxed_timeout,
2755 )
2757 def update_connection_kwargs(
2758 self,
2759 **kwargs,
2760 ):
2761 """
2762 Update the connection kwargs for all future connections.
2764 This method updates the connection kwargs for all future connections created by the pool.
2765 Existing connections are not affected.
2766 """
2767 self.connection_kwargs.update(kwargs)
2769 def update_active_connections_for_reconnect(
2770 self,
2771 moving_address_src: Optional[str] = None,
2772 ):
2773 """
2774 Mark all active connections for reconnect.
2775 This is used when a cluster node is migrated to a different address.
2777 :param moving_address_src: The address of the node that is being moved.
2778 """
2779 with self._get_pool_lock():
2780 for conn in self._get_in_use_connections():
2781 if self._should_update_connection(
2782 conn, "connected_address", moving_address_src
2783 ):
2784 conn.mark_for_reconnect()
2786 def disconnect_free_connections(
2787 self,
2788 moving_address_src: Optional[str] = None,
2789 ):
2790 """
2791 Disconnect all free/available connections.
2792 This is used when a cluster node is migrated to a different address.
2794 :param moving_address_src: The address of the node that is being moved.
2795 """
2796 with self._get_pool_lock():
2797 for conn in self._get_free_connections():
2798 if self._should_update_connection(
2799 conn, "connected_address", moving_address_src
2800 ):
2801 conn.disconnect()
2804class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2805 """
2806 Create a connection pool. ``If max_connections`` is set, then this
2807 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2808 limit is reached.
2810 By default, TCP connections are created unless ``connection_class``
2811 is specified. Use class:`.UnixDomainSocketConnection` for
2812 unix sockets.
2813 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2815 If ``maint_notifications_config`` is provided, the connection pool will support
2816 maintenance notifications.
2817 Maintenance notifications are supported only with RESP3.
2818 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2819 the maintenance notifications will be enabled by default.
2821 Any additional keyword arguments are passed to the constructor of
2822 ``connection_class``.
2823 """
2825 @classmethod
2826 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2827 """
2828 Return a connection pool configured from the given URL.
2830 For example::
2832 redis://[[username]:[password]]@localhost:6379/0
2833 rediss://[[username]:[password]]@localhost:6379/0
2834 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2836 Three URL schemes are supported:
2838 - `redis://` creates a TCP socket connection. See more at:
2839 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2840 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2841 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2842 - ``unix://``: creates a Unix Domain Socket connection.
2844 The username, password, hostname, path and all querystring values
2845 are passed through urllib.parse.unquote in order to replace any
2846 percent-encoded values with their corresponding characters.
2848 There are several ways to specify a database number. The first value
2849 found will be used:
2851 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2852 2. If using the redis:// or rediss:// schemes, the path argument
2853 of the url, e.g. redis://localhost/0
2854 3. A ``db`` keyword argument to this function.
2856 If none of these options are specified, the default db=0 is used.
2858 All querystring options are cast to their appropriate Python types.
2859 Boolean arguments can be specified with string values "True"/"False"
2860 or "Yes"/"No". Values that cannot be properly cast cause a
2861 ``ValueError`` to be raised. Once parsed, the querystring arguments
2862 and keyword arguments are passed to the ``ConnectionPool``'s
2863 class initializer. In the case of conflicting arguments, querystring
2864 arguments always win.
2865 """
2866 url_options = parse_url(url)
2868 if "connection_class" in kwargs:
2869 url_options["connection_class"] = kwargs["connection_class"]
2871 kwargs.update(url_options)
2872 return cls(**kwargs)
2874 def __init__(
2875 self,
2876 connection_class=Connection,
2877 max_connections: Optional[int] = None,
2878 cache_factory: Optional[CacheFactoryInterface] = None,
2879 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2880 **connection_kwargs,
2881 ):
2882 max_connections = max_connections or 100
2883 if not isinstance(max_connections, int) or max_connections < 0:
2884 raise ValueError('"max_connections" must be a positive integer')
2886 self.connection_class = connection_class
2887 self._connection_kwargs = connection_kwargs
2888 self.max_connections = max_connections
2889 self.cache = None
2890 self._cache_factory = cache_factory
2892 try:
2893 supports_maint_notifications = issubclass(
2894 connection_class, MaintNotificationsAbstractConnection
2895 )
2896 is_unix_domain_socket_connection = issubclass(
2897 connection_class, UnixDomainSocketConnection
2898 )
2899 except TypeError:
2900 supports_maint_notifications = False
2901 is_unix_domain_socket_connection = False
2903 if is_unix_domain_socket_connection or not supports_maint_notifications:
2904 if (
2905 maint_notifications_config
2906 and maint_notifications_config.enabled is True
2907 ):
2908 raise RedisError(
2909 "Maintenance notifications are not supported with "
2910 f"{connection_class}"
2911 )
2912 maint_notifications_config = MaintNotificationsConfig(enabled=False)
2914 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2915 if self._event_dispatcher is None:
2916 self._event_dispatcher = EventDispatcher()
2918 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2919 if not check_protocol_version(self._connection_kwargs.get("protocol"), 3):
2920 raise RedisError("Client caching is only supported with RESP version 3")
2922 cache = self._connection_kwargs.get("cache")
2924 if cache is not None:
2925 if not isinstance(cache, CacheInterface):
2926 raise ValueError("Cache must implement CacheInterface")
2928 self.cache = cache
2929 else:
2930 if self._cache_factory is not None:
2931 self.cache = CacheProxy(self._cache_factory.get_cache())
2932 else:
2933 self.cache = CacheFactory(
2934 self._connection_kwargs.get("cache_config")
2935 ).get_cache()
2937 init_csc_items()
2938 register_csc_items_callback(
2939 callback=lambda: self.cache.size,
2940 pool_name=get_pool_name(self),
2941 )
2943 connection_kwargs.pop("cache", None)
2944 connection_kwargs.pop("cache_config", None)
2946 # a lock to protect the critical section in _checkpid().
2947 # this lock is acquired when the process id changes, such as
2948 # after a fork. during this time, multiple threads in the child
2949 # process could attempt to acquire this lock. the first thread
2950 # to acquire the lock will reset the data structures and lock
2951 # object of this pool. subsequent threads acquiring this lock
2952 # will notice the first thread already did the work and simply
2953 # release the lock.
2955 self._fork_lock = threading.RLock()
2956 self._lock = threading.RLock()
2958 # Generate unique pool ID for observability (matches go-redis behavior)
2959 import secrets
2961 self._pool_id = secrets.token_hex(4)
2963 MaintNotificationsAbstractConnectionPool.__init__(
2964 self,
2965 maint_notifications_config=maint_notifications_config,
2966 **connection_kwargs,
2967 )
2969 self.reset()
2971 # Keys that should be redacted in __repr__ to avoid exposing sensitive information
2972 SENSITIVE_REPR_KEYS = frozenset(
2973 {
2974 "password",
2975 "username",
2976 "ssl_password",
2977 "credential_provider",
2978 }
2979 )
2981 def __repr__(self) -> str:
2982 conn_kwargs = ",".join(
2983 [
2984 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}"
2985 for k, v in self.connection_kwargs.items()
2986 ]
2987 )
2988 return (
2989 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2990 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2991 f"({conn_kwargs})>)>"
2992 )
2994 @property
2995 def connection_kwargs(self) -> Dict[str, Any]:
2996 return self._connection_kwargs
2998 @connection_kwargs.setter
2999 def connection_kwargs(self, value: Dict[str, Any]):
3000 self._connection_kwargs = value
3002 def get_protocol(self):
3003 """
3004 Returns:
3005 The RESP protocol version, or ``None`` if the protocol is not specified,
3006 in which case the server default will be used.
3007 """
3008 return self.connection_kwargs.get("protocol", None)
3010 def reset(self) -> None:
3011 # Record metrics for connections being removed before clearing
3012 # (only if attributes exist - they won't during __init__)
3013 if hasattr(self, "_available_connections") and hasattr(
3014 self, "_in_use_connections"
3015 ):
3016 with self._lock:
3017 idle_count = len(self._available_connections)
3018 in_use_count = len(self._in_use_connections)
3019 if idle_count > 0 or in_use_count > 0:
3020 pool_name = get_pool_name(self)
3021 if idle_count > 0:
3022 record_connection_count(
3023 pool_name=pool_name,
3024 connection_state=ConnectionState.IDLE,
3025 counter=-idle_count,
3026 )
3027 if in_use_count > 0:
3028 record_connection_count(
3029 pool_name=pool_name,
3030 connection_state=ConnectionState.USED,
3031 counter=-in_use_count,
3032 )
3034 self._created_connections = 0
3035 self._available_connections = []
3036 self._in_use_connections = set()
3038 # this must be the last operation in this method. while reset() is
3039 # called when holding _fork_lock, other threads in this process
3040 # can call _checkpid() which compares self.pid and os.getpid() without
3041 # holding any lock (for performance reasons). keeping this assignment
3042 # as the last operation ensures that those other threads will also
3043 # notice a pid difference and block waiting for the first thread to
3044 # release _fork_lock. when each of these threads eventually acquire
3045 # _fork_lock, they will notice that another thread already called
3046 # reset() and they will immediately release _fork_lock and continue on.
3047 self.pid = os.getpid()
3049 def __del__(self) -> None:
3050 """Clean up connection pool and record metrics when garbage collected."""
3051 try:
3052 if not hasattr(self, "_available_connections") or not hasattr(
3053 self, "_in_use_connections"
3054 ):
3055 return
3056 # Record metrics for all connections being removed
3057 idle_count = len(self._available_connections)
3058 in_use_count = len(self._in_use_connections)
3059 if idle_count > 0 or in_use_count > 0:
3060 pool_name = get_pool_name(self)
3061 if idle_count > 0:
3062 record_connection_count(
3063 pool_name=pool_name,
3064 connection_state=ConnectionState.IDLE,
3065 counter=-idle_count,
3066 )
3067 if in_use_count > 0:
3068 record_connection_count(
3069 pool_name=pool_name,
3070 connection_state=ConnectionState.USED,
3071 counter=-in_use_count,
3072 )
3073 except Exception:
3074 pass
3076 def _checkpid(self) -> None:
3077 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
3078 # systems. this is called by all ConnectionPool methods that
3079 # manipulate the pool's state such as get_connection() and release().
3080 #
3081 # _checkpid() determines whether the process has forked by comparing
3082 # the current process id to the process id saved on the ConnectionPool
3083 # instance. if these values are the same, _checkpid() simply returns.
3084 #
3085 # when the process ids differ, _checkpid() assumes that the process
3086 # has forked and that we're now running in the child process. the child
3087 # process cannot use the parent's file descriptors (e.g., sockets).
3088 # therefore, when _checkpid() sees the process id change, it calls
3089 # reset() in order to reinitialize the child's ConnectionPool. this
3090 # will cause the child to make all new connection objects.
3091 #
3092 # _checkpid() is protected by self._fork_lock to ensure that multiple
3093 # threads in the child process do not call reset() multiple times.
3094 #
3095 # there is an extremely small chance this could fail in the following
3096 # scenario:
3097 # 1. process A calls _checkpid() for the first time and acquires
3098 # self._fork_lock.
3099 # 2. while holding self._fork_lock, process A forks (the fork()
3100 # could happen in a different thread owned by process A)
3101 # 3. process B (the forked child process) inherits the
3102 # ConnectionPool's state from the parent. that state includes
3103 # a locked _fork_lock. process B will not be notified when
3104 # process A releases the _fork_lock and will thus never be
3105 # able to acquire the _fork_lock.
3106 #
3107 # to mitigate this possible deadlock, _checkpid() will only wait 5
3108 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
3109 # that time it is assumed that the child is deadlocked and a
3110 # redis.ChildDeadlockedError error is raised.
3111 if self.pid != os.getpid():
3112 acquired = self._fork_lock.acquire(timeout=5)
3113 if not acquired:
3114 raise ChildDeadlockedError
3115 # reset() the instance for the new process if another thread
3116 # hasn't already done so
3117 try:
3118 if self.pid != os.getpid():
3119 self.reset()
3120 finally:
3121 self._fork_lock.release()
3123 @deprecated_args(
3124 args_to_warn=["*"],
3125 reason="Use get_connection() without args instead",
3126 version="5.3.0",
3127 )
3128 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
3129 "Get a connection from the pool"
3131 # Start timing for observability
3132 self._checkpid()
3133 is_created = False
3135 with self._lock:
3136 try:
3137 connection = self._available_connections.pop()
3138 except IndexError:
3139 # Start timing for observability
3140 start_time_created = time.monotonic()
3142 connection = self.make_connection()
3143 is_created = True
3144 self._in_use_connections.add(connection)
3146 # Record state transition: IDLE -> USED
3147 # (make_connection already recorded IDLE +1 for new connections)
3148 # This ensures counters stay balanced if connect() fails and release() is called
3149 pool_name = get_pool_name(self)
3150 record_connection_count(
3151 pool_name=pool_name,
3152 connection_state=ConnectionState.IDLE,
3153 counter=-1,
3154 )
3155 record_connection_count(
3156 pool_name=pool_name,
3157 connection_state=ConnectionState.USED,
3158 counter=1,
3159 )
3161 try:
3162 # ensure this connection is connected to Redis
3163 connection.connect()
3164 # connections that the pool provides should be ready to send
3165 # a command. if not, the connection was either returned to the
3166 # pool before all data has been read or the socket has been
3167 # closed. either way, reconnect and verify everything is good.
3168 try:
3169 if (
3170 connection.can_read()
3171 and self.cache is None
3172 and not self.maint_notifications_enabled()
3173 ):
3174 raise ConnectionError("Connection has data")
3175 except (ConnectionError, TimeoutError, OSError):
3176 connection.disconnect()
3177 connection.connect()
3178 if connection.can_read():
3179 raise ConnectionError("Connection not ready")
3180 except BaseException:
3181 # release the connection back to the pool so that we don't
3182 # leak it
3183 self.release(connection)
3184 raise
3186 if is_created:
3187 record_connection_create_time(
3188 connection_pool=self,
3189 duration_seconds=time.monotonic() - start_time_created,
3190 )
3192 return connection
3194 def get_encoder(self) -> Encoder:
3195 "Return an encoder based on encoding settings"
3196 kwargs = self.connection_kwargs
3197 return Encoder(
3198 encoding=kwargs.get("encoding", "utf-8"),
3199 encoding_errors=kwargs.get("encoding_errors", "strict"),
3200 decode_responses=kwargs.get("decode_responses", False),
3201 )
3203 def make_connection(self) -> "ConnectionInterface":
3204 "Create a new connection"
3205 if self._created_connections >= self.max_connections:
3206 raise MaxConnectionsError("Too many connections")
3207 self._created_connections += 1
3209 kwargs = dict(self.connection_kwargs)
3211 # Create the connection first, then record metrics only on success
3212 if self.cache is not None:
3213 connection = CacheProxyConnection(
3214 self.connection_class(**kwargs), self.cache, self._lock
3215 )
3216 else:
3217 connection = self.connection_class(**kwargs)
3219 # Record new connection created (starts as IDLE) - only after successful construction
3220 record_connection_count(
3221 pool_name=get_pool_name(self),
3222 connection_state=ConnectionState.IDLE,
3223 counter=1,
3224 )
3226 return connection
3228 def release(self, connection: "Connection") -> None:
3229 "Releases the connection back to the pool"
3230 self._checkpid()
3231 with self._lock:
3232 try:
3233 self._in_use_connections.remove(connection)
3234 except KeyError:
3235 # Gracefully fail when a connection is returned to this pool
3236 # that the pool doesn't actually own
3237 return
3239 if self.owns_connection(connection):
3240 if connection.should_reconnect():
3241 connection.disconnect()
3242 self._available_connections.append(connection)
3243 self._event_dispatcher.dispatch(
3244 AfterConnectionReleasedEvent(connection)
3245 )
3247 # Record state transition: USED -> IDLE
3248 pool_name = get_pool_name(self)
3249 record_connection_count(
3250 pool_name=pool_name,
3251 connection_state=ConnectionState.USED,
3252 counter=-1,
3253 )
3254 record_connection_count(
3255 pool_name=pool_name,
3256 connection_state=ConnectionState.IDLE,
3257 counter=1,
3258 )
3259 else:
3260 # Pool doesn't own this connection, do not add it back
3261 # to the pool.
3262 # The created connections count should not be changed,
3263 # because the connection was not created by the pool.
3264 # Still need to decrement USED since it was counted in get_connection()
3265 connection.disconnect()
3266 record_connection_count(
3267 pool_name="unknown_pool",
3268 connection_state=ConnectionState.USED,
3269 counter=-1,
3270 )
3271 return
3273 def owns_connection(self, connection: "Connection") -> int:
3274 return connection.pid == self.pid
3276 def disconnect(self, inuse_connections: bool = True) -> None:
3277 """
3278 Disconnects connections in the pool
3280 If ``inuse_connections`` is True, disconnect connections that are
3281 currently in use, potentially by other threads. Otherwise only disconnect
3282 connections that are idle in the pool.
3283 """
3284 self._checkpid()
3285 with self._lock:
3286 if inuse_connections:
3287 connections = chain(
3288 self._available_connections, self._in_use_connections
3289 )
3290 else:
3291 connections = self._available_connections
3293 for connection in connections:
3294 connection.disconnect()
3296 def close(self) -> None:
3297 """Close the pool, disconnecting all connections"""
3298 self.disconnect()
3300 def set_retry(self, retry: Retry) -> None:
3301 self.connection_kwargs.update({"retry": retry})
3302 for conn in self._available_connections:
3303 conn.retry = retry
3304 for conn in self._in_use_connections:
3305 conn.retry = retry
3307 def re_auth_callback(self, token: TokenInterface):
3308 with self._lock:
3309 for conn in self._available_connections:
3310 conn.retry.call_with_retry(
3311 lambda: conn.send_command(
3312 "AUTH", token.try_get("oid"), token.get_value()
3313 ),
3314 lambda error: self._mock(error),
3315 )
3316 conn.retry.call_with_retry(
3317 lambda: conn.read_response(), lambda error: self._mock(error)
3318 )
3319 for conn in self._in_use_connections:
3320 conn.set_re_auth_token(token)
3322 def _get_pool_lock(self):
3323 return self._lock
3325 def _get_free_connections(self):
3326 with self._lock:
3327 return list(self._available_connections)
3329 def _get_in_use_connections(self):
3330 with self._lock:
3331 return set(self._in_use_connections)
3333 def _mock(self, error: RedisError):
3334 """
3335 Dummy functions, needs to be passed as error callback to retry object.
3336 :param error:
3337 :return:
3338 """
3339 pass
3341 def get_connection_count(self) -> List[tuple[int, dict]]:
3342 from redis.observability.attributes import get_pool_name
3344 attributes = AttributeBuilder.build_base_attributes()
3345 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
3346 free_connections_attributes = attributes.copy()
3347 in_use_connections_attributes = attributes.copy()
3349 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3350 ConnectionState.IDLE.value
3351 )
3352 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3353 ConnectionState.USED.value
3354 )
3356 return [
3357 (len(self._get_free_connections()), free_connections_attributes),
3358 (len(self._get_in_use_connections()), in_use_connections_attributes),
3359 ]
3362class BlockingConnectionPool(ConnectionPool):
3363 """
3364 Thread-safe blocking connection pool::
3366 >>> from redis.client import Redis
3367 >>> client = Redis(connection_pool=BlockingConnectionPool())
3369 It performs the same function as the default
3370 :py:class:`~redis.ConnectionPool` implementation, in that,
3371 it maintains a pool of reusable connections that can be shared by
3372 multiple redis clients (safely across threads if required).
3374 The difference is that, in the event that a client tries to get a
3375 connection from the pool when all of connections are in use, rather than
3376 raising a :py:class:`~redis.ConnectionError` (as the default
3377 :py:class:`~redis.ConnectionPool` implementation does), it
3378 makes the client wait ("blocks") for a specified number of seconds until
3379 a connection becomes available.
3381 Use ``max_connections`` to increase / decrease the pool size::
3383 >>> pool = BlockingConnectionPool(max_connections=10)
3385 Use ``timeout`` to tell it either how many seconds to wait for a connection
3386 to become available, or to block forever:
3388 >>> # Block forever.
3389 >>> pool = BlockingConnectionPool(timeout=None)
3391 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
3392 >>> # not available.
3393 >>> pool = BlockingConnectionPool(timeout=5)
3394 """
3396 def __init__(
3397 self,
3398 max_connections=50,
3399 timeout=20,
3400 connection_class=Connection,
3401 queue_class=LifoQueue,
3402 **connection_kwargs,
3403 ):
3404 self.queue_class = queue_class
3405 self.timeout = timeout
3406 self._in_maintenance = False
3407 self._locked = False
3408 super().__init__(
3409 connection_class=connection_class,
3410 max_connections=max_connections,
3411 **connection_kwargs,
3412 )
3414 def reset(self):
3415 # Create and fill up a thread safe queue with ``None`` values.
3416 try:
3417 if self._in_maintenance:
3418 self._lock.acquire()
3419 self._locked = True
3421 # Record metrics for connections being removed before clearing
3422 # Note: Access pool.queue directly to avoid deadlock since we may
3423 # already hold self._lock (which is non-reentrant)
3424 if (
3425 hasattr(self, "_connections")
3426 and self._connections
3427 and hasattr(self, "pool")
3428 ):
3429 with self._lock:
3430 connections_in_queue = {conn for conn in self.pool.queue if conn}
3431 idle_count = len(connections_in_queue)
3432 in_use_count = len(self._connections) - idle_count
3433 if idle_count > 0 or in_use_count > 0:
3434 pool_name = get_pool_name(self)
3435 if idle_count > 0:
3436 record_connection_count(
3437 pool_name=pool_name,
3438 connection_state=ConnectionState.IDLE,
3439 counter=-idle_count,
3440 )
3441 if in_use_count > 0:
3442 record_connection_count(
3443 pool_name=pool_name,
3444 connection_state=ConnectionState.USED,
3445 counter=-in_use_count,
3446 )
3448 self.pool = self.queue_class(self.max_connections)
3449 while True:
3450 try:
3451 self.pool.put_nowait(None)
3452 except Full:
3453 break
3455 # Keep a list of actual connection instances so that we can
3456 # disconnect them later.
3457 self._connections = []
3458 finally:
3459 if self._locked:
3460 try:
3461 self._lock.release()
3462 except Exception:
3463 pass
3464 self._locked = False
3466 # this must be the last operation in this method. while reset() is
3467 # called when holding _fork_lock, other threads in this process
3468 # can call _checkpid() which compares self.pid and os.getpid() without
3469 # holding any lock (for performance reasons). keeping this assignment
3470 # as the last operation ensures that those other threads will also
3471 # notice a pid difference and block waiting for the first thread to
3472 # release _fork_lock. when each of these threads eventually acquire
3473 # _fork_lock, they will notice that another thread already called
3474 # reset() and they will immediately release _fork_lock and continue on.
3475 self.pid = os.getpid()
3477 def __del__(self) -> None:
3478 """Clean up connection pool and record metrics when garbage collected."""
3479 try:
3480 # Note: Access pool.queue directly to avoid potential deadlock
3481 # if GC runs while the lock is held by the same thread
3482 if (
3483 hasattr(self, "_connections")
3484 and self._connections
3485 and hasattr(self, "pool")
3486 ):
3487 connections_in_queue = {conn for conn in self.pool.queue if conn}
3488 idle_count = len(connections_in_queue)
3489 in_use_count = len(self._connections) - idle_count
3490 if idle_count > 0 or in_use_count > 0:
3491 pool_name = get_pool_name(self)
3492 if idle_count > 0:
3493 record_connection_count(
3494 pool_name=pool_name,
3495 connection_state=ConnectionState.IDLE,
3496 counter=-idle_count,
3497 )
3498 if in_use_count > 0:
3499 record_connection_count(
3500 pool_name=pool_name,
3501 connection_state=ConnectionState.USED,
3502 counter=-in_use_count,
3503 )
3504 except Exception:
3505 pass
3507 def make_connection(self):
3508 "Make a fresh connection."
3509 try:
3510 if self._in_maintenance:
3511 self._lock.acquire()
3512 self._locked = True
3514 if self.cache is not None:
3515 connection = CacheProxyConnection(
3516 self.connection_class(**self.connection_kwargs),
3517 self.cache,
3518 self._lock,
3519 )
3520 else:
3521 connection = self.connection_class(**self.connection_kwargs)
3522 self._connections.append(connection)
3524 # Record new connection created (starts as IDLE)
3525 record_connection_count(
3526 pool_name=get_pool_name(self),
3527 connection_state=ConnectionState.IDLE,
3528 counter=1,
3529 )
3531 return connection
3532 finally:
3533 if self._locked:
3534 try:
3535 self._lock.release()
3536 except Exception:
3537 pass
3538 self._locked = False
3540 @deprecated_args(
3541 args_to_warn=["*"],
3542 reason="Use get_connection() without args instead",
3543 version="5.3.0",
3544 )
3545 def get_connection(self, command_name=None, *keys, **options):
3546 """
3547 Get a connection, blocking for ``self.timeout`` until a connection
3548 is available from the pool.
3550 If the connection returned is ``None`` then creates a new connection.
3551 Because we use a last-in first-out queue, the existing connections
3552 (having been returned to the pool after the initial ``None`` values
3553 were added) will be returned before ``None`` values. This means we only
3554 create new connections when we need to, i.e.: the actual number of
3555 connections will only increase in response to demand.
3556 """
3557 start_time_acquired = time.monotonic()
3558 # Make sure we haven't changed process.
3559 self._checkpid()
3560 is_created = False
3562 # Try and get a connection from the pool. If one isn't available within
3563 # self.timeout then raise a ``ConnectionError``.
3564 connection = None
3565 try:
3566 if self._in_maintenance:
3567 self._lock.acquire()
3568 self._locked = True
3569 try:
3570 connection = self.pool.get(block=True, timeout=self.timeout)
3571 except Empty:
3572 # Note that this is not caught by the redis client and will be
3573 # raised unless handled by application code. If you want never to
3574 raise ConnectionError("No connection available.")
3576 # If the ``connection`` is actually ``None`` then that's a cue to make
3577 # a new connection to add to the pool.
3578 if connection is None:
3579 # Start timing for observability
3580 start_time_created = time.monotonic()
3581 connection = self.make_connection()
3582 is_created = True
3583 finally:
3584 if self._locked:
3585 try:
3586 self._lock.release()
3587 except Exception:
3588 pass
3589 self._locked = False
3591 # Record state transition: IDLE -> USED
3592 # (make_connection already recorded IDLE +1 for new connections)
3593 # This ensures counters stay balanced if connect() fails and release() is called
3594 pool_name = get_pool_name(self)
3595 record_connection_count(
3596 pool_name=pool_name,
3597 connection_state=ConnectionState.IDLE,
3598 counter=-1,
3599 )
3600 record_connection_count(
3601 pool_name=pool_name,
3602 connection_state=ConnectionState.USED,
3603 counter=1,
3604 )
3606 try:
3607 # ensure this connection is connected to Redis
3608 connection.connect()
3609 # connections that the pool provides should be ready to send
3610 # a command. if not, the connection was either returned to the
3611 # pool before all data has been read or the socket has been
3612 # closed. either way, reconnect and verify everything is good.
3613 try:
3614 if connection.can_read():
3615 raise ConnectionError("Connection has data")
3616 except (ConnectionError, TimeoutError, OSError):
3617 connection.disconnect()
3618 connection.connect()
3619 if connection.can_read():
3620 raise ConnectionError("Connection not ready")
3621 except BaseException:
3622 # release the connection back to the pool so that we don't leak it
3623 self.release(connection)
3624 raise
3626 if is_created:
3627 record_connection_create_time(
3628 connection_pool=self,
3629 duration_seconds=time.monotonic() - start_time_created,
3630 )
3632 record_connection_wait_time(
3633 pool_name=pool_name,
3634 duration_seconds=time.monotonic() - start_time_acquired,
3635 )
3637 return connection
3639 def release(self, connection):
3640 "Releases the connection back to the pool."
3641 # Make sure we haven't changed process.
3642 self._checkpid()
3644 try:
3645 if self._in_maintenance:
3646 self._lock.acquire()
3647 self._locked = True
3648 if not self.owns_connection(connection):
3649 # pool doesn't own this connection. do not add it back
3650 # to the pool. instead add a None value which is a placeholder
3651 # that will cause the pool to recreate the connection if
3652 # its needed.
3653 connection.disconnect()
3654 self.pool.put_nowait(None)
3655 # Still need to decrement USED since it was counted in get_connection()
3656 record_connection_count(
3657 pool_name="unknown_pool",
3658 connection_state=ConnectionState.USED,
3659 counter=-1,
3660 )
3661 return
3662 if connection.should_reconnect():
3663 connection.disconnect()
3664 # Put the connection back into the pool.
3665 pool_name = get_pool_name(self)
3666 try:
3667 self.pool.put_nowait(connection)
3669 # Record state transition: USED -> IDLE
3670 record_connection_count(
3671 pool_name=pool_name,
3672 connection_state=ConnectionState.USED,
3673 counter=-1,
3674 )
3675 record_connection_count(
3676 pool_name=pool_name,
3677 connection_state=ConnectionState.IDLE,
3678 counter=1,
3679 )
3680 except Full:
3681 pass
3682 finally:
3683 if self._locked:
3684 try:
3685 self._lock.release()
3686 except Exception:
3687 pass
3688 self._locked = False
3690 def disconnect(self, inuse_connections: bool = True):
3691 """
3692 Disconnects either all connections in the pool or just the free connections.
3693 """
3694 self._checkpid()
3695 try:
3696 if self._in_maintenance:
3697 self._lock.acquire()
3698 self._locked = True
3700 if inuse_connections:
3701 connections = self._connections
3702 else:
3703 connections = self._get_free_connections()
3705 for connection in connections:
3706 connection.disconnect()
3707 finally:
3708 if self._locked:
3709 try:
3710 self._lock.release()
3711 except Exception:
3712 pass
3713 self._locked = False
3715 def _get_free_connections(self):
3716 with self._lock:
3717 return {conn for conn in self.pool.queue if conn}
3719 def _get_in_use_connections(self):
3720 with self._lock:
3721 # free connections
3722 connections_in_queue = {conn for conn in self.pool.queue if conn}
3723 # in self._connections we keep all created connections
3724 # so the ones that are not in the queue are the in use ones
3725 return {
3726 conn for conn in self._connections if conn not in connections_in_queue
3727 }
3729 def set_in_maintenance(self, in_maintenance: bool):
3730 """
3731 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
3733 This is used to prevent new connections from being created while we are in maintenance mode.
3734 The pool will be in maintenance mode only when we are processing a MOVING notification.
3735 """
3736 self._in_maintenance = in_maintenance