Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32 CacheProxy,
33)
35from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
36from ._parsers.socket import SENTINEL
37from .auth.token import TokenInterface
38from .backoff import NoBackoff
39from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
40from .driver_info import DriverInfo, resolve_driver_info
41from .event import AfterConnectionReleasedEvent, EventDispatcher
42from .exceptions import (
43 AuthenticationError,
44 AuthenticationWrongNumberOfArgsError,
45 ChildDeadlockedError,
46 ConnectionError,
47 DataError,
48 MaxConnectionsError,
49 RedisError,
50 ResponseError,
51 TimeoutError,
52)
53from .maint_notifications import (
54 MaintenanceState,
55 MaintNotificationsConfig,
56 MaintNotificationsConnectionHandler,
57 MaintNotificationsPoolHandler,
58 OSSMaintNotificationsHandler,
59)
60from .observability.attributes import (
61 DB_CLIENT_CONNECTION_POOL_NAME,
62 DB_CLIENT_CONNECTION_STATE,
63 AttributeBuilder,
64 ConnectionState,
65 CSCReason,
66 CSCResult,
67 get_pool_name,
68)
69from .observability.metrics import CloseReason
70from .observability.recorder import (
71 init_csc_items,
72 record_connection_closed,
73 record_connection_count,
74 record_connection_create_time,
75 record_connection_wait_time,
76 record_csc_eviction,
77 record_csc_network_saved,
78 record_csc_request,
79 record_error_count,
80 register_csc_items_callback,
81)
82from .retry import Retry
83from .utils import (
84 CRYPTOGRAPHY_AVAILABLE,
85 DEFAULT_RESP_VERSION,
86 HIREDIS_AVAILABLE,
87 SSL_AVAILABLE,
88 check_protocol_version,
89 compare_versions,
90 deprecated_args,
91 ensure_string,
92 format_error_message,
93 str_if_bytes,
94)
96if SSL_AVAILABLE:
97 import ssl
98 from ssl import VerifyFlags
99else:
100 ssl = None
101 VerifyFlags = None
103if HIREDIS_AVAILABLE:
104 import hiredis
106SYM_STAR = b"*"
107SYM_DOLLAR = b"$"
108SYM_CRLF = b"\r\n"
109SYM_EMPTY = b""
112DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
113if HIREDIS_AVAILABLE:
114 DefaultParser = _HiredisParser
115else:
116 DefaultParser = _RESP3Parser
119class HiredisRespSerializer:
120 def pack(self, *args: List):
121 """Pack a series of arguments into the Redis protocol"""
122 output = []
124 if isinstance(args[0], str):
125 args = tuple(args[0].encode().split()) + args[1:]
126 elif b" " in args[0]:
127 args = tuple(args[0].split()) + args[1:]
128 try:
129 output.append(hiredis.pack_command(args))
130 except TypeError:
131 _, value, traceback = sys.exc_info()
132 raise DataError(value).with_traceback(traceback)
134 return output
137class PythonRespSerializer:
138 def __init__(self, buffer_cutoff, encode) -> None:
139 self._buffer_cutoff = buffer_cutoff
140 self.encode = encode
142 def pack(self, *args):
143 """Pack a series of arguments into the Redis protocol"""
144 output = []
145 # the client might have included 1 or more literal arguments in
146 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
147 # arguments to be sent separately, so split the first argument
148 # manually. These arguments should be bytestrings so that they are
149 # not encoded.
150 if isinstance(args[0], str):
151 args = tuple(args[0].encode().split()) + args[1:]
152 elif b" " in args[0]:
153 args = tuple(args[0].split()) + args[1:]
155 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
157 buffer_cutoff = self._buffer_cutoff
158 for arg in map(self.encode, args):
159 # to avoid large string mallocs, chunk the command into the
160 # output list if we're sending large values or memoryviews
161 arg_length = len(arg)
162 if (
163 len(buff) > buffer_cutoff
164 or arg_length > buffer_cutoff
165 or isinstance(arg, memoryview)
166 ):
167 buff = SYM_EMPTY.join(
168 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
169 )
170 output.append(buff)
171 output.append(arg)
172 buff = SYM_CRLF
173 else:
174 buff = SYM_EMPTY.join(
175 (
176 buff,
177 SYM_DOLLAR,
178 str(arg_length).encode(),
179 SYM_CRLF,
180 arg,
181 SYM_CRLF,
182 )
183 )
184 output.append(buff)
185 return output
188class ConnectionInterface:
189 @abstractmethod
190 def repr_pieces(self):
191 pass
193 @abstractmethod
194 def register_connect_callback(self, callback):
195 pass
197 @abstractmethod
198 def deregister_connect_callback(self, callback):
199 pass
201 @abstractmethod
202 def set_parser(self, parser_class):
203 pass
205 @abstractmethod
206 def get_protocol(self):
207 pass
209 @abstractmethod
210 def connect(self):
211 pass
213 @abstractmethod
214 def on_connect(self):
215 pass
217 @abstractmethod
218 def disconnect(self, *args, **kwargs):
219 pass
221 @abstractmethod
222 def check_health(self):
223 pass
225 @abstractmethod
226 def send_packed_command(self, command, check_health=True):
227 pass
229 @abstractmethod
230 def send_command(self, *args, **kwargs):
231 pass
233 @abstractmethod
234 def can_read(self, timeout=0):
235 pass
237 @abstractmethod
238 def read_response(
239 self,
240 disable_decoding=False,
241 *,
242 timeout: Union[float, object] = SENTINEL,
243 disconnect_on_error=True,
244 push_request=False,
245 ):
246 pass
248 @abstractmethod
249 def pack_command(self, *args):
250 pass
252 @abstractmethod
253 def pack_commands(self, commands):
254 pass
256 @property
257 @abstractmethod
258 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
259 pass
261 @abstractmethod
262 def set_re_auth_token(self, token: TokenInterface):
263 pass
265 @abstractmethod
266 def re_auth(self):
267 pass
269 @abstractmethod
270 def mark_for_reconnect(self):
271 """
272 Mark the connection to be reconnected on the next command.
273 This is useful when a connection is moved to a different node.
274 """
275 pass
277 @abstractmethod
278 def should_reconnect(self):
279 """
280 Returns True if the connection should be reconnected.
281 """
282 pass
284 @abstractmethod
285 def reset_should_reconnect(self):
286 """
287 Reset the internal flag to False.
288 """
289 pass
291 @abstractmethod
292 def extract_connection_details(self) -> str:
293 pass
295 @property
296 @abstractmethod
297 def is_connected(self) -> bool:
298 """
299 Return ``True`` if the connection to the server is active.
300 """
301 pass
304class MaintNotificationsAbstractConnection:
305 """
306 Abstract class for handling maintenance notifications logic.
307 This class is expected to be used as base class together with ConnectionInterface.
309 This class is intended to be used with multiple inheritance!
311 All logic related to maintenance notifications is encapsulated in this class.
312 """
314 def __init__(
315 self,
316 maint_notifications_config: Optional[MaintNotificationsConfig],
317 maint_notifications_pool_handler: Optional[
318 MaintNotificationsPoolHandler
319 ] = None,
320 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
321 maintenance_notification_hash: Optional[int] = None,
322 orig_host_address: Optional[str] = None,
323 orig_socket_timeout: Optional[float] = None,
324 orig_socket_connect_timeout: Optional[float] = None,
325 oss_cluster_maint_notifications_handler: Optional[
326 OSSMaintNotificationsHandler
327 ] = None,
328 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
329 event_dispatcher: Optional[EventDispatcher] = None,
330 ):
331 """
332 Initialize the maintenance notifications for the connection.
334 Args:
335 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
336 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
337 maintenance_state (MaintenanceState): The current maintenance state of the connection.
338 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
339 orig_host_address (Optional[str]): The original host address of the connection.
340 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
341 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
342 oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications.
343 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
344 If not provided, the parser from the connection is used.
345 This is useful when the parser is created after this object.
346 """
347 self.maint_notifications_config = maint_notifications_config
348 self.maintenance_state = maintenance_state
349 self.maintenance_notification_hash = maintenance_notification_hash
351 if event_dispatcher is not None:
352 self.event_dispatcher = event_dispatcher
353 else:
354 self.event_dispatcher = EventDispatcher()
356 self._configure_maintenance_notifications(
357 maint_notifications_pool_handler,
358 orig_host_address,
359 orig_socket_timeout,
360 orig_socket_connect_timeout,
361 oss_cluster_maint_notifications_handler,
362 parser,
363 )
364 self._processed_start_maint_notifications = set()
365 self._skipped_end_maint_notifications = set()
367 @abstractmethod
368 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
369 pass
371 @abstractmethod
372 def _get_socket(self) -> Optional[socket.socket]:
373 pass
375 @abstractmethod
376 def get_protocol(self) -> Union[int, str]:
377 """
378 Returns:
379 The RESP protocol version, or ``None`` if the protocol is not specified,
380 in which case the server default will be used.
381 """
382 pass
384 @property
385 @abstractmethod
386 def host(self) -> str:
387 pass
389 @host.setter
390 @abstractmethod
391 def host(self, value: str):
392 pass
394 @property
395 @abstractmethod
396 def socket_timeout(self) -> Optional[Union[float, int]]:
397 pass
399 @socket_timeout.setter
400 @abstractmethod
401 def socket_timeout(self, value: Optional[Union[float, int]]):
402 pass
404 @property
405 @abstractmethod
406 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
407 pass
409 @socket_connect_timeout.setter
410 @abstractmethod
411 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
412 pass
414 @abstractmethod
415 def send_command(self, *args, **kwargs):
416 pass
418 @abstractmethod
419 def read_response(
420 self,
421 disable_decoding=False,
422 *,
423 timeout: Union[float, object] = SENTINEL,
424 disconnect_on_error=True,
425 push_request=False,
426 ):
427 pass
429 @abstractmethod
430 def disconnect(self, *args, **kwargs):
431 pass
433 def _configure_maintenance_notifications(
434 self,
435 maint_notifications_pool_handler: Optional[
436 MaintNotificationsPoolHandler
437 ] = None,
438 orig_host_address=None,
439 orig_socket_timeout=None,
440 orig_socket_connect_timeout=None,
441 oss_cluster_maint_notifications_handler: Optional[
442 OSSMaintNotificationsHandler
443 ] = None,
444 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
445 ):
446 """
447 Enable maintenance notifications by setting up
448 handlers and storing original connection parameters.
450 Should be used ONLY with parsers that support push notifications.
451 """
452 if (
453 not self.maint_notifications_config
454 or not self.maint_notifications_config.enabled
455 ):
456 self._maint_notifications_pool_handler = None
457 self._maint_notifications_connection_handler = None
458 self._oss_cluster_maint_notifications_handler = None
459 return
461 if not parser:
462 raise RedisError(
463 "To configure maintenance notifications, a parser must be provided!"
464 )
466 if not isinstance(parser, _HiredisParser) and not isinstance(
467 parser, _RESP3Parser
468 ):
469 raise RedisError(
470 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
471 )
473 if maint_notifications_pool_handler:
474 # Extract a reference to a new pool handler that copies all properties
475 # of the original one and has a different connection reference
476 # This is needed because when we attach the handler to the parser
477 # we need to make sure that the handler has a reference to the
478 # connection that the parser is attached to.
479 self._maint_notifications_pool_handler = (
480 maint_notifications_pool_handler.get_handler_for_connection()
481 )
482 self._maint_notifications_pool_handler.set_connection(self)
483 else:
484 self._maint_notifications_pool_handler = None
486 self._maint_notifications_connection_handler = (
487 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
488 )
490 if oss_cluster_maint_notifications_handler:
491 self._oss_cluster_maint_notifications_handler = (
492 oss_cluster_maint_notifications_handler
493 )
494 else:
495 self._oss_cluster_maint_notifications_handler = None
497 # Set up OSS cluster handler to parser if available
498 if self._oss_cluster_maint_notifications_handler:
499 parser.set_oss_cluster_maint_push_handler(
500 self._oss_cluster_maint_notifications_handler.handle_notification
501 )
503 # Set up pool handler to parser if available
504 if self._maint_notifications_pool_handler:
505 parser.set_node_moving_push_handler(
506 self._maint_notifications_pool_handler.handle_notification
507 )
509 # Set up connection handler
510 parser.set_maintenance_push_handler(
511 self._maint_notifications_connection_handler.handle_notification
512 )
514 # Store original connection parameters
515 self.orig_host_address = orig_host_address if orig_host_address else self.host
516 self.orig_socket_timeout = (
517 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
518 )
519 self.orig_socket_connect_timeout = (
520 orig_socket_connect_timeout
521 if orig_socket_connect_timeout
522 else self.socket_connect_timeout
523 )
525 def set_maint_notifications_pool_handler_for_connection(
526 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
527 ):
528 # Deep copy the pool handler to avoid sharing the same pool handler
529 # between multiple connections, because otherwise each connection will override
530 # the connection reference and the pool handler will only hold a reference
531 # to the last connection that was set.
532 maint_notifications_pool_handler_copy = (
533 maint_notifications_pool_handler.get_handler_for_connection()
534 )
536 maint_notifications_pool_handler_copy.set_connection(self)
537 self._get_parser().set_node_moving_push_handler(
538 maint_notifications_pool_handler_copy.handle_notification
539 )
541 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
543 # Update maintenance notification connection handler if it doesn't exist
544 if not self._maint_notifications_connection_handler:
545 self._maint_notifications_connection_handler = (
546 MaintNotificationsConnectionHandler(
547 self, maint_notifications_pool_handler.config
548 )
549 )
550 self._get_parser().set_maintenance_push_handler(
551 self._maint_notifications_connection_handler.handle_notification
552 )
553 else:
554 self._maint_notifications_connection_handler.config = (
555 maint_notifications_pool_handler.config
556 )
558 def set_maint_notifications_cluster_handler_for_connection(
559 self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
560 ):
561 self._get_parser().set_oss_cluster_maint_push_handler(
562 oss_cluster_maint_notifications_handler.handle_notification
563 )
565 self._oss_cluster_maint_notifications_handler = (
566 oss_cluster_maint_notifications_handler
567 )
569 # Update maintenance notification connection handler if it doesn't exist
570 if not self._maint_notifications_connection_handler:
571 self._maint_notifications_connection_handler = (
572 MaintNotificationsConnectionHandler(
573 self, oss_cluster_maint_notifications_handler.config
574 )
575 )
576 self._get_parser().set_maintenance_push_handler(
577 self._maint_notifications_connection_handler.handle_notification
578 )
579 else:
580 self._maint_notifications_connection_handler.config = (
581 oss_cluster_maint_notifications_handler.config
582 )
584 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
585 # Send maintenance notifications handshake if RESP3 is active
586 # and maintenance notifications are enabled
587 # and we have a host to determine the endpoint type from
588 # When the maint_notifications_config enabled mode is "auto",
589 # we just log a warning if the handshake fails
590 # When the mode is enabled=True, we raise an exception in case of failure
591 if (
592 self.get_protocol() not in [2, "2"]
593 and self.maint_notifications_config
594 and self.maint_notifications_config.enabled
595 and self._maint_notifications_connection_handler
596 and hasattr(self, "host")
597 ):
598 self._enable_maintenance_notifications(
599 maint_notifications_config=self.maint_notifications_config,
600 check_health=check_health,
601 )
603 def _enable_maintenance_notifications(
604 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
605 ):
606 try:
607 host = getattr(self, "host", None)
608 if host is None:
609 raise ValueError(
610 "Cannot enable maintenance notifications for connection"
611 " object that doesn't have a host attribute."
612 )
613 else:
614 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
615 self.send_command(
616 "CLIENT",
617 "MAINT_NOTIFICATIONS",
618 "ON",
619 "moving-endpoint-type",
620 endpoint_type.value,
621 check_health=check_health,
622 )
623 response = self.read_response()
624 if not response or str_if_bytes(response) != "OK":
625 raise ResponseError(
626 "The server doesn't support maintenance notifications"
627 )
628 except Exception as e:
629 if (
630 isinstance(e, ResponseError)
631 and maint_notifications_config.enabled == "auto"
632 ):
633 # Log warning but don't fail the connection
634 import logging
636 logger = logging.getLogger(__name__)
637 logger.debug(f"Failed to enable maintenance notifications: {e}")
638 else:
639 raise
641 def get_resolved_ip(self) -> Optional[str]:
642 """
643 Extract the resolved IP address from an
644 established connection or resolve it from the host.
646 First tries to get the actual IP from the socket (most accurate),
647 then falls back to DNS resolution if needed.
649 Args:
650 connection: The connection object to extract the IP from
652 Returns:
653 str: The resolved IP address, or None if it cannot be determined
654 """
656 # Method 1: Try to get the actual IP from the established socket connection
657 # This is most accurate as it shows the exact IP being used
658 try:
659 conn_socket = self._get_socket()
660 if conn_socket is not None:
661 peer_addr = conn_socket.getpeername()
662 if peer_addr and len(peer_addr) >= 1:
663 # For TCP sockets, peer_addr is typically (host, port) tuple
664 # Return just the host part
665 return peer_addr[0]
666 except (AttributeError, OSError):
667 # Socket might not be connected or getpeername() might fail
668 pass
670 # Method 2: Fallback to DNS resolution of the host
671 # This is less accurate but works when socket is not available
672 try:
673 host = getattr(self, "host", "localhost")
674 port = getattr(self, "port", 6379)
675 if host:
676 # Use getaddrinfo to resolve the hostname to IP
677 # This mimics what the connection would do during _connect()
678 addr_info = socket.getaddrinfo(
679 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
680 )
681 if addr_info:
682 # Return the IP from the first result
683 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
684 # sockaddr[0] is the IP address
685 return str(addr_info[0][4][0])
686 except (AttributeError, OSError, socket.gaierror):
687 # DNS resolution might fail
688 pass
690 return None
692 @property
693 def maintenance_state(self) -> MaintenanceState:
694 return self._maintenance_state
696 @maintenance_state.setter
697 def maintenance_state(self, state: "MaintenanceState"):
698 self._maintenance_state = state
700 def add_maint_start_notification(self, id: int):
701 self._processed_start_maint_notifications.add(id)
703 def get_processed_start_notifications(self) -> set:
704 return self._processed_start_maint_notifications
706 def add_skipped_end_notification(self, id: int):
707 self._skipped_end_maint_notifications.add(id)
709 def get_skipped_end_notifications(self) -> set:
710 return self._skipped_end_maint_notifications
712 def reset_received_notifications(self):
713 self._processed_start_maint_notifications.clear()
714 self._skipped_end_maint_notifications.clear()
716 def getpeername(self):
717 """
718 Returns the peer name of the connection.
719 """
720 conn_socket = self._get_socket()
721 if conn_socket:
722 return conn_socket.getpeername()[0]
723 return None
725 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
726 conn_socket = self._get_socket()
727 if conn_socket:
728 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
729 # if the current timeout is 0 it means we are in the middle of a can_read call
730 # in this case we don't want to change the timeout because the operation
731 # is non-blocking and should return immediately
732 # Changing the state from non-blocking to blocking in the middle of a read operation
733 # will lead to a deadlock
734 if conn_socket.gettimeout() != 0:
735 conn_socket.settimeout(timeout)
736 self.update_parser_timeout(timeout)
738 def update_parser_timeout(self, timeout: Optional[float] = None):
739 parser = self._get_parser()
740 if parser and parser._buffer:
741 if isinstance(parser, _RESP3Parser) and timeout:
742 parser._buffer.socket_timeout = timeout
743 elif isinstance(parser, _HiredisParser):
744 parser._socket_timeout = timeout
746 def set_tmp_settings(
747 self,
748 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
749 tmp_relaxed_timeout: Optional[float] = None,
750 ):
751 """
752 The value of SENTINEL is used to indicate that the property should not be updated.
753 """
754 if tmp_host_address and tmp_host_address != SENTINEL:
755 self.host = str(tmp_host_address)
756 if tmp_relaxed_timeout != -1:
757 self.socket_timeout = tmp_relaxed_timeout
758 self.socket_connect_timeout = tmp_relaxed_timeout
760 def reset_tmp_settings(
761 self,
762 reset_host_address: bool = False,
763 reset_relaxed_timeout: bool = False,
764 ):
765 if reset_host_address:
766 self.host = self.orig_host_address
767 if reset_relaxed_timeout:
768 self.socket_timeout = self.orig_socket_timeout
769 self.socket_connect_timeout = self.orig_socket_connect_timeout
772class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
773 "Manages communication to and from a Redis server"
775 @deprecated_args(
776 args_to_warn=["lib_name", "lib_version"],
777 reason="Use 'driver_info' parameter instead. "
778 "lib_name and lib_version will be removed in a future version.",
779 )
780 def __init__(
781 self,
782 db: int = 0,
783 password: Optional[str] = None,
784 socket_timeout: Optional[float] = None,
785 socket_connect_timeout: Optional[float] = None,
786 retry_on_timeout: bool = False,
787 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
788 encoding: str = "utf-8",
789 encoding_errors: str = "strict",
790 decode_responses: bool = False,
791 parser_class=DefaultParser,
792 socket_read_size: int = 65536,
793 health_check_interval: int = 0,
794 client_name: Optional[str] = None,
795 lib_name: Optional[str] = None,
796 lib_version: Optional[str] = None,
797 driver_info: Optional[DriverInfo] = None,
798 username: Optional[str] = None,
799 retry: Union[Any, None] = None,
800 redis_connect_func: Optional[Callable[[], None]] = None,
801 credential_provider: Optional[CredentialProvider] = None,
802 protocol: Optional[int] = 3,
803 command_packer: Optional[Callable[[], None]] = None,
804 event_dispatcher: Optional[EventDispatcher] = None,
805 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
806 maint_notifications_pool_handler: Optional[
807 MaintNotificationsPoolHandler
808 ] = None,
809 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
810 maintenance_notification_hash: Optional[int] = None,
811 orig_host_address: Optional[str] = None,
812 orig_socket_timeout: Optional[float] = None,
813 orig_socket_connect_timeout: Optional[float] = None,
814 oss_cluster_maint_notifications_handler: Optional[
815 OSSMaintNotificationsHandler
816 ] = None,
817 ):
818 """
819 Initialize a new Connection.
821 To specify a retry policy for specific errors, first set
822 `retry_on_error` to a list of the error/s to retry on, then set
823 `retry` to a valid `Retry` object.
824 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
826 Parameters
827 ----------
828 driver_info : DriverInfo, optional
829 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
830 are ignored. If not provided, a DriverInfo will be created from lib_name
831 and lib_version (or defaults if those are also None).
832 lib_name : str, optional
833 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
834 lib_version : str, optional
835 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
836 """
837 if (username or password) and credential_provider is not None:
838 raise DataError(
839 "'username' and 'password' cannot be passed along with 'credential_"
840 "provider'. Please provide only one of the following arguments: \n"
841 "1. 'password' and (optional) 'username'\n"
842 "2. 'credential_provider'"
843 )
844 if event_dispatcher is None:
845 self._event_dispatcher = EventDispatcher()
846 else:
847 self._event_dispatcher = event_dispatcher
848 self.pid = os.getpid()
849 self.db = db
850 self.client_name = client_name
852 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
853 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
855 self.credential_provider = credential_provider
856 self.password = password
857 self.username = username
858 self._socket_timeout = socket_timeout
859 if socket_connect_timeout is None:
860 socket_connect_timeout = socket_timeout
861 self._socket_connect_timeout = socket_connect_timeout
862 self.retry_on_timeout = retry_on_timeout
863 if retry_on_error is SENTINEL:
864 retry_on_errors_list = []
865 else:
866 retry_on_errors_list = list(retry_on_error)
867 if retry_on_timeout:
868 # Add TimeoutError to the errors list to retry on
869 retry_on_errors_list.append(TimeoutError)
870 self.retry_on_error = retry_on_errors_list
871 if retry or self.retry_on_error:
872 if retry is None:
873 self.retry = Retry(NoBackoff(), 1)
874 else:
875 # deep-copy the Retry object as it is mutable
876 self.retry = copy.deepcopy(retry)
877 if self.retry_on_error:
878 # Update the retry's supported errors with the specified errors
879 self.retry.update_supported_errors(self.retry_on_error)
880 else:
881 self.retry = Retry(NoBackoff(), 0)
882 self.health_check_interval = health_check_interval
883 self.next_health_check = 0
884 self.redis_connect_func = redis_connect_func
885 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
886 self.handshake_metadata = None
887 self._sock = None
888 self._socket_read_size = socket_read_size
889 self._connect_callbacks = []
890 self._buffer_cutoff = 6000
891 self._re_auth_token: Optional[TokenInterface] = None
892 try:
893 p = int(protocol)
894 except TypeError:
895 p = DEFAULT_RESP_VERSION
896 except ValueError:
897 raise ConnectionError("protocol must be an integer")
898 else:
899 if p < 2 or p > 3:
900 raise ConnectionError("protocol must be either 2 or 3")
901 self.protocol = p
902 # Reconcile parser ↔ protocol mismatches.
903 # Hiredis handles both RESP2 and RESP3 natively, so only
904 # pure-Python parsers need to be swapped.
905 if self.protocol == 3 and parser_class == _RESP2Parser:
906 parser_class = _RESP3Parser
907 elif self.protocol == 2 and parser_class == _RESP3Parser:
908 parser_class = _RESP2Parser
909 self.set_parser(parser_class)
911 self._command_packer = self._construct_command_packer(command_packer)
912 self._should_reconnect = False
914 # Set up maintenance notifications
915 MaintNotificationsAbstractConnection.__init__(
916 self,
917 maint_notifications_config,
918 maint_notifications_pool_handler,
919 maintenance_state,
920 maintenance_notification_hash,
921 orig_host_address,
922 orig_socket_timeout,
923 orig_socket_connect_timeout,
924 oss_cluster_maint_notifications_handler,
925 self._parser,
926 event_dispatcher=self._event_dispatcher,
927 )
929 def __repr__(self):
930 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
931 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
933 @abstractmethod
934 def repr_pieces(self):
935 pass
937 def __del__(self):
938 try:
939 self.disconnect()
940 except Exception:
941 pass
943 @property
944 def is_connected(self) -> bool:
945 return self._sock is not None
947 def _construct_command_packer(self, packer):
948 if packer is not None:
949 return packer
950 elif HIREDIS_AVAILABLE:
951 return HiredisRespSerializer()
952 else:
953 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
955 def register_connect_callback(self, callback):
956 """
957 Register a callback to be called when the connection is established either
958 initially or reconnected. This allows listeners to issue commands that
959 are ephemeral to the connection, for example pub/sub subscription or
960 key tracking. The callback must be a _method_ and will be kept as
961 a weak reference.
962 """
963 wm = weakref.WeakMethod(callback)
964 if wm not in self._connect_callbacks:
965 self._connect_callbacks.append(wm)
967 def deregister_connect_callback(self, callback):
968 """
969 De-register a previously registered callback. It will no-longer receive
970 notifications on connection events. Calling this is not required when the
971 listener goes away, since the callbacks are kept as weak methods.
972 """
973 try:
974 self._connect_callbacks.remove(weakref.WeakMethod(callback))
975 except ValueError:
976 pass
978 def set_parser(self, parser_class):
979 """
980 Creates a new instance of parser_class with socket size:
981 _socket_read_size and assigns it to the parser for the connection
982 :param parser_class: The required parser class
983 """
984 self._parser = parser_class(socket_read_size=self._socket_read_size)
986 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
987 return self._parser
989 def connect(self):
990 "Connects to the Redis server if not already connected"
991 # try once the socket connect with the handshake, retry the whole
992 # connect/handshake flow based on retry policy
993 self.retry.call_with_retry(
994 lambda: self.connect_check_health(
995 check_health=True, retry_socket_connect=False
996 ),
997 lambda error: self.disconnect(error),
998 )
1000 def connect_check_health(
1001 self, check_health: bool = True, retry_socket_connect: bool = True
1002 ):
1003 if self._sock:
1004 return
1005 # Track actual retry attempts for error reporting
1006 actual_retry_attempts = [0]
1008 def failure_callback(error, failure_count):
1009 actual_retry_attempts[0] = failure_count
1010 self.disconnect(error=error, failure_count=failure_count)
1012 try:
1013 if retry_socket_connect:
1014 sock = self.retry.call_with_retry(
1015 self._connect,
1016 failure_callback,
1017 with_failure_count=True,
1018 )
1019 else:
1020 sock = self._connect()
1021 except socket.timeout:
1022 e = TimeoutError("Timeout connecting to server")
1023 record_error_count(
1024 server_address=self.host,
1025 server_port=self.port,
1026 network_peer_address=self.host,
1027 network_peer_port=self.port,
1028 error_type=e,
1029 retry_attempts=actual_retry_attempts[0],
1030 )
1031 raise e
1032 except OSError as e:
1033 e = ConnectionError(self._error_message(e))
1034 record_error_count(
1035 server_address=getattr(self, "host", None),
1036 server_port=getattr(self, "port", None),
1037 network_peer_address=getattr(self, "host", None),
1038 network_peer_port=getattr(self, "port", None),
1039 error_type=e,
1040 retry_attempts=actual_retry_attempts[0],
1041 )
1042 raise e
1044 self._sock = sock
1045 try:
1046 if self.redis_connect_func is None:
1047 # Use the default on_connect function
1048 self.on_connect_check_health(check_health=check_health)
1049 else:
1050 # Use the passed function redis_connect_func
1051 self.redis_connect_func(self)
1052 except RedisError:
1053 # clean up after any error in on_connect
1054 self.disconnect()
1055 raise
1057 # run any user callbacks. right now the only internal callback
1058 # is for pubsub channel/pattern resubscription
1059 # first, remove any dead weakrefs
1060 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
1061 for ref in self._connect_callbacks:
1062 callback = ref()
1063 if callback:
1064 callback(self)
1066 @abstractmethod
1067 def _connect(self):
1068 pass
1070 @abstractmethod
1071 def _host_error(self):
1072 pass
1074 def _error_message(self, exception):
1075 return format_error_message(self._host_error(), exception)
1077 def on_connect(self):
1078 self.on_connect_check_health(check_health=True)
1080 def on_connect_check_health(self, check_health: bool = True):
1081 "Initialize the connection, authenticate and select a database"
1082 self._parser.on_connect(self)
1083 parser = self._parser
1085 auth_args = None
1086 # if credential provider or username and/or password are set, authenticate
1087 if self.credential_provider or (self.username or self.password):
1088 cred_provider = (
1089 self.credential_provider
1090 or UsernamePasswordCredentialProvider(self.username, self.password)
1091 )
1092 auth_args = cred_provider.get_credentials()
1094 # if resp version is specified and we have auth args,
1095 # we need to send them via HELLO
1096 if auth_args and self.protocol not in [2, "2"]:
1097 if isinstance(self._parser, _RESP2Parser):
1098 self.set_parser(_RESP3Parser)
1099 # update cluster exception classes
1100 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1101 self._parser.on_connect(self)
1102 if len(auth_args) == 1:
1103 auth_args = ["default", auth_args[0]]
1104 # avoid checking health here -- PING will fail if we try
1105 # to check the health prior to the AUTH
1106 self.send_command(
1107 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
1108 )
1109 self.handshake_metadata = self.read_response()
1110 # if response.get(b"proto") != self.protocol and response.get(
1111 # "proto"
1112 # ) != self.protocol:
1113 # raise ConnectionError("Invalid RESP version")
1114 elif auth_args:
1115 # avoid checking health here -- PING will fail if we try
1116 # to check the health prior to the AUTH
1117 self.send_command("AUTH", *auth_args, check_health=False)
1119 try:
1120 auth_response = self.read_response()
1121 except AuthenticationWrongNumberOfArgsError:
1122 # a username and password were specified but the Redis
1123 # server seems to be < 6.0.0 which expects a single password
1124 # arg. retry auth with just the password.
1125 # https://github.com/andymccurdy/redis-py/issues/1274
1126 self.send_command("AUTH", auth_args[-1], check_health=False)
1127 auth_response = self.read_response()
1129 if str_if_bytes(auth_response) != "OK":
1130 raise AuthenticationError("Invalid Username or Password")
1132 # if resp version is specified, switch to it
1133 elif self.protocol not in [2, "2"]:
1134 if isinstance(self._parser, _RESP2Parser):
1135 self.set_parser(_RESP3Parser)
1136 # update cluster exception classes
1137 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1138 self._parser.on_connect(self)
1139 self.send_command("HELLO", self.protocol, check_health=check_health)
1140 self.handshake_metadata = self.read_response()
1141 if (
1142 self.handshake_metadata.get(b"proto") != self.protocol
1143 and self.handshake_metadata.get("proto") != self.protocol
1144 ):
1145 raise ConnectionError("Invalid RESP version")
1147 # Activate maintenance notifications for this connection
1148 # if enabled in the configuration
1149 # This is a no-op if maintenance notifications are not enabled
1150 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
1152 # if a client_name is given, set it
1153 if self.client_name:
1154 self.send_command(
1155 "CLIENT",
1156 "SETNAME",
1157 self.client_name,
1158 check_health=check_health,
1159 )
1160 if str_if_bytes(self.read_response()) != "OK":
1161 raise ConnectionError("Error setting client name")
1163 # Set the library name and version from driver_info
1164 try:
1165 if self.driver_info and self.driver_info.formatted_name:
1166 self.send_command(
1167 "CLIENT",
1168 "SETINFO",
1169 "LIB-NAME",
1170 self.driver_info.formatted_name,
1171 check_health=check_health,
1172 )
1173 self.read_response()
1174 except ResponseError:
1175 pass
1177 try:
1178 if self.driver_info and self.driver_info.lib_version:
1179 self.send_command(
1180 "CLIENT",
1181 "SETINFO",
1182 "LIB-VER",
1183 self.driver_info.lib_version,
1184 check_health=check_health,
1185 )
1186 self.read_response()
1187 except ResponseError:
1188 pass
1190 # if a database is specified, switch to it
1191 if self.db:
1192 self.send_command("SELECT", self.db, check_health=check_health)
1193 if str_if_bytes(self.read_response()) != "OK":
1194 raise ConnectionError("Invalid Database")
1196 def disconnect(self, *args, **kwargs):
1197 "Disconnects from the Redis server"
1198 self._parser.on_disconnect()
1200 conn_sock = self._sock
1201 self._sock = None
1202 # reset the reconnect flag
1203 self.reset_should_reconnect()
1205 if conn_sock is None:
1206 return
1208 if os.getpid() == self.pid:
1209 try:
1210 conn_sock.shutdown(socket.SHUT_RDWR)
1211 except (OSError, TypeError):
1212 pass
1214 try:
1215 conn_sock.close()
1216 except OSError:
1217 pass
1219 error = kwargs.get("error")
1220 failure_count = kwargs.get("failure_count")
1221 health_check_failed = kwargs.get("health_check_failed")
1223 if error:
1224 if health_check_failed:
1225 close_reason = CloseReason.HEALTHCHECK_FAILED
1226 else:
1227 close_reason = CloseReason.ERROR
1229 if failure_count is not None and failure_count > self.retry.get_retries():
1230 record_error_count(
1231 server_address=self.host,
1232 server_port=self.port,
1233 network_peer_address=self.host,
1234 network_peer_port=self.port,
1235 error_type=error,
1236 retry_attempts=failure_count,
1237 )
1239 record_connection_closed(
1240 close_reason=close_reason,
1241 error_type=error,
1242 )
1243 else:
1244 record_connection_closed(
1245 close_reason=CloseReason.APPLICATION_CLOSE,
1246 )
1248 if self.maintenance_state == MaintenanceState.MAINTENANCE:
1249 # this block will be executed only if the connection was in maintenance state
1250 # and the connection was closed.
1251 # The state change won't be applied on connections that are in Moving state
1252 # because their state and configurations will be handled when the moving ttl expires.
1253 self.reset_tmp_settings(reset_relaxed_timeout=True)
1254 self.maintenance_state = MaintenanceState.NONE
1255 # reset the sets that keep track of received start maint
1256 # notifications and skipped end maint notifications
1257 self.reset_received_notifications()
1259 def mark_for_reconnect(self):
1260 self._should_reconnect = True
1262 def should_reconnect(self):
1263 return self._should_reconnect
1265 def reset_should_reconnect(self):
1266 self._should_reconnect = False
1268 def _send_ping(self):
1269 """Send PING, expect PONG in return"""
1270 self.send_command("PING", check_health=False)
1271 if str_if_bytes(self.read_response()) != "PONG":
1272 raise ConnectionError("Bad response from PING health check")
1274 def _ping_failed(self, error, failure_count):
1275 """Function to call when PING fails"""
1276 self.disconnect(
1277 error=error, failure_count=failure_count, health_check_failed=True
1278 )
1280 def check_health(self):
1281 """Check the health of the connection with a PING/PONG"""
1282 if self.health_check_interval and time.monotonic() > self.next_health_check:
1283 self.retry.call_with_retry(
1284 self._send_ping,
1285 self._ping_failed,
1286 with_failure_count=True,
1287 )
1289 def send_packed_command(self, command, check_health=True):
1290 """Send an already packed command to the Redis server"""
1291 if not self._sock:
1292 self.connect_check_health(check_health=False)
1293 # guard against health check recursion
1294 if check_health:
1295 self.check_health()
1296 try:
1297 if isinstance(command, str):
1298 command = [command]
1299 for item in command:
1300 self._sock.sendall(item)
1301 except socket.timeout:
1302 self.disconnect()
1303 raise TimeoutError("Timeout writing to socket")
1304 except OSError as e:
1305 self.disconnect()
1306 if len(e.args) == 1:
1307 errno, errmsg = "UNKNOWN", e.args[0]
1308 else:
1309 errno = e.args[0]
1310 errmsg = e.args[1]
1311 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1312 except BaseException:
1313 # BaseExceptions can be raised when a socket send operation is not
1314 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1315 # to send un-sent data. However, the send_packed_command() API
1316 # does not support it so there is no point in keeping the connection open.
1317 self.disconnect()
1318 raise
1320 def send_command(self, *args, **kwargs):
1321 """Pack and send a command to the Redis server"""
1322 self.send_packed_command(
1323 self._command_packer.pack(*args),
1324 check_health=kwargs.get("check_health", True),
1325 )
1327 def can_read(self, timeout=0):
1328 """Poll the socket to see if there's data that can be read."""
1329 sock = self._sock
1330 if not sock:
1331 self.connect()
1333 host_error = self._host_error()
1335 try:
1336 return self._parser.can_read(timeout)
1338 except OSError as e:
1339 self.disconnect()
1340 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1342 def read_response(
1343 self,
1344 disable_decoding=False,
1345 *,
1346 timeout: Union[float, object] = SENTINEL,
1347 disconnect_on_error=True,
1348 push_request=False,
1349 ):
1350 """Read the response from a previously sent command"""
1352 host_error = self._host_error()
1354 try:
1355 if self.protocol in ["3", 3]:
1356 response = self._parser.read_response(
1357 disable_decoding=disable_decoding,
1358 push_request=push_request,
1359 timeout=timeout,
1360 )
1361 else:
1362 response = self._parser.read_response(
1363 disable_decoding=disable_decoding, timeout=timeout
1364 )
1365 except socket.timeout:
1366 if disconnect_on_error:
1367 self.disconnect()
1368 raise TimeoutError(f"Timeout reading from {host_error}")
1369 except OSError as e:
1370 if disconnect_on_error:
1371 self.disconnect()
1372 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1373 except BaseException:
1374 # Also by default close in case of BaseException. A lot of code
1375 # relies on this behaviour when doing Command/Response pairs.
1376 # See #1128.
1377 if disconnect_on_error:
1378 self.disconnect()
1379 raise
1381 if self.health_check_interval:
1382 self.next_health_check = time.monotonic() + self.health_check_interval
1384 if isinstance(response, ResponseError):
1385 try:
1386 raise response
1387 finally:
1388 del response # avoid creating ref cycles
1389 return response
1391 def pack_command(self, *args):
1392 """Pack a series of arguments into the Redis protocol"""
1393 return self._command_packer.pack(*args)
1395 def pack_commands(self, commands):
1396 """Pack multiple commands into the Redis protocol"""
1397 output = []
1398 pieces = []
1399 buffer_length = 0
1400 buffer_cutoff = self._buffer_cutoff
1402 for cmd in commands:
1403 for chunk in self._command_packer.pack(*cmd):
1404 chunklen = len(chunk)
1405 if (
1406 buffer_length > buffer_cutoff
1407 or chunklen > buffer_cutoff
1408 or isinstance(chunk, memoryview)
1409 ):
1410 if pieces:
1411 output.append(SYM_EMPTY.join(pieces))
1412 buffer_length = 0
1413 pieces = []
1415 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1416 output.append(chunk)
1417 else:
1418 pieces.append(chunk)
1419 buffer_length += chunklen
1421 if pieces:
1422 output.append(SYM_EMPTY.join(pieces))
1423 return output
1425 def get_protocol(self) -> Union[int, str]:
1426 return self.protocol
1428 @property
1429 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1430 return self._handshake_metadata
1432 @handshake_metadata.setter
1433 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1434 self._handshake_metadata = value
1436 def set_re_auth_token(self, token: TokenInterface):
1437 self._re_auth_token = token
1439 def re_auth(self):
1440 if self._re_auth_token is not None:
1441 self.send_command(
1442 "AUTH",
1443 self._re_auth_token.try_get("oid"),
1444 self._re_auth_token.get_value(),
1445 )
1446 self.read_response()
1447 self._re_auth_token = None
1449 def _get_socket(self) -> Optional[socket.socket]:
1450 return self._sock
1452 @property
1453 def socket_timeout(self) -> Optional[Union[float, int]]:
1454 return self._socket_timeout
1456 @socket_timeout.setter
1457 def socket_timeout(self, value: Optional[Union[float, int]]):
1458 self._socket_timeout = value
1460 @property
1461 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1462 return self._socket_connect_timeout
1464 @socket_connect_timeout.setter
1465 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1466 self._socket_connect_timeout = value
1468 def extract_connection_details(self) -> str:
1469 socket_address = None
1470 if self._sock is None:
1471 return "not connected"
1472 try:
1473 socket_address = self._sock.getsockname() if self._sock else None
1474 socket_address = socket_address[1] if socket_address else None
1475 except (AttributeError, OSError):
1476 pass
1478 return f"connected to ip {self.get_resolved_ip()}, local socket port: {socket_address}"
1481class Connection(AbstractConnection):
1482 "Manages TCP communication to and from a Redis server"
1484 def __init__(
1485 self,
1486 host="localhost",
1487 port=6379,
1488 socket_keepalive=False,
1489 socket_keepalive_options=None,
1490 socket_type=0,
1491 **kwargs,
1492 ):
1493 self._host = host
1494 self.port = int(port)
1495 self.socket_keepalive = socket_keepalive
1496 self.socket_keepalive_options = socket_keepalive_options or {}
1497 self.socket_type = socket_type
1498 super().__init__(**kwargs)
1500 def repr_pieces(self):
1501 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1502 if self.client_name:
1503 pieces.append(("client_name", self.client_name))
1504 return pieces
1506 def _connect(self):
1507 "Create a TCP socket connection"
1508 # we want to mimic what socket.create_connection does to support
1509 # ipv4/ipv6, but we want to set options prior to calling
1510 # socket.connect()
1511 err = None
1513 for res in socket.getaddrinfo(
1514 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1515 ):
1516 family, socktype, proto, canonname, socket_address = res
1517 sock = None
1518 try:
1519 sock = socket.socket(family, socktype, proto)
1520 # TCP_NODELAY
1521 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1523 # TCP_KEEPALIVE
1524 if self.socket_keepalive:
1525 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1526 for k, v in self.socket_keepalive_options.items():
1527 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1529 # set the socket_connect_timeout before we connect
1530 sock.settimeout(self.socket_connect_timeout)
1532 # connect
1533 sock.connect(socket_address)
1535 # set the socket_timeout now that we're connected
1536 sock.settimeout(self.socket_timeout)
1537 return sock
1539 except OSError as _:
1540 err = _
1541 if sock is not None:
1542 try:
1543 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1544 except OSError:
1545 pass
1546 sock.close()
1548 if err is not None:
1549 raise err
1550 raise OSError("socket.getaddrinfo returned an empty list")
1552 def _host_error(self):
1553 return f"{self.host}:{self.port}"
1555 @property
1556 def host(self) -> str:
1557 return self._host
1559 @host.setter
1560 def host(self, value: str):
1561 self._host = value
1564class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1565 DUMMY_CACHE_VALUE = b"foo"
1566 MIN_ALLOWED_VERSION = "7.4.0"
1567 DEFAULT_SERVER_NAME = "redis"
1569 def __init__(
1570 self,
1571 conn: ConnectionInterface,
1572 cache: CacheInterface,
1573 pool_lock: threading.RLock,
1574 ):
1575 self.pid = os.getpid()
1576 self._conn = conn
1577 self.retry = self._conn.retry
1578 self.host = self._conn.host
1579 self.port = self._conn.port
1580 self.db = self._conn.db
1581 self._event_dispatcher = self._conn._event_dispatcher
1582 self.credential_provider = conn.credential_provider
1583 self._pool_lock = pool_lock
1584 self._cache = cache
1585 self._cache_lock = threading.RLock()
1586 self._current_command_cache_key = None
1587 self._current_options = None
1588 self.register_connect_callback(self._enable_tracking_callback)
1590 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1591 MaintNotificationsAbstractConnection.__init__(
1592 self,
1593 self._conn.maint_notifications_config,
1594 self._conn._maint_notifications_pool_handler,
1595 self._conn.maintenance_state,
1596 self._conn.maintenance_notification_hash,
1597 self._conn.host,
1598 self._conn.socket_timeout,
1599 self._conn.socket_connect_timeout,
1600 self._conn._oss_cluster_maint_notifications_handler,
1601 self._conn._get_parser(),
1602 event_dispatcher=self._conn.event_dispatcher,
1603 )
1605 def repr_pieces(self):
1606 return self._conn.repr_pieces()
1608 @property
1609 def is_connected(self) -> bool:
1610 return self._conn.is_connected
1612 def register_connect_callback(self, callback):
1613 self._conn.register_connect_callback(callback)
1615 def deregister_connect_callback(self, callback):
1616 self._conn.deregister_connect_callback(callback)
1618 def set_parser(self, parser_class):
1619 self._conn.set_parser(parser_class)
1621 def set_maint_notifications_pool_handler_for_connection(
1622 self, maint_notifications_pool_handler
1623 ):
1624 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1625 self._conn.set_maint_notifications_pool_handler_for_connection(
1626 maint_notifications_pool_handler
1627 )
1629 def set_maint_notifications_cluster_handler_for_connection(
1630 self, oss_cluster_maint_notifications_handler
1631 ):
1632 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1633 self._conn.set_maint_notifications_cluster_handler_for_connection(
1634 oss_cluster_maint_notifications_handler
1635 )
1637 def get_protocol(self):
1638 return self._conn.get_protocol()
1640 def connect(self):
1641 self._conn.connect()
1643 server_name = self._conn.handshake_metadata.get(b"server", None)
1644 if server_name is None:
1645 server_name = self._conn.handshake_metadata.get("server", None)
1646 server_ver = self._conn.handshake_metadata.get(b"version", None)
1647 if server_ver is None:
1648 server_ver = self._conn.handshake_metadata.get("version", None)
1649 if server_ver is None or server_name is None:
1650 raise ConnectionError("Cannot retrieve information about server version")
1652 server_ver = ensure_string(server_ver)
1653 server_name = ensure_string(server_name)
1655 if (
1656 server_name != self.DEFAULT_SERVER_NAME
1657 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1658 ):
1659 raise ConnectionError(
1660 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1661 )
1663 def on_connect(self):
1664 self._conn.on_connect()
1666 def disconnect(self, *args, **kwargs):
1667 with self._cache_lock:
1668 self._cache.flush()
1669 self._conn.disconnect(*args, **kwargs)
1671 def check_health(self):
1672 self._conn.check_health()
1674 def send_packed_command(self, command, check_health=True):
1675 # TODO: Investigate if it's possible to unpack command
1676 # or extract keys from packed command
1677 self._conn.send_packed_command(command)
1679 def send_command(self, *args, **kwargs):
1680 self._process_pending_invalidations()
1682 with self._cache_lock:
1683 # Command is write command or not allowed
1684 # to be cached.
1685 if not self._cache.is_cachable(
1686 CacheKey(command=args[0], redis_keys=(), redis_args=())
1687 ):
1688 self._current_command_cache_key = None
1689 self._conn.send_command(*args, **kwargs)
1690 return
1692 if kwargs.get("keys") is None:
1693 raise ValueError("Cannot create cache key.")
1695 # Creates cache key.
1696 self._current_command_cache_key = CacheKey(
1697 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1698 )
1700 with self._cache_lock:
1701 # We have to trigger invalidation processing in case if
1702 # it was cached by another connection to avoid
1703 # queueing invalidations in stale connections.
1704 if self._cache.get(self._current_command_cache_key):
1705 entry = self._cache.get(self._current_command_cache_key)
1707 if entry.connection_ref != self._conn:
1708 with self._pool_lock:
1709 while entry.connection_ref.can_read():
1710 entry.connection_ref.read_response(push_request=True)
1712 # Re-check: if the entry was invalidated during the drain,
1713 # fall through to send the command over the network.
1714 if self._cache.get(self._current_command_cache_key):
1715 return
1717 # Set temporary entry value to prevent
1718 # race condition from another connection.
1719 self._cache.set(
1720 CacheEntry(
1721 cache_key=self._current_command_cache_key,
1722 cache_value=self.DUMMY_CACHE_VALUE,
1723 status=CacheEntryStatus.IN_PROGRESS,
1724 connection_ref=self._conn,
1725 )
1726 )
1728 # Send command over socket only if it's allowed
1729 # read-only command that not yet cached.
1730 self._conn.send_command(*args, **kwargs)
1732 def can_read(self, timeout=0):
1733 return self._conn.can_read(timeout)
1735 def read_response(
1736 self,
1737 disable_decoding=False,
1738 *,
1739 timeout: Union[float, object] = SENTINEL,
1740 disconnect_on_error=True,
1741 push_request=False,
1742 ):
1743 with self._cache_lock:
1744 # Check if command response exists in a cache and it's not in progress.
1745 if self._current_command_cache_key is not None:
1746 if (
1747 self._cache.get(self._current_command_cache_key) is not None
1748 and self._cache.get(self._current_command_cache_key).status
1749 != CacheEntryStatus.IN_PROGRESS
1750 ):
1751 res = copy.deepcopy(
1752 self._cache.get(self._current_command_cache_key).cache_value
1753 )
1754 self._current_command_cache_key = None
1755 record_csc_request(
1756 result=CSCResult.HIT,
1757 )
1758 record_csc_network_saved(
1759 bytes_saved=len(res) if hasattr(res, "__len__") else 0,
1760 )
1761 return res
1762 record_csc_request(
1763 result=CSCResult.MISS,
1764 )
1766 response = self._conn.read_response(
1767 disable_decoding=disable_decoding,
1768 timeout=timeout,
1769 disconnect_on_error=disconnect_on_error,
1770 push_request=push_request,
1771 )
1773 with self._cache_lock:
1774 # Prevent not-allowed command from caching.
1775 if self._current_command_cache_key is None:
1776 return response
1777 # If response is None prevent from caching.
1778 if response is None:
1779 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1780 return response
1782 cache_entry = self._cache.get(self._current_command_cache_key)
1784 # Cache only responses that still valid
1785 # and wasn't invalidated by another connection in meantime.
1786 if cache_entry is not None:
1787 cache_entry.status = CacheEntryStatus.VALID
1788 cache_entry.cache_value = response
1789 self._cache.set(cache_entry)
1791 self._current_command_cache_key = None
1793 return response
1795 def pack_command(self, *args):
1796 return self._conn.pack_command(*args)
1798 def pack_commands(self, commands):
1799 return self._conn.pack_commands(commands)
1801 @property
1802 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1803 return self._conn.handshake_metadata
1805 def set_re_auth_token(self, token: TokenInterface):
1806 self._conn.set_re_auth_token(token)
1808 def re_auth(self):
1809 self._conn.re_auth()
1811 def mark_for_reconnect(self):
1812 self._conn.mark_for_reconnect()
1814 def should_reconnect(self):
1815 return self._conn.should_reconnect()
1817 def reset_should_reconnect(self):
1818 self._conn.reset_should_reconnect()
1820 @property
1821 def host(self) -> str:
1822 return self._conn.host
1824 @host.setter
1825 def host(self, value: str):
1826 self._conn.host = value
1828 @property
1829 def socket_timeout(self) -> Optional[Union[float, int]]:
1830 return self._conn.socket_timeout
1832 @socket_timeout.setter
1833 def socket_timeout(self, value: Optional[Union[float, int]]):
1834 self._conn.socket_timeout = value
1836 @property
1837 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1838 return self._conn.socket_connect_timeout
1840 @socket_connect_timeout.setter
1841 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1842 self._conn.socket_connect_timeout = value
1844 @property
1845 def _maint_notifications_connection_handler(
1846 self,
1847 ) -> Optional[MaintNotificationsConnectionHandler]:
1848 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1849 return self._conn._maint_notifications_connection_handler
1851 @_maint_notifications_connection_handler.setter
1852 def _maint_notifications_connection_handler(
1853 self, value: Optional[MaintNotificationsConnectionHandler]
1854 ):
1855 self._conn._maint_notifications_connection_handler = value
1857 def _get_socket(self) -> Optional[socket.socket]:
1858 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1859 return self._conn._get_socket()
1860 else:
1861 raise NotImplementedError(
1862 "Maintenance notifications are not supported by this connection type"
1863 )
1865 def _get_maint_notifications_connection_instance(
1866 self, connection
1867 ) -> MaintNotificationsAbstractConnection:
1868 """
1869 Validate that connection instance supports maintenance notifications.
1870 With this helper method we ensure that we are working
1871 with the correct connection type.
1872 After twe validate that connection instance supports maintenance notifications
1873 we can safely return the connection instance
1874 as MaintNotificationsAbstractConnection.
1875 """
1876 if not isinstance(connection, MaintNotificationsAbstractConnection):
1877 raise NotImplementedError(
1878 "Maintenance notifications are not supported by this connection type"
1879 )
1880 else:
1881 return connection
1883 @property
1884 def maintenance_state(self) -> MaintenanceState:
1885 con = self._get_maint_notifications_connection_instance(self._conn)
1886 return con.maintenance_state
1888 @maintenance_state.setter
1889 def maintenance_state(self, state: MaintenanceState):
1890 con = self._get_maint_notifications_connection_instance(self._conn)
1891 con.maintenance_state = state
1893 def getpeername(self):
1894 con = self._get_maint_notifications_connection_instance(self._conn)
1895 return con.getpeername()
1897 def get_resolved_ip(self):
1898 con = self._get_maint_notifications_connection_instance(self._conn)
1899 return con.get_resolved_ip()
1901 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1902 con = self._get_maint_notifications_connection_instance(self._conn)
1903 con.update_current_socket_timeout(relaxed_timeout)
1905 def set_tmp_settings(
1906 self,
1907 tmp_host_address: Optional[str] = None,
1908 tmp_relaxed_timeout: Optional[float] = None,
1909 ):
1910 con = self._get_maint_notifications_connection_instance(self._conn)
1911 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1913 def reset_tmp_settings(
1914 self,
1915 reset_host_address: bool = False,
1916 reset_relaxed_timeout: bool = False,
1917 ):
1918 con = self._get_maint_notifications_connection_instance(self._conn)
1919 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1921 def _connect(self):
1922 self._conn._connect()
1924 def _host_error(self):
1925 self._conn._host_error()
1927 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1928 conn.send_command("CLIENT", "TRACKING", "ON")
1929 conn.read_response()
1930 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1932 def _process_pending_invalidations(self):
1933 while self.can_read():
1934 self._conn.read_response(push_request=True)
1936 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1937 with self._cache_lock:
1938 # Flush cache when DB flushed on server-side
1939 if data[1] is None:
1940 self._cache.flush()
1941 else:
1942 keys_deleted = self._cache.delete_by_redis_keys(data[1])
1944 if len(keys_deleted) > 0:
1945 record_csc_eviction(
1946 count=len(keys_deleted),
1947 reason=CSCReason.INVALIDATION,
1948 )
1950 def extract_connection_details(self) -> str:
1951 return self._conn.extract_connection_details()
1954class SSLConnection(Connection):
1955 """Manages SSL connections to and from the Redis server(s).
1956 This class extends the Connection class, adding SSL functionality, and making
1957 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1958 """ # noqa
1960 def __init__(
1961 self,
1962 ssl_keyfile=None,
1963 ssl_certfile=None,
1964 ssl_cert_reqs="required",
1965 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1966 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1967 ssl_ca_certs=None,
1968 ssl_ca_data=None,
1969 ssl_check_hostname=True,
1970 ssl_ca_path=None,
1971 ssl_password=None,
1972 ssl_validate_ocsp=False,
1973 ssl_validate_ocsp_stapled=False,
1974 ssl_ocsp_context=None,
1975 ssl_ocsp_expected_cert=None,
1976 ssl_min_version=None,
1977 ssl_ciphers=None,
1978 **kwargs,
1979 ):
1980 """Constructor
1982 Args:
1983 ssl_keyfile: Path to an ssl private key. Defaults to None.
1984 ssl_certfile: Path to an ssl certificate. Defaults to None.
1985 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1986 or an ssl.VerifyMode. Defaults to "required".
1987 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1988 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1989 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1990 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1991 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1992 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1993 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1995 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1996 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1997 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1998 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1999 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
2000 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
2002 Raises:
2003 RedisError
2004 """ # noqa
2005 if not SSL_AVAILABLE:
2006 raise RedisError("Python wasn't built with SSL support")
2008 self.keyfile = ssl_keyfile
2009 self.certfile = ssl_certfile
2010 if ssl_cert_reqs is None:
2011 ssl_cert_reqs = ssl.CERT_NONE
2012 elif isinstance(ssl_cert_reqs, str):
2013 CERT_REQS = { # noqa: N806
2014 "none": ssl.CERT_NONE,
2015 "optional": ssl.CERT_OPTIONAL,
2016 "required": ssl.CERT_REQUIRED,
2017 }
2018 if ssl_cert_reqs not in CERT_REQS:
2019 raise RedisError(
2020 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
2021 )
2022 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
2023 self.cert_reqs = ssl_cert_reqs
2024 self.ssl_include_verify_flags = ssl_include_verify_flags
2025 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
2026 self.ca_certs = ssl_ca_certs
2027 self.ca_data = ssl_ca_data
2028 self.ca_path = ssl_ca_path
2029 self.check_hostname = (
2030 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
2031 )
2032 self.certificate_password = ssl_password
2033 self.ssl_validate_ocsp = ssl_validate_ocsp
2034 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
2035 self.ssl_ocsp_context = ssl_ocsp_context
2036 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
2037 self.ssl_min_version = ssl_min_version
2038 self.ssl_ciphers = ssl_ciphers
2039 super().__init__(**kwargs)
2041 def _connect(self):
2042 """
2043 Wrap the socket with SSL support, handling potential errors.
2044 """
2045 sock = super()._connect()
2046 try:
2047 return self._wrap_socket_with_ssl(sock)
2048 except (OSError, RedisError):
2049 sock.close()
2050 raise
2052 def _wrap_socket_with_ssl(self, sock):
2053 """
2054 Wraps the socket with SSL support.
2056 Args:
2057 sock: The plain socket to wrap with SSL.
2059 Returns:
2060 An SSL wrapped socket.
2061 """
2062 context = ssl.create_default_context()
2063 context.check_hostname = self.check_hostname
2064 context.verify_mode = self.cert_reqs
2065 if self.ssl_include_verify_flags:
2066 for flag in self.ssl_include_verify_flags:
2067 context.verify_flags |= flag
2068 if self.ssl_exclude_verify_flags:
2069 for flag in self.ssl_exclude_verify_flags:
2070 context.verify_flags &= ~flag
2071 if self.certfile or self.keyfile:
2072 context.load_cert_chain(
2073 certfile=self.certfile,
2074 keyfile=self.keyfile,
2075 password=self.certificate_password,
2076 )
2077 if (
2078 self.ca_certs is not None
2079 or self.ca_path is not None
2080 or self.ca_data is not None
2081 ):
2082 context.load_verify_locations(
2083 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
2084 )
2085 if self.ssl_min_version is not None:
2086 context.minimum_version = self.ssl_min_version
2087 if self.ssl_ciphers:
2088 context.set_ciphers(self.ssl_ciphers)
2089 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
2090 raise RedisError("cryptography is not installed.")
2092 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
2093 raise RedisError(
2094 "Either an OCSP staple or pure OCSP connection must be validated "
2095 "- not both."
2096 )
2098 sslsock = context.wrap_socket(sock, server_hostname=self.host)
2100 # validation for the stapled case
2101 if self.ssl_validate_ocsp_stapled:
2102 import OpenSSL
2104 from .ocsp import ocsp_staple_verifier
2106 # if a context is provided use it - otherwise, a basic context
2107 if self.ssl_ocsp_context is None:
2108 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
2109 staple_ctx.use_certificate_file(self.certfile)
2110 staple_ctx.use_privatekey_file(self.keyfile)
2111 else:
2112 staple_ctx = self.ssl_ocsp_context
2114 staple_ctx.set_ocsp_client_callback(
2115 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
2116 )
2118 # need another socket
2119 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
2120 con.request_ocsp()
2121 con.connect((self.host, self.port))
2122 con.do_handshake()
2123 con.shutdown()
2124 return sslsock
2126 # pure ocsp validation
2127 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
2128 from .ocsp import OCSPVerifier
2130 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
2131 if o.is_valid():
2132 return sslsock
2133 else:
2134 raise ConnectionError("ocsp validation error")
2135 return sslsock
2138class UnixDomainSocketConnection(AbstractConnection):
2139 "Manages UDS communication to and from a Redis server"
2141 def __init__(self, path="", socket_timeout=None, **kwargs):
2142 super().__init__(**kwargs)
2143 self.path = path
2144 self.socket_timeout = socket_timeout
2146 def repr_pieces(self):
2147 pieces = [("path", self.path), ("db", self.db)]
2148 if self.client_name:
2149 pieces.append(("client_name", self.client_name))
2150 return pieces
2152 def _connect(self):
2153 "Create a Unix domain socket connection"
2154 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
2155 sock.settimeout(self.socket_connect_timeout)
2156 try:
2157 sock.connect(self.path)
2158 except OSError:
2159 # Prevent ResourceWarnings for unclosed sockets.
2160 try:
2161 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
2162 except OSError:
2163 pass
2164 sock.close()
2165 raise
2166 sock.settimeout(self.socket_timeout)
2167 return sock
2169 def _host_error(self):
2170 return self.path
2173FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
2176def to_bool(value):
2177 if value is None or value == "":
2178 return None
2179 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
2180 return False
2181 return bool(value)
2184def parse_ssl_verify_flags(value):
2185 # flags are passed in as a string representation of a list,
2186 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
2187 verify_flags_str = value.replace("[", "").replace("]", "")
2189 verify_flags = []
2190 for flag in verify_flags_str.split(","):
2191 flag = flag.strip()
2192 if not hasattr(VerifyFlags, flag):
2193 raise ValueError(f"Invalid ssl verify flag: {flag}")
2194 verify_flags.append(getattr(VerifyFlags, flag))
2195 return verify_flags
2198URL_QUERY_ARGUMENT_PARSERS = {
2199 "db": int,
2200 "socket_timeout": float,
2201 "socket_connect_timeout": float,
2202 "socket_keepalive": to_bool,
2203 "retry_on_timeout": to_bool,
2204 "retry_on_error": list,
2205 "max_connections": int,
2206 "health_check_interval": int,
2207 "ssl_check_hostname": to_bool,
2208 "ssl_include_verify_flags": parse_ssl_verify_flags,
2209 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
2210 "timeout": float,
2211}
2214def parse_url(url):
2215 if not (
2216 url.startswith("redis://")
2217 or url.startswith("rediss://")
2218 or url.startswith("unix://")
2219 ):
2220 raise ValueError(
2221 "Redis URL must specify one of the following "
2222 "schemes (redis://, rediss://, unix://)"
2223 )
2225 url = urlparse(url)
2226 kwargs = {}
2228 for name, value in parse_qs(url.query).items():
2229 if value and len(value) > 0:
2230 value = unquote(value[0])
2231 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
2232 if parser:
2233 try:
2234 kwargs[name] = parser(value)
2235 except (TypeError, ValueError):
2236 raise ValueError(f"Invalid value for '{name}' in connection URL.")
2237 else:
2238 kwargs[name] = value
2240 if url.username:
2241 kwargs["username"] = unquote(url.username)
2242 if url.password:
2243 kwargs["password"] = unquote(url.password)
2245 # We only support redis://, rediss:// and unix:// schemes.
2246 if url.scheme == "unix":
2247 if url.path:
2248 kwargs["path"] = unquote(url.path)
2249 kwargs["connection_class"] = UnixDomainSocketConnection
2251 else: # implied: url.scheme in ("redis", "rediss"):
2252 if url.hostname:
2253 kwargs["host"] = unquote(url.hostname)
2254 if url.port:
2255 kwargs["port"] = int(url.port)
2257 # If there's a path argument, use it as the db argument if a
2258 # querystring value wasn't specified
2259 if url.path and "db" not in kwargs:
2260 try:
2261 kwargs["db"] = int(unquote(url.path).replace("/", ""))
2262 except (AttributeError, ValueError):
2263 pass
2265 if url.scheme == "rediss":
2266 kwargs["connection_class"] = SSLConnection
2268 return kwargs
2271_CP = TypeVar("_CP", bound="ConnectionPool")
2274class ConnectionPoolInterface(ABC):
2275 @abstractmethod
2276 def get_protocol(self):
2277 pass
2279 @abstractmethod
2280 def reset(self):
2281 pass
2283 @abstractmethod
2284 @deprecated_args(
2285 args_to_warn=["*"],
2286 reason="Use get_connection() without args instead",
2287 version="5.3.0",
2288 )
2289 def get_connection(
2290 self, command_name: Optional[str], *keys, **options
2291 ) -> ConnectionInterface:
2292 pass
2294 @abstractmethod
2295 def get_encoder(self):
2296 pass
2298 @abstractmethod
2299 def release(self, connection: ConnectionInterface):
2300 pass
2302 @abstractmethod
2303 def disconnect(self, inuse_connections: bool = True):
2304 pass
2306 @abstractmethod
2307 def close(self):
2308 pass
2310 @abstractmethod
2311 def set_retry(self, retry: Retry):
2312 pass
2314 @abstractmethod
2315 def re_auth_callback(self, token: TokenInterface):
2316 pass
2318 @abstractmethod
2319 def get_connection_count(self) -> list[tuple[int, dict]]:
2320 """
2321 Returns a connection count (both idle and in use).
2322 """
2323 pass
2326class MaintNotificationsAbstractConnectionPool:
2327 """
2328 Abstract class for handling maintenance notifications logic.
2329 This class is mixed into the ConnectionPool classes.
2331 This class is not intended to be used directly!
2333 All logic related to maintenance notifications and
2334 connection pool handling is encapsulated in this class.
2335 """
2337 def __init__(
2338 self,
2339 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2340 oss_cluster_maint_notifications_handler: Optional[
2341 OSSMaintNotificationsHandler
2342 ] = None,
2343 **kwargs,
2344 ):
2345 # Initialize maintenance notifications
2346 is_protocol_supported = check_protocol_version(
2347 kwargs.get("protocol", DEFAULT_RESP_VERSION), 3
2348 )
2350 if maint_notifications_config is None and is_protocol_supported:
2351 maint_notifications_config = MaintNotificationsConfig()
2353 if maint_notifications_config and maint_notifications_config.enabled:
2354 if not is_protocol_supported:
2355 raise RedisError(
2356 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2357 )
2359 self._event_dispatcher = kwargs.get("event_dispatcher", None)
2360 if self._event_dispatcher is None:
2361 self._event_dispatcher = EventDispatcher()
2363 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2364 self, maint_notifications_config
2365 )
2366 if oss_cluster_maint_notifications_handler:
2367 self._oss_cluster_maint_notifications_handler = (
2368 oss_cluster_maint_notifications_handler
2369 )
2370 self._update_connection_kwargs_for_maint_notifications(
2371 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler
2372 )
2373 self._maint_notifications_pool_handler = None
2374 else:
2375 self._oss_cluster_maint_notifications_handler = None
2376 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2377 self, maint_notifications_config
2378 )
2380 self._update_connection_kwargs_for_maint_notifications(
2381 maint_notifications_pool_handler=self._maint_notifications_pool_handler
2382 )
2383 else:
2384 self._maint_notifications_pool_handler = None
2385 self._oss_cluster_maint_notifications_handler = None
2387 @property
2388 @abstractmethod
2389 def connection_kwargs(self) -> Dict[str, Any]:
2390 pass
2392 @connection_kwargs.setter
2393 @abstractmethod
2394 def connection_kwargs(self, value: Dict[str, Any]):
2395 pass
2397 @abstractmethod
2398 def _get_pool_lock(self) -> threading.RLock:
2399 pass
2401 @abstractmethod
2402 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2403 pass
2405 @abstractmethod
2406 def _get_in_use_connections(
2407 self,
2408 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2409 pass
2411 def maint_notifications_enabled(self):
2412 """
2413 Returns:
2414 True if the maintenance notifications are enabled, False otherwise.
2415 The maintenance notifications config is stored in the pool handler.
2416 If the pool handler is not set, the maintenance notifications are not enabled.
2417 """
2418 if self._oss_cluster_maint_notifications_handler:
2419 maint_notifications_config = (
2420 self._oss_cluster_maint_notifications_handler.config
2421 )
2422 else:
2423 maint_notifications_config = (
2424 self._maint_notifications_pool_handler.config
2425 if self._maint_notifications_pool_handler
2426 else None
2427 )
2429 return maint_notifications_config and maint_notifications_config.enabled
2431 def update_maint_notifications_config(
2432 self,
2433 maint_notifications_config: MaintNotificationsConfig,
2434 oss_cluster_maint_notifications_handler: Optional[
2435 OSSMaintNotificationsHandler
2436 ] = None,
2437 ):
2438 """
2439 Updates the maintenance notifications configuration.
2440 This method should be called only if the pool was created
2441 without enabling the maintenance notifications and
2442 in a later point in time maintenance notifications
2443 are requested to be enabled.
2444 """
2445 if (
2446 self.maint_notifications_enabled()
2447 and not maint_notifications_config.enabled
2448 ):
2449 raise ValueError(
2450 "Cannot disable maintenance notifications after enabling them"
2451 )
2452 if oss_cluster_maint_notifications_handler:
2453 self._oss_cluster_maint_notifications_handler = (
2454 oss_cluster_maint_notifications_handler
2455 )
2456 else:
2457 # first update pool settings
2458 if not self._maint_notifications_pool_handler:
2459 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2460 self, maint_notifications_config
2461 )
2462 else:
2463 self._maint_notifications_pool_handler.config = (
2464 maint_notifications_config
2465 )
2467 # then update connection kwargs and existing connections
2468 self._update_connection_kwargs_for_maint_notifications(
2469 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2470 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2471 )
2472 self._update_maint_notifications_configs_for_connections(
2473 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2474 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2475 )
2477 def _update_connection_kwargs_for_maint_notifications(
2478 self,
2479 maint_notifications_pool_handler: Optional[
2480 MaintNotificationsPoolHandler
2481 ] = None,
2482 oss_cluster_maint_notifications_handler: Optional[
2483 OSSMaintNotificationsHandler
2484 ] = None,
2485 ):
2486 """
2487 Update the connection kwargs for all future connections.
2488 """
2489 if not self.maint_notifications_enabled():
2490 return
2491 if maint_notifications_pool_handler:
2492 self.connection_kwargs.update(
2493 {
2494 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2495 "maint_notifications_config": maint_notifications_pool_handler.config,
2496 }
2497 )
2498 if oss_cluster_maint_notifications_handler:
2499 self.connection_kwargs.update(
2500 {
2501 "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler,
2502 "maint_notifications_config": oss_cluster_maint_notifications_handler.config,
2503 }
2504 )
2506 # Store original connection parameters for maintenance notifications.
2507 if self.connection_kwargs.get("orig_host_address", None) is None:
2508 # If orig_host_address is None it means we haven't
2509 # configured the original values yet
2510 self.connection_kwargs.update(
2511 {
2512 "orig_host_address": self.connection_kwargs.get("host"),
2513 "orig_socket_timeout": self.connection_kwargs.get(
2514 "socket_timeout", None
2515 ),
2516 "orig_socket_connect_timeout": self.connection_kwargs.get(
2517 "socket_connect_timeout", None
2518 ),
2519 }
2520 )
2522 def _update_maint_notifications_configs_for_connections(
2523 self,
2524 maint_notifications_pool_handler: Optional[
2525 MaintNotificationsPoolHandler
2526 ] = None,
2527 oss_cluster_maint_notifications_handler: Optional[
2528 OSSMaintNotificationsHandler
2529 ] = None,
2530 ):
2531 """Update the maintenance notifications config for all connections in the pool."""
2532 with self._get_pool_lock():
2533 for conn in self._get_free_connections():
2534 if oss_cluster_maint_notifications_handler:
2535 # set cluster handler for conn
2536 conn.set_maint_notifications_cluster_handler_for_connection(
2537 oss_cluster_maint_notifications_handler
2538 )
2539 conn.maint_notifications_config = (
2540 oss_cluster_maint_notifications_handler.config
2541 )
2542 elif maint_notifications_pool_handler:
2543 conn.set_maint_notifications_pool_handler_for_connection(
2544 maint_notifications_pool_handler
2545 )
2546 conn.maint_notifications_config = (
2547 maint_notifications_pool_handler.config
2548 )
2549 else:
2550 raise ValueError(
2551 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2552 )
2553 conn.disconnect()
2554 for conn in self._get_in_use_connections():
2555 if oss_cluster_maint_notifications_handler:
2556 conn.maint_notifications_config = (
2557 oss_cluster_maint_notifications_handler.config
2558 )
2559 conn._configure_maintenance_notifications(
2560 oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler
2561 )
2562 elif maint_notifications_pool_handler:
2563 conn.set_maint_notifications_pool_handler_for_connection(
2564 maint_notifications_pool_handler
2565 )
2566 conn.maint_notifications_config = (
2567 maint_notifications_pool_handler.config
2568 )
2569 else:
2570 raise ValueError(
2571 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2572 )
2573 conn.mark_for_reconnect()
2575 def _should_update_connection(
2576 self,
2577 conn: "MaintNotificationsAbstractConnection",
2578 matching_pattern: Literal[
2579 "connected_address", "configured_address", "notification_hash"
2580 ] = "connected_address",
2581 matching_address: Optional[str] = None,
2582 matching_notification_hash: Optional[int] = None,
2583 ) -> bool:
2584 """
2585 Check if the connection should be updated based on the matching criteria.
2586 """
2587 if matching_pattern == "connected_address":
2588 if matching_address and conn.getpeername() != matching_address:
2589 return False
2590 elif matching_pattern == "configured_address":
2591 if matching_address and conn.host != matching_address:
2592 return False
2593 elif matching_pattern == "notification_hash":
2594 if (
2595 matching_notification_hash
2596 and conn.maintenance_notification_hash != matching_notification_hash
2597 ):
2598 return False
2599 return True
2601 def update_connection_settings(
2602 self,
2603 conn: "MaintNotificationsAbstractConnection",
2604 state: Optional["MaintenanceState"] = None,
2605 maintenance_notification_hash: Optional[int] = None,
2606 host_address: Optional[str] = None,
2607 relaxed_timeout: Optional[float] = None,
2608 update_notification_hash: bool = False,
2609 reset_host_address: bool = False,
2610 reset_relaxed_timeout: bool = False,
2611 ):
2612 """
2613 Update the settings for a single connection.
2614 """
2615 if state:
2616 conn.maintenance_state = state
2618 if update_notification_hash:
2619 # update the notification hash only if requested
2620 conn.maintenance_notification_hash = maintenance_notification_hash
2622 if host_address is not None:
2623 conn.set_tmp_settings(tmp_host_address=host_address)
2625 if relaxed_timeout is not None:
2626 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2628 if reset_relaxed_timeout or reset_host_address:
2629 conn.reset_tmp_settings(
2630 reset_host_address=reset_host_address,
2631 reset_relaxed_timeout=reset_relaxed_timeout,
2632 )
2634 conn.update_current_socket_timeout(relaxed_timeout)
2636 def update_connections_settings(
2637 self,
2638 state: Optional["MaintenanceState"] = None,
2639 maintenance_notification_hash: Optional[int] = None,
2640 host_address: Optional[str] = None,
2641 relaxed_timeout: Optional[float] = None,
2642 matching_address: Optional[str] = None,
2643 matching_notification_hash: Optional[int] = None,
2644 matching_pattern: Literal[
2645 "connected_address", "configured_address", "notification_hash"
2646 ] = "connected_address",
2647 update_notification_hash: bool = False,
2648 reset_host_address: bool = False,
2649 reset_relaxed_timeout: bool = False,
2650 include_free_connections: bool = True,
2651 ):
2652 """
2653 Update the settings for all matching connections in the pool.
2655 This method does not create new connections.
2656 This method does not affect the connection kwargs.
2658 :param state: The maintenance state to set for the connection.
2659 :param maintenance_notification_hash: The hash of the maintenance notification
2660 to set for the connection.
2661 :param host_address: The host address to set for the connection.
2662 :param relaxed_timeout: The relaxed timeout to set for the connection.
2663 :param matching_address: The address to match for the connection.
2664 :param matching_notification_hash: The notification hash to match for the connection.
2665 :param matching_pattern: The pattern to match for the connection.
2666 :param update_notification_hash: Whether to update the notification hash for the connection.
2667 :param reset_host_address: Whether to reset the host address to the original address.
2668 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2669 :param include_free_connections: Whether to include free/available connections.
2670 """
2671 with self._get_pool_lock():
2672 for conn in self._get_in_use_connections():
2673 if self._should_update_connection(
2674 conn,
2675 matching_pattern,
2676 matching_address,
2677 matching_notification_hash,
2678 ):
2679 self.update_connection_settings(
2680 conn,
2681 state=state,
2682 maintenance_notification_hash=maintenance_notification_hash,
2683 host_address=host_address,
2684 relaxed_timeout=relaxed_timeout,
2685 update_notification_hash=update_notification_hash,
2686 reset_host_address=reset_host_address,
2687 reset_relaxed_timeout=reset_relaxed_timeout,
2688 )
2690 if include_free_connections:
2691 for conn in self._get_free_connections():
2692 if self._should_update_connection(
2693 conn,
2694 matching_pattern,
2695 matching_address,
2696 matching_notification_hash,
2697 ):
2698 self.update_connection_settings(
2699 conn,
2700 state=state,
2701 maintenance_notification_hash=maintenance_notification_hash,
2702 host_address=host_address,
2703 relaxed_timeout=relaxed_timeout,
2704 update_notification_hash=update_notification_hash,
2705 reset_host_address=reset_host_address,
2706 reset_relaxed_timeout=reset_relaxed_timeout,
2707 )
2709 def update_connection_kwargs(
2710 self,
2711 **kwargs,
2712 ):
2713 """
2714 Update the connection kwargs for all future connections.
2716 This method updates the connection kwargs for all future connections created by the pool.
2717 Existing connections are not affected.
2718 """
2719 self.connection_kwargs.update(kwargs)
2721 def update_active_connections_for_reconnect(
2722 self,
2723 moving_address_src: Optional[str] = None,
2724 ):
2725 """
2726 Mark all active connections for reconnect.
2727 This is used when a cluster node is migrated to a different address.
2729 :param moving_address_src: The address of the node that is being moved.
2730 """
2731 with self._get_pool_lock():
2732 for conn in self._get_in_use_connections():
2733 if self._should_update_connection(
2734 conn, "connected_address", moving_address_src
2735 ):
2736 conn.mark_for_reconnect()
2738 def disconnect_free_connections(
2739 self,
2740 moving_address_src: Optional[str] = None,
2741 ):
2742 """
2743 Disconnect all free/available connections.
2744 This is used when a cluster node is migrated to a different address.
2746 :param moving_address_src: The address of the node that is being moved.
2747 """
2748 with self._get_pool_lock():
2749 for conn in self._get_free_connections():
2750 if self._should_update_connection(
2751 conn, "connected_address", moving_address_src
2752 ):
2753 conn.disconnect()
2756class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2757 """
2758 Create a connection pool. ``If max_connections`` is set, then this
2759 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2760 limit is reached.
2762 By default, TCP connections are created unless ``connection_class``
2763 is specified. Use class:`.UnixDomainSocketConnection` for
2764 unix sockets.
2765 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2767 If ``maint_notifications_config`` is provided, the connection pool will support
2768 maintenance notifications.
2769 Maintenance notifications are supported only with RESP3.
2770 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2771 the maintenance notifications will be enabled by default.
2773 Any additional keyword arguments are passed to the constructor of
2774 ``connection_class``.
2775 """
2777 @classmethod
2778 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2779 """
2780 Return a connection pool configured from the given URL.
2782 For example::
2784 redis://[[username]:[password]]@localhost:6379/0
2785 rediss://[[username]:[password]]@localhost:6379/0
2786 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2788 Three URL schemes are supported:
2790 - `redis://` creates a TCP socket connection. See more at:
2791 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2792 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2793 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2794 - ``unix://``: creates a Unix Domain Socket connection.
2796 The username, password, hostname, path and all querystring values
2797 are passed through urllib.parse.unquote in order to replace any
2798 percent-encoded values with their corresponding characters.
2800 There are several ways to specify a database number. The first value
2801 found will be used:
2803 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2804 2. If using the redis:// or rediss:// schemes, the path argument
2805 of the url, e.g. redis://localhost/0
2806 3. A ``db`` keyword argument to this function.
2808 If none of these options are specified, the default db=0 is used.
2810 All querystring options are cast to their appropriate Python types.
2811 Boolean arguments can be specified with string values "True"/"False"
2812 or "Yes"/"No". Values that cannot be properly cast cause a
2813 ``ValueError`` to be raised. Once parsed, the querystring arguments
2814 and keyword arguments are passed to the ``ConnectionPool``'s
2815 class initializer. In the case of conflicting arguments, querystring
2816 arguments always win.
2817 """
2818 url_options = parse_url(url)
2820 if "connection_class" in kwargs:
2821 url_options["connection_class"] = kwargs["connection_class"]
2823 kwargs.update(url_options)
2824 return cls(**kwargs)
2826 def __init__(
2827 self,
2828 connection_class=Connection,
2829 max_connections: Optional[int] = None,
2830 cache_factory: Optional[CacheFactoryInterface] = None,
2831 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2832 **connection_kwargs,
2833 ):
2834 max_connections = max_connections or 2**31
2835 if not isinstance(max_connections, int) or max_connections < 0:
2836 raise ValueError('"max_connections" must be a positive integer')
2838 self.connection_class = connection_class
2839 self._connection_kwargs = connection_kwargs
2840 self.max_connections = max_connections
2841 self.cache = None
2842 self._cache_factory = cache_factory
2844 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2845 if self._event_dispatcher is None:
2846 self._event_dispatcher = EventDispatcher()
2848 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2849 if not check_protocol_version(
2850 self._connection_kwargs.get("protocol", DEFAULT_RESP_VERSION), 3
2851 ):
2852 raise RedisError("Client caching is only supported with RESP version 3")
2854 cache = self._connection_kwargs.get("cache")
2856 if cache is not None:
2857 if not isinstance(cache, CacheInterface):
2858 raise ValueError("Cache must implement CacheInterface")
2860 self.cache = cache
2861 else:
2862 if self._cache_factory is not None:
2863 self.cache = CacheProxy(self._cache_factory.get_cache())
2864 else:
2865 self.cache = CacheFactory(
2866 self._connection_kwargs.get("cache_config")
2867 ).get_cache()
2869 init_csc_items()
2870 register_csc_items_callback(
2871 callback=lambda: self.cache.size,
2872 pool_name=get_pool_name(self),
2873 )
2875 connection_kwargs.pop("cache", None)
2876 connection_kwargs.pop("cache_config", None)
2878 # a lock to protect the critical section in _checkpid().
2879 # this lock is acquired when the process id changes, such as
2880 # after a fork. during this time, multiple threads in the child
2881 # process could attempt to acquire this lock. the first thread
2882 # to acquire the lock will reset the data structures and lock
2883 # object of this pool. subsequent threads acquiring this lock
2884 # will notice the first thread already did the work and simply
2885 # release the lock.
2887 self._fork_lock = threading.RLock()
2888 self._lock = threading.RLock()
2890 # Generate unique pool ID for observability (matches go-redis behavior)
2891 import secrets
2893 self._pool_id = secrets.token_hex(4)
2895 MaintNotificationsAbstractConnectionPool.__init__(
2896 self,
2897 maint_notifications_config=maint_notifications_config,
2898 **connection_kwargs,
2899 )
2901 self.reset()
2903 # Keys that should be redacted in __repr__ to avoid exposing sensitive information
2904 SENSITIVE_REPR_KEYS = frozenset(
2905 {
2906 "password",
2907 "username",
2908 "ssl_password",
2909 "credential_provider",
2910 }
2911 )
2913 def __repr__(self) -> str:
2914 conn_kwargs = ",".join(
2915 [
2916 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}"
2917 for k, v in self.connection_kwargs.items()
2918 ]
2919 )
2920 return (
2921 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2922 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2923 f"({conn_kwargs})>)>"
2924 )
2926 @property
2927 def connection_kwargs(self) -> Dict[str, Any]:
2928 return self._connection_kwargs
2930 @connection_kwargs.setter
2931 def connection_kwargs(self, value: Dict[str, Any]):
2932 self._connection_kwargs = value
2934 def get_protocol(self):
2935 """
2936 Returns:
2937 The RESP protocol version, or ``None`` if the protocol is not specified,
2938 in which case the server default will be used.
2939 """
2940 return self.connection_kwargs.get("protocol", None)
2942 def reset(self) -> None:
2943 # Record metrics for connections being removed before clearing
2944 # (only if attributes exist - they won't during __init__)
2945 if hasattr(self, "_available_connections") and hasattr(
2946 self, "_in_use_connections"
2947 ):
2948 with self._lock:
2949 idle_count = len(self._available_connections)
2950 in_use_count = len(self._in_use_connections)
2951 if idle_count > 0 or in_use_count > 0:
2952 pool_name = get_pool_name(self)
2953 if idle_count > 0:
2954 record_connection_count(
2955 pool_name=pool_name,
2956 connection_state=ConnectionState.IDLE,
2957 counter=-idle_count,
2958 )
2959 if in_use_count > 0:
2960 record_connection_count(
2961 pool_name=pool_name,
2962 connection_state=ConnectionState.USED,
2963 counter=-in_use_count,
2964 )
2966 self._created_connections = 0
2967 self._available_connections = []
2968 self._in_use_connections = set()
2970 # this must be the last operation in this method. while reset() is
2971 # called when holding _fork_lock, other threads in this process
2972 # can call _checkpid() which compares self.pid and os.getpid() without
2973 # holding any lock (for performance reasons). keeping this assignment
2974 # as the last operation ensures that those other threads will also
2975 # notice a pid difference and block waiting for the first thread to
2976 # release _fork_lock. when each of these threads eventually acquire
2977 # _fork_lock, they will notice that another thread already called
2978 # reset() and they will immediately release _fork_lock and continue on.
2979 self.pid = os.getpid()
2981 def __del__(self) -> None:
2982 """Clean up connection pool and record metrics when garbage collected."""
2983 try:
2984 if not hasattr(self, "_available_connections") or not hasattr(
2985 self, "_in_use_connections"
2986 ):
2987 return
2988 # Record metrics for all connections being removed
2989 idle_count = len(self._available_connections)
2990 in_use_count = len(self._in_use_connections)
2991 if idle_count > 0 or in_use_count > 0:
2992 pool_name = get_pool_name(self)
2993 if idle_count > 0:
2994 record_connection_count(
2995 pool_name=pool_name,
2996 connection_state=ConnectionState.IDLE,
2997 counter=-idle_count,
2998 )
2999 if in_use_count > 0:
3000 record_connection_count(
3001 pool_name=pool_name,
3002 connection_state=ConnectionState.USED,
3003 counter=-in_use_count,
3004 )
3005 except Exception:
3006 pass
3008 def _checkpid(self) -> None:
3009 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
3010 # systems. this is called by all ConnectionPool methods that
3011 # manipulate the pool's state such as get_connection() and release().
3012 #
3013 # _checkpid() determines whether the process has forked by comparing
3014 # the current process id to the process id saved on the ConnectionPool
3015 # instance. if these values are the same, _checkpid() simply returns.
3016 #
3017 # when the process ids differ, _checkpid() assumes that the process
3018 # has forked and that we're now running in the child process. the child
3019 # process cannot use the parent's file descriptors (e.g., sockets).
3020 # therefore, when _checkpid() sees the process id change, it calls
3021 # reset() in order to reinitialize the child's ConnectionPool. this
3022 # will cause the child to make all new connection objects.
3023 #
3024 # _checkpid() is protected by self._fork_lock to ensure that multiple
3025 # threads in the child process do not call reset() multiple times.
3026 #
3027 # there is an extremely small chance this could fail in the following
3028 # scenario:
3029 # 1. process A calls _checkpid() for the first time and acquires
3030 # self._fork_lock.
3031 # 2. while holding self._fork_lock, process A forks (the fork()
3032 # could happen in a different thread owned by process A)
3033 # 3. process B (the forked child process) inherits the
3034 # ConnectionPool's state from the parent. that state includes
3035 # a locked _fork_lock. process B will not be notified when
3036 # process A releases the _fork_lock and will thus never be
3037 # able to acquire the _fork_lock.
3038 #
3039 # to mitigate this possible deadlock, _checkpid() will only wait 5
3040 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
3041 # that time it is assumed that the child is deadlocked and a
3042 # redis.ChildDeadlockedError error is raised.
3043 if self.pid != os.getpid():
3044 acquired = self._fork_lock.acquire(timeout=5)
3045 if not acquired:
3046 raise ChildDeadlockedError
3047 # reset() the instance for the new process if another thread
3048 # hasn't already done so
3049 try:
3050 if self.pid != os.getpid():
3051 self.reset()
3052 finally:
3053 self._fork_lock.release()
3055 @deprecated_args(
3056 args_to_warn=["*"],
3057 reason="Use get_connection() without args instead",
3058 version="5.3.0",
3059 )
3060 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
3061 "Get a connection from the pool"
3063 # Start timing for observability
3064 self._checkpid()
3065 is_created = False
3067 with self._lock:
3068 try:
3069 connection = self._available_connections.pop()
3070 except IndexError:
3071 # Start timing for observability
3072 start_time_created = time.monotonic()
3074 connection = self.make_connection()
3075 is_created = True
3076 self._in_use_connections.add(connection)
3078 # Record state transition: IDLE -> USED
3079 # (make_connection already recorded IDLE +1 for new connections)
3080 # This ensures counters stay balanced if connect() fails and release() is called
3081 pool_name = get_pool_name(self)
3082 record_connection_count(
3083 pool_name=pool_name,
3084 connection_state=ConnectionState.IDLE,
3085 counter=-1,
3086 )
3087 record_connection_count(
3088 pool_name=pool_name,
3089 connection_state=ConnectionState.USED,
3090 counter=1,
3091 )
3093 try:
3094 # ensure this connection is connected to Redis
3095 connection.connect()
3096 # connections that the pool provides should be ready to send
3097 # a command. if not, the connection was either returned to the
3098 # pool before all data has been read or the socket has been
3099 # closed. either way, reconnect and verify everything is good.
3100 try:
3101 if (
3102 connection.can_read()
3103 and self.cache is None
3104 and not self.maint_notifications_enabled()
3105 ):
3106 raise ConnectionError("Connection has data")
3107 except (ConnectionError, TimeoutError, OSError):
3108 connection.disconnect()
3109 connection.connect()
3110 if connection.can_read():
3111 raise ConnectionError("Connection not ready")
3112 except BaseException:
3113 # release the connection back to the pool so that we don't
3114 # leak it
3115 self.release(connection)
3116 raise
3118 if is_created:
3119 record_connection_create_time(
3120 connection_pool=self,
3121 duration_seconds=time.monotonic() - start_time_created,
3122 )
3124 return connection
3126 def get_encoder(self) -> Encoder:
3127 "Return an encoder based on encoding settings"
3128 kwargs = self.connection_kwargs
3129 return Encoder(
3130 encoding=kwargs.get("encoding", "utf-8"),
3131 encoding_errors=kwargs.get("encoding_errors", "strict"),
3132 decode_responses=kwargs.get("decode_responses", False),
3133 )
3135 def make_connection(self) -> "ConnectionInterface":
3136 "Create a new connection"
3137 if self._created_connections >= self.max_connections:
3138 raise MaxConnectionsError("Too many connections")
3139 self._created_connections += 1
3141 kwargs = dict(self.connection_kwargs)
3143 # Create the connection first, then record metrics only on success
3144 if self.cache is not None:
3145 connection = CacheProxyConnection(
3146 self.connection_class(**kwargs), self.cache, self._lock
3147 )
3148 else:
3149 connection = self.connection_class(**kwargs)
3151 # Record new connection created (starts as IDLE) - only after successful construction
3152 record_connection_count(
3153 pool_name=get_pool_name(self),
3154 connection_state=ConnectionState.IDLE,
3155 counter=1,
3156 )
3158 return connection
3160 def release(self, connection: "Connection") -> None:
3161 "Releases the connection back to the pool"
3162 self._checkpid()
3163 with self._lock:
3164 try:
3165 self._in_use_connections.remove(connection)
3166 except KeyError:
3167 # Gracefully fail when a connection is returned to this pool
3168 # that the pool doesn't actually own
3169 return
3171 if self.owns_connection(connection):
3172 if connection.should_reconnect():
3173 connection.disconnect()
3174 self._available_connections.append(connection)
3175 self._event_dispatcher.dispatch(
3176 AfterConnectionReleasedEvent(connection)
3177 )
3179 # Record state transition: USED -> IDLE
3180 pool_name = get_pool_name(self)
3181 record_connection_count(
3182 pool_name=pool_name,
3183 connection_state=ConnectionState.USED,
3184 counter=-1,
3185 )
3186 record_connection_count(
3187 pool_name=pool_name,
3188 connection_state=ConnectionState.IDLE,
3189 counter=1,
3190 )
3191 else:
3192 # Pool doesn't own this connection, do not add it back
3193 # to the pool.
3194 # The created connections count should not be changed,
3195 # because the connection was not created by the pool.
3196 # Still need to decrement USED since it was counted in get_connection()
3197 connection.disconnect()
3198 record_connection_count(
3199 pool_name="unknown_pool",
3200 connection_state=ConnectionState.USED,
3201 counter=-1,
3202 )
3203 return
3205 def owns_connection(self, connection: "Connection") -> int:
3206 return connection.pid == self.pid
3208 def disconnect(self, inuse_connections: bool = True) -> None:
3209 """
3210 Disconnects connections in the pool
3212 If ``inuse_connections`` is True, disconnect connections that are
3213 currently in use, potentially by other threads. Otherwise only disconnect
3214 connections that are idle in the pool.
3215 """
3216 self._checkpid()
3217 with self._lock:
3218 if inuse_connections:
3219 connections = chain(
3220 self._available_connections, self._in_use_connections
3221 )
3222 else:
3223 connections = self._available_connections
3225 for connection in connections:
3226 connection.disconnect()
3228 def close(self) -> None:
3229 """Close the pool, disconnecting all connections"""
3230 self.disconnect()
3232 def set_retry(self, retry: Retry) -> None:
3233 self.connection_kwargs.update({"retry": retry})
3234 for conn in self._available_connections:
3235 conn.retry = retry
3236 for conn in self._in_use_connections:
3237 conn.retry = retry
3239 def re_auth_callback(self, token: TokenInterface):
3240 with self._lock:
3241 for conn in self._available_connections:
3242 conn.retry.call_with_retry(
3243 lambda: conn.send_command(
3244 "AUTH", token.try_get("oid"), token.get_value()
3245 ),
3246 lambda error: self._mock(error),
3247 )
3248 conn.retry.call_with_retry(
3249 lambda: conn.read_response(), lambda error: self._mock(error)
3250 )
3251 for conn in self._in_use_connections:
3252 conn.set_re_auth_token(token)
3254 def _get_pool_lock(self):
3255 return self._lock
3257 def _get_free_connections(self):
3258 with self._lock:
3259 return list(self._available_connections)
3261 def _get_in_use_connections(self):
3262 with self._lock:
3263 return set(self._in_use_connections)
3265 def _mock(self, error: RedisError):
3266 """
3267 Dummy functions, needs to be passed as error callback to retry object.
3268 :param error:
3269 :return:
3270 """
3271 pass
3273 def get_connection_count(self) -> List[tuple[int, dict]]:
3274 from redis.observability.attributes import get_pool_name
3276 attributes = AttributeBuilder.build_base_attributes()
3277 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
3278 free_connections_attributes = attributes.copy()
3279 in_use_connections_attributes = attributes.copy()
3281 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3282 ConnectionState.IDLE.value
3283 )
3284 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3285 ConnectionState.USED.value
3286 )
3288 return [
3289 (len(self._get_free_connections()), free_connections_attributes),
3290 (len(self._get_in_use_connections()), in_use_connections_attributes),
3291 ]
3294class BlockingConnectionPool(ConnectionPool):
3295 """
3296 Thread-safe blocking connection pool::
3298 >>> from redis.client import Redis
3299 >>> client = Redis(connection_pool=BlockingConnectionPool())
3301 It performs the same function as the default
3302 :py:class:`~redis.ConnectionPool` implementation, in that,
3303 it maintains a pool of reusable connections that can be shared by
3304 multiple redis clients (safely across threads if required).
3306 The difference is that, in the event that a client tries to get a
3307 connection from the pool when all of connections are in use, rather than
3308 raising a :py:class:`~redis.ConnectionError` (as the default
3309 :py:class:`~redis.ConnectionPool` implementation does), it
3310 makes the client wait ("blocks") for a specified number of seconds until
3311 a connection becomes available.
3313 Use ``max_connections`` to increase / decrease the pool size::
3315 >>> pool = BlockingConnectionPool(max_connections=10)
3317 Use ``timeout`` to tell it either how many seconds to wait for a connection
3318 to become available, or to block forever:
3320 >>> # Block forever.
3321 >>> pool = BlockingConnectionPool(timeout=None)
3323 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
3324 >>> # not available.
3325 >>> pool = BlockingConnectionPool(timeout=5)
3326 """
3328 def __init__(
3329 self,
3330 max_connections=50,
3331 timeout=20,
3332 connection_class=Connection,
3333 queue_class=LifoQueue,
3334 **connection_kwargs,
3335 ):
3336 self.queue_class = queue_class
3337 self.timeout = timeout
3338 self._in_maintenance = False
3339 self._locked = False
3340 super().__init__(
3341 connection_class=connection_class,
3342 max_connections=max_connections,
3343 **connection_kwargs,
3344 )
3346 def reset(self):
3347 # Create and fill up a thread safe queue with ``None`` values.
3348 try:
3349 if self._in_maintenance:
3350 self._lock.acquire()
3351 self._locked = True
3353 # Record metrics for connections being removed before clearing
3354 # Note: Access pool.queue directly to avoid deadlock since we may
3355 # already hold self._lock (which is non-reentrant)
3356 if (
3357 hasattr(self, "_connections")
3358 and self._connections
3359 and hasattr(self, "pool")
3360 ):
3361 with self._lock:
3362 connections_in_queue = {conn for conn in self.pool.queue if conn}
3363 idle_count = len(connections_in_queue)
3364 in_use_count = len(self._connections) - idle_count
3365 if idle_count > 0 or in_use_count > 0:
3366 pool_name = get_pool_name(self)
3367 if idle_count > 0:
3368 record_connection_count(
3369 pool_name=pool_name,
3370 connection_state=ConnectionState.IDLE,
3371 counter=-idle_count,
3372 )
3373 if in_use_count > 0:
3374 record_connection_count(
3375 pool_name=pool_name,
3376 connection_state=ConnectionState.USED,
3377 counter=-in_use_count,
3378 )
3380 self.pool = self.queue_class(self.max_connections)
3381 while True:
3382 try:
3383 self.pool.put_nowait(None)
3384 except Full:
3385 break
3387 # Keep a list of actual connection instances so that we can
3388 # disconnect them later.
3389 self._connections = []
3390 finally:
3391 if self._locked:
3392 try:
3393 self._lock.release()
3394 except Exception:
3395 pass
3396 self._locked = False
3398 # this must be the last operation in this method. while reset() is
3399 # called when holding _fork_lock, other threads in this process
3400 # can call _checkpid() which compares self.pid and os.getpid() without
3401 # holding any lock (for performance reasons). keeping this assignment
3402 # as the last operation ensures that those other threads will also
3403 # notice a pid difference and block waiting for the first thread to
3404 # release _fork_lock. when each of these threads eventually acquire
3405 # _fork_lock, they will notice that another thread already called
3406 # reset() and they will immediately release _fork_lock and continue on.
3407 self.pid = os.getpid()
3409 def __del__(self) -> None:
3410 """Clean up connection pool and record metrics when garbage collected."""
3411 try:
3412 # Note: Access pool.queue directly to avoid potential deadlock
3413 # if GC runs while the lock is held by the same thread
3414 if (
3415 hasattr(self, "_connections")
3416 and self._connections
3417 and hasattr(self, "pool")
3418 ):
3419 connections_in_queue = {conn for conn in self.pool.queue if conn}
3420 idle_count = len(connections_in_queue)
3421 in_use_count = len(self._connections) - idle_count
3422 if idle_count > 0 or in_use_count > 0:
3423 pool_name = get_pool_name(self)
3424 if idle_count > 0:
3425 record_connection_count(
3426 pool_name=pool_name,
3427 connection_state=ConnectionState.IDLE,
3428 counter=-idle_count,
3429 )
3430 if in_use_count > 0:
3431 record_connection_count(
3432 pool_name=pool_name,
3433 connection_state=ConnectionState.USED,
3434 counter=-in_use_count,
3435 )
3436 except Exception:
3437 pass
3439 def make_connection(self):
3440 "Make a fresh connection."
3441 try:
3442 if self._in_maintenance:
3443 self._lock.acquire()
3444 self._locked = True
3446 if self.cache is not None:
3447 connection = CacheProxyConnection(
3448 self.connection_class(**self.connection_kwargs),
3449 self.cache,
3450 self._lock,
3451 )
3452 else:
3453 connection = self.connection_class(**self.connection_kwargs)
3454 self._connections.append(connection)
3456 # Record new connection created (starts as IDLE)
3457 record_connection_count(
3458 pool_name=get_pool_name(self),
3459 connection_state=ConnectionState.IDLE,
3460 counter=1,
3461 )
3463 return connection
3464 finally:
3465 if self._locked:
3466 try:
3467 self._lock.release()
3468 except Exception:
3469 pass
3470 self._locked = False
3472 @deprecated_args(
3473 args_to_warn=["*"],
3474 reason="Use get_connection() without args instead",
3475 version="5.3.0",
3476 )
3477 def get_connection(self, command_name=None, *keys, **options):
3478 """
3479 Get a connection, blocking for ``self.timeout`` until a connection
3480 is available from the pool.
3482 If the connection returned is ``None`` then creates a new connection.
3483 Because we use a last-in first-out queue, the existing connections
3484 (having been returned to the pool after the initial ``None`` values
3485 were added) will be returned before ``None`` values. This means we only
3486 create new connections when we need to, i.e.: the actual number of
3487 connections will only increase in response to demand.
3488 """
3489 start_time_acquired = time.monotonic()
3490 # Make sure we haven't changed process.
3491 self._checkpid()
3492 is_created = False
3494 # Try and get a connection from the pool. If one isn't available within
3495 # self.timeout then raise a ``ConnectionError``.
3496 connection = None
3497 try:
3498 if self._in_maintenance:
3499 self._lock.acquire()
3500 self._locked = True
3501 try:
3502 connection = self.pool.get(block=True, timeout=self.timeout)
3503 except Empty:
3504 # Note that this is not caught by the redis client and will be
3505 # raised unless handled by application code. If you want never to
3506 raise ConnectionError("No connection available.")
3508 # If the ``connection`` is actually ``None`` then that's a cue to make
3509 # a new connection to add to the pool.
3510 if connection is None:
3511 # Start timing for observability
3512 start_time_created = time.monotonic()
3513 connection = self.make_connection()
3514 is_created = True
3515 finally:
3516 if self._locked:
3517 try:
3518 self._lock.release()
3519 except Exception:
3520 pass
3521 self._locked = False
3523 # Record state transition: IDLE -> USED
3524 # (make_connection already recorded IDLE +1 for new connections)
3525 # This ensures counters stay balanced if connect() fails and release() is called
3526 pool_name = get_pool_name(self)
3527 record_connection_count(
3528 pool_name=pool_name,
3529 connection_state=ConnectionState.IDLE,
3530 counter=-1,
3531 )
3532 record_connection_count(
3533 pool_name=pool_name,
3534 connection_state=ConnectionState.USED,
3535 counter=1,
3536 )
3538 try:
3539 # ensure this connection is connected to Redis
3540 connection.connect()
3541 # connections that the pool provides should be ready to send
3542 # a command. if not, the connection was either returned to the
3543 # pool before all data has been read or the socket has been
3544 # closed. either way, reconnect and verify everything is good.
3545 try:
3546 if connection.can_read():
3547 raise ConnectionError("Connection has data")
3548 except (ConnectionError, TimeoutError, OSError):
3549 connection.disconnect()
3550 connection.connect()
3551 if connection.can_read():
3552 raise ConnectionError("Connection not ready")
3553 except BaseException:
3554 # release the connection back to the pool so that we don't leak it
3555 self.release(connection)
3556 raise
3558 if is_created:
3559 record_connection_create_time(
3560 connection_pool=self,
3561 duration_seconds=time.monotonic() - start_time_created,
3562 )
3564 record_connection_wait_time(
3565 pool_name=pool_name,
3566 duration_seconds=time.monotonic() - start_time_acquired,
3567 )
3569 return connection
3571 def release(self, connection):
3572 "Releases the connection back to the pool."
3573 # Make sure we haven't changed process.
3574 self._checkpid()
3576 try:
3577 if self._in_maintenance:
3578 self._lock.acquire()
3579 self._locked = True
3580 if not self.owns_connection(connection):
3581 # pool doesn't own this connection. do not add it back
3582 # to the pool. instead add a None value which is a placeholder
3583 # that will cause the pool to recreate the connection if
3584 # its needed.
3585 connection.disconnect()
3586 self.pool.put_nowait(None)
3587 # Still need to decrement USED since it was counted in get_connection()
3588 record_connection_count(
3589 pool_name="unknown_pool",
3590 connection_state=ConnectionState.USED,
3591 counter=-1,
3592 )
3593 return
3594 if connection.should_reconnect():
3595 connection.disconnect()
3596 # Put the connection back into the pool.
3597 pool_name = get_pool_name(self)
3598 try:
3599 self.pool.put_nowait(connection)
3601 # Record state transition: USED -> IDLE
3602 record_connection_count(
3603 pool_name=pool_name,
3604 connection_state=ConnectionState.USED,
3605 counter=-1,
3606 )
3607 record_connection_count(
3608 pool_name=pool_name,
3609 connection_state=ConnectionState.IDLE,
3610 counter=1,
3611 )
3612 except Full:
3613 pass
3614 finally:
3615 if self._locked:
3616 try:
3617 self._lock.release()
3618 except Exception:
3619 pass
3620 self._locked = False
3622 def disconnect(self, inuse_connections: bool = True):
3623 """
3624 Disconnects either all connections in the pool or just the free connections.
3625 """
3626 self._checkpid()
3627 try:
3628 if self._in_maintenance:
3629 self._lock.acquire()
3630 self._locked = True
3632 if inuse_connections:
3633 connections = self._connections
3634 else:
3635 connections = self._get_free_connections()
3637 for connection in connections:
3638 connection.disconnect()
3639 finally:
3640 if self._locked:
3641 try:
3642 self._lock.release()
3643 except Exception:
3644 pass
3645 self._locked = False
3647 def _get_free_connections(self):
3648 with self._lock:
3649 return {conn for conn in self.pool.queue if conn}
3651 def _get_in_use_connections(self):
3652 with self._lock:
3653 # free connections
3654 connections_in_queue = {conn for conn in self.pool.queue if conn}
3655 # in self._connections we keep all created connections
3656 # so the ones that are not in the queue are the in use ones
3657 return {
3658 conn for conn in self._connections if conn not in connections_in_queue
3659 }
3661 def set_in_maintenance(self, in_maintenance: bool):
3662 """
3663 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
3665 This is used to prevent new connections from being created while we are in maintenance mode.
3666 The pool will be in maintenance mode only when we are processing a MOVING notification.
3667 """
3668 self._in_maintenance = in_maintenance