Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32 CacheProxy,
33)
35from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
36from ._parsers.socket import SENTINEL
37from .auth.token import TokenInterface
38from .backoff import NoBackoff
39from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
40from .driver_info import DriverInfo, resolve_driver_info
41from .event import AfterConnectionReleasedEvent, EventDispatcher
42from .exceptions import (
43 AuthenticationError,
44 AuthenticationWrongNumberOfArgsError,
45 ChildDeadlockedError,
46 ConnectionError,
47 DataError,
48 MaxConnectionsError,
49 RedisError,
50 ResponseError,
51 TimeoutError,
52)
53from .maint_notifications import (
54 MaintenanceState,
55 MaintNotificationsConfig,
56 MaintNotificationsConnectionHandler,
57 MaintNotificationsPoolHandler,
58 OSSMaintNotificationsHandler,
59)
60from .observability.attributes import (
61 DB_CLIENT_CONNECTION_POOL_NAME,
62 DB_CLIENT_CONNECTION_STATE,
63 AttributeBuilder,
64 ConnectionState,
65 CSCReason,
66 CSCResult,
67 get_pool_name,
68)
69from .observability.metrics import CloseReason
70from .observability.recorder import (
71 init_csc_items,
72 record_connection_closed,
73 record_connection_count,
74 record_connection_create_time,
75 record_connection_wait_time,
76 record_csc_eviction,
77 record_csc_network_saved,
78 record_csc_request,
79 record_error_count,
80 register_csc_items_callback,
81)
82from .retry import Retry
83from .utils import (
84 CRYPTOGRAPHY_AVAILABLE,
85 HIREDIS_AVAILABLE,
86 SSL_AVAILABLE,
87 check_protocol_version,
88 compare_versions,
89 deprecated_args,
90 ensure_string,
91 format_error_message,
92 str_if_bytes,
93)
95if SSL_AVAILABLE:
96 import ssl
97 from ssl import VerifyFlags
98else:
99 ssl = None
100 VerifyFlags = None
102if HIREDIS_AVAILABLE:
103 import hiredis
105SYM_STAR = b"*"
106SYM_DOLLAR = b"$"
107SYM_CRLF = b"\r\n"
108SYM_EMPTY = b""
110DEFAULT_RESP_VERSION = 2
112DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
113if HIREDIS_AVAILABLE:
114 DefaultParser = _HiredisParser
115else:
116 DefaultParser = _RESP2Parser
119class HiredisRespSerializer:
120 def pack(self, *args: List):
121 """Pack a series of arguments into the Redis protocol"""
122 output = []
124 if isinstance(args[0], str):
125 args = tuple(args[0].encode().split()) + args[1:]
126 elif b" " in args[0]:
127 args = tuple(args[0].split()) + args[1:]
128 try:
129 output.append(hiredis.pack_command(args))
130 except TypeError:
131 _, value, traceback = sys.exc_info()
132 raise DataError(value).with_traceback(traceback)
134 return output
137class PythonRespSerializer:
138 def __init__(self, buffer_cutoff, encode) -> None:
139 self._buffer_cutoff = buffer_cutoff
140 self.encode = encode
142 def pack(self, *args):
143 """Pack a series of arguments into the Redis protocol"""
144 output = []
145 # the client might have included 1 or more literal arguments in
146 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
147 # arguments to be sent separately, so split the first argument
148 # manually. These arguments should be bytestrings so that they are
149 # not encoded.
150 if isinstance(args[0], str):
151 args = tuple(args[0].encode().split()) + args[1:]
152 elif b" " in args[0]:
153 args = tuple(args[0].split()) + args[1:]
155 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
157 buffer_cutoff = self._buffer_cutoff
158 for arg in map(self.encode, args):
159 # to avoid large string mallocs, chunk the command into the
160 # output list if we're sending large values or memoryviews
161 arg_length = len(arg)
162 if (
163 len(buff) > buffer_cutoff
164 or arg_length > buffer_cutoff
165 or isinstance(arg, memoryview)
166 ):
167 buff = SYM_EMPTY.join(
168 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
169 )
170 output.append(buff)
171 output.append(arg)
172 buff = SYM_CRLF
173 else:
174 buff = SYM_EMPTY.join(
175 (
176 buff,
177 SYM_DOLLAR,
178 str(arg_length).encode(),
179 SYM_CRLF,
180 arg,
181 SYM_CRLF,
182 )
183 )
184 output.append(buff)
185 return output
188class ConnectionInterface:
189 @abstractmethod
190 def repr_pieces(self):
191 pass
193 @abstractmethod
194 def register_connect_callback(self, callback):
195 pass
197 @abstractmethod
198 def deregister_connect_callback(self, callback):
199 pass
201 @abstractmethod
202 def set_parser(self, parser_class):
203 pass
205 @abstractmethod
206 def get_protocol(self):
207 pass
209 @abstractmethod
210 def connect(self):
211 pass
213 @abstractmethod
214 def on_connect(self):
215 pass
217 @abstractmethod
218 def disconnect(self, *args, **kwargs):
219 pass
221 @abstractmethod
222 def check_health(self):
223 pass
225 @abstractmethod
226 def send_packed_command(self, command, check_health=True):
227 pass
229 @abstractmethod
230 def send_command(self, *args, **kwargs):
231 pass
233 @abstractmethod
234 def can_read(self, timeout=0):
235 pass
237 @abstractmethod
238 def read_response(
239 self,
240 disable_decoding=False,
241 *,
242 timeout: Union[float, object] = SENTINEL,
243 disconnect_on_error=True,
244 push_request=False,
245 ):
246 pass
248 @abstractmethod
249 def pack_command(self, *args):
250 pass
252 @abstractmethod
253 def pack_commands(self, commands):
254 pass
256 @property
257 @abstractmethod
258 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
259 pass
261 @abstractmethod
262 def set_re_auth_token(self, token: TokenInterface):
263 pass
265 @abstractmethod
266 def re_auth(self):
267 pass
269 @abstractmethod
270 def mark_for_reconnect(self):
271 """
272 Mark the connection to be reconnected on the next command.
273 This is useful when a connection is moved to a different node.
274 """
275 pass
277 @abstractmethod
278 def should_reconnect(self):
279 """
280 Returns True if the connection should be reconnected.
281 """
282 pass
284 @abstractmethod
285 def reset_should_reconnect(self):
286 """
287 Reset the internal flag to False.
288 """
289 pass
291 @abstractmethod
292 def extract_connection_details(self) -> str:
293 pass
295 @property
296 @abstractmethod
297 def is_connected(self) -> bool:
298 """
299 Return ``True`` if the connection to the server is active.
300 """
301 pass
304class MaintNotificationsAbstractConnection:
305 """
306 Abstract class for handling maintenance notifications logic.
307 This class is expected to be used as base class together with ConnectionInterface.
309 This class is intended to be used with multiple inheritance!
311 All logic related to maintenance notifications is encapsulated in this class.
312 """
314 def __init__(
315 self,
316 maint_notifications_config: Optional[MaintNotificationsConfig],
317 maint_notifications_pool_handler: Optional[
318 MaintNotificationsPoolHandler
319 ] = None,
320 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
321 maintenance_notification_hash: Optional[int] = None,
322 orig_host_address: Optional[str] = None,
323 orig_socket_timeout: Optional[float] = None,
324 orig_socket_connect_timeout: Optional[float] = None,
325 oss_cluster_maint_notifications_handler: Optional[
326 OSSMaintNotificationsHandler
327 ] = None,
328 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
329 event_dispatcher: Optional[EventDispatcher] = None,
330 ):
331 """
332 Initialize the maintenance notifications for the connection.
334 Args:
335 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
336 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
337 maintenance_state (MaintenanceState): The current maintenance state of the connection.
338 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
339 orig_host_address (Optional[str]): The original host address of the connection.
340 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
341 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
342 oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications.
343 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
344 If not provided, the parser from the connection is used.
345 This is useful when the parser is created after this object.
346 """
347 self.maint_notifications_config = maint_notifications_config
348 self.maintenance_state = maintenance_state
349 self.maintenance_notification_hash = maintenance_notification_hash
351 if event_dispatcher is not None:
352 self.event_dispatcher = event_dispatcher
353 else:
354 self.event_dispatcher = EventDispatcher()
356 self._configure_maintenance_notifications(
357 maint_notifications_pool_handler,
358 orig_host_address,
359 orig_socket_timeout,
360 orig_socket_connect_timeout,
361 oss_cluster_maint_notifications_handler,
362 parser,
363 )
364 self._processed_start_maint_notifications = set()
365 self._skipped_end_maint_notifications = set()
367 @abstractmethod
368 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
369 pass
371 @abstractmethod
372 def _get_socket(self) -> Optional[socket.socket]:
373 pass
375 @abstractmethod
376 def get_protocol(self) -> Union[int, str]:
377 """
378 Returns:
379 The RESP protocol version, or ``None`` if the protocol is not specified,
380 in which case the server default will be used.
381 """
382 pass
384 @property
385 @abstractmethod
386 def host(self) -> str:
387 pass
389 @host.setter
390 @abstractmethod
391 def host(self, value: str):
392 pass
394 @property
395 @abstractmethod
396 def socket_timeout(self) -> Optional[Union[float, int]]:
397 pass
399 @socket_timeout.setter
400 @abstractmethod
401 def socket_timeout(self, value: Optional[Union[float, int]]):
402 pass
404 @property
405 @abstractmethod
406 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
407 pass
409 @socket_connect_timeout.setter
410 @abstractmethod
411 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
412 pass
414 @abstractmethod
415 def send_command(self, *args, **kwargs):
416 pass
418 @abstractmethod
419 def read_response(
420 self,
421 disable_decoding=False,
422 *,
423 timeout: Union[float, object] = SENTINEL,
424 disconnect_on_error=True,
425 push_request=False,
426 ):
427 pass
429 @abstractmethod
430 def disconnect(self, *args, **kwargs):
431 pass
433 def _configure_maintenance_notifications(
434 self,
435 maint_notifications_pool_handler: Optional[
436 MaintNotificationsPoolHandler
437 ] = None,
438 orig_host_address=None,
439 orig_socket_timeout=None,
440 orig_socket_connect_timeout=None,
441 oss_cluster_maint_notifications_handler: Optional[
442 OSSMaintNotificationsHandler
443 ] = None,
444 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
445 ):
446 """
447 Enable maintenance notifications by setting up
448 handlers and storing original connection parameters.
450 Should be used ONLY with parsers that support push notifications.
451 """
452 if (
453 not self.maint_notifications_config
454 or not self.maint_notifications_config.enabled
455 ):
456 self._maint_notifications_pool_handler = None
457 self._maint_notifications_connection_handler = None
458 self._oss_cluster_maint_notifications_handler = None
459 return
461 if not parser:
462 raise RedisError(
463 "To configure maintenance notifications, a parser must be provided!"
464 )
466 if not isinstance(parser, _HiredisParser) and not isinstance(
467 parser, _RESP3Parser
468 ):
469 raise RedisError(
470 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
471 )
473 if maint_notifications_pool_handler:
474 # Extract a reference to a new pool handler that copies all properties
475 # of the original one and has a different connection reference
476 # This is needed because when we attach the handler to the parser
477 # we need to make sure that the handler has a reference to the
478 # connection that the parser is attached to.
479 self._maint_notifications_pool_handler = (
480 maint_notifications_pool_handler.get_handler_for_connection()
481 )
482 self._maint_notifications_pool_handler.set_connection(self)
483 else:
484 self._maint_notifications_pool_handler = None
486 self._maint_notifications_connection_handler = (
487 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
488 )
490 if oss_cluster_maint_notifications_handler:
491 self._oss_cluster_maint_notifications_handler = (
492 oss_cluster_maint_notifications_handler
493 )
494 else:
495 self._oss_cluster_maint_notifications_handler = None
497 # Set up OSS cluster handler to parser if available
498 if self._oss_cluster_maint_notifications_handler:
499 parser.set_oss_cluster_maint_push_handler(
500 self._oss_cluster_maint_notifications_handler.handle_notification
501 )
503 # Set up pool handler to parser if available
504 if self._maint_notifications_pool_handler:
505 parser.set_node_moving_push_handler(
506 self._maint_notifications_pool_handler.handle_notification
507 )
509 # Set up connection handler
510 parser.set_maintenance_push_handler(
511 self._maint_notifications_connection_handler.handle_notification
512 )
514 # Store original connection parameters
515 self.orig_host_address = orig_host_address if orig_host_address else self.host
516 self.orig_socket_timeout = (
517 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
518 )
519 self.orig_socket_connect_timeout = (
520 orig_socket_connect_timeout
521 if orig_socket_connect_timeout
522 else self.socket_connect_timeout
523 )
525 def set_maint_notifications_pool_handler_for_connection(
526 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
527 ):
528 # Deep copy the pool handler to avoid sharing the same pool handler
529 # between multiple connections, because otherwise each connection will override
530 # the connection reference and the pool handler will only hold a reference
531 # to the last connection that was set.
532 maint_notifications_pool_handler_copy = (
533 maint_notifications_pool_handler.get_handler_for_connection()
534 )
536 maint_notifications_pool_handler_copy.set_connection(self)
537 self._get_parser().set_node_moving_push_handler(
538 maint_notifications_pool_handler_copy.handle_notification
539 )
541 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
543 # Update maintenance notification connection handler if it doesn't exist
544 if not self._maint_notifications_connection_handler:
545 self._maint_notifications_connection_handler = (
546 MaintNotificationsConnectionHandler(
547 self, maint_notifications_pool_handler.config
548 )
549 )
550 self._get_parser().set_maintenance_push_handler(
551 self._maint_notifications_connection_handler.handle_notification
552 )
553 else:
554 self._maint_notifications_connection_handler.config = (
555 maint_notifications_pool_handler.config
556 )
558 def set_maint_notifications_cluster_handler_for_connection(
559 self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
560 ):
561 self._get_parser().set_oss_cluster_maint_push_handler(
562 oss_cluster_maint_notifications_handler.handle_notification
563 )
565 self._oss_cluster_maint_notifications_handler = (
566 oss_cluster_maint_notifications_handler
567 )
569 # Update maintenance notification connection handler if it doesn't exist
570 if not self._maint_notifications_connection_handler:
571 self._maint_notifications_connection_handler = (
572 MaintNotificationsConnectionHandler(
573 self, oss_cluster_maint_notifications_handler.config
574 )
575 )
576 self._get_parser().set_maintenance_push_handler(
577 self._maint_notifications_connection_handler.handle_notification
578 )
579 else:
580 self._maint_notifications_connection_handler.config = (
581 oss_cluster_maint_notifications_handler.config
582 )
584 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
585 # Send maintenance notifications handshake if RESP3 is active
586 # and maintenance notifications are enabled
587 # and we have a host to determine the endpoint type from
588 # When the maint_notifications_config enabled mode is "auto",
589 # we just log a warning if the handshake fails
590 # When the mode is enabled=True, we raise an exception in case of failure
591 if (
592 self.get_protocol() not in [2, "2"]
593 and self.maint_notifications_config
594 and self.maint_notifications_config.enabled
595 and self._maint_notifications_connection_handler
596 and hasattr(self, "host")
597 ):
598 self._enable_maintenance_notifications(
599 maint_notifications_config=self.maint_notifications_config,
600 check_health=check_health,
601 )
603 def _enable_maintenance_notifications(
604 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
605 ):
606 try:
607 host = getattr(self, "host", None)
608 if host is None:
609 raise ValueError(
610 "Cannot enable maintenance notifications for connection"
611 " object that doesn't have a host attribute."
612 )
613 else:
614 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
615 self.send_command(
616 "CLIENT",
617 "MAINT_NOTIFICATIONS",
618 "ON",
619 "moving-endpoint-type",
620 endpoint_type.value,
621 check_health=check_health,
622 )
623 response = self.read_response()
624 if not response or str_if_bytes(response) != "OK":
625 raise ResponseError(
626 "The server doesn't support maintenance notifications"
627 )
628 except Exception as e:
629 if (
630 isinstance(e, ResponseError)
631 and maint_notifications_config.enabled == "auto"
632 ):
633 # Log warning but don't fail the connection
634 import logging
636 logger = logging.getLogger(__name__)
637 logger.debug(f"Failed to enable maintenance notifications: {e}")
638 else:
639 raise
641 def get_resolved_ip(self) -> Optional[str]:
642 """
643 Extract the resolved IP address from an
644 established connection or resolve it from the host.
646 First tries to get the actual IP from the socket (most accurate),
647 then falls back to DNS resolution if needed.
649 Args:
650 connection: The connection object to extract the IP from
652 Returns:
653 str: The resolved IP address, or None if it cannot be determined
654 """
656 # Method 1: Try to get the actual IP from the established socket connection
657 # This is most accurate as it shows the exact IP being used
658 try:
659 conn_socket = self._get_socket()
660 if conn_socket is not None:
661 peer_addr = conn_socket.getpeername()
662 if peer_addr and len(peer_addr) >= 1:
663 # For TCP sockets, peer_addr is typically (host, port) tuple
664 # Return just the host part
665 return peer_addr[0]
666 except (AttributeError, OSError):
667 # Socket might not be connected or getpeername() might fail
668 pass
670 # Method 2: Fallback to DNS resolution of the host
671 # This is less accurate but works when socket is not available
672 try:
673 host = getattr(self, "host", "localhost")
674 port = getattr(self, "port", 6379)
675 if host:
676 # Use getaddrinfo to resolve the hostname to IP
677 # This mimics what the connection would do during _connect()
678 addr_info = socket.getaddrinfo(
679 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
680 )
681 if addr_info:
682 # Return the IP from the first result
683 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
684 # sockaddr[0] is the IP address
685 return str(addr_info[0][4][0])
686 except (AttributeError, OSError, socket.gaierror):
687 # DNS resolution might fail
688 pass
690 return None
692 @property
693 def maintenance_state(self) -> MaintenanceState:
694 return self._maintenance_state
696 @maintenance_state.setter
697 def maintenance_state(self, state: "MaintenanceState"):
698 self._maintenance_state = state
700 def add_maint_start_notification(self, id: int):
701 self._processed_start_maint_notifications.add(id)
703 def get_processed_start_notifications(self) -> set:
704 return self._processed_start_maint_notifications
706 def add_skipped_end_notification(self, id: int):
707 self._skipped_end_maint_notifications.add(id)
709 def get_skipped_end_notifications(self) -> set:
710 return self._skipped_end_maint_notifications
712 def reset_received_notifications(self):
713 self._processed_start_maint_notifications.clear()
714 self._skipped_end_maint_notifications.clear()
716 def getpeername(self):
717 """
718 Returns the peer name of the connection.
719 """
720 conn_socket = self._get_socket()
721 if conn_socket:
722 return conn_socket.getpeername()[0]
723 return None
725 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
726 conn_socket = self._get_socket()
727 if conn_socket:
728 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
729 # if the current timeout is 0 it means we are in the middle of a can_read call
730 # in this case we don't want to change the timeout because the operation
731 # is non-blocking and should return immediately
732 # Changing the state from non-blocking to blocking in the middle of a read operation
733 # will lead to a deadlock
734 if conn_socket.gettimeout() != 0:
735 conn_socket.settimeout(timeout)
736 self.update_parser_timeout(timeout)
738 def update_parser_timeout(self, timeout: Optional[float] = None):
739 parser = self._get_parser()
740 if parser and parser._buffer:
741 if isinstance(parser, _RESP3Parser) and timeout:
742 parser._buffer.socket_timeout = timeout
743 elif isinstance(parser, _HiredisParser):
744 parser._socket_timeout = timeout
746 def set_tmp_settings(
747 self,
748 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
749 tmp_relaxed_timeout: Optional[float] = None,
750 ):
751 """
752 The value of SENTINEL is used to indicate that the property should not be updated.
753 """
754 if tmp_host_address and tmp_host_address != SENTINEL:
755 self.host = str(tmp_host_address)
756 if tmp_relaxed_timeout != -1:
757 self.socket_timeout = tmp_relaxed_timeout
758 self.socket_connect_timeout = tmp_relaxed_timeout
760 def reset_tmp_settings(
761 self,
762 reset_host_address: bool = False,
763 reset_relaxed_timeout: bool = False,
764 ):
765 if reset_host_address:
766 self.host = self.orig_host_address
767 if reset_relaxed_timeout:
768 self.socket_timeout = self.orig_socket_timeout
769 self.socket_connect_timeout = self.orig_socket_connect_timeout
772class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
773 "Manages communication to and from a Redis server"
775 @deprecated_args(
776 args_to_warn=["lib_name", "lib_version"],
777 reason="Use 'driver_info' parameter instead. "
778 "lib_name and lib_version will be removed in a future version.",
779 )
780 def __init__(
781 self,
782 db: int = 0,
783 password: Optional[str] = None,
784 socket_timeout: Optional[float] = None,
785 socket_connect_timeout: Optional[float] = None,
786 retry_on_timeout: bool = False,
787 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
788 encoding: str = "utf-8",
789 encoding_errors: str = "strict",
790 decode_responses: bool = False,
791 parser_class=DefaultParser,
792 socket_read_size: int = 65536,
793 health_check_interval: int = 0,
794 client_name: Optional[str] = None,
795 lib_name: Optional[str] = None,
796 lib_version: Optional[str] = None,
797 driver_info: Optional[DriverInfo] = None,
798 username: Optional[str] = None,
799 retry: Union[Any, None] = None,
800 redis_connect_func: Optional[Callable[[], None]] = None,
801 credential_provider: Optional[CredentialProvider] = None,
802 protocol: Optional[int] = 2,
803 command_packer: Optional[Callable[[], None]] = None,
804 event_dispatcher: Optional[EventDispatcher] = None,
805 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
806 maint_notifications_pool_handler: Optional[
807 MaintNotificationsPoolHandler
808 ] = None,
809 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
810 maintenance_notification_hash: Optional[int] = None,
811 orig_host_address: Optional[str] = None,
812 orig_socket_timeout: Optional[float] = None,
813 orig_socket_connect_timeout: Optional[float] = None,
814 oss_cluster_maint_notifications_handler: Optional[
815 OSSMaintNotificationsHandler
816 ] = None,
817 ):
818 """
819 Initialize a new Connection.
821 To specify a retry policy for specific errors, first set
822 `retry_on_error` to a list of the error/s to retry on, then set
823 `retry` to a valid `Retry` object.
824 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
826 Parameters
827 ----------
828 driver_info : DriverInfo, optional
829 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
830 are ignored. If not provided, a DriverInfo will be created from lib_name
831 and lib_version (or defaults if those are also None).
832 lib_name : str, optional
833 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
834 lib_version : str, optional
835 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
836 """
837 if (username or password) and credential_provider is not None:
838 raise DataError(
839 "'username' and 'password' cannot be passed along with 'credential_"
840 "provider'. Please provide only one of the following arguments: \n"
841 "1. 'password' and (optional) 'username'\n"
842 "2. 'credential_provider'"
843 )
844 if event_dispatcher is None:
845 self._event_dispatcher = EventDispatcher()
846 else:
847 self._event_dispatcher = event_dispatcher
848 self.pid = os.getpid()
849 self.db = db
850 self.client_name = client_name
852 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
853 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
855 self.credential_provider = credential_provider
856 self.password = password
857 self.username = username
858 self._socket_timeout = socket_timeout
859 if socket_connect_timeout is None:
860 socket_connect_timeout = socket_timeout
861 self._socket_connect_timeout = socket_connect_timeout
862 self.retry_on_timeout = retry_on_timeout
863 if retry_on_error is SENTINEL:
864 retry_on_errors_list = []
865 else:
866 retry_on_errors_list = list(retry_on_error)
867 if retry_on_timeout:
868 # Add TimeoutError to the errors list to retry on
869 retry_on_errors_list.append(TimeoutError)
870 self.retry_on_error = retry_on_errors_list
871 if retry or self.retry_on_error:
872 if retry is None:
873 self.retry = Retry(NoBackoff(), 1)
874 else:
875 # deep-copy the Retry object as it is mutable
876 self.retry = copy.deepcopy(retry)
877 if self.retry_on_error:
878 # Update the retry's supported errors with the specified errors
879 self.retry.update_supported_errors(self.retry_on_error)
880 else:
881 self.retry = Retry(NoBackoff(), 0)
882 self.health_check_interval = health_check_interval
883 self.next_health_check = 0
884 self.redis_connect_func = redis_connect_func
885 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
886 self.handshake_metadata = None
887 self._sock = None
888 self._socket_read_size = socket_read_size
889 self._connect_callbacks = []
890 self._buffer_cutoff = 6000
891 self._re_auth_token: Optional[TokenInterface] = None
892 try:
893 p = int(protocol)
894 except TypeError:
895 p = DEFAULT_RESP_VERSION
896 except ValueError:
897 raise ConnectionError("protocol must be an integer")
898 else:
899 if p < 2 or p > 3:
900 raise ConnectionError("protocol must be either 2 or 3")
901 self.protocol = p
902 if self.protocol == 3 and parser_class == _RESP2Parser:
903 # If the protocol is 3 but the parser is RESP2, change it to RESP3
904 # This is needed because the parser might be set before the protocol
905 # or might be provided as a kwarg to the constructor
906 # We need to react on discrepancy only for RESP2 and RESP3
907 # as hiredis supports both
908 parser_class = _RESP3Parser
909 self.set_parser(parser_class)
911 self._command_packer = self._construct_command_packer(command_packer)
912 self._should_reconnect = False
914 # Set up maintenance notifications
915 MaintNotificationsAbstractConnection.__init__(
916 self,
917 maint_notifications_config,
918 maint_notifications_pool_handler,
919 maintenance_state,
920 maintenance_notification_hash,
921 orig_host_address,
922 orig_socket_timeout,
923 orig_socket_connect_timeout,
924 oss_cluster_maint_notifications_handler,
925 self._parser,
926 event_dispatcher=self._event_dispatcher,
927 )
929 def __repr__(self):
930 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
931 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
933 @abstractmethod
934 def repr_pieces(self):
935 pass
937 def __del__(self):
938 try:
939 self.disconnect()
940 except Exception:
941 pass
943 @property
944 def is_connected(self) -> bool:
945 return self._sock is not None
947 def _construct_command_packer(self, packer):
948 if packer is not None:
949 return packer
950 elif HIREDIS_AVAILABLE:
951 return HiredisRespSerializer()
952 else:
953 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
955 def register_connect_callback(self, callback):
956 """
957 Register a callback to be called when the connection is established either
958 initially or reconnected. This allows listeners to issue commands that
959 are ephemeral to the connection, for example pub/sub subscription or
960 key tracking. The callback must be a _method_ and will be kept as
961 a weak reference.
962 """
963 wm = weakref.WeakMethod(callback)
964 if wm not in self._connect_callbacks:
965 self._connect_callbacks.append(wm)
967 def deregister_connect_callback(self, callback):
968 """
969 De-register a previously registered callback. It will no-longer receive
970 notifications on connection events. Calling this is not required when the
971 listener goes away, since the callbacks are kept as weak methods.
972 """
973 try:
974 self._connect_callbacks.remove(weakref.WeakMethod(callback))
975 except ValueError:
976 pass
978 def set_parser(self, parser_class):
979 """
980 Creates a new instance of parser_class with socket size:
981 _socket_read_size and assigns it to the parser for the connection
982 :param parser_class: The required parser class
983 """
984 self._parser = parser_class(socket_read_size=self._socket_read_size)
986 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
987 return self._parser
989 def connect(self):
990 "Connects to the Redis server if not already connected"
991 # try once the socket connect with the handshake, retry the whole
992 # connect/handshake flow based on retry policy
993 self.retry.call_with_retry(
994 lambda: self.connect_check_health(
995 check_health=True, retry_socket_connect=False
996 ),
997 lambda error: self.disconnect(error),
998 )
1000 def connect_check_health(
1001 self, check_health: bool = True, retry_socket_connect: bool = True
1002 ):
1003 if self._sock:
1004 return
1005 # Track actual retry attempts for error reporting
1006 actual_retry_attempts = [0]
1008 def failure_callback(error, failure_count):
1009 actual_retry_attempts[0] = failure_count
1010 self.disconnect(error=error, failure_count=failure_count)
1012 try:
1013 if retry_socket_connect:
1014 sock = self.retry.call_with_retry(
1015 self._connect,
1016 failure_callback,
1017 with_failure_count=True,
1018 )
1019 else:
1020 sock = self._connect()
1021 except socket.timeout:
1022 e = TimeoutError("Timeout connecting to server")
1023 record_error_count(
1024 server_address=self.host,
1025 server_port=self.port,
1026 network_peer_address=self.host,
1027 network_peer_port=self.port,
1028 error_type=e,
1029 retry_attempts=actual_retry_attempts[0],
1030 )
1031 raise e
1032 except OSError as e:
1033 e = ConnectionError(self._error_message(e))
1034 record_error_count(
1035 server_address=getattr(self, "host", None),
1036 server_port=getattr(self, "port", None),
1037 network_peer_address=getattr(self, "host", None),
1038 network_peer_port=getattr(self, "port", None),
1039 error_type=e,
1040 retry_attempts=actual_retry_attempts[0],
1041 )
1042 raise e
1044 self._sock = sock
1045 try:
1046 if self.redis_connect_func is None:
1047 # Use the default on_connect function
1048 self.on_connect_check_health(check_health=check_health)
1049 else:
1050 # Use the passed function redis_connect_func
1051 self.redis_connect_func(self)
1052 except RedisError:
1053 # clean up after any error in on_connect
1054 self.disconnect()
1055 raise
1057 # run any user callbacks. right now the only internal callback
1058 # is for pubsub channel/pattern resubscription
1059 # first, remove any dead weakrefs
1060 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
1061 for ref in self._connect_callbacks:
1062 callback = ref()
1063 if callback:
1064 callback(self)
1066 @abstractmethod
1067 def _connect(self):
1068 pass
1070 @abstractmethod
1071 def _host_error(self):
1072 pass
1074 def _error_message(self, exception):
1075 return format_error_message(self._host_error(), exception)
1077 def on_connect(self):
1078 self.on_connect_check_health(check_health=True)
1080 def on_connect_check_health(self, check_health: bool = True):
1081 "Initialize the connection, authenticate and select a database"
1082 self._parser.on_connect(self)
1083 parser = self._parser
1085 auth_args = None
1086 # if credential provider or username and/or password are set, authenticate
1087 if self.credential_provider or (self.username or self.password):
1088 cred_provider = (
1089 self.credential_provider
1090 or UsernamePasswordCredentialProvider(self.username, self.password)
1091 )
1092 auth_args = cred_provider.get_credentials()
1094 # if resp version is specified and we have auth args,
1095 # we need to send them via HELLO
1096 if auth_args and self.protocol not in [2, "2"]:
1097 if isinstance(self._parser, _RESP2Parser):
1098 self.set_parser(_RESP3Parser)
1099 # update cluster exception classes
1100 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1101 self._parser.on_connect(self)
1102 if len(auth_args) == 1:
1103 auth_args = ["default", auth_args[0]]
1104 # avoid checking health here -- PING will fail if we try
1105 # to check the health prior to the AUTH
1106 self.send_command(
1107 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
1108 )
1109 self.handshake_metadata = self.read_response()
1110 # if response.get(b"proto") != self.protocol and response.get(
1111 # "proto"
1112 # ) != self.protocol:
1113 # raise ConnectionError("Invalid RESP version")
1114 elif auth_args:
1115 # avoid checking health here -- PING will fail if we try
1116 # to check the health prior to the AUTH
1117 self.send_command("AUTH", *auth_args, check_health=False)
1119 try:
1120 auth_response = self.read_response()
1121 except AuthenticationWrongNumberOfArgsError:
1122 # a username and password were specified but the Redis
1123 # server seems to be < 6.0.0 which expects a single password
1124 # arg. retry auth with just the password.
1125 # https://github.com/andymccurdy/redis-py/issues/1274
1126 self.send_command("AUTH", auth_args[-1], check_health=False)
1127 auth_response = self.read_response()
1129 if str_if_bytes(auth_response) != "OK":
1130 raise AuthenticationError("Invalid Username or Password")
1132 # if resp version is specified, switch to it
1133 elif self.protocol not in [2, "2"]:
1134 if isinstance(self._parser, _RESP2Parser):
1135 self.set_parser(_RESP3Parser)
1136 # update cluster exception classes
1137 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1138 self._parser.on_connect(self)
1139 self.send_command("HELLO", self.protocol, check_health=check_health)
1140 self.handshake_metadata = self.read_response()
1141 if (
1142 self.handshake_metadata.get(b"proto") != self.protocol
1143 and self.handshake_metadata.get("proto") != self.protocol
1144 ):
1145 raise ConnectionError("Invalid RESP version")
1147 # Activate maintenance notifications for this connection
1148 # if enabled in the configuration
1149 # This is a no-op if maintenance notifications are not enabled
1150 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
1152 # if a client_name is given, set it
1153 if self.client_name:
1154 self.send_command(
1155 "CLIENT",
1156 "SETNAME",
1157 self.client_name,
1158 check_health=check_health,
1159 )
1160 if str_if_bytes(self.read_response()) != "OK":
1161 raise ConnectionError("Error setting client name")
1163 # Set the library name and version from driver_info
1164 try:
1165 if self.driver_info and self.driver_info.formatted_name:
1166 self.send_command(
1167 "CLIENT",
1168 "SETINFO",
1169 "LIB-NAME",
1170 self.driver_info.formatted_name,
1171 check_health=check_health,
1172 )
1173 self.read_response()
1174 except ResponseError:
1175 pass
1177 try:
1178 if self.driver_info and self.driver_info.lib_version:
1179 self.send_command(
1180 "CLIENT",
1181 "SETINFO",
1182 "LIB-VER",
1183 self.driver_info.lib_version,
1184 check_health=check_health,
1185 )
1186 self.read_response()
1187 except ResponseError:
1188 pass
1190 # if a database is specified, switch to it
1191 if self.db:
1192 self.send_command("SELECT", self.db, check_health=check_health)
1193 if str_if_bytes(self.read_response()) != "OK":
1194 raise ConnectionError("Invalid Database")
1196 def disconnect(self, *args, **kwargs):
1197 "Disconnects from the Redis server"
1198 self._parser.on_disconnect()
1200 conn_sock = self._sock
1201 self._sock = None
1202 # reset the reconnect flag
1203 self.reset_should_reconnect()
1205 if conn_sock is None:
1206 return
1208 if os.getpid() == self.pid:
1209 try:
1210 conn_sock.shutdown(socket.SHUT_RDWR)
1211 except (OSError, TypeError):
1212 pass
1214 try:
1215 conn_sock.close()
1216 except OSError:
1217 pass
1219 error = kwargs.get("error")
1220 failure_count = kwargs.get("failure_count")
1221 health_check_failed = kwargs.get("health_check_failed")
1223 if error:
1224 if health_check_failed:
1225 close_reason = CloseReason.HEALTHCHECK_FAILED
1226 else:
1227 close_reason = CloseReason.ERROR
1229 if failure_count is not None and failure_count > self.retry.get_retries():
1230 record_error_count(
1231 server_address=self.host,
1232 server_port=self.port,
1233 network_peer_address=self.host,
1234 network_peer_port=self.port,
1235 error_type=error,
1236 retry_attempts=failure_count,
1237 )
1239 record_connection_closed(
1240 close_reason=close_reason,
1241 error_type=error,
1242 )
1243 else:
1244 record_connection_closed(
1245 close_reason=CloseReason.APPLICATION_CLOSE,
1246 )
1248 if self.maintenance_state == MaintenanceState.MAINTENANCE:
1249 # this block will be executed only if the connection was in maintenance state
1250 # and the connection was closed.
1251 # The state change won't be applied on connections that are in Moving state
1252 # because their state and configurations will be handled when the moving ttl expires.
1253 self.reset_tmp_settings(reset_relaxed_timeout=True)
1254 self.maintenance_state = MaintenanceState.NONE
1255 # reset the sets that keep track of received start maint
1256 # notifications and skipped end maint notifications
1257 self.reset_received_notifications()
1259 def mark_for_reconnect(self):
1260 self._should_reconnect = True
1262 def should_reconnect(self):
1263 return self._should_reconnect
1265 def reset_should_reconnect(self):
1266 self._should_reconnect = False
1268 def _send_ping(self):
1269 """Send PING, expect PONG in return"""
1270 self.send_command("PING", check_health=False)
1271 if str_if_bytes(self.read_response()) != "PONG":
1272 raise ConnectionError("Bad response from PING health check")
1274 def _ping_failed(self, error, failure_count):
1275 """Function to call when PING fails"""
1276 self.disconnect(
1277 error=error, failure_count=failure_count, health_check_failed=True
1278 )
1280 def check_health(self):
1281 """Check the health of the connection with a PING/PONG"""
1282 if self.health_check_interval and time.monotonic() > self.next_health_check:
1283 self.retry.call_with_retry(
1284 self._send_ping,
1285 self._ping_failed,
1286 with_failure_count=True,
1287 )
1289 def send_packed_command(self, command, check_health=True):
1290 """Send an already packed command to the Redis server"""
1291 if not self._sock:
1292 self.connect_check_health(check_health=False)
1293 # guard against health check recursion
1294 if check_health:
1295 self.check_health()
1296 try:
1297 if isinstance(command, str):
1298 command = [command]
1299 for item in command:
1300 self._sock.sendall(item)
1301 except socket.timeout:
1302 self.disconnect()
1303 raise TimeoutError("Timeout writing to socket")
1304 except OSError as e:
1305 self.disconnect()
1306 if len(e.args) == 1:
1307 errno, errmsg = "UNKNOWN", e.args[0]
1308 else:
1309 errno = e.args[0]
1310 errmsg = e.args[1]
1311 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1312 except BaseException:
1313 # BaseExceptions can be raised when a socket send operation is not
1314 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1315 # to send un-sent data. However, the send_packed_command() API
1316 # does not support it so there is no point in keeping the connection open.
1317 self.disconnect()
1318 raise
1320 def send_command(self, *args, **kwargs):
1321 """Pack and send a command to the Redis server"""
1322 self.send_packed_command(
1323 self._command_packer.pack(*args),
1324 check_health=kwargs.get("check_health", True),
1325 )
1327 def can_read(self, timeout=0):
1328 """Poll the socket to see if there's data that can be read."""
1329 sock = self._sock
1330 if not sock:
1331 self.connect()
1333 host_error = self._host_error()
1335 try:
1336 return self._parser.can_read(timeout)
1338 except OSError as e:
1339 self.disconnect()
1340 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1342 def read_response(
1343 self,
1344 disable_decoding=False,
1345 *,
1346 timeout: Union[float, object] = SENTINEL,
1347 disconnect_on_error=True,
1348 push_request=False,
1349 ):
1350 """Read the response from a previously sent command"""
1352 host_error = self._host_error()
1354 try:
1355 if self.protocol in ["3", 3]:
1356 response = self._parser.read_response(
1357 disable_decoding=disable_decoding,
1358 push_request=push_request,
1359 timeout=timeout,
1360 )
1361 else:
1362 response = self._parser.read_response(
1363 disable_decoding=disable_decoding, timeout=timeout
1364 )
1365 except socket.timeout:
1366 if disconnect_on_error:
1367 self.disconnect()
1368 raise TimeoutError(f"Timeout reading from {host_error}")
1369 except OSError as e:
1370 if disconnect_on_error:
1371 self.disconnect()
1372 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1373 except BaseException:
1374 # Also by default close in case of BaseException. A lot of code
1375 # relies on this behaviour when doing Command/Response pairs.
1376 # See #1128.
1377 if disconnect_on_error:
1378 self.disconnect()
1379 raise
1381 if self.health_check_interval:
1382 self.next_health_check = time.monotonic() + self.health_check_interval
1384 if isinstance(response, ResponseError):
1385 try:
1386 raise response
1387 finally:
1388 del response # avoid creating ref cycles
1389 return response
1391 def pack_command(self, *args):
1392 """Pack a series of arguments into the Redis protocol"""
1393 return self._command_packer.pack(*args)
1395 def pack_commands(self, commands):
1396 """Pack multiple commands into the Redis protocol"""
1397 output = []
1398 pieces = []
1399 buffer_length = 0
1400 buffer_cutoff = self._buffer_cutoff
1402 for cmd in commands:
1403 for chunk in self._command_packer.pack(*cmd):
1404 chunklen = len(chunk)
1405 if (
1406 buffer_length > buffer_cutoff
1407 or chunklen > buffer_cutoff
1408 or isinstance(chunk, memoryview)
1409 ):
1410 if pieces:
1411 output.append(SYM_EMPTY.join(pieces))
1412 buffer_length = 0
1413 pieces = []
1415 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1416 output.append(chunk)
1417 else:
1418 pieces.append(chunk)
1419 buffer_length += chunklen
1421 if pieces:
1422 output.append(SYM_EMPTY.join(pieces))
1423 return output
1425 def get_protocol(self) -> Union[int, str]:
1426 return self.protocol
1428 @property
1429 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1430 return self._handshake_metadata
1432 @handshake_metadata.setter
1433 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1434 self._handshake_metadata = value
1436 def set_re_auth_token(self, token: TokenInterface):
1437 self._re_auth_token = token
1439 def re_auth(self):
1440 if self._re_auth_token is not None:
1441 self.send_command(
1442 "AUTH",
1443 self._re_auth_token.try_get("oid"),
1444 self._re_auth_token.get_value(),
1445 )
1446 self.read_response()
1447 self._re_auth_token = None
1449 def _get_socket(self) -> Optional[socket.socket]:
1450 return self._sock
1452 @property
1453 def socket_timeout(self) -> Optional[Union[float, int]]:
1454 return self._socket_timeout
1456 @socket_timeout.setter
1457 def socket_timeout(self, value: Optional[Union[float, int]]):
1458 self._socket_timeout = value
1460 @property
1461 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1462 return self._socket_connect_timeout
1464 @socket_connect_timeout.setter
1465 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1466 self._socket_connect_timeout = value
1468 def extract_connection_details(self) -> str:
1469 socket_address = None
1470 if self._sock is None:
1471 return "not connected"
1472 try:
1473 socket_address = self._sock.getsockname() if self._sock else None
1474 socket_address = socket_address[1] if socket_address else None
1475 except (AttributeError, OSError):
1476 pass
1478 return f"connected to ip {self.get_resolved_ip()}, local socket port: {socket_address}"
1481class Connection(AbstractConnection):
1482 "Manages TCP communication to and from a Redis server"
1484 def __init__(
1485 self,
1486 host="localhost",
1487 port=6379,
1488 socket_keepalive=False,
1489 socket_keepalive_options=None,
1490 socket_type=0,
1491 **kwargs,
1492 ):
1493 self._host = host
1494 self.port = int(port)
1495 self.socket_keepalive = socket_keepalive
1496 self.socket_keepalive_options = socket_keepalive_options or {}
1497 self.socket_type = socket_type
1498 super().__init__(**kwargs)
1500 def repr_pieces(self):
1501 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1502 if self.client_name:
1503 pieces.append(("client_name", self.client_name))
1504 return pieces
1506 def _connect(self):
1507 "Create a TCP socket connection"
1508 # we want to mimic what socket.create_connection does to support
1509 # ipv4/ipv6, but we want to set options prior to calling
1510 # socket.connect()
1511 err = None
1513 for res in socket.getaddrinfo(
1514 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1515 ):
1516 family, socktype, proto, canonname, socket_address = res
1517 sock = None
1518 try:
1519 sock = socket.socket(family, socktype, proto)
1520 # TCP_NODELAY
1521 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1523 # TCP_KEEPALIVE
1524 if self.socket_keepalive:
1525 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1526 for k, v in self.socket_keepalive_options.items():
1527 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1529 # set the socket_connect_timeout before we connect
1530 sock.settimeout(self.socket_connect_timeout)
1532 # connect
1533 sock.connect(socket_address)
1535 # set the socket_timeout now that we're connected
1536 sock.settimeout(self.socket_timeout)
1537 return sock
1539 except OSError as _:
1540 err = _
1541 if sock is not None:
1542 try:
1543 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1544 except OSError:
1545 pass
1546 sock.close()
1548 if err is not None:
1549 raise err
1550 raise OSError("socket.getaddrinfo returned an empty list")
1552 def _host_error(self):
1553 return f"{self.host}:{self.port}"
1555 @property
1556 def host(self) -> str:
1557 return self._host
1559 @host.setter
1560 def host(self, value: str):
1561 self._host = value
1564class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1565 DUMMY_CACHE_VALUE = b"foo"
1566 MIN_ALLOWED_VERSION = "7.4.0"
1567 DEFAULT_SERVER_NAME = "redis"
1569 def __init__(
1570 self,
1571 conn: ConnectionInterface,
1572 cache: CacheInterface,
1573 pool_lock: threading.RLock,
1574 ):
1575 self.pid = os.getpid()
1576 self._conn = conn
1577 self.retry = self._conn.retry
1578 self.host = self._conn.host
1579 self.port = self._conn.port
1580 self.db = self._conn.db
1581 self._event_dispatcher = self._conn._event_dispatcher
1582 self.credential_provider = conn.credential_provider
1583 self._pool_lock = pool_lock
1584 self._cache = cache
1585 self._cache_lock = threading.RLock()
1586 self._current_command_cache_key = None
1587 self._current_options = None
1588 self.register_connect_callback(self._enable_tracking_callback)
1590 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1591 MaintNotificationsAbstractConnection.__init__(
1592 self,
1593 self._conn.maint_notifications_config,
1594 self._conn._maint_notifications_pool_handler,
1595 self._conn.maintenance_state,
1596 self._conn.maintenance_notification_hash,
1597 self._conn.host,
1598 self._conn.socket_timeout,
1599 self._conn.socket_connect_timeout,
1600 self._conn._oss_cluster_maint_notifications_handler,
1601 self._conn._get_parser(),
1602 event_dispatcher=self._conn.event_dispatcher,
1603 )
1605 def repr_pieces(self):
1606 return self._conn.repr_pieces()
1608 @property
1609 def is_connected(self) -> bool:
1610 return self._conn.is_connected
1612 def register_connect_callback(self, callback):
1613 self._conn.register_connect_callback(callback)
1615 def deregister_connect_callback(self, callback):
1616 self._conn.deregister_connect_callback(callback)
1618 def set_parser(self, parser_class):
1619 self._conn.set_parser(parser_class)
1621 def set_maint_notifications_pool_handler_for_connection(
1622 self, maint_notifications_pool_handler
1623 ):
1624 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1625 self._conn.set_maint_notifications_pool_handler_for_connection(
1626 maint_notifications_pool_handler
1627 )
1629 def set_maint_notifications_cluster_handler_for_connection(
1630 self, oss_cluster_maint_notifications_handler
1631 ):
1632 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1633 self._conn.set_maint_notifications_cluster_handler_for_connection(
1634 oss_cluster_maint_notifications_handler
1635 )
1637 def get_protocol(self):
1638 return self._conn.get_protocol()
1640 def connect(self):
1641 self._conn.connect()
1643 server_name = self._conn.handshake_metadata.get(b"server", None)
1644 if server_name is None:
1645 server_name = self._conn.handshake_metadata.get("server", None)
1646 server_ver = self._conn.handshake_metadata.get(b"version", None)
1647 if server_ver is None:
1648 server_ver = self._conn.handshake_metadata.get("version", None)
1649 if server_ver is None or server_name is None:
1650 raise ConnectionError("Cannot retrieve information about server version")
1652 server_ver = ensure_string(server_ver)
1653 server_name = ensure_string(server_name)
1655 if (
1656 server_name != self.DEFAULT_SERVER_NAME
1657 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1658 ):
1659 raise ConnectionError(
1660 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1661 )
1663 def on_connect(self):
1664 self._conn.on_connect()
1666 def disconnect(self, *args, **kwargs):
1667 with self._cache_lock:
1668 self._cache.flush()
1669 self._conn.disconnect(*args, **kwargs)
1671 def check_health(self):
1672 self._conn.check_health()
1674 def send_packed_command(self, command, check_health=True):
1675 # TODO: Investigate if it's possible to unpack command
1676 # or extract keys from packed command
1677 self._conn.send_packed_command(command)
1679 def send_command(self, *args, **kwargs):
1680 self._process_pending_invalidations()
1682 with self._cache_lock:
1683 # Command is write command or not allowed
1684 # to be cached.
1685 if not self._cache.is_cachable(
1686 CacheKey(command=args[0], redis_keys=(), redis_args=())
1687 ):
1688 self._current_command_cache_key = None
1689 self._conn.send_command(*args, **kwargs)
1690 return
1692 if kwargs.get("keys") is None:
1693 raise ValueError("Cannot create cache key.")
1695 # Creates cache key.
1696 self._current_command_cache_key = CacheKey(
1697 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1698 )
1700 with self._cache_lock:
1701 # We have to trigger invalidation processing in case if
1702 # it was cached by another connection to avoid
1703 # queueing invalidations in stale connections.
1704 if self._cache.get(self._current_command_cache_key):
1705 entry = self._cache.get(self._current_command_cache_key)
1707 if entry.connection_ref != self._conn:
1708 with self._pool_lock:
1709 while entry.connection_ref.can_read():
1710 entry.connection_ref.read_response(push_request=True)
1712 # Re-check: if the entry was invalidated during the drain,
1713 # fall through to send the command over the network.
1714 if self._cache.get(self._current_command_cache_key):
1715 return
1717 # Set temporary entry value to prevent
1718 # race condition from another connection.
1719 self._cache.set(
1720 CacheEntry(
1721 cache_key=self._current_command_cache_key,
1722 cache_value=self.DUMMY_CACHE_VALUE,
1723 status=CacheEntryStatus.IN_PROGRESS,
1724 connection_ref=self._conn,
1725 )
1726 )
1728 # Send command over socket only if it's allowed
1729 # read-only command that not yet cached.
1730 self._conn.send_command(*args, **kwargs)
1732 def can_read(self, timeout=0):
1733 return self._conn.can_read(timeout)
1735 def read_response(
1736 self,
1737 disable_decoding=False,
1738 *,
1739 timeout: Union[float, object] = SENTINEL,
1740 disconnect_on_error=True,
1741 push_request=False,
1742 ):
1743 with self._cache_lock:
1744 # Check if command response exists in a cache and it's not in progress.
1745 if self._current_command_cache_key is not None:
1746 if (
1747 self._cache.get(self._current_command_cache_key) is not None
1748 and self._cache.get(self._current_command_cache_key).status
1749 != CacheEntryStatus.IN_PROGRESS
1750 ):
1751 res = copy.deepcopy(
1752 self._cache.get(self._current_command_cache_key).cache_value
1753 )
1754 self._current_command_cache_key = None
1755 record_csc_request(
1756 result=CSCResult.HIT,
1757 )
1758 record_csc_network_saved(
1759 bytes_saved=len(res) if hasattr(res, "__len__") else 0,
1760 )
1761 return res
1762 record_csc_request(
1763 result=CSCResult.MISS,
1764 )
1766 response = self._conn.read_response(
1767 disable_decoding=disable_decoding,
1768 timeout=timeout,
1769 disconnect_on_error=disconnect_on_error,
1770 push_request=push_request,
1771 )
1773 with self._cache_lock:
1774 # Prevent not-allowed command from caching.
1775 if self._current_command_cache_key is None:
1776 return response
1777 # If response is None prevent from caching.
1778 if response is None:
1779 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1780 return response
1782 cache_entry = self._cache.get(self._current_command_cache_key)
1784 # Cache only responses that still valid
1785 # and wasn't invalidated by another connection in meantime.
1786 if cache_entry is not None:
1787 cache_entry.status = CacheEntryStatus.VALID
1788 cache_entry.cache_value = response
1789 self._cache.set(cache_entry)
1791 self._current_command_cache_key = None
1793 return response
1795 def pack_command(self, *args):
1796 return self._conn.pack_command(*args)
1798 def pack_commands(self, commands):
1799 return self._conn.pack_commands(commands)
1801 @property
1802 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1803 return self._conn.handshake_metadata
1805 def set_re_auth_token(self, token: TokenInterface):
1806 self._conn.set_re_auth_token(token)
1808 def re_auth(self):
1809 self._conn.re_auth()
1811 def mark_for_reconnect(self):
1812 self._conn.mark_for_reconnect()
1814 def should_reconnect(self):
1815 return self._conn.should_reconnect()
1817 def reset_should_reconnect(self):
1818 self._conn.reset_should_reconnect()
1820 @property
1821 def host(self) -> str:
1822 return self._conn.host
1824 @host.setter
1825 def host(self, value: str):
1826 self._conn.host = value
1828 @property
1829 def socket_timeout(self) -> Optional[Union[float, int]]:
1830 return self._conn.socket_timeout
1832 @socket_timeout.setter
1833 def socket_timeout(self, value: Optional[Union[float, int]]):
1834 self._conn.socket_timeout = value
1836 @property
1837 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1838 return self._conn.socket_connect_timeout
1840 @socket_connect_timeout.setter
1841 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1842 self._conn.socket_connect_timeout = value
1844 @property
1845 def _maint_notifications_connection_handler(
1846 self,
1847 ) -> Optional[MaintNotificationsConnectionHandler]:
1848 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1849 return self._conn._maint_notifications_connection_handler
1851 @_maint_notifications_connection_handler.setter
1852 def _maint_notifications_connection_handler(
1853 self, value: Optional[MaintNotificationsConnectionHandler]
1854 ):
1855 self._conn._maint_notifications_connection_handler = value
1857 def _get_socket(self) -> Optional[socket.socket]:
1858 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1859 return self._conn._get_socket()
1860 else:
1861 raise NotImplementedError(
1862 "Maintenance notifications are not supported by this connection type"
1863 )
1865 def _get_maint_notifications_connection_instance(
1866 self, connection
1867 ) -> MaintNotificationsAbstractConnection:
1868 """
1869 Validate that connection instance supports maintenance notifications.
1870 With this helper method we ensure that we are working
1871 with the correct connection type.
1872 After twe validate that connection instance supports maintenance notifications
1873 we can safely return the connection instance
1874 as MaintNotificationsAbstractConnection.
1875 """
1876 if not isinstance(connection, MaintNotificationsAbstractConnection):
1877 raise NotImplementedError(
1878 "Maintenance notifications are not supported by this connection type"
1879 )
1880 else:
1881 return connection
1883 @property
1884 def maintenance_state(self) -> MaintenanceState:
1885 con = self._get_maint_notifications_connection_instance(self._conn)
1886 return con.maintenance_state
1888 @maintenance_state.setter
1889 def maintenance_state(self, state: MaintenanceState):
1890 con = self._get_maint_notifications_connection_instance(self._conn)
1891 con.maintenance_state = state
1893 def getpeername(self):
1894 con = self._get_maint_notifications_connection_instance(self._conn)
1895 return con.getpeername()
1897 def get_resolved_ip(self):
1898 con = self._get_maint_notifications_connection_instance(self._conn)
1899 return con.get_resolved_ip()
1901 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1902 con = self._get_maint_notifications_connection_instance(self._conn)
1903 con.update_current_socket_timeout(relaxed_timeout)
1905 def set_tmp_settings(
1906 self,
1907 tmp_host_address: Optional[str] = None,
1908 tmp_relaxed_timeout: Optional[float] = None,
1909 ):
1910 con = self._get_maint_notifications_connection_instance(self._conn)
1911 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1913 def reset_tmp_settings(
1914 self,
1915 reset_host_address: bool = False,
1916 reset_relaxed_timeout: bool = False,
1917 ):
1918 con = self._get_maint_notifications_connection_instance(self._conn)
1919 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1921 def _connect(self):
1922 self._conn._connect()
1924 def _host_error(self):
1925 self._conn._host_error()
1927 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1928 conn.send_command("CLIENT", "TRACKING", "ON")
1929 conn.read_response()
1930 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1932 def _process_pending_invalidations(self):
1933 while self.can_read():
1934 self._conn.read_response(push_request=True)
1936 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1937 with self._cache_lock:
1938 # Flush cache when DB flushed on server-side
1939 if data[1] is None:
1940 self._cache.flush()
1941 else:
1942 keys_deleted = self._cache.delete_by_redis_keys(data[1])
1944 if len(keys_deleted) > 0:
1945 record_csc_eviction(
1946 count=len(keys_deleted),
1947 reason=CSCReason.INVALIDATION,
1948 )
1950 def extract_connection_details(self) -> str:
1951 return self._conn.extract_connection_details()
1954class SSLConnection(Connection):
1955 """Manages SSL connections to and from the Redis server(s).
1956 This class extends the Connection class, adding SSL functionality, and making
1957 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1958 """ # noqa
1960 def __init__(
1961 self,
1962 ssl_keyfile=None,
1963 ssl_certfile=None,
1964 ssl_cert_reqs="required",
1965 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1966 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1967 ssl_ca_certs=None,
1968 ssl_ca_data=None,
1969 ssl_check_hostname=True,
1970 ssl_ca_path=None,
1971 ssl_password=None,
1972 ssl_validate_ocsp=False,
1973 ssl_validate_ocsp_stapled=False,
1974 ssl_ocsp_context=None,
1975 ssl_ocsp_expected_cert=None,
1976 ssl_min_version=None,
1977 ssl_ciphers=None,
1978 **kwargs,
1979 ):
1980 """Constructor
1982 Args:
1983 ssl_keyfile: Path to an ssl private key. Defaults to None.
1984 ssl_certfile: Path to an ssl certificate. Defaults to None.
1985 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1986 or an ssl.VerifyMode. Defaults to "required".
1987 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1988 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1989 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1990 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1991 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1992 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1993 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1995 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1996 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1997 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1998 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1999 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
2000 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
2002 Raises:
2003 RedisError
2004 """ # noqa
2005 if not SSL_AVAILABLE:
2006 raise RedisError("Python wasn't built with SSL support")
2008 self.keyfile = ssl_keyfile
2009 self.certfile = ssl_certfile
2010 if ssl_cert_reqs is None:
2011 ssl_cert_reqs = ssl.CERT_NONE
2012 elif isinstance(ssl_cert_reqs, str):
2013 CERT_REQS = { # noqa: N806
2014 "none": ssl.CERT_NONE,
2015 "optional": ssl.CERT_OPTIONAL,
2016 "required": ssl.CERT_REQUIRED,
2017 }
2018 if ssl_cert_reqs not in CERT_REQS:
2019 raise RedisError(
2020 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
2021 )
2022 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
2023 self.cert_reqs = ssl_cert_reqs
2024 self.ssl_include_verify_flags = ssl_include_verify_flags
2025 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
2026 self.ca_certs = ssl_ca_certs
2027 self.ca_data = ssl_ca_data
2028 self.ca_path = ssl_ca_path
2029 self.check_hostname = (
2030 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
2031 )
2032 self.certificate_password = ssl_password
2033 self.ssl_validate_ocsp = ssl_validate_ocsp
2034 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
2035 self.ssl_ocsp_context = ssl_ocsp_context
2036 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
2037 self.ssl_min_version = ssl_min_version
2038 self.ssl_ciphers = ssl_ciphers
2039 super().__init__(**kwargs)
2041 def _connect(self):
2042 """
2043 Wrap the socket with SSL support, handling potential errors.
2044 """
2045 sock = super()._connect()
2046 try:
2047 return self._wrap_socket_with_ssl(sock)
2048 except (OSError, RedisError):
2049 sock.close()
2050 raise
2052 def _wrap_socket_with_ssl(self, sock):
2053 """
2054 Wraps the socket with SSL support.
2056 Args:
2057 sock: The plain socket to wrap with SSL.
2059 Returns:
2060 An SSL wrapped socket.
2061 """
2062 context = ssl.create_default_context()
2063 context.check_hostname = self.check_hostname
2064 context.verify_mode = self.cert_reqs
2065 if self.ssl_include_verify_flags:
2066 for flag in self.ssl_include_verify_flags:
2067 context.verify_flags |= flag
2068 if self.ssl_exclude_verify_flags:
2069 for flag in self.ssl_exclude_verify_flags:
2070 context.verify_flags &= ~flag
2071 if self.certfile or self.keyfile:
2072 context.load_cert_chain(
2073 certfile=self.certfile,
2074 keyfile=self.keyfile,
2075 password=self.certificate_password,
2076 )
2077 if (
2078 self.ca_certs is not None
2079 or self.ca_path is not None
2080 or self.ca_data is not None
2081 ):
2082 context.load_verify_locations(
2083 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
2084 )
2085 if self.ssl_min_version is not None:
2086 context.minimum_version = self.ssl_min_version
2087 if self.ssl_ciphers:
2088 context.set_ciphers(self.ssl_ciphers)
2089 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
2090 raise RedisError("cryptography is not installed.")
2092 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
2093 raise RedisError(
2094 "Either an OCSP staple or pure OCSP connection must be validated "
2095 "- not both."
2096 )
2098 sslsock = context.wrap_socket(sock, server_hostname=self.host)
2100 # validation for the stapled case
2101 if self.ssl_validate_ocsp_stapled:
2102 import OpenSSL
2104 from .ocsp import ocsp_staple_verifier
2106 # if a context is provided use it - otherwise, a basic context
2107 if self.ssl_ocsp_context is None:
2108 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
2109 staple_ctx.use_certificate_file(self.certfile)
2110 staple_ctx.use_privatekey_file(self.keyfile)
2111 else:
2112 staple_ctx = self.ssl_ocsp_context
2114 staple_ctx.set_ocsp_client_callback(
2115 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
2116 )
2118 # need another socket
2119 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
2120 con.request_ocsp()
2121 con.connect((self.host, self.port))
2122 con.do_handshake()
2123 con.shutdown()
2124 return sslsock
2126 # pure ocsp validation
2127 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
2128 from .ocsp import OCSPVerifier
2130 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
2131 if o.is_valid():
2132 return sslsock
2133 else:
2134 raise ConnectionError("ocsp validation error")
2135 return sslsock
2138class UnixDomainSocketConnection(AbstractConnection):
2139 "Manages UDS communication to and from a Redis server"
2141 def __init__(self, path="", socket_timeout=None, **kwargs):
2142 super().__init__(**kwargs)
2143 self.path = path
2144 self.socket_timeout = socket_timeout
2146 def repr_pieces(self):
2147 pieces = [("path", self.path), ("db", self.db)]
2148 if self.client_name:
2149 pieces.append(("client_name", self.client_name))
2150 return pieces
2152 def _connect(self):
2153 "Create a Unix domain socket connection"
2154 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
2155 sock.settimeout(self.socket_connect_timeout)
2156 try:
2157 sock.connect(self.path)
2158 except OSError:
2159 # Prevent ResourceWarnings for unclosed sockets.
2160 try:
2161 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
2162 except OSError:
2163 pass
2164 sock.close()
2165 raise
2166 sock.settimeout(self.socket_timeout)
2167 return sock
2169 def _host_error(self):
2170 return self.path
2173FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
2176def to_bool(value):
2177 if value is None or value == "":
2178 return None
2179 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
2180 return False
2181 return bool(value)
2184def parse_ssl_verify_flags(value):
2185 # flags are passed in as a string representation of a list,
2186 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
2187 verify_flags_str = value.replace("[", "").replace("]", "")
2189 verify_flags = []
2190 for flag in verify_flags_str.split(","):
2191 flag = flag.strip()
2192 if not hasattr(VerifyFlags, flag):
2193 raise ValueError(f"Invalid ssl verify flag: {flag}")
2194 verify_flags.append(getattr(VerifyFlags, flag))
2195 return verify_flags
2198URL_QUERY_ARGUMENT_PARSERS = {
2199 "db": int,
2200 "socket_timeout": float,
2201 "socket_connect_timeout": float,
2202 "socket_keepalive": to_bool,
2203 "retry_on_timeout": to_bool,
2204 "retry_on_error": list,
2205 "max_connections": int,
2206 "health_check_interval": int,
2207 "ssl_check_hostname": to_bool,
2208 "ssl_include_verify_flags": parse_ssl_verify_flags,
2209 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
2210 "timeout": float,
2211}
2214def parse_url(url):
2215 if not (
2216 url.startswith("redis://")
2217 or url.startswith("rediss://")
2218 or url.startswith("unix://")
2219 ):
2220 raise ValueError(
2221 "Redis URL must specify one of the following "
2222 "schemes (redis://, rediss://, unix://)"
2223 )
2225 url = urlparse(url)
2226 kwargs = {}
2228 for name, value in parse_qs(url.query).items():
2229 if value and len(value) > 0:
2230 value = unquote(value[0])
2231 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
2232 if parser:
2233 try:
2234 kwargs[name] = parser(value)
2235 except (TypeError, ValueError):
2236 raise ValueError(f"Invalid value for '{name}' in connection URL.")
2237 else:
2238 kwargs[name] = value
2240 if url.username:
2241 kwargs["username"] = unquote(url.username)
2242 if url.password:
2243 kwargs["password"] = unquote(url.password)
2245 # We only support redis://, rediss:// and unix:// schemes.
2246 if url.scheme == "unix":
2247 if url.path:
2248 kwargs["path"] = unquote(url.path)
2249 kwargs["connection_class"] = UnixDomainSocketConnection
2251 else: # implied: url.scheme in ("redis", "rediss"):
2252 if url.hostname:
2253 kwargs["host"] = unquote(url.hostname)
2254 if url.port:
2255 kwargs["port"] = int(url.port)
2257 # If there's a path argument, use it as the db argument if a
2258 # querystring value wasn't specified
2259 if url.path and "db" not in kwargs:
2260 try:
2261 kwargs["db"] = int(unquote(url.path).replace("/", ""))
2262 except (AttributeError, ValueError):
2263 pass
2265 if url.scheme == "rediss":
2266 kwargs["connection_class"] = SSLConnection
2268 return kwargs
2271_CP = TypeVar("_CP", bound="ConnectionPool")
2274class ConnectionPoolInterface(ABC):
2275 @abstractmethod
2276 def get_protocol(self):
2277 pass
2279 @abstractmethod
2280 def reset(self):
2281 pass
2283 @abstractmethod
2284 @deprecated_args(
2285 args_to_warn=["*"],
2286 reason="Use get_connection() without args instead",
2287 version="5.3.0",
2288 )
2289 def get_connection(
2290 self, command_name: Optional[str], *keys, **options
2291 ) -> ConnectionInterface:
2292 pass
2294 @abstractmethod
2295 def get_encoder(self):
2296 pass
2298 @abstractmethod
2299 def release(self, connection: ConnectionInterface):
2300 pass
2302 @abstractmethod
2303 def disconnect(self, inuse_connections: bool = True):
2304 pass
2306 @abstractmethod
2307 def close(self):
2308 pass
2310 @abstractmethod
2311 def set_retry(self, retry: Retry):
2312 pass
2314 @abstractmethod
2315 def re_auth_callback(self, token: TokenInterface):
2316 pass
2318 @abstractmethod
2319 def get_connection_count(self) -> list[tuple[int, dict]]:
2320 """
2321 Returns a connection count (both idle and in use).
2322 """
2323 pass
2326class MaintNotificationsAbstractConnectionPool:
2327 """
2328 Abstract class for handling maintenance notifications logic.
2329 This class is mixed into the ConnectionPool classes.
2331 This class is not intended to be used directly!
2333 All logic related to maintenance notifications and
2334 connection pool handling is encapsulated in this class.
2335 """
2337 def __init__(
2338 self,
2339 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2340 oss_cluster_maint_notifications_handler: Optional[
2341 OSSMaintNotificationsHandler
2342 ] = None,
2343 **kwargs,
2344 ):
2345 # Initialize maintenance notifications
2346 is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3)
2348 if maint_notifications_config is None and is_protocol_supported:
2349 maint_notifications_config = MaintNotificationsConfig()
2351 if maint_notifications_config and maint_notifications_config.enabled:
2352 if not is_protocol_supported:
2353 raise RedisError(
2354 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2355 )
2357 self._event_dispatcher = kwargs.get("event_dispatcher", None)
2358 if self._event_dispatcher is None:
2359 self._event_dispatcher = EventDispatcher()
2361 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2362 self, maint_notifications_config
2363 )
2364 if oss_cluster_maint_notifications_handler:
2365 self._oss_cluster_maint_notifications_handler = (
2366 oss_cluster_maint_notifications_handler
2367 )
2368 self._update_connection_kwargs_for_maint_notifications(
2369 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler
2370 )
2371 self._maint_notifications_pool_handler = None
2372 else:
2373 self._oss_cluster_maint_notifications_handler = None
2374 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2375 self, maint_notifications_config
2376 )
2378 self._update_connection_kwargs_for_maint_notifications(
2379 maint_notifications_pool_handler=self._maint_notifications_pool_handler
2380 )
2381 else:
2382 self._maint_notifications_pool_handler = None
2383 self._oss_cluster_maint_notifications_handler = None
2385 @property
2386 @abstractmethod
2387 def connection_kwargs(self) -> Dict[str, Any]:
2388 pass
2390 @connection_kwargs.setter
2391 @abstractmethod
2392 def connection_kwargs(self, value: Dict[str, Any]):
2393 pass
2395 @abstractmethod
2396 def _get_pool_lock(self) -> threading.RLock:
2397 pass
2399 @abstractmethod
2400 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2401 pass
2403 @abstractmethod
2404 def _get_in_use_connections(
2405 self,
2406 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2407 pass
2409 def maint_notifications_enabled(self):
2410 """
2411 Returns:
2412 True if the maintenance notifications are enabled, False otherwise.
2413 The maintenance notifications config is stored in the pool handler.
2414 If the pool handler is not set, the maintenance notifications are not enabled.
2415 """
2416 if self._oss_cluster_maint_notifications_handler:
2417 maint_notifications_config = (
2418 self._oss_cluster_maint_notifications_handler.config
2419 )
2420 else:
2421 maint_notifications_config = (
2422 self._maint_notifications_pool_handler.config
2423 if self._maint_notifications_pool_handler
2424 else None
2425 )
2427 return maint_notifications_config and maint_notifications_config.enabled
2429 def update_maint_notifications_config(
2430 self,
2431 maint_notifications_config: MaintNotificationsConfig,
2432 oss_cluster_maint_notifications_handler: Optional[
2433 OSSMaintNotificationsHandler
2434 ] = None,
2435 ):
2436 """
2437 Updates the maintenance notifications configuration.
2438 This method should be called only if the pool was created
2439 without enabling the maintenance notifications and
2440 in a later point in time maintenance notifications
2441 are requested to be enabled.
2442 """
2443 if (
2444 self.maint_notifications_enabled()
2445 and not maint_notifications_config.enabled
2446 ):
2447 raise ValueError(
2448 "Cannot disable maintenance notifications after enabling them"
2449 )
2450 if oss_cluster_maint_notifications_handler:
2451 self._oss_cluster_maint_notifications_handler = (
2452 oss_cluster_maint_notifications_handler
2453 )
2454 else:
2455 # first update pool settings
2456 if not self._maint_notifications_pool_handler:
2457 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2458 self, maint_notifications_config
2459 )
2460 else:
2461 self._maint_notifications_pool_handler.config = (
2462 maint_notifications_config
2463 )
2465 # then update connection kwargs and existing connections
2466 self._update_connection_kwargs_for_maint_notifications(
2467 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2468 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2469 )
2470 self._update_maint_notifications_configs_for_connections(
2471 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2472 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2473 )
2475 def _update_connection_kwargs_for_maint_notifications(
2476 self,
2477 maint_notifications_pool_handler: Optional[
2478 MaintNotificationsPoolHandler
2479 ] = None,
2480 oss_cluster_maint_notifications_handler: Optional[
2481 OSSMaintNotificationsHandler
2482 ] = None,
2483 ):
2484 """
2485 Update the connection kwargs for all future connections.
2486 """
2487 if not self.maint_notifications_enabled():
2488 return
2489 if maint_notifications_pool_handler:
2490 self.connection_kwargs.update(
2491 {
2492 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2493 "maint_notifications_config": maint_notifications_pool_handler.config,
2494 }
2495 )
2496 if oss_cluster_maint_notifications_handler:
2497 self.connection_kwargs.update(
2498 {
2499 "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler,
2500 "maint_notifications_config": oss_cluster_maint_notifications_handler.config,
2501 }
2502 )
2504 # Store original connection parameters for maintenance notifications.
2505 if self.connection_kwargs.get("orig_host_address", None) is None:
2506 # If orig_host_address is None it means we haven't
2507 # configured the original values yet
2508 self.connection_kwargs.update(
2509 {
2510 "orig_host_address": self.connection_kwargs.get("host"),
2511 "orig_socket_timeout": self.connection_kwargs.get(
2512 "socket_timeout", None
2513 ),
2514 "orig_socket_connect_timeout": self.connection_kwargs.get(
2515 "socket_connect_timeout", None
2516 ),
2517 }
2518 )
2520 def _update_maint_notifications_configs_for_connections(
2521 self,
2522 maint_notifications_pool_handler: Optional[
2523 MaintNotificationsPoolHandler
2524 ] = None,
2525 oss_cluster_maint_notifications_handler: Optional[
2526 OSSMaintNotificationsHandler
2527 ] = None,
2528 ):
2529 """Update the maintenance notifications config for all connections in the pool."""
2530 with self._get_pool_lock():
2531 for conn in self._get_free_connections():
2532 if oss_cluster_maint_notifications_handler:
2533 # set cluster handler for conn
2534 conn.set_maint_notifications_cluster_handler_for_connection(
2535 oss_cluster_maint_notifications_handler
2536 )
2537 conn.maint_notifications_config = (
2538 oss_cluster_maint_notifications_handler.config
2539 )
2540 elif maint_notifications_pool_handler:
2541 conn.set_maint_notifications_pool_handler_for_connection(
2542 maint_notifications_pool_handler
2543 )
2544 conn.maint_notifications_config = (
2545 maint_notifications_pool_handler.config
2546 )
2547 else:
2548 raise ValueError(
2549 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2550 )
2551 conn.disconnect()
2552 for conn in self._get_in_use_connections():
2553 if oss_cluster_maint_notifications_handler:
2554 conn.maint_notifications_config = (
2555 oss_cluster_maint_notifications_handler.config
2556 )
2557 conn._configure_maintenance_notifications(
2558 oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler
2559 )
2560 elif maint_notifications_pool_handler:
2561 conn.set_maint_notifications_pool_handler_for_connection(
2562 maint_notifications_pool_handler
2563 )
2564 conn.maint_notifications_config = (
2565 maint_notifications_pool_handler.config
2566 )
2567 else:
2568 raise ValueError(
2569 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2570 )
2571 conn.mark_for_reconnect()
2573 def _should_update_connection(
2574 self,
2575 conn: "MaintNotificationsAbstractConnection",
2576 matching_pattern: Literal[
2577 "connected_address", "configured_address", "notification_hash"
2578 ] = "connected_address",
2579 matching_address: Optional[str] = None,
2580 matching_notification_hash: Optional[int] = None,
2581 ) -> bool:
2582 """
2583 Check if the connection should be updated based on the matching criteria.
2584 """
2585 if matching_pattern == "connected_address":
2586 if matching_address and conn.getpeername() != matching_address:
2587 return False
2588 elif matching_pattern == "configured_address":
2589 if matching_address and conn.host != matching_address:
2590 return False
2591 elif matching_pattern == "notification_hash":
2592 if (
2593 matching_notification_hash
2594 and conn.maintenance_notification_hash != matching_notification_hash
2595 ):
2596 return False
2597 return True
2599 def update_connection_settings(
2600 self,
2601 conn: "MaintNotificationsAbstractConnection",
2602 state: Optional["MaintenanceState"] = None,
2603 maintenance_notification_hash: Optional[int] = None,
2604 host_address: Optional[str] = None,
2605 relaxed_timeout: Optional[float] = None,
2606 update_notification_hash: bool = False,
2607 reset_host_address: bool = False,
2608 reset_relaxed_timeout: bool = False,
2609 ):
2610 """
2611 Update the settings for a single connection.
2612 """
2613 if state:
2614 conn.maintenance_state = state
2616 if update_notification_hash:
2617 # update the notification hash only if requested
2618 conn.maintenance_notification_hash = maintenance_notification_hash
2620 if host_address is not None:
2621 conn.set_tmp_settings(tmp_host_address=host_address)
2623 if relaxed_timeout is not None:
2624 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2626 if reset_relaxed_timeout or reset_host_address:
2627 conn.reset_tmp_settings(
2628 reset_host_address=reset_host_address,
2629 reset_relaxed_timeout=reset_relaxed_timeout,
2630 )
2632 conn.update_current_socket_timeout(relaxed_timeout)
2634 def update_connections_settings(
2635 self,
2636 state: Optional["MaintenanceState"] = None,
2637 maintenance_notification_hash: Optional[int] = None,
2638 host_address: Optional[str] = None,
2639 relaxed_timeout: Optional[float] = None,
2640 matching_address: Optional[str] = None,
2641 matching_notification_hash: Optional[int] = None,
2642 matching_pattern: Literal[
2643 "connected_address", "configured_address", "notification_hash"
2644 ] = "connected_address",
2645 update_notification_hash: bool = False,
2646 reset_host_address: bool = False,
2647 reset_relaxed_timeout: bool = False,
2648 include_free_connections: bool = True,
2649 ):
2650 """
2651 Update the settings for all matching connections in the pool.
2653 This method does not create new connections.
2654 This method does not affect the connection kwargs.
2656 :param state: The maintenance state to set for the connection.
2657 :param maintenance_notification_hash: The hash of the maintenance notification
2658 to set for the connection.
2659 :param host_address: The host address to set for the connection.
2660 :param relaxed_timeout: The relaxed timeout to set for the connection.
2661 :param matching_address: The address to match for the connection.
2662 :param matching_notification_hash: The notification hash to match for the connection.
2663 :param matching_pattern: The pattern to match for the connection.
2664 :param update_notification_hash: Whether to update the notification hash for the connection.
2665 :param reset_host_address: Whether to reset the host address to the original address.
2666 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2667 :param include_free_connections: Whether to include free/available connections.
2668 """
2669 with self._get_pool_lock():
2670 for conn in self._get_in_use_connections():
2671 if self._should_update_connection(
2672 conn,
2673 matching_pattern,
2674 matching_address,
2675 matching_notification_hash,
2676 ):
2677 self.update_connection_settings(
2678 conn,
2679 state=state,
2680 maintenance_notification_hash=maintenance_notification_hash,
2681 host_address=host_address,
2682 relaxed_timeout=relaxed_timeout,
2683 update_notification_hash=update_notification_hash,
2684 reset_host_address=reset_host_address,
2685 reset_relaxed_timeout=reset_relaxed_timeout,
2686 )
2688 if include_free_connections:
2689 for conn in self._get_free_connections():
2690 if self._should_update_connection(
2691 conn,
2692 matching_pattern,
2693 matching_address,
2694 matching_notification_hash,
2695 ):
2696 self.update_connection_settings(
2697 conn,
2698 state=state,
2699 maintenance_notification_hash=maintenance_notification_hash,
2700 host_address=host_address,
2701 relaxed_timeout=relaxed_timeout,
2702 update_notification_hash=update_notification_hash,
2703 reset_host_address=reset_host_address,
2704 reset_relaxed_timeout=reset_relaxed_timeout,
2705 )
2707 def update_connection_kwargs(
2708 self,
2709 **kwargs,
2710 ):
2711 """
2712 Update the connection kwargs for all future connections.
2714 This method updates the connection kwargs for all future connections created by the pool.
2715 Existing connections are not affected.
2716 """
2717 self.connection_kwargs.update(kwargs)
2719 def update_active_connections_for_reconnect(
2720 self,
2721 moving_address_src: Optional[str] = None,
2722 ):
2723 """
2724 Mark all active connections for reconnect.
2725 This is used when a cluster node is migrated to a different address.
2727 :param moving_address_src: The address of the node that is being moved.
2728 """
2729 with self._get_pool_lock():
2730 for conn in self._get_in_use_connections():
2731 if self._should_update_connection(
2732 conn, "connected_address", moving_address_src
2733 ):
2734 conn.mark_for_reconnect()
2736 def disconnect_free_connections(
2737 self,
2738 moving_address_src: Optional[str] = None,
2739 ):
2740 """
2741 Disconnect all free/available connections.
2742 This is used when a cluster node is migrated to a different address.
2744 :param moving_address_src: The address of the node that is being moved.
2745 """
2746 with self._get_pool_lock():
2747 for conn in self._get_free_connections():
2748 if self._should_update_connection(
2749 conn, "connected_address", moving_address_src
2750 ):
2751 conn.disconnect()
2754class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2755 """
2756 Create a connection pool. ``If max_connections`` is set, then this
2757 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2758 limit is reached.
2760 By default, TCP connections are created unless ``connection_class``
2761 is specified. Use class:`.UnixDomainSocketConnection` for
2762 unix sockets.
2763 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2765 If ``maint_notifications_config`` is provided, the connection pool will support
2766 maintenance notifications.
2767 Maintenance notifications are supported only with RESP3.
2768 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2769 the maintenance notifications will be enabled by default.
2771 Any additional keyword arguments are passed to the constructor of
2772 ``connection_class``.
2773 """
2775 @classmethod
2776 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2777 """
2778 Return a connection pool configured from the given URL.
2780 For example::
2782 redis://[[username]:[password]]@localhost:6379/0
2783 rediss://[[username]:[password]]@localhost:6379/0
2784 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2786 Three URL schemes are supported:
2788 - `redis://` creates a TCP socket connection. See more at:
2789 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2790 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2791 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2792 - ``unix://``: creates a Unix Domain Socket connection.
2794 The username, password, hostname, path and all querystring values
2795 are passed through urllib.parse.unquote in order to replace any
2796 percent-encoded values with their corresponding characters.
2798 There are several ways to specify a database number. The first value
2799 found will be used:
2801 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2802 2. If using the redis:// or rediss:// schemes, the path argument
2803 of the url, e.g. redis://localhost/0
2804 3. A ``db`` keyword argument to this function.
2806 If none of these options are specified, the default db=0 is used.
2808 All querystring options are cast to their appropriate Python types.
2809 Boolean arguments can be specified with string values "True"/"False"
2810 or "Yes"/"No". Values that cannot be properly cast cause a
2811 ``ValueError`` to be raised. Once parsed, the querystring arguments
2812 and keyword arguments are passed to the ``ConnectionPool``'s
2813 class initializer. In the case of conflicting arguments, querystring
2814 arguments always win.
2815 """
2816 url_options = parse_url(url)
2818 if "connection_class" in kwargs:
2819 url_options["connection_class"] = kwargs["connection_class"]
2821 kwargs.update(url_options)
2822 return cls(**kwargs)
2824 def __init__(
2825 self,
2826 connection_class=Connection,
2827 max_connections: Optional[int] = None,
2828 cache_factory: Optional[CacheFactoryInterface] = None,
2829 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2830 **connection_kwargs,
2831 ):
2832 max_connections = max_connections or 2**31
2833 if not isinstance(max_connections, int) or max_connections < 0:
2834 raise ValueError('"max_connections" must be a positive integer')
2836 self.connection_class = connection_class
2837 self._connection_kwargs = connection_kwargs
2838 self.max_connections = max_connections
2839 self.cache = None
2840 self._cache_factory = cache_factory
2842 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2843 if self._event_dispatcher is None:
2844 self._event_dispatcher = EventDispatcher()
2846 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2847 if not check_protocol_version(self._connection_kwargs.get("protocol"), 3):
2848 raise RedisError("Client caching is only supported with RESP version 3")
2850 cache = self._connection_kwargs.get("cache")
2852 if cache is not None:
2853 if not isinstance(cache, CacheInterface):
2854 raise ValueError("Cache must implement CacheInterface")
2856 self.cache = cache
2857 else:
2858 if self._cache_factory is not None:
2859 self.cache = CacheProxy(self._cache_factory.get_cache())
2860 else:
2861 self.cache = CacheFactory(
2862 self._connection_kwargs.get("cache_config")
2863 ).get_cache()
2865 init_csc_items()
2866 register_csc_items_callback(
2867 callback=lambda: self.cache.size,
2868 pool_name=get_pool_name(self),
2869 )
2871 connection_kwargs.pop("cache", None)
2872 connection_kwargs.pop("cache_config", None)
2874 # a lock to protect the critical section in _checkpid().
2875 # this lock is acquired when the process id changes, such as
2876 # after a fork. during this time, multiple threads in the child
2877 # process could attempt to acquire this lock. the first thread
2878 # to acquire the lock will reset the data structures and lock
2879 # object of this pool. subsequent threads acquiring this lock
2880 # will notice the first thread already did the work and simply
2881 # release the lock.
2883 self._fork_lock = threading.RLock()
2884 self._lock = threading.RLock()
2886 # Generate unique pool ID for observability (matches go-redis behavior)
2887 import secrets
2889 self._pool_id = secrets.token_hex(4)
2891 MaintNotificationsAbstractConnectionPool.__init__(
2892 self,
2893 maint_notifications_config=maint_notifications_config,
2894 **connection_kwargs,
2895 )
2897 self.reset()
2899 # Keys that should be redacted in __repr__ to avoid exposing sensitive information
2900 SENSITIVE_REPR_KEYS = frozenset(
2901 {
2902 "password",
2903 "username",
2904 "ssl_password",
2905 "credential_provider",
2906 }
2907 )
2909 def __repr__(self) -> str:
2910 conn_kwargs = ",".join(
2911 [
2912 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}"
2913 for k, v in self.connection_kwargs.items()
2914 ]
2915 )
2916 return (
2917 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2918 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2919 f"({conn_kwargs})>)>"
2920 )
2922 @property
2923 def connection_kwargs(self) -> Dict[str, Any]:
2924 return self._connection_kwargs
2926 @connection_kwargs.setter
2927 def connection_kwargs(self, value: Dict[str, Any]):
2928 self._connection_kwargs = value
2930 def get_protocol(self):
2931 """
2932 Returns:
2933 The RESP protocol version, or ``None`` if the protocol is not specified,
2934 in which case the server default will be used.
2935 """
2936 return self.connection_kwargs.get("protocol", None)
2938 def reset(self) -> None:
2939 # Record metrics for connections being removed before clearing
2940 # (only if attributes exist - they won't during __init__)
2941 if hasattr(self, "_available_connections") and hasattr(
2942 self, "_in_use_connections"
2943 ):
2944 with self._lock:
2945 idle_count = len(self._available_connections)
2946 in_use_count = len(self._in_use_connections)
2947 if idle_count > 0 or in_use_count > 0:
2948 pool_name = get_pool_name(self)
2949 if idle_count > 0:
2950 record_connection_count(
2951 pool_name=pool_name,
2952 connection_state=ConnectionState.IDLE,
2953 counter=-idle_count,
2954 )
2955 if in_use_count > 0:
2956 record_connection_count(
2957 pool_name=pool_name,
2958 connection_state=ConnectionState.USED,
2959 counter=-in_use_count,
2960 )
2962 self._created_connections = 0
2963 self._available_connections = []
2964 self._in_use_connections = set()
2966 # this must be the last operation in this method. while reset() is
2967 # called when holding _fork_lock, other threads in this process
2968 # can call _checkpid() which compares self.pid and os.getpid() without
2969 # holding any lock (for performance reasons). keeping this assignment
2970 # as the last operation ensures that those other threads will also
2971 # notice a pid difference and block waiting for the first thread to
2972 # release _fork_lock. when each of these threads eventually acquire
2973 # _fork_lock, they will notice that another thread already called
2974 # reset() and they will immediately release _fork_lock and continue on.
2975 self.pid = os.getpid()
2977 def __del__(self) -> None:
2978 """Clean up connection pool and record metrics when garbage collected."""
2979 try:
2980 if not hasattr(self, "_available_connections") or not hasattr(
2981 self, "_in_use_connections"
2982 ):
2983 return
2984 # Record metrics for all connections being removed
2985 idle_count = len(self._available_connections)
2986 in_use_count = len(self._in_use_connections)
2987 if idle_count > 0 or in_use_count > 0:
2988 pool_name = get_pool_name(self)
2989 if idle_count > 0:
2990 record_connection_count(
2991 pool_name=pool_name,
2992 connection_state=ConnectionState.IDLE,
2993 counter=-idle_count,
2994 )
2995 if in_use_count > 0:
2996 record_connection_count(
2997 pool_name=pool_name,
2998 connection_state=ConnectionState.USED,
2999 counter=-in_use_count,
3000 )
3001 except Exception:
3002 pass
3004 def _checkpid(self) -> None:
3005 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
3006 # systems. this is called by all ConnectionPool methods that
3007 # manipulate the pool's state such as get_connection() and release().
3008 #
3009 # _checkpid() determines whether the process has forked by comparing
3010 # the current process id to the process id saved on the ConnectionPool
3011 # instance. if these values are the same, _checkpid() simply returns.
3012 #
3013 # when the process ids differ, _checkpid() assumes that the process
3014 # has forked and that we're now running in the child process. the child
3015 # process cannot use the parent's file descriptors (e.g., sockets).
3016 # therefore, when _checkpid() sees the process id change, it calls
3017 # reset() in order to reinitialize the child's ConnectionPool. this
3018 # will cause the child to make all new connection objects.
3019 #
3020 # _checkpid() is protected by self._fork_lock to ensure that multiple
3021 # threads in the child process do not call reset() multiple times.
3022 #
3023 # there is an extremely small chance this could fail in the following
3024 # scenario:
3025 # 1. process A calls _checkpid() for the first time and acquires
3026 # self._fork_lock.
3027 # 2. while holding self._fork_lock, process A forks (the fork()
3028 # could happen in a different thread owned by process A)
3029 # 3. process B (the forked child process) inherits the
3030 # ConnectionPool's state from the parent. that state includes
3031 # a locked _fork_lock. process B will not be notified when
3032 # process A releases the _fork_lock and will thus never be
3033 # able to acquire the _fork_lock.
3034 #
3035 # to mitigate this possible deadlock, _checkpid() will only wait 5
3036 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
3037 # that time it is assumed that the child is deadlocked and a
3038 # redis.ChildDeadlockedError error is raised.
3039 if self.pid != os.getpid():
3040 acquired = self._fork_lock.acquire(timeout=5)
3041 if not acquired:
3042 raise ChildDeadlockedError
3043 # reset() the instance for the new process if another thread
3044 # hasn't already done so
3045 try:
3046 if self.pid != os.getpid():
3047 self.reset()
3048 finally:
3049 self._fork_lock.release()
3051 @deprecated_args(
3052 args_to_warn=["*"],
3053 reason="Use get_connection() without args instead",
3054 version="5.3.0",
3055 )
3056 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
3057 "Get a connection from the pool"
3059 # Start timing for observability
3060 self._checkpid()
3061 is_created = False
3063 with self._lock:
3064 try:
3065 connection = self._available_connections.pop()
3066 except IndexError:
3067 # Start timing for observability
3068 start_time_created = time.monotonic()
3070 connection = self.make_connection()
3071 is_created = True
3072 self._in_use_connections.add(connection)
3074 # Record state transition: IDLE -> USED
3075 # (make_connection already recorded IDLE +1 for new connections)
3076 # This ensures counters stay balanced if connect() fails and release() is called
3077 pool_name = get_pool_name(self)
3078 record_connection_count(
3079 pool_name=pool_name,
3080 connection_state=ConnectionState.IDLE,
3081 counter=-1,
3082 )
3083 record_connection_count(
3084 pool_name=pool_name,
3085 connection_state=ConnectionState.USED,
3086 counter=1,
3087 )
3089 try:
3090 # ensure this connection is connected to Redis
3091 connection.connect()
3092 # connections that the pool provides should be ready to send
3093 # a command. if not, the connection was either returned to the
3094 # pool before all data has been read or the socket has been
3095 # closed. either way, reconnect and verify everything is good.
3096 try:
3097 if (
3098 connection.can_read()
3099 and self.cache is None
3100 and not self.maint_notifications_enabled()
3101 ):
3102 raise ConnectionError("Connection has data")
3103 except (ConnectionError, TimeoutError, OSError):
3104 connection.disconnect()
3105 connection.connect()
3106 if connection.can_read():
3107 raise ConnectionError("Connection not ready")
3108 except BaseException:
3109 # release the connection back to the pool so that we don't
3110 # leak it
3111 self.release(connection)
3112 raise
3114 if is_created:
3115 record_connection_create_time(
3116 connection_pool=self,
3117 duration_seconds=time.monotonic() - start_time_created,
3118 )
3120 return connection
3122 def get_encoder(self) -> Encoder:
3123 "Return an encoder based on encoding settings"
3124 kwargs = self.connection_kwargs
3125 return Encoder(
3126 encoding=kwargs.get("encoding", "utf-8"),
3127 encoding_errors=kwargs.get("encoding_errors", "strict"),
3128 decode_responses=kwargs.get("decode_responses", False),
3129 )
3131 def make_connection(self) -> "ConnectionInterface":
3132 "Create a new connection"
3133 if self._created_connections >= self.max_connections:
3134 raise MaxConnectionsError("Too many connections")
3135 self._created_connections += 1
3137 kwargs = dict(self.connection_kwargs)
3139 # Create the connection first, then record metrics only on success
3140 if self.cache is not None:
3141 connection = CacheProxyConnection(
3142 self.connection_class(**kwargs), self.cache, self._lock
3143 )
3144 else:
3145 connection = self.connection_class(**kwargs)
3147 # Record new connection created (starts as IDLE) - only after successful construction
3148 record_connection_count(
3149 pool_name=get_pool_name(self),
3150 connection_state=ConnectionState.IDLE,
3151 counter=1,
3152 )
3154 return connection
3156 def release(self, connection: "Connection") -> None:
3157 "Releases the connection back to the pool"
3158 self._checkpid()
3159 with self._lock:
3160 try:
3161 self._in_use_connections.remove(connection)
3162 except KeyError:
3163 # Gracefully fail when a connection is returned to this pool
3164 # that the pool doesn't actually own
3165 return
3167 if self.owns_connection(connection):
3168 if connection.should_reconnect():
3169 connection.disconnect()
3170 self._available_connections.append(connection)
3171 self._event_dispatcher.dispatch(
3172 AfterConnectionReleasedEvent(connection)
3173 )
3175 # Record state transition: USED -> IDLE
3176 pool_name = get_pool_name(self)
3177 record_connection_count(
3178 pool_name=pool_name,
3179 connection_state=ConnectionState.USED,
3180 counter=-1,
3181 )
3182 record_connection_count(
3183 pool_name=pool_name,
3184 connection_state=ConnectionState.IDLE,
3185 counter=1,
3186 )
3187 else:
3188 # Pool doesn't own this connection, do not add it back
3189 # to the pool.
3190 # The created connections count should not be changed,
3191 # because the connection was not created by the pool.
3192 # Still need to decrement USED since it was counted in get_connection()
3193 connection.disconnect()
3194 record_connection_count(
3195 pool_name="unknown_pool",
3196 connection_state=ConnectionState.USED,
3197 counter=-1,
3198 )
3199 return
3201 def owns_connection(self, connection: "Connection") -> int:
3202 return connection.pid == self.pid
3204 def disconnect(self, inuse_connections: bool = True) -> None:
3205 """
3206 Disconnects connections in the pool
3208 If ``inuse_connections`` is True, disconnect connections that are
3209 currently in use, potentially by other threads. Otherwise only disconnect
3210 connections that are idle in the pool.
3211 """
3212 self._checkpid()
3213 with self._lock:
3214 if inuse_connections:
3215 connections = chain(
3216 self._available_connections, self._in_use_connections
3217 )
3218 else:
3219 connections = self._available_connections
3221 for connection in connections:
3222 connection.disconnect()
3224 def close(self) -> None:
3225 """Close the pool, disconnecting all connections"""
3226 self.disconnect()
3228 def set_retry(self, retry: Retry) -> None:
3229 self.connection_kwargs.update({"retry": retry})
3230 for conn in self._available_connections:
3231 conn.retry = retry
3232 for conn in self._in_use_connections:
3233 conn.retry = retry
3235 def re_auth_callback(self, token: TokenInterface):
3236 with self._lock:
3237 for conn in self._available_connections:
3238 conn.retry.call_with_retry(
3239 lambda: conn.send_command(
3240 "AUTH", token.try_get("oid"), token.get_value()
3241 ),
3242 lambda error: self._mock(error),
3243 )
3244 conn.retry.call_with_retry(
3245 lambda: conn.read_response(), lambda error: self._mock(error)
3246 )
3247 for conn in self._in_use_connections:
3248 conn.set_re_auth_token(token)
3250 def _get_pool_lock(self):
3251 return self._lock
3253 def _get_free_connections(self):
3254 with self._lock:
3255 return list(self._available_connections)
3257 def _get_in_use_connections(self):
3258 with self._lock:
3259 return set(self._in_use_connections)
3261 def _mock(self, error: RedisError):
3262 """
3263 Dummy functions, needs to be passed as error callback to retry object.
3264 :param error:
3265 :return:
3266 """
3267 pass
3269 def get_connection_count(self) -> List[tuple[int, dict]]:
3270 from redis.observability.attributes import get_pool_name
3272 attributes = AttributeBuilder.build_base_attributes()
3273 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
3274 free_connections_attributes = attributes.copy()
3275 in_use_connections_attributes = attributes.copy()
3277 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3278 ConnectionState.IDLE.value
3279 )
3280 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3281 ConnectionState.USED.value
3282 )
3284 return [
3285 (len(self._get_free_connections()), free_connections_attributes),
3286 (len(self._get_in_use_connections()), in_use_connections_attributes),
3287 ]
3290class BlockingConnectionPool(ConnectionPool):
3291 """
3292 Thread-safe blocking connection pool::
3294 >>> from redis.client import Redis
3295 >>> client = Redis(connection_pool=BlockingConnectionPool())
3297 It performs the same function as the default
3298 :py:class:`~redis.ConnectionPool` implementation, in that,
3299 it maintains a pool of reusable connections that can be shared by
3300 multiple redis clients (safely across threads if required).
3302 The difference is that, in the event that a client tries to get a
3303 connection from the pool when all of connections are in use, rather than
3304 raising a :py:class:`~redis.ConnectionError` (as the default
3305 :py:class:`~redis.ConnectionPool` implementation does), it
3306 makes the client wait ("blocks") for a specified number of seconds until
3307 a connection becomes available.
3309 Use ``max_connections`` to increase / decrease the pool size::
3311 >>> pool = BlockingConnectionPool(max_connections=10)
3313 Use ``timeout`` to tell it either how many seconds to wait for a connection
3314 to become available, or to block forever:
3316 >>> # Block forever.
3317 >>> pool = BlockingConnectionPool(timeout=None)
3319 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
3320 >>> # not available.
3321 >>> pool = BlockingConnectionPool(timeout=5)
3322 """
3324 def __init__(
3325 self,
3326 max_connections=50,
3327 timeout=20,
3328 connection_class=Connection,
3329 queue_class=LifoQueue,
3330 **connection_kwargs,
3331 ):
3332 self.queue_class = queue_class
3333 self.timeout = timeout
3334 self._in_maintenance = False
3335 self._locked = False
3336 super().__init__(
3337 connection_class=connection_class,
3338 max_connections=max_connections,
3339 **connection_kwargs,
3340 )
3342 def reset(self):
3343 # Create and fill up a thread safe queue with ``None`` values.
3344 try:
3345 if self._in_maintenance:
3346 self._lock.acquire()
3347 self._locked = True
3349 # Record metrics for connections being removed before clearing
3350 # Note: Access pool.queue directly to avoid deadlock since we may
3351 # already hold self._lock (which is non-reentrant)
3352 if (
3353 hasattr(self, "_connections")
3354 and self._connections
3355 and hasattr(self, "pool")
3356 ):
3357 with self._lock:
3358 connections_in_queue = {conn for conn in self.pool.queue if conn}
3359 idle_count = len(connections_in_queue)
3360 in_use_count = len(self._connections) - idle_count
3361 if idle_count > 0 or in_use_count > 0:
3362 pool_name = get_pool_name(self)
3363 if idle_count > 0:
3364 record_connection_count(
3365 pool_name=pool_name,
3366 connection_state=ConnectionState.IDLE,
3367 counter=-idle_count,
3368 )
3369 if in_use_count > 0:
3370 record_connection_count(
3371 pool_name=pool_name,
3372 connection_state=ConnectionState.USED,
3373 counter=-in_use_count,
3374 )
3376 self.pool = self.queue_class(self.max_connections)
3377 while True:
3378 try:
3379 self.pool.put_nowait(None)
3380 except Full:
3381 break
3383 # Keep a list of actual connection instances so that we can
3384 # disconnect them later.
3385 self._connections = []
3386 finally:
3387 if self._locked:
3388 try:
3389 self._lock.release()
3390 except Exception:
3391 pass
3392 self._locked = False
3394 # this must be the last operation in this method. while reset() is
3395 # called when holding _fork_lock, other threads in this process
3396 # can call _checkpid() which compares self.pid and os.getpid() without
3397 # holding any lock (for performance reasons). keeping this assignment
3398 # as the last operation ensures that those other threads will also
3399 # notice a pid difference and block waiting for the first thread to
3400 # release _fork_lock. when each of these threads eventually acquire
3401 # _fork_lock, they will notice that another thread already called
3402 # reset() and they will immediately release _fork_lock and continue on.
3403 self.pid = os.getpid()
3405 def __del__(self) -> None:
3406 """Clean up connection pool and record metrics when garbage collected."""
3407 try:
3408 # Note: Access pool.queue directly to avoid potential deadlock
3409 # if GC runs while the lock is held by the same thread
3410 if (
3411 hasattr(self, "_connections")
3412 and self._connections
3413 and hasattr(self, "pool")
3414 ):
3415 connections_in_queue = {conn for conn in self.pool.queue if conn}
3416 idle_count = len(connections_in_queue)
3417 in_use_count = len(self._connections) - idle_count
3418 if idle_count > 0 or in_use_count > 0:
3419 pool_name = get_pool_name(self)
3420 if idle_count > 0:
3421 record_connection_count(
3422 pool_name=pool_name,
3423 connection_state=ConnectionState.IDLE,
3424 counter=-idle_count,
3425 )
3426 if in_use_count > 0:
3427 record_connection_count(
3428 pool_name=pool_name,
3429 connection_state=ConnectionState.USED,
3430 counter=-in_use_count,
3431 )
3432 except Exception:
3433 pass
3435 def make_connection(self):
3436 "Make a fresh connection."
3437 try:
3438 if self._in_maintenance:
3439 self._lock.acquire()
3440 self._locked = True
3442 if self.cache is not None:
3443 connection = CacheProxyConnection(
3444 self.connection_class(**self.connection_kwargs),
3445 self.cache,
3446 self._lock,
3447 )
3448 else:
3449 connection = self.connection_class(**self.connection_kwargs)
3450 self._connections.append(connection)
3452 # Record new connection created (starts as IDLE)
3453 record_connection_count(
3454 pool_name=get_pool_name(self),
3455 connection_state=ConnectionState.IDLE,
3456 counter=1,
3457 )
3459 return connection
3460 finally:
3461 if self._locked:
3462 try:
3463 self._lock.release()
3464 except Exception:
3465 pass
3466 self._locked = False
3468 @deprecated_args(
3469 args_to_warn=["*"],
3470 reason="Use get_connection() without args instead",
3471 version="5.3.0",
3472 )
3473 def get_connection(self, command_name=None, *keys, **options):
3474 """
3475 Get a connection, blocking for ``self.timeout`` until a connection
3476 is available from the pool.
3478 If the connection returned is ``None`` then creates a new connection.
3479 Because we use a last-in first-out queue, the existing connections
3480 (having been returned to the pool after the initial ``None`` values
3481 were added) will be returned before ``None`` values. This means we only
3482 create new connections when we need to, i.e.: the actual number of
3483 connections will only increase in response to demand.
3484 """
3485 start_time_acquired = time.monotonic()
3486 # Make sure we haven't changed process.
3487 self._checkpid()
3488 is_created = False
3490 # Try and get a connection from the pool. If one isn't available within
3491 # self.timeout then raise a ``ConnectionError``.
3492 connection = None
3493 try:
3494 if self._in_maintenance:
3495 self._lock.acquire()
3496 self._locked = True
3497 try:
3498 connection = self.pool.get(block=True, timeout=self.timeout)
3499 except Empty:
3500 # Note that this is not caught by the redis client and will be
3501 # raised unless handled by application code. If you want never to
3502 raise ConnectionError("No connection available.")
3504 # If the ``connection`` is actually ``None`` then that's a cue to make
3505 # a new connection to add to the pool.
3506 if connection is None:
3507 # Start timing for observability
3508 start_time_created = time.monotonic()
3509 connection = self.make_connection()
3510 is_created = True
3511 finally:
3512 if self._locked:
3513 try:
3514 self._lock.release()
3515 except Exception:
3516 pass
3517 self._locked = False
3519 # Record state transition: IDLE -> USED
3520 # (make_connection already recorded IDLE +1 for new connections)
3521 # This ensures counters stay balanced if connect() fails and release() is called
3522 pool_name = get_pool_name(self)
3523 record_connection_count(
3524 pool_name=pool_name,
3525 connection_state=ConnectionState.IDLE,
3526 counter=-1,
3527 )
3528 record_connection_count(
3529 pool_name=pool_name,
3530 connection_state=ConnectionState.USED,
3531 counter=1,
3532 )
3534 try:
3535 # ensure this connection is connected to Redis
3536 connection.connect()
3537 # connections that the pool provides should be ready to send
3538 # a command. if not, the connection was either returned to the
3539 # pool before all data has been read or the socket has been
3540 # closed. either way, reconnect and verify everything is good.
3541 try:
3542 if connection.can_read():
3543 raise ConnectionError("Connection has data")
3544 except (ConnectionError, TimeoutError, OSError):
3545 connection.disconnect()
3546 connection.connect()
3547 if connection.can_read():
3548 raise ConnectionError("Connection not ready")
3549 except BaseException:
3550 # release the connection back to the pool so that we don't leak it
3551 self.release(connection)
3552 raise
3554 if is_created:
3555 record_connection_create_time(
3556 connection_pool=self,
3557 duration_seconds=time.monotonic() - start_time_created,
3558 )
3560 record_connection_wait_time(
3561 pool_name=pool_name,
3562 duration_seconds=time.monotonic() - start_time_acquired,
3563 )
3565 return connection
3567 def release(self, connection):
3568 "Releases the connection back to the pool."
3569 # Make sure we haven't changed process.
3570 self._checkpid()
3572 try:
3573 if self._in_maintenance:
3574 self._lock.acquire()
3575 self._locked = True
3576 if not self.owns_connection(connection):
3577 # pool doesn't own this connection. do not add it back
3578 # to the pool. instead add a None value which is a placeholder
3579 # that will cause the pool to recreate the connection if
3580 # its needed.
3581 connection.disconnect()
3582 self.pool.put_nowait(None)
3583 # Still need to decrement USED since it was counted in get_connection()
3584 record_connection_count(
3585 pool_name="unknown_pool",
3586 connection_state=ConnectionState.USED,
3587 counter=-1,
3588 )
3589 return
3590 if connection.should_reconnect():
3591 connection.disconnect()
3592 # Put the connection back into the pool.
3593 pool_name = get_pool_name(self)
3594 try:
3595 self.pool.put_nowait(connection)
3597 # Record state transition: USED -> IDLE
3598 record_connection_count(
3599 pool_name=pool_name,
3600 connection_state=ConnectionState.USED,
3601 counter=-1,
3602 )
3603 record_connection_count(
3604 pool_name=pool_name,
3605 connection_state=ConnectionState.IDLE,
3606 counter=1,
3607 )
3608 except Full:
3609 pass
3610 finally:
3611 if self._locked:
3612 try:
3613 self._lock.release()
3614 except Exception:
3615 pass
3616 self._locked = False
3618 def disconnect(self, inuse_connections: bool = True):
3619 """
3620 Disconnects either all connections in the pool or just the free connections.
3621 """
3622 self._checkpid()
3623 try:
3624 if self._in_maintenance:
3625 self._lock.acquire()
3626 self._locked = True
3628 if inuse_connections:
3629 connections = self._connections
3630 else:
3631 connections = self._get_free_connections()
3633 for connection in connections:
3634 connection.disconnect()
3635 finally:
3636 if self._locked:
3637 try:
3638 self._lock.release()
3639 except Exception:
3640 pass
3641 self._locked = False
3643 def _get_free_connections(self):
3644 with self._lock:
3645 return {conn for conn in self.pool.queue if conn}
3647 def _get_in_use_connections(self):
3648 with self._lock:
3649 # free connections
3650 connections_in_queue = {conn for conn in self.pool.queue if conn}
3651 # in self._connections we keep all created connections
3652 # so the ones that are not in the queue are the in use ones
3653 return {
3654 conn for conn in self._connections if conn not in connections_in_queue
3655 }
3657 def set_in_maintenance(self, in_maintenance: bool):
3658 """
3659 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
3661 This is used to prevent new connections from being created while we are in maintenance mode.
3662 The pool will be in maintenance mode only when we are processing a MOVING notification.
3663 """
3664 self._in_maintenance = in_maintenance