Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32 CacheProxy,
33)
35from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
36from .auth.token import TokenInterface
37from .backoff import NoBackoff
38from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
39from .driver_info import DriverInfo, resolve_driver_info
40from .event import AfterConnectionReleasedEvent, EventDispatcher
41from .exceptions import (
42 AuthenticationError,
43 AuthenticationWrongNumberOfArgsError,
44 ChildDeadlockedError,
45 ConnectionError,
46 DataError,
47 MaxConnectionsError,
48 RedisError,
49 ResponseError,
50 TimeoutError,
51)
52from .maint_notifications import (
53 MaintenanceState,
54 MaintNotificationsConfig,
55 MaintNotificationsConnectionHandler,
56 MaintNotificationsPoolHandler,
57 OSSMaintNotificationsHandler,
58)
59from .observability.attributes import (
60 DB_CLIENT_CONNECTION_POOL_NAME,
61 DB_CLIENT_CONNECTION_STATE,
62 AttributeBuilder,
63 ConnectionState,
64 CSCReason,
65 CSCResult,
66 get_pool_name,
67)
68from .observability.metrics import CloseReason
69from .observability.recorder import (
70 init_csc_items,
71 record_connection_closed,
72 record_connection_create_time,
73 record_connection_wait_time,
74 record_csc_eviction,
75 record_csc_network_saved,
76 record_csc_request,
77 record_error_count,
78 register_csc_items_callback,
79)
80from .retry import Retry
81from .utils import (
82 CRYPTOGRAPHY_AVAILABLE,
83 HIREDIS_AVAILABLE,
84 SSL_AVAILABLE,
85 check_protocol_version,
86 compare_versions,
87 deprecated_args,
88 ensure_string,
89 format_error_message,
90 str_if_bytes,
91)
93if SSL_AVAILABLE:
94 import ssl
95 from ssl import VerifyFlags
96else:
97 ssl = None
98 VerifyFlags = None
100if HIREDIS_AVAILABLE:
101 import hiredis
103SYM_STAR = b"*"
104SYM_DOLLAR = b"$"
105SYM_CRLF = b"\r\n"
106SYM_EMPTY = b""
108DEFAULT_RESP_VERSION = 2
110SENTINEL = object()
112DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
113if HIREDIS_AVAILABLE:
114 DefaultParser = _HiredisParser
115else:
116 DefaultParser = _RESP2Parser
119class HiredisRespSerializer:
120 def pack(self, *args: List):
121 """Pack a series of arguments into the Redis protocol"""
122 output = []
124 if isinstance(args[0], str):
125 args = tuple(args[0].encode().split()) + args[1:]
126 elif b" " in args[0]:
127 args = tuple(args[0].split()) + args[1:]
128 try:
129 output.append(hiredis.pack_command(args))
130 except TypeError:
131 _, value, traceback = sys.exc_info()
132 raise DataError(value).with_traceback(traceback)
134 return output
137class PythonRespSerializer:
138 def __init__(self, buffer_cutoff, encode) -> None:
139 self._buffer_cutoff = buffer_cutoff
140 self.encode = encode
142 def pack(self, *args):
143 """Pack a series of arguments into the Redis protocol"""
144 output = []
145 # the client might have included 1 or more literal arguments in
146 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
147 # arguments to be sent separately, so split the first argument
148 # manually. These arguments should be bytestrings so that they are
149 # not encoded.
150 if isinstance(args[0], str):
151 args = tuple(args[0].encode().split()) + args[1:]
152 elif b" " in args[0]:
153 args = tuple(args[0].split()) + args[1:]
155 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
157 buffer_cutoff = self._buffer_cutoff
158 for arg in map(self.encode, args):
159 # to avoid large string mallocs, chunk the command into the
160 # output list if we're sending large values or memoryviews
161 arg_length = len(arg)
162 if (
163 len(buff) > buffer_cutoff
164 or arg_length > buffer_cutoff
165 or isinstance(arg, memoryview)
166 ):
167 buff = SYM_EMPTY.join(
168 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
169 )
170 output.append(buff)
171 output.append(arg)
172 buff = SYM_CRLF
173 else:
174 buff = SYM_EMPTY.join(
175 (
176 buff,
177 SYM_DOLLAR,
178 str(arg_length).encode(),
179 SYM_CRLF,
180 arg,
181 SYM_CRLF,
182 )
183 )
184 output.append(buff)
185 return output
188class ConnectionInterface:
189 @abstractmethod
190 def repr_pieces(self):
191 pass
193 @abstractmethod
194 def register_connect_callback(self, callback):
195 pass
197 @abstractmethod
198 def deregister_connect_callback(self, callback):
199 pass
201 @abstractmethod
202 def set_parser(self, parser_class):
203 pass
205 @abstractmethod
206 def get_protocol(self):
207 pass
209 @abstractmethod
210 def connect(self):
211 pass
213 @abstractmethod
214 def on_connect(self):
215 pass
217 @abstractmethod
218 def disconnect(self, *args, **kwargs):
219 pass
221 @abstractmethod
222 def check_health(self):
223 pass
225 @abstractmethod
226 def send_packed_command(self, command, check_health=True):
227 pass
229 @abstractmethod
230 def send_command(self, *args, **kwargs):
231 pass
233 @abstractmethod
234 def can_read(self, timeout=0):
235 pass
237 @abstractmethod
238 def read_response(
239 self,
240 disable_decoding=False,
241 *,
242 disconnect_on_error=True,
243 push_request=False,
244 ):
245 pass
247 @abstractmethod
248 def pack_command(self, *args):
249 pass
251 @abstractmethod
252 def pack_commands(self, commands):
253 pass
255 @property
256 @abstractmethod
257 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
258 pass
260 @abstractmethod
261 def set_re_auth_token(self, token: TokenInterface):
262 pass
264 @abstractmethod
265 def re_auth(self):
266 pass
268 @abstractmethod
269 def mark_for_reconnect(self):
270 """
271 Mark the connection to be reconnected on the next command.
272 This is useful when a connection is moved to a different node.
273 """
274 pass
276 @abstractmethod
277 def should_reconnect(self):
278 """
279 Returns True if the connection should be reconnected.
280 """
281 pass
283 @abstractmethod
284 def reset_should_reconnect(self):
285 """
286 Reset the internal flag to False.
287 """
288 pass
290 @abstractmethod
291 def extract_connection_details(self) -> str:
292 pass
295class MaintNotificationsAbstractConnection:
296 """
297 Abstract class for handling maintenance notifications logic.
298 This class is expected to be used as base class together with ConnectionInterface.
300 This class is intended to be used with multiple inheritance!
302 All logic related to maintenance notifications is encapsulated in this class.
303 """
305 def __init__(
306 self,
307 maint_notifications_config: Optional[MaintNotificationsConfig],
308 maint_notifications_pool_handler: Optional[
309 MaintNotificationsPoolHandler
310 ] = None,
311 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
312 maintenance_notification_hash: Optional[int] = None,
313 orig_host_address: Optional[str] = None,
314 orig_socket_timeout: Optional[float] = None,
315 orig_socket_connect_timeout: Optional[float] = None,
316 oss_cluster_maint_notifications_handler: Optional[
317 OSSMaintNotificationsHandler
318 ] = None,
319 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
320 event_dispatcher: Optional[EventDispatcher] = None,
321 ):
322 """
323 Initialize the maintenance notifications for the connection.
325 Args:
326 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
327 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
328 maintenance_state (MaintenanceState): The current maintenance state of the connection.
329 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
330 orig_host_address (Optional[str]): The original host address of the connection.
331 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
332 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
333 oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications.
334 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
335 If not provided, the parser from the connection is used.
336 This is useful when the parser is created after this object.
337 """
338 self.maint_notifications_config = maint_notifications_config
339 self.maintenance_state = maintenance_state
340 self.maintenance_notification_hash = maintenance_notification_hash
342 if event_dispatcher is not None:
343 self.event_dispatcher = event_dispatcher
344 else:
345 self.event_dispatcher = EventDispatcher()
347 self._configure_maintenance_notifications(
348 maint_notifications_pool_handler,
349 orig_host_address,
350 orig_socket_timeout,
351 orig_socket_connect_timeout,
352 oss_cluster_maint_notifications_handler,
353 parser,
354 )
355 self._processed_start_maint_notifications = set()
356 self._skipped_end_maint_notifications = set()
358 @abstractmethod
359 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
360 pass
362 @abstractmethod
363 def _get_socket(self) -> Optional[socket.socket]:
364 pass
366 @abstractmethod
367 def get_protocol(self) -> Union[int, str]:
368 """
369 Returns:
370 The RESP protocol version, or ``None`` if the protocol is not specified,
371 in which case the server default will be used.
372 """
373 pass
375 @property
376 @abstractmethod
377 def host(self) -> str:
378 pass
380 @host.setter
381 @abstractmethod
382 def host(self, value: str):
383 pass
385 @property
386 @abstractmethod
387 def socket_timeout(self) -> Optional[Union[float, int]]:
388 pass
390 @socket_timeout.setter
391 @abstractmethod
392 def socket_timeout(self, value: Optional[Union[float, int]]):
393 pass
395 @property
396 @abstractmethod
397 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
398 pass
400 @socket_connect_timeout.setter
401 @abstractmethod
402 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
403 pass
405 @abstractmethod
406 def send_command(self, *args, **kwargs):
407 pass
409 @abstractmethod
410 def read_response(
411 self,
412 disable_decoding=False,
413 *,
414 disconnect_on_error=True,
415 push_request=False,
416 ):
417 pass
419 @abstractmethod
420 def disconnect(self, *args, **kwargs):
421 pass
423 def _configure_maintenance_notifications(
424 self,
425 maint_notifications_pool_handler: Optional[
426 MaintNotificationsPoolHandler
427 ] = None,
428 orig_host_address=None,
429 orig_socket_timeout=None,
430 orig_socket_connect_timeout=None,
431 oss_cluster_maint_notifications_handler: Optional[
432 OSSMaintNotificationsHandler
433 ] = None,
434 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
435 ):
436 """
437 Enable maintenance notifications by setting up
438 handlers and storing original connection parameters.
440 Should be used ONLY with parsers that support push notifications.
441 """
442 if (
443 not self.maint_notifications_config
444 or not self.maint_notifications_config.enabled
445 ):
446 self._maint_notifications_pool_handler = None
447 self._maint_notifications_connection_handler = None
448 self._oss_cluster_maint_notifications_handler = None
449 return
451 if not parser:
452 raise RedisError(
453 "To configure maintenance notifications, a parser must be provided!"
454 )
456 if not isinstance(parser, _HiredisParser) and not isinstance(
457 parser, _RESP3Parser
458 ):
459 raise RedisError(
460 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
461 )
463 if maint_notifications_pool_handler:
464 # Extract a reference to a new pool handler that copies all properties
465 # of the original one and has a different connection reference
466 # This is needed because when we attach the handler to the parser
467 # we need to make sure that the handler has a reference to the
468 # connection that the parser is attached to.
469 self._maint_notifications_pool_handler = (
470 maint_notifications_pool_handler.get_handler_for_connection()
471 )
472 self._maint_notifications_pool_handler.set_connection(self)
473 else:
474 self._maint_notifications_pool_handler = None
476 self._maint_notifications_connection_handler = (
477 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
478 )
480 if oss_cluster_maint_notifications_handler:
481 self._oss_cluster_maint_notifications_handler = (
482 oss_cluster_maint_notifications_handler
483 )
484 else:
485 self._oss_cluster_maint_notifications_handler = None
487 # Set up OSS cluster handler to parser if available
488 if self._oss_cluster_maint_notifications_handler:
489 parser.set_oss_cluster_maint_push_handler(
490 self._oss_cluster_maint_notifications_handler.handle_notification
491 )
493 # Set up pool handler to parser if available
494 if self._maint_notifications_pool_handler:
495 parser.set_node_moving_push_handler(
496 self._maint_notifications_pool_handler.handle_notification
497 )
499 # Set up connection handler
500 parser.set_maintenance_push_handler(
501 self._maint_notifications_connection_handler.handle_notification
502 )
504 # Store original connection parameters
505 self.orig_host_address = orig_host_address if orig_host_address else self.host
506 self.orig_socket_timeout = (
507 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
508 )
509 self.orig_socket_connect_timeout = (
510 orig_socket_connect_timeout
511 if orig_socket_connect_timeout
512 else self.socket_connect_timeout
513 )
515 def set_maint_notifications_pool_handler_for_connection(
516 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
517 ):
518 # Deep copy the pool handler to avoid sharing the same pool handler
519 # between multiple connections, because otherwise each connection will override
520 # the connection reference and the pool handler will only hold a reference
521 # to the last connection that was set.
522 maint_notifications_pool_handler_copy = (
523 maint_notifications_pool_handler.get_handler_for_connection()
524 )
526 maint_notifications_pool_handler_copy.set_connection(self)
527 self._get_parser().set_node_moving_push_handler(
528 maint_notifications_pool_handler_copy.handle_notification
529 )
531 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
533 # Update maintenance notification connection handler if it doesn't exist
534 if not self._maint_notifications_connection_handler:
535 self._maint_notifications_connection_handler = (
536 MaintNotificationsConnectionHandler(
537 self, maint_notifications_pool_handler.config
538 )
539 )
540 self._get_parser().set_maintenance_push_handler(
541 self._maint_notifications_connection_handler.handle_notification
542 )
543 else:
544 self._maint_notifications_connection_handler.config = (
545 maint_notifications_pool_handler.config
546 )
548 def set_maint_notifications_cluster_handler_for_connection(
549 self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
550 ):
551 self._get_parser().set_oss_cluster_maint_push_handler(
552 oss_cluster_maint_notifications_handler.handle_notification
553 )
555 self._oss_cluster_maint_notifications_handler = (
556 oss_cluster_maint_notifications_handler
557 )
559 # Update maintenance notification connection handler if it doesn't exist
560 if not self._maint_notifications_connection_handler:
561 self._maint_notifications_connection_handler = (
562 MaintNotificationsConnectionHandler(
563 self, oss_cluster_maint_notifications_handler.config
564 )
565 )
566 self._get_parser().set_maintenance_push_handler(
567 self._maint_notifications_connection_handler.handle_notification
568 )
569 else:
570 self._maint_notifications_connection_handler.config = (
571 oss_cluster_maint_notifications_handler.config
572 )
574 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
575 # Send maintenance notifications handshake if RESP3 is active
576 # and maintenance notifications are enabled
577 # and we have a host to determine the endpoint type from
578 # When the maint_notifications_config enabled mode is "auto",
579 # we just log a warning if the handshake fails
580 # When the mode is enabled=True, we raise an exception in case of failure
581 if (
582 self.get_protocol() not in [2, "2"]
583 and self.maint_notifications_config
584 and self.maint_notifications_config.enabled
585 and self._maint_notifications_connection_handler
586 and hasattr(self, "host")
587 ):
588 self._enable_maintenance_notifications(
589 maint_notifications_config=self.maint_notifications_config,
590 check_health=check_health,
591 )
593 def _enable_maintenance_notifications(
594 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
595 ):
596 try:
597 host = getattr(self, "host", None)
598 if host is None:
599 raise ValueError(
600 "Cannot enable maintenance notifications for connection"
601 " object that doesn't have a host attribute."
602 )
603 else:
604 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
605 self.send_command(
606 "CLIENT",
607 "MAINT_NOTIFICATIONS",
608 "ON",
609 "moving-endpoint-type",
610 endpoint_type.value,
611 check_health=check_health,
612 )
613 response = self.read_response()
614 if not response or str_if_bytes(response) != "OK":
615 raise ResponseError(
616 "The server doesn't support maintenance notifications"
617 )
618 except Exception as e:
619 if (
620 isinstance(e, ResponseError)
621 and maint_notifications_config.enabled == "auto"
622 ):
623 # Log warning but don't fail the connection
624 import logging
626 logger = logging.getLogger(__name__)
627 logger.debug(f"Failed to enable maintenance notifications: {e}")
628 else:
629 raise
631 def get_resolved_ip(self) -> Optional[str]:
632 """
633 Extract the resolved IP address from an
634 established connection or resolve it from the host.
636 First tries to get the actual IP from the socket (most accurate),
637 then falls back to DNS resolution if needed.
639 Args:
640 connection: The connection object to extract the IP from
642 Returns:
643 str: The resolved IP address, or None if it cannot be determined
644 """
646 # Method 1: Try to get the actual IP from the established socket connection
647 # This is most accurate as it shows the exact IP being used
648 try:
649 conn_socket = self._get_socket()
650 if conn_socket is not None:
651 peer_addr = conn_socket.getpeername()
652 if peer_addr and len(peer_addr) >= 1:
653 # For TCP sockets, peer_addr is typically (host, port) tuple
654 # Return just the host part
655 return peer_addr[0]
656 except (AttributeError, OSError):
657 # Socket might not be connected or getpeername() might fail
658 pass
660 # Method 2: Fallback to DNS resolution of the host
661 # This is less accurate but works when socket is not available
662 try:
663 host = getattr(self, "host", "localhost")
664 port = getattr(self, "port", 6379)
665 if host:
666 # Use getaddrinfo to resolve the hostname to IP
667 # This mimics what the connection would do during _connect()
668 addr_info = socket.getaddrinfo(
669 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
670 )
671 if addr_info:
672 # Return the IP from the first result
673 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
674 # sockaddr[0] is the IP address
675 return str(addr_info[0][4][0])
676 except (AttributeError, OSError, socket.gaierror):
677 # DNS resolution might fail
678 pass
680 return None
682 @property
683 def maintenance_state(self) -> MaintenanceState:
684 return self._maintenance_state
686 @maintenance_state.setter
687 def maintenance_state(self, state: "MaintenanceState"):
688 self._maintenance_state = state
690 def add_maint_start_notification(self, id: int):
691 self._processed_start_maint_notifications.add(id)
693 def get_processed_start_notifications(self) -> set:
694 return self._processed_start_maint_notifications
696 def add_skipped_end_notification(self, id: int):
697 self._skipped_end_maint_notifications.add(id)
699 def get_skipped_end_notifications(self) -> set:
700 return self._skipped_end_maint_notifications
702 def reset_received_notifications(self):
703 self._processed_start_maint_notifications.clear()
704 self._skipped_end_maint_notifications.clear()
706 def getpeername(self):
707 """
708 Returns the peer name of the connection.
709 """
710 conn_socket = self._get_socket()
711 if conn_socket:
712 return conn_socket.getpeername()[0]
713 return None
715 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
716 conn_socket = self._get_socket()
717 if conn_socket:
718 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
719 # if the current timeout is 0 it means we are in the middle of a can_read call
720 # in this case we don't want to change the timeout because the operation
721 # is non-blocking and should return immediately
722 # Changing the state from non-blocking to blocking in the middle of a read operation
723 # will lead to a deadlock
724 if conn_socket.gettimeout() != 0:
725 conn_socket.settimeout(timeout)
726 self.update_parser_timeout(timeout)
728 def update_parser_timeout(self, timeout: Optional[float] = None):
729 parser = self._get_parser()
730 if parser and parser._buffer:
731 if isinstance(parser, _RESP3Parser) and timeout:
732 parser._buffer.socket_timeout = timeout
733 elif isinstance(parser, _HiredisParser):
734 parser._socket_timeout = timeout
736 def set_tmp_settings(
737 self,
738 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
739 tmp_relaxed_timeout: Optional[float] = None,
740 ):
741 """
742 The value of SENTINEL is used to indicate that the property should not be updated.
743 """
744 if tmp_host_address and tmp_host_address != SENTINEL:
745 self.host = str(tmp_host_address)
746 if tmp_relaxed_timeout != -1:
747 self.socket_timeout = tmp_relaxed_timeout
748 self.socket_connect_timeout = tmp_relaxed_timeout
750 def reset_tmp_settings(
751 self,
752 reset_host_address: bool = False,
753 reset_relaxed_timeout: bool = False,
754 ):
755 if reset_host_address:
756 self.host = self.orig_host_address
757 if reset_relaxed_timeout:
758 self.socket_timeout = self.orig_socket_timeout
759 self.socket_connect_timeout = self.orig_socket_connect_timeout
762class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
763 "Manages communication to and from a Redis server"
765 @deprecated_args(
766 args_to_warn=["lib_name", "lib_version"],
767 reason="Use 'driver_info' parameter instead. "
768 "lib_name and lib_version will be removed in a future version.",
769 )
770 def __init__(
771 self,
772 db: int = 0,
773 password: Optional[str] = None,
774 socket_timeout: Optional[float] = None,
775 socket_connect_timeout: Optional[float] = None,
776 retry_on_timeout: bool = False,
777 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
778 encoding: str = "utf-8",
779 encoding_errors: str = "strict",
780 decode_responses: bool = False,
781 parser_class=DefaultParser,
782 socket_read_size: int = 65536,
783 health_check_interval: int = 0,
784 client_name: Optional[str] = None,
785 lib_name: Optional[str] = None,
786 lib_version: Optional[str] = None,
787 driver_info: Optional[DriverInfo] = None,
788 username: Optional[str] = None,
789 retry: Union[Any, None] = None,
790 redis_connect_func: Optional[Callable[[], None]] = None,
791 credential_provider: Optional[CredentialProvider] = None,
792 protocol: Optional[int] = 2,
793 command_packer: Optional[Callable[[], None]] = None,
794 event_dispatcher: Optional[EventDispatcher] = None,
795 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
796 maint_notifications_pool_handler: Optional[
797 MaintNotificationsPoolHandler
798 ] = None,
799 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
800 maintenance_notification_hash: Optional[int] = None,
801 orig_host_address: Optional[str] = None,
802 orig_socket_timeout: Optional[float] = None,
803 orig_socket_connect_timeout: Optional[float] = None,
804 oss_cluster_maint_notifications_handler: Optional[
805 OSSMaintNotificationsHandler
806 ] = None,
807 ):
808 """
809 Initialize a new Connection.
811 To specify a retry policy for specific errors, first set
812 `retry_on_error` to a list of the error/s to retry on, then set
813 `retry` to a valid `Retry` object.
814 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
816 Parameters
817 ----------
818 driver_info : DriverInfo, optional
819 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
820 are ignored. If not provided, a DriverInfo will be created from lib_name
821 and lib_version (or defaults if those are also None).
822 lib_name : str, optional
823 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
824 lib_version : str, optional
825 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
826 """
827 if (username or password) and credential_provider is not None:
828 raise DataError(
829 "'username' and 'password' cannot be passed along with 'credential_"
830 "provider'. Please provide only one of the following arguments: \n"
831 "1. 'password' and (optional) 'username'\n"
832 "2. 'credential_provider'"
833 )
834 if event_dispatcher is None:
835 self._event_dispatcher = EventDispatcher()
836 else:
837 self._event_dispatcher = event_dispatcher
838 self.pid = os.getpid()
839 self.db = db
840 self.client_name = client_name
842 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
843 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
845 self.credential_provider = credential_provider
846 self.password = password
847 self.username = username
848 self._socket_timeout = socket_timeout
849 if socket_connect_timeout is None:
850 socket_connect_timeout = socket_timeout
851 self._socket_connect_timeout = socket_connect_timeout
852 self.retry_on_timeout = retry_on_timeout
853 if retry_on_error is SENTINEL:
854 retry_on_errors_list = []
855 else:
856 retry_on_errors_list = list(retry_on_error)
857 if retry_on_timeout:
858 # Add TimeoutError to the errors list to retry on
859 retry_on_errors_list.append(TimeoutError)
860 self.retry_on_error = retry_on_errors_list
861 if retry or self.retry_on_error:
862 if retry is None:
863 self.retry = Retry(NoBackoff(), 1)
864 else:
865 # deep-copy the Retry object as it is mutable
866 self.retry = copy.deepcopy(retry)
867 if self.retry_on_error:
868 # Update the retry's supported errors with the specified errors
869 self.retry.update_supported_errors(self.retry_on_error)
870 else:
871 self.retry = Retry(NoBackoff(), 0)
872 self.health_check_interval = health_check_interval
873 self.next_health_check = 0
874 self.redis_connect_func = redis_connect_func
875 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
876 self.handshake_metadata = None
877 self._sock = None
878 self._socket_read_size = socket_read_size
879 self._connect_callbacks = []
880 self._buffer_cutoff = 6000
881 self._re_auth_token: Optional[TokenInterface] = None
882 try:
883 p = int(protocol)
884 except TypeError:
885 p = DEFAULT_RESP_VERSION
886 except ValueError:
887 raise ConnectionError("protocol must be an integer")
888 else:
889 if p < 2 or p > 3:
890 raise ConnectionError("protocol must be either 2 or 3")
891 self.protocol = p
892 if self.protocol == 3 and parser_class == _RESP2Parser:
893 # If the protocol is 3 but the parser is RESP2, change it to RESP3
894 # This is needed because the parser might be set before the protocol
895 # or might be provided as a kwarg to the constructor
896 # We need to react on discrepancy only for RESP2 and RESP3
897 # as hiredis supports both
898 parser_class = _RESP3Parser
899 self.set_parser(parser_class)
901 self._command_packer = self._construct_command_packer(command_packer)
902 self._should_reconnect = False
904 # Set up maintenance notifications
905 MaintNotificationsAbstractConnection.__init__(
906 self,
907 maint_notifications_config,
908 maint_notifications_pool_handler,
909 maintenance_state,
910 maintenance_notification_hash,
911 orig_host_address,
912 orig_socket_timeout,
913 orig_socket_connect_timeout,
914 oss_cluster_maint_notifications_handler,
915 self._parser,
916 event_dispatcher=self._event_dispatcher,
917 )
919 def __repr__(self):
920 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
921 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
923 @abstractmethod
924 def repr_pieces(self):
925 pass
927 def __del__(self):
928 try:
929 self.disconnect()
930 except Exception:
931 pass
933 def _construct_command_packer(self, packer):
934 if packer is not None:
935 return packer
936 elif HIREDIS_AVAILABLE:
937 return HiredisRespSerializer()
938 else:
939 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
941 def register_connect_callback(self, callback):
942 """
943 Register a callback to be called when the connection is established either
944 initially or reconnected. This allows listeners to issue commands that
945 are ephemeral to the connection, for example pub/sub subscription or
946 key tracking. The callback must be a _method_ and will be kept as
947 a weak reference.
948 """
949 wm = weakref.WeakMethod(callback)
950 if wm not in self._connect_callbacks:
951 self._connect_callbacks.append(wm)
953 def deregister_connect_callback(self, callback):
954 """
955 De-register a previously registered callback. It will no-longer receive
956 notifications on connection events. Calling this is not required when the
957 listener goes away, since the callbacks are kept as weak methods.
958 """
959 try:
960 self._connect_callbacks.remove(weakref.WeakMethod(callback))
961 except ValueError:
962 pass
964 def set_parser(self, parser_class):
965 """
966 Creates a new instance of parser_class with socket size:
967 _socket_read_size and assigns it to the parser for the connection
968 :param parser_class: The required parser class
969 """
970 self._parser = parser_class(socket_read_size=self._socket_read_size)
972 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
973 return self._parser
975 def connect(self):
976 "Connects to the Redis server if not already connected"
977 # try once the socket connect with the handshake, retry the whole
978 # connect/handshake flow based on retry policy
979 self.retry.call_with_retry(
980 lambda: self.connect_check_health(
981 check_health=True, retry_socket_connect=False
982 ),
983 lambda error: self.disconnect(error),
984 )
986 def connect_check_health(
987 self, check_health: bool = True, retry_socket_connect: bool = True
988 ):
989 if self._sock:
990 return
991 # Track actual retry attempts for error reporting
992 actual_retry_attempts = [0]
994 def failure_callback(error, failure_count):
995 actual_retry_attempts[0] = failure_count
996 self.disconnect(error=error, failure_count=failure_count)
998 try:
999 if retry_socket_connect:
1000 sock = self.retry.call_with_retry(
1001 self._connect,
1002 failure_callback,
1003 with_failure_count=True,
1004 )
1005 else:
1006 sock = self._connect()
1007 except socket.timeout:
1008 e = TimeoutError("Timeout connecting to server")
1009 record_error_count(
1010 server_address=self.host,
1011 server_port=self.port,
1012 network_peer_address=self.host,
1013 network_peer_port=self.port,
1014 error_type=e,
1015 retry_attempts=actual_retry_attempts[0],
1016 )
1017 raise e
1018 except OSError as e:
1019 e = ConnectionError(self._error_message(e))
1020 record_error_count(
1021 server_address=getattr(self, "host", None),
1022 server_port=getattr(self, "port", None),
1023 network_peer_address=getattr(self, "host", None),
1024 network_peer_port=getattr(self, "port", None),
1025 error_type=e,
1026 retry_attempts=actual_retry_attempts[0],
1027 )
1028 raise e
1030 self._sock = sock
1031 try:
1032 if self.redis_connect_func is None:
1033 # Use the default on_connect function
1034 self.on_connect_check_health(check_health=check_health)
1035 else:
1036 # Use the passed function redis_connect_func
1037 self.redis_connect_func(self)
1038 except RedisError:
1039 # clean up after any error in on_connect
1040 self.disconnect()
1041 raise
1043 # run any user callbacks. right now the only internal callback
1044 # is for pubsub channel/pattern resubscription
1045 # first, remove any dead weakrefs
1046 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
1047 for ref in self._connect_callbacks:
1048 callback = ref()
1049 if callback:
1050 callback(self)
1052 @abstractmethod
1053 def _connect(self):
1054 pass
1056 @abstractmethod
1057 def _host_error(self):
1058 pass
1060 def _error_message(self, exception):
1061 return format_error_message(self._host_error(), exception)
1063 def on_connect(self):
1064 self.on_connect_check_health(check_health=True)
1066 def on_connect_check_health(self, check_health: bool = True):
1067 "Initialize the connection, authenticate and select a database"
1068 self._parser.on_connect(self)
1069 parser = self._parser
1071 auth_args = None
1072 # if credential provider or username and/or password are set, authenticate
1073 if self.credential_provider or (self.username or self.password):
1074 cred_provider = (
1075 self.credential_provider
1076 or UsernamePasswordCredentialProvider(self.username, self.password)
1077 )
1078 auth_args = cred_provider.get_credentials()
1080 # if resp version is specified and we have auth args,
1081 # we need to send them via HELLO
1082 if auth_args and self.protocol not in [2, "2"]:
1083 if isinstance(self._parser, _RESP2Parser):
1084 self.set_parser(_RESP3Parser)
1085 # update cluster exception classes
1086 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1087 self._parser.on_connect(self)
1088 if len(auth_args) == 1:
1089 auth_args = ["default", auth_args[0]]
1090 # avoid checking health here -- PING will fail if we try
1091 # to check the health prior to the AUTH
1092 self.send_command(
1093 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
1094 )
1095 self.handshake_metadata = self.read_response()
1096 # if response.get(b"proto") != self.protocol and response.get(
1097 # "proto"
1098 # ) != self.protocol:
1099 # raise ConnectionError("Invalid RESP version")
1100 elif auth_args:
1101 # avoid checking health here -- PING will fail if we try
1102 # to check the health prior to the AUTH
1103 self.send_command("AUTH", *auth_args, check_health=False)
1105 try:
1106 auth_response = self.read_response()
1107 except AuthenticationWrongNumberOfArgsError:
1108 # a username and password were specified but the Redis
1109 # server seems to be < 6.0.0 which expects a single password
1110 # arg. retry auth with just the password.
1111 # https://github.com/andymccurdy/redis-py/issues/1274
1112 self.send_command("AUTH", auth_args[-1], check_health=False)
1113 auth_response = self.read_response()
1115 if str_if_bytes(auth_response) != "OK":
1116 raise AuthenticationError("Invalid Username or Password")
1118 # if resp version is specified, switch to it
1119 elif self.protocol not in [2, "2"]:
1120 if isinstance(self._parser, _RESP2Parser):
1121 self.set_parser(_RESP3Parser)
1122 # update cluster exception classes
1123 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1124 self._parser.on_connect(self)
1125 self.send_command("HELLO", self.protocol, check_health=check_health)
1126 self.handshake_metadata = self.read_response()
1127 if (
1128 self.handshake_metadata.get(b"proto") != self.protocol
1129 and self.handshake_metadata.get("proto") != self.protocol
1130 ):
1131 raise ConnectionError("Invalid RESP version")
1133 # Activate maintenance notifications for this connection
1134 # if enabled in the configuration
1135 # This is a no-op if maintenance notifications are not enabled
1136 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
1138 # if a client_name is given, set it
1139 if self.client_name:
1140 self.send_command(
1141 "CLIENT",
1142 "SETNAME",
1143 self.client_name,
1144 check_health=check_health,
1145 )
1146 if str_if_bytes(self.read_response()) != "OK":
1147 raise ConnectionError("Error setting client name")
1149 # Set the library name and version from driver_info
1150 try:
1151 if self.driver_info and self.driver_info.formatted_name:
1152 self.send_command(
1153 "CLIENT",
1154 "SETINFO",
1155 "LIB-NAME",
1156 self.driver_info.formatted_name,
1157 check_health=check_health,
1158 )
1159 self.read_response()
1160 except ResponseError:
1161 pass
1163 try:
1164 if self.driver_info and self.driver_info.lib_version:
1165 self.send_command(
1166 "CLIENT",
1167 "SETINFO",
1168 "LIB-VER",
1169 self.driver_info.lib_version,
1170 check_health=check_health,
1171 )
1172 self.read_response()
1173 except ResponseError:
1174 pass
1176 # if a database is specified, switch to it
1177 if self.db:
1178 self.send_command("SELECT", self.db, check_health=check_health)
1179 if str_if_bytes(self.read_response()) != "OK":
1180 raise ConnectionError("Invalid Database")
1182 def disconnect(self, *args, **kwargs):
1183 "Disconnects from the Redis server"
1184 self._parser.on_disconnect()
1186 conn_sock = self._sock
1187 self._sock = None
1188 # reset the reconnect flag
1189 self.reset_should_reconnect()
1191 if conn_sock is None:
1192 return
1194 if os.getpid() == self.pid:
1195 try:
1196 conn_sock.shutdown(socket.SHUT_RDWR)
1197 except (OSError, TypeError):
1198 pass
1200 try:
1201 conn_sock.close()
1202 except OSError:
1203 pass
1205 error = kwargs.get("error")
1206 failure_count = kwargs.get("failure_count")
1207 health_check_failed = kwargs.get("health_check_failed")
1209 if error:
1210 if health_check_failed:
1211 close_reason = CloseReason.HEALTHCHECK_FAILED
1212 else:
1213 close_reason = CloseReason.ERROR
1215 if failure_count is not None and failure_count > self.retry.get_retries():
1216 record_error_count(
1217 server_address=self.host,
1218 server_port=self.port,
1219 network_peer_address=self.host,
1220 network_peer_port=self.port,
1221 error_type=error,
1222 retry_attempts=failure_count,
1223 )
1225 record_connection_closed(
1226 close_reason=close_reason,
1227 error_type=error,
1228 )
1229 else:
1230 record_connection_closed(
1231 close_reason=CloseReason.APPLICATION_CLOSE,
1232 )
1234 if self.maintenance_state == MaintenanceState.MAINTENANCE:
1235 # this block will be executed only if the connection was in maintenance state
1236 # and the connection was closed.
1237 # The state change won't be applied on connections that are in Moving state
1238 # because their state and configurations will be handled when the moving ttl expires.
1239 self.reset_tmp_settings(reset_relaxed_timeout=True)
1240 self.maintenance_state = MaintenanceState.NONE
1241 # reset the sets that keep track of received start maint
1242 # notifications and skipped end maint notifications
1243 self.reset_received_notifications()
1245 def mark_for_reconnect(self):
1246 self._should_reconnect = True
1248 def should_reconnect(self):
1249 return self._should_reconnect
1251 def reset_should_reconnect(self):
1252 self._should_reconnect = False
1254 def _send_ping(self):
1255 """Send PING, expect PONG in return"""
1256 self.send_command("PING", check_health=False)
1257 if str_if_bytes(self.read_response()) != "PONG":
1258 raise ConnectionError("Bad response from PING health check")
1260 def _ping_failed(self, error, failure_count):
1261 """Function to call when PING fails"""
1262 self.disconnect(
1263 error=error, failure_count=failure_count, health_check_failed=True
1264 )
1266 def check_health(self):
1267 """Check the health of the connection with a PING/PONG"""
1268 if self.health_check_interval and time.monotonic() > self.next_health_check:
1269 self.retry.call_with_retry(
1270 self._send_ping,
1271 self._ping_failed,
1272 with_failure_count=True,
1273 )
1275 def send_packed_command(self, command, check_health=True):
1276 """Send an already packed command to the Redis server"""
1277 if not self._sock:
1278 self.connect_check_health(check_health=False)
1279 # guard against health check recursion
1280 if check_health:
1281 self.check_health()
1282 try:
1283 if isinstance(command, str):
1284 command = [command]
1285 for item in command:
1286 self._sock.sendall(item)
1287 except socket.timeout:
1288 self.disconnect()
1289 raise TimeoutError("Timeout writing to socket")
1290 except OSError as e:
1291 self.disconnect()
1292 if len(e.args) == 1:
1293 errno, errmsg = "UNKNOWN", e.args[0]
1294 else:
1295 errno = e.args[0]
1296 errmsg = e.args[1]
1297 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1298 except BaseException:
1299 # BaseExceptions can be raised when a socket send operation is not
1300 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1301 # to send un-sent data. However, the send_packed_command() API
1302 # does not support it so there is no point in keeping the connection open.
1303 self.disconnect()
1304 raise
1306 def send_command(self, *args, **kwargs):
1307 """Pack and send a command to the Redis server"""
1308 self.send_packed_command(
1309 self._command_packer.pack(*args),
1310 check_health=kwargs.get("check_health", True),
1311 )
1313 def can_read(self, timeout=0):
1314 """Poll the socket to see if there's data that can be read."""
1315 sock = self._sock
1316 if not sock:
1317 self.connect()
1319 host_error = self._host_error()
1321 try:
1322 return self._parser.can_read(timeout)
1324 except OSError as e:
1325 self.disconnect()
1326 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1328 def read_response(
1329 self,
1330 disable_decoding=False,
1331 *,
1332 disconnect_on_error=True,
1333 push_request=False,
1334 ):
1335 """Read the response from a previously sent command"""
1337 host_error = self._host_error()
1339 try:
1340 if self.protocol in ["3", 3]:
1341 response = self._parser.read_response(
1342 disable_decoding=disable_decoding, push_request=push_request
1343 )
1344 else:
1345 response = self._parser.read_response(disable_decoding=disable_decoding)
1346 except socket.timeout:
1347 if disconnect_on_error:
1348 self.disconnect()
1349 raise TimeoutError(f"Timeout reading from {host_error}")
1350 except OSError as e:
1351 if disconnect_on_error:
1352 self.disconnect()
1353 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1354 except BaseException:
1355 # Also by default close in case of BaseException. A lot of code
1356 # relies on this behaviour when doing Command/Response pairs.
1357 # See #1128.
1358 if disconnect_on_error:
1359 self.disconnect()
1360 raise
1362 if self.health_check_interval:
1363 self.next_health_check = time.monotonic() + self.health_check_interval
1365 if isinstance(response, ResponseError):
1366 try:
1367 raise response
1368 finally:
1369 del response # avoid creating ref cycles
1370 return response
1372 def pack_command(self, *args):
1373 """Pack a series of arguments into the Redis protocol"""
1374 return self._command_packer.pack(*args)
1376 def pack_commands(self, commands):
1377 """Pack multiple commands into the Redis protocol"""
1378 output = []
1379 pieces = []
1380 buffer_length = 0
1381 buffer_cutoff = self._buffer_cutoff
1383 for cmd in commands:
1384 for chunk in self._command_packer.pack(*cmd):
1385 chunklen = len(chunk)
1386 if (
1387 buffer_length > buffer_cutoff
1388 or chunklen > buffer_cutoff
1389 or isinstance(chunk, memoryview)
1390 ):
1391 if pieces:
1392 output.append(SYM_EMPTY.join(pieces))
1393 buffer_length = 0
1394 pieces = []
1396 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1397 output.append(chunk)
1398 else:
1399 pieces.append(chunk)
1400 buffer_length += chunklen
1402 if pieces:
1403 output.append(SYM_EMPTY.join(pieces))
1404 return output
1406 def get_protocol(self) -> Union[int, str]:
1407 return self.protocol
1409 @property
1410 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1411 return self._handshake_metadata
1413 @handshake_metadata.setter
1414 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1415 self._handshake_metadata = value
1417 def set_re_auth_token(self, token: TokenInterface):
1418 self._re_auth_token = token
1420 def re_auth(self):
1421 if self._re_auth_token is not None:
1422 self.send_command(
1423 "AUTH",
1424 self._re_auth_token.try_get("oid"),
1425 self._re_auth_token.get_value(),
1426 )
1427 self.read_response()
1428 self._re_auth_token = None
1430 def _get_socket(self) -> Optional[socket.socket]:
1431 return self._sock
1433 @property
1434 def socket_timeout(self) -> Optional[Union[float, int]]:
1435 return self._socket_timeout
1437 @socket_timeout.setter
1438 def socket_timeout(self, value: Optional[Union[float, int]]):
1439 self._socket_timeout = value
1441 @property
1442 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1443 return self._socket_connect_timeout
1445 @socket_connect_timeout.setter
1446 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1447 self._socket_connect_timeout = value
1449 def extract_connection_details(self) -> str:
1450 socket_address = None
1451 if self._sock is None:
1452 return "not connected"
1453 try:
1454 socket_address = self._sock.getsockname() if self._sock else None
1455 socket_address = socket_address[1] if socket_address else None
1456 except (AttributeError, OSError):
1457 pass
1459 return f"connected to ip {self.get_resolved_ip()}, local socket port: {socket_address}"
1462class Connection(AbstractConnection):
1463 "Manages TCP communication to and from a Redis server"
1465 def __init__(
1466 self,
1467 host="localhost",
1468 port=6379,
1469 socket_keepalive=False,
1470 socket_keepalive_options=None,
1471 socket_type=0,
1472 **kwargs,
1473 ):
1474 self._host = host
1475 self.port = int(port)
1476 self.socket_keepalive = socket_keepalive
1477 self.socket_keepalive_options = socket_keepalive_options or {}
1478 self.socket_type = socket_type
1479 super().__init__(**kwargs)
1481 def repr_pieces(self):
1482 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1483 if self.client_name:
1484 pieces.append(("client_name", self.client_name))
1485 return pieces
1487 def _connect(self):
1488 "Create a TCP socket connection"
1489 # we want to mimic what socket.create_connection does to support
1490 # ipv4/ipv6, but we want to set options prior to calling
1491 # socket.connect()
1492 err = None
1494 for res in socket.getaddrinfo(
1495 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1496 ):
1497 family, socktype, proto, canonname, socket_address = res
1498 sock = None
1499 try:
1500 sock = socket.socket(family, socktype, proto)
1501 # TCP_NODELAY
1502 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1504 # TCP_KEEPALIVE
1505 if self.socket_keepalive:
1506 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1507 for k, v in self.socket_keepalive_options.items():
1508 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1510 # set the socket_connect_timeout before we connect
1511 sock.settimeout(self.socket_connect_timeout)
1513 # connect
1514 sock.connect(socket_address)
1516 # set the socket_timeout now that we're connected
1517 sock.settimeout(self.socket_timeout)
1518 return sock
1520 except OSError as _:
1521 err = _
1522 if sock is not None:
1523 try:
1524 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1525 except OSError:
1526 pass
1527 sock.close()
1529 if err is not None:
1530 raise err
1531 raise OSError("socket.getaddrinfo returned an empty list")
1533 def _host_error(self):
1534 return f"{self.host}:{self.port}"
1536 @property
1537 def host(self) -> str:
1538 return self._host
1540 @host.setter
1541 def host(self, value: str):
1542 self._host = value
1545class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1546 DUMMY_CACHE_VALUE = b"foo"
1547 MIN_ALLOWED_VERSION = "7.4.0"
1548 DEFAULT_SERVER_NAME = "redis"
1550 def __init__(
1551 self,
1552 conn: ConnectionInterface,
1553 cache: CacheInterface,
1554 pool_lock: threading.RLock,
1555 ):
1556 self.pid = os.getpid()
1557 self._conn = conn
1558 self.retry = self._conn.retry
1559 self.host = self._conn.host
1560 self.port = self._conn.port
1561 self.db = self._conn.db
1562 self._event_dispatcher = self._conn._event_dispatcher
1563 self.credential_provider = conn.credential_provider
1564 self._pool_lock = pool_lock
1565 self._cache = cache
1566 self._cache_lock = threading.RLock()
1567 self._current_command_cache_key = None
1568 self._current_options = None
1569 self.register_connect_callback(self._enable_tracking_callback)
1571 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1572 MaintNotificationsAbstractConnection.__init__(
1573 self,
1574 self._conn.maint_notifications_config,
1575 self._conn._maint_notifications_pool_handler,
1576 self._conn.maintenance_state,
1577 self._conn.maintenance_notification_hash,
1578 self._conn.host,
1579 self._conn.socket_timeout,
1580 self._conn.socket_connect_timeout,
1581 self._conn._oss_cluster_maint_notifications_handler,
1582 self._conn._get_parser(),
1583 event_dispatcher=self._conn.event_dispatcher,
1584 )
1586 def repr_pieces(self):
1587 return self._conn.repr_pieces()
1589 def register_connect_callback(self, callback):
1590 self._conn.register_connect_callback(callback)
1592 def deregister_connect_callback(self, callback):
1593 self._conn.deregister_connect_callback(callback)
1595 def set_parser(self, parser_class):
1596 self._conn.set_parser(parser_class)
1598 def set_maint_notifications_pool_handler_for_connection(
1599 self, maint_notifications_pool_handler
1600 ):
1601 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1602 self._conn.set_maint_notifications_pool_handler_for_connection(
1603 maint_notifications_pool_handler
1604 )
1606 def set_maint_notifications_cluster_handler_for_connection(
1607 self, oss_cluster_maint_notifications_handler
1608 ):
1609 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1610 self._conn.set_maint_notifications_cluster_handler_for_connection(
1611 oss_cluster_maint_notifications_handler
1612 )
1614 def get_protocol(self):
1615 return self._conn.get_protocol()
1617 def connect(self):
1618 self._conn.connect()
1620 server_name = self._conn.handshake_metadata.get(b"server", None)
1621 if server_name is None:
1622 server_name = self._conn.handshake_metadata.get("server", None)
1623 server_ver = self._conn.handshake_metadata.get(b"version", None)
1624 if server_ver is None:
1625 server_ver = self._conn.handshake_metadata.get("version", None)
1626 if server_ver is None or server_name is None:
1627 raise ConnectionError("Cannot retrieve information about server version")
1629 server_ver = ensure_string(server_ver)
1630 server_name = ensure_string(server_name)
1632 if (
1633 server_name != self.DEFAULT_SERVER_NAME
1634 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1635 ):
1636 raise ConnectionError(
1637 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1638 )
1640 def on_connect(self):
1641 self._conn.on_connect()
1643 def disconnect(self, *args, **kwargs):
1644 with self._cache_lock:
1645 self._cache.flush()
1646 self._conn.disconnect(*args, **kwargs)
1648 def check_health(self):
1649 self._conn.check_health()
1651 def send_packed_command(self, command, check_health=True):
1652 # TODO: Investigate if it's possible to unpack command
1653 # or extract keys from packed command
1654 self._conn.send_packed_command(command)
1656 def send_command(self, *args, **kwargs):
1657 self._process_pending_invalidations()
1659 with self._cache_lock:
1660 # Command is write command or not allowed
1661 # to be cached.
1662 if not self._cache.is_cachable(
1663 CacheKey(command=args[0], redis_keys=(), redis_args=())
1664 ):
1665 self._current_command_cache_key = None
1666 self._conn.send_command(*args, **kwargs)
1667 return
1669 if kwargs.get("keys") is None:
1670 raise ValueError("Cannot create cache key.")
1672 # Creates cache key.
1673 self._current_command_cache_key = CacheKey(
1674 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1675 )
1677 with self._cache_lock:
1678 # We have to trigger invalidation processing in case if
1679 # it was cached by another connection to avoid
1680 # queueing invalidations in stale connections.
1681 if self._cache.get(self._current_command_cache_key):
1682 entry = self._cache.get(self._current_command_cache_key)
1684 if entry.connection_ref != self._conn:
1685 with self._pool_lock:
1686 while entry.connection_ref.can_read():
1687 entry.connection_ref.read_response(push_request=True)
1689 return
1691 # Set temporary entry value to prevent
1692 # race condition from another connection.
1693 self._cache.set(
1694 CacheEntry(
1695 cache_key=self._current_command_cache_key,
1696 cache_value=self.DUMMY_CACHE_VALUE,
1697 status=CacheEntryStatus.IN_PROGRESS,
1698 connection_ref=self._conn,
1699 )
1700 )
1702 # Send command over socket only if it's allowed
1703 # read-only command that not yet cached.
1704 self._conn.send_command(*args, **kwargs)
1706 def can_read(self, timeout=0):
1707 return self._conn.can_read(timeout)
1709 def read_response(
1710 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1711 ):
1712 with self._cache_lock:
1713 # Check if command response exists in a cache and it's not in progress.
1714 if self._current_command_cache_key is not None:
1715 if (
1716 self._cache.get(self._current_command_cache_key) is not None
1717 and self._cache.get(self._current_command_cache_key).status
1718 != CacheEntryStatus.IN_PROGRESS
1719 ):
1720 res = copy.deepcopy(
1721 self._cache.get(self._current_command_cache_key).cache_value
1722 )
1723 self._current_command_cache_key = None
1724 record_csc_request(
1725 result=CSCResult.HIT,
1726 )
1727 record_csc_network_saved(
1728 bytes_saved=len(res),
1729 )
1730 return res
1731 record_csc_request(
1732 result=CSCResult.MISS,
1733 )
1735 response = self._conn.read_response(
1736 disable_decoding=disable_decoding,
1737 disconnect_on_error=disconnect_on_error,
1738 push_request=push_request,
1739 )
1741 with self._cache_lock:
1742 # Prevent not-allowed command from caching.
1743 if self._current_command_cache_key is None:
1744 return response
1745 # If response is None prevent from caching.
1746 if response is None:
1747 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1748 return response
1750 cache_entry = self._cache.get(self._current_command_cache_key)
1752 # Cache only responses that still valid
1753 # and wasn't invalidated by another connection in meantime.
1754 if cache_entry is not None:
1755 cache_entry.status = CacheEntryStatus.VALID
1756 cache_entry.cache_value = response
1757 self._cache.set(cache_entry)
1759 self._current_command_cache_key = None
1761 return response
1763 def pack_command(self, *args):
1764 return self._conn.pack_command(*args)
1766 def pack_commands(self, commands):
1767 return self._conn.pack_commands(commands)
1769 @property
1770 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1771 return self._conn.handshake_metadata
1773 def set_re_auth_token(self, token: TokenInterface):
1774 self._conn.set_re_auth_token(token)
1776 def re_auth(self):
1777 self._conn.re_auth()
1779 def mark_for_reconnect(self):
1780 self._conn.mark_for_reconnect()
1782 def should_reconnect(self):
1783 return self._conn.should_reconnect()
1785 def reset_should_reconnect(self):
1786 self._conn.reset_should_reconnect()
1788 @property
1789 def host(self) -> str:
1790 return self._conn.host
1792 @host.setter
1793 def host(self, value: str):
1794 self._conn.host = value
1796 @property
1797 def socket_timeout(self) -> Optional[Union[float, int]]:
1798 return self._conn.socket_timeout
1800 @socket_timeout.setter
1801 def socket_timeout(self, value: Optional[Union[float, int]]):
1802 self._conn.socket_timeout = value
1804 @property
1805 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1806 return self._conn.socket_connect_timeout
1808 @socket_connect_timeout.setter
1809 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1810 self._conn.socket_connect_timeout = value
1812 @property
1813 def _maint_notifications_connection_handler(
1814 self,
1815 ) -> Optional[MaintNotificationsConnectionHandler]:
1816 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1817 return self._conn._maint_notifications_connection_handler
1819 @_maint_notifications_connection_handler.setter
1820 def _maint_notifications_connection_handler(
1821 self, value: Optional[MaintNotificationsConnectionHandler]
1822 ):
1823 self._conn._maint_notifications_connection_handler = value
1825 def _get_socket(self) -> Optional[socket.socket]:
1826 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1827 return self._conn._get_socket()
1828 else:
1829 raise NotImplementedError(
1830 "Maintenance notifications are not supported by this connection type"
1831 )
1833 def _get_maint_notifications_connection_instance(
1834 self, connection
1835 ) -> MaintNotificationsAbstractConnection:
1836 """
1837 Validate that connection instance supports maintenance notifications.
1838 With this helper method we ensure that we are working
1839 with the correct connection type.
1840 After twe validate that connection instance supports maintenance notifications
1841 we can safely return the connection instance
1842 as MaintNotificationsAbstractConnection.
1843 """
1844 if not isinstance(connection, MaintNotificationsAbstractConnection):
1845 raise NotImplementedError(
1846 "Maintenance notifications are not supported by this connection type"
1847 )
1848 else:
1849 return connection
1851 @property
1852 def maintenance_state(self) -> MaintenanceState:
1853 con = self._get_maint_notifications_connection_instance(self._conn)
1854 return con.maintenance_state
1856 @maintenance_state.setter
1857 def maintenance_state(self, state: MaintenanceState):
1858 con = self._get_maint_notifications_connection_instance(self._conn)
1859 con.maintenance_state = state
1861 def getpeername(self):
1862 con = self._get_maint_notifications_connection_instance(self._conn)
1863 return con.getpeername()
1865 def get_resolved_ip(self):
1866 con = self._get_maint_notifications_connection_instance(self._conn)
1867 return con.get_resolved_ip()
1869 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1870 con = self._get_maint_notifications_connection_instance(self._conn)
1871 con.update_current_socket_timeout(relaxed_timeout)
1873 def set_tmp_settings(
1874 self,
1875 tmp_host_address: Optional[str] = None,
1876 tmp_relaxed_timeout: Optional[float] = None,
1877 ):
1878 con = self._get_maint_notifications_connection_instance(self._conn)
1879 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1881 def reset_tmp_settings(
1882 self,
1883 reset_host_address: bool = False,
1884 reset_relaxed_timeout: bool = False,
1885 ):
1886 con = self._get_maint_notifications_connection_instance(self._conn)
1887 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1889 def _connect(self):
1890 self._conn._connect()
1892 def _host_error(self):
1893 self._conn._host_error()
1895 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1896 conn.send_command("CLIENT", "TRACKING", "ON")
1897 conn.read_response()
1898 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1900 def _process_pending_invalidations(self):
1901 while self.can_read():
1902 self._conn.read_response(push_request=True)
1904 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1905 with self._cache_lock:
1906 # Flush cache when DB flushed on server-side
1907 if data[1] is None:
1908 self._cache.flush()
1909 else:
1910 keys_deleted = self._cache.delete_by_redis_keys(data[1])
1912 if len(keys_deleted) > 0:
1913 record_csc_eviction(
1914 count=len(keys_deleted),
1915 reason=CSCReason.INVALIDATION,
1916 )
1918 def extract_connection_details(self) -> str:
1919 return self._conn.extract_connection_details()
1922class SSLConnection(Connection):
1923 """Manages SSL connections to and from the Redis server(s).
1924 This class extends the Connection class, adding SSL functionality, and making
1925 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1926 """ # noqa
1928 def __init__(
1929 self,
1930 ssl_keyfile=None,
1931 ssl_certfile=None,
1932 ssl_cert_reqs="required",
1933 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1934 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1935 ssl_ca_certs=None,
1936 ssl_ca_data=None,
1937 ssl_check_hostname=True,
1938 ssl_ca_path=None,
1939 ssl_password=None,
1940 ssl_validate_ocsp=False,
1941 ssl_validate_ocsp_stapled=False,
1942 ssl_ocsp_context=None,
1943 ssl_ocsp_expected_cert=None,
1944 ssl_min_version=None,
1945 ssl_ciphers=None,
1946 **kwargs,
1947 ):
1948 """Constructor
1950 Args:
1951 ssl_keyfile: Path to an ssl private key. Defaults to None.
1952 ssl_certfile: Path to an ssl certificate. Defaults to None.
1953 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1954 or an ssl.VerifyMode. Defaults to "required".
1955 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1956 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1957 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1958 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1959 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1960 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1961 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1963 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1964 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1965 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1966 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1967 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1968 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1970 Raises:
1971 RedisError
1972 """ # noqa
1973 if not SSL_AVAILABLE:
1974 raise RedisError("Python wasn't built with SSL support")
1976 self.keyfile = ssl_keyfile
1977 self.certfile = ssl_certfile
1978 if ssl_cert_reqs is None:
1979 ssl_cert_reqs = ssl.CERT_NONE
1980 elif isinstance(ssl_cert_reqs, str):
1981 CERT_REQS = { # noqa: N806
1982 "none": ssl.CERT_NONE,
1983 "optional": ssl.CERT_OPTIONAL,
1984 "required": ssl.CERT_REQUIRED,
1985 }
1986 if ssl_cert_reqs not in CERT_REQS:
1987 raise RedisError(
1988 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1989 )
1990 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1991 self.cert_reqs = ssl_cert_reqs
1992 self.ssl_include_verify_flags = ssl_include_verify_flags
1993 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1994 self.ca_certs = ssl_ca_certs
1995 self.ca_data = ssl_ca_data
1996 self.ca_path = ssl_ca_path
1997 self.check_hostname = (
1998 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1999 )
2000 self.certificate_password = ssl_password
2001 self.ssl_validate_ocsp = ssl_validate_ocsp
2002 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
2003 self.ssl_ocsp_context = ssl_ocsp_context
2004 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
2005 self.ssl_min_version = ssl_min_version
2006 self.ssl_ciphers = ssl_ciphers
2007 super().__init__(**kwargs)
2009 def _connect(self):
2010 """
2011 Wrap the socket with SSL support, handling potential errors.
2012 """
2013 sock = super()._connect()
2014 try:
2015 return self._wrap_socket_with_ssl(sock)
2016 except (OSError, RedisError):
2017 sock.close()
2018 raise
2020 def _wrap_socket_with_ssl(self, sock):
2021 """
2022 Wraps the socket with SSL support.
2024 Args:
2025 sock: The plain socket to wrap with SSL.
2027 Returns:
2028 An SSL wrapped socket.
2029 """
2030 context = ssl.create_default_context()
2031 context.check_hostname = self.check_hostname
2032 context.verify_mode = self.cert_reqs
2033 if self.ssl_include_verify_flags:
2034 for flag in self.ssl_include_verify_flags:
2035 context.verify_flags |= flag
2036 if self.ssl_exclude_verify_flags:
2037 for flag in self.ssl_exclude_verify_flags:
2038 context.verify_flags &= ~flag
2039 if self.certfile or self.keyfile:
2040 context.load_cert_chain(
2041 certfile=self.certfile,
2042 keyfile=self.keyfile,
2043 password=self.certificate_password,
2044 )
2045 if (
2046 self.ca_certs is not None
2047 or self.ca_path is not None
2048 or self.ca_data is not None
2049 ):
2050 context.load_verify_locations(
2051 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
2052 )
2053 if self.ssl_min_version is not None:
2054 context.minimum_version = self.ssl_min_version
2055 if self.ssl_ciphers:
2056 context.set_ciphers(self.ssl_ciphers)
2057 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
2058 raise RedisError("cryptography is not installed.")
2060 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
2061 raise RedisError(
2062 "Either an OCSP staple or pure OCSP connection must be validated "
2063 "- not both."
2064 )
2066 sslsock = context.wrap_socket(sock, server_hostname=self.host)
2068 # validation for the stapled case
2069 if self.ssl_validate_ocsp_stapled:
2070 import OpenSSL
2072 from .ocsp import ocsp_staple_verifier
2074 # if a context is provided use it - otherwise, a basic context
2075 if self.ssl_ocsp_context is None:
2076 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
2077 staple_ctx.use_certificate_file(self.certfile)
2078 staple_ctx.use_privatekey_file(self.keyfile)
2079 else:
2080 staple_ctx = self.ssl_ocsp_context
2082 staple_ctx.set_ocsp_client_callback(
2083 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
2084 )
2086 # need another socket
2087 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
2088 con.request_ocsp()
2089 con.connect((self.host, self.port))
2090 con.do_handshake()
2091 con.shutdown()
2092 return sslsock
2094 # pure ocsp validation
2095 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
2096 from .ocsp import OCSPVerifier
2098 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
2099 if o.is_valid():
2100 return sslsock
2101 else:
2102 raise ConnectionError("ocsp validation error")
2103 return sslsock
2106class UnixDomainSocketConnection(AbstractConnection):
2107 "Manages UDS communication to and from a Redis server"
2109 def __init__(self, path="", socket_timeout=None, **kwargs):
2110 super().__init__(**kwargs)
2111 self.path = path
2112 self.socket_timeout = socket_timeout
2114 def repr_pieces(self):
2115 pieces = [("path", self.path), ("db", self.db)]
2116 if self.client_name:
2117 pieces.append(("client_name", self.client_name))
2118 return pieces
2120 def _connect(self):
2121 "Create a Unix domain socket connection"
2122 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
2123 sock.settimeout(self.socket_connect_timeout)
2124 try:
2125 sock.connect(self.path)
2126 except OSError:
2127 # Prevent ResourceWarnings for unclosed sockets.
2128 try:
2129 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
2130 except OSError:
2131 pass
2132 sock.close()
2133 raise
2134 sock.settimeout(self.socket_timeout)
2135 return sock
2137 def _host_error(self):
2138 return self.path
2141FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
2144def to_bool(value):
2145 if value is None or value == "":
2146 return None
2147 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
2148 return False
2149 return bool(value)
2152def parse_ssl_verify_flags(value):
2153 # flags are passed in as a string representation of a list,
2154 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
2155 verify_flags_str = value.replace("[", "").replace("]", "")
2157 verify_flags = []
2158 for flag in verify_flags_str.split(","):
2159 flag = flag.strip()
2160 if not hasattr(VerifyFlags, flag):
2161 raise ValueError(f"Invalid ssl verify flag: {flag}")
2162 verify_flags.append(getattr(VerifyFlags, flag))
2163 return verify_flags
2166URL_QUERY_ARGUMENT_PARSERS = {
2167 "db": int,
2168 "socket_timeout": float,
2169 "socket_connect_timeout": float,
2170 "socket_keepalive": to_bool,
2171 "retry_on_timeout": to_bool,
2172 "retry_on_error": list,
2173 "max_connections": int,
2174 "health_check_interval": int,
2175 "ssl_check_hostname": to_bool,
2176 "ssl_include_verify_flags": parse_ssl_verify_flags,
2177 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
2178 "timeout": float,
2179}
2182def parse_url(url):
2183 if not (
2184 url.startswith("redis://")
2185 or url.startswith("rediss://")
2186 or url.startswith("unix://")
2187 ):
2188 raise ValueError(
2189 "Redis URL must specify one of the following "
2190 "schemes (redis://, rediss://, unix://)"
2191 )
2193 url = urlparse(url)
2194 kwargs = {}
2196 for name, value in parse_qs(url.query).items():
2197 if value and len(value) > 0:
2198 value = unquote(value[0])
2199 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
2200 if parser:
2201 try:
2202 kwargs[name] = parser(value)
2203 except (TypeError, ValueError):
2204 raise ValueError(f"Invalid value for '{name}' in connection URL.")
2205 else:
2206 kwargs[name] = value
2208 if url.username:
2209 kwargs["username"] = unquote(url.username)
2210 if url.password:
2211 kwargs["password"] = unquote(url.password)
2213 # We only support redis://, rediss:// and unix:// schemes.
2214 if url.scheme == "unix":
2215 if url.path:
2216 kwargs["path"] = unquote(url.path)
2217 kwargs["connection_class"] = UnixDomainSocketConnection
2219 else: # implied: url.scheme in ("redis", "rediss"):
2220 if url.hostname:
2221 kwargs["host"] = unquote(url.hostname)
2222 if url.port:
2223 kwargs["port"] = int(url.port)
2225 # If there's a path argument, use it as the db argument if a
2226 # querystring value wasn't specified
2227 if url.path and "db" not in kwargs:
2228 try:
2229 kwargs["db"] = int(unquote(url.path).replace("/", ""))
2230 except (AttributeError, ValueError):
2231 pass
2233 if url.scheme == "rediss":
2234 kwargs["connection_class"] = SSLConnection
2236 return kwargs
2239_CP = TypeVar("_CP", bound="ConnectionPool")
2242class ConnectionPoolInterface(ABC):
2243 @abstractmethod
2244 def get_protocol(self):
2245 pass
2247 @abstractmethod
2248 def reset(self):
2249 pass
2251 @abstractmethod
2252 @deprecated_args(
2253 args_to_warn=["*"],
2254 reason="Use get_connection() without args instead",
2255 version="5.3.0",
2256 )
2257 def get_connection(
2258 self, command_name: Optional[str], *keys, **options
2259 ) -> ConnectionInterface:
2260 pass
2262 @abstractmethod
2263 def get_encoder(self):
2264 pass
2266 @abstractmethod
2267 def release(self, connection: ConnectionInterface):
2268 pass
2270 @abstractmethod
2271 def disconnect(self, inuse_connections: bool = True):
2272 pass
2274 @abstractmethod
2275 def close(self):
2276 pass
2278 @abstractmethod
2279 def set_retry(self, retry: Retry):
2280 pass
2282 @abstractmethod
2283 def re_auth_callback(self, token: TokenInterface):
2284 pass
2286 @abstractmethod
2287 def get_connection_count(self) -> list[tuple[int, dict]]:
2288 """
2289 Returns a connection count (both idle and in use).
2290 """
2291 pass
2294class MaintNotificationsAbstractConnectionPool:
2295 """
2296 Abstract class for handling maintenance notifications logic.
2297 This class is mixed into the ConnectionPool classes.
2299 This class is not intended to be used directly!
2301 All logic related to maintenance notifications and
2302 connection pool handling is encapsulated in this class.
2303 """
2305 def __init__(
2306 self,
2307 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2308 oss_cluster_maint_notifications_handler: Optional[
2309 OSSMaintNotificationsHandler
2310 ] = None,
2311 **kwargs,
2312 ):
2313 # Initialize maintenance notifications
2314 is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3)
2316 if maint_notifications_config is None and is_protocol_supported:
2317 maint_notifications_config = MaintNotificationsConfig()
2319 if maint_notifications_config and maint_notifications_config.enabled:
2320 if not is_protocol_supported:
2321 raise RedisError(
2322 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2323 )
2325 self._event_dispatcher = kwargs.get("event_dispatcher", None)
2326 if self._event_dispatcher is None:
2327 self._event_dispatcher = EventDispatcher()
2329 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2330 self, maint_notifications_config
2331 )
2332 if oss_cluster_maint_notifications_handler:
2333 self._oss_cluster_maint_notifications_handler = (
2334 oss_cluster_maint_notifications_handler
2335 )
2336 self._update_connection_kwargs_for_maint_notifications(
2337 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler
2338 )
2339 self._maint_notifications_pool_handler = None
2340 else:
2341 self._oss_cluster_maint_notifications_handler = None
2342 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2343 self, maint_notifications_config
2344 )
2346 self._update_connection_kwargs_for_maint_notifications(
2347 maint_notifications_pool_handler=self._maint_notifications_pool_handler
2348 )
2349 else:
2350 self._maint_notifications_pool_handler = None
2351 self._oss_cluster_maint_notifications_handler = None
2353 @property
2354 @abstractmethod
2355 def connection_kwargs(self) -> Dict[str, Any]:
2356 pass
2358 @connection_kwargs.setter
2359 @abstractmethod
2360 def connection_kwargs(self, value: Dict[str, Any]):
2361 pass
2363 @abstractmethod
2364 def _get_pool_lock(self) -> threading.RLock:
2365 pass
2367 @abstractmethod
2368 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2369 pass
2371 @abstractmethod
2372 def _get_in_use_connections(
2373 self,
2374 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2375 pass
2377 def maint_notifications_enabled(self):
2378 """
2379 Returns:
2380 True if the maintenance notifications are enabled, False otherwise.
2381 The maintenance notifications config is stored in the pool handler.
2382 If the pool handler is not set, the maintenance notifications are not enabled.
2383 """
2384 if self._oss_cluster_maint_notifications_handler:
2385 maint_notifications_config = (
2386 self._oss_cluster_maint_notifications_handler.config
2387 )
2388 else:
2389 maint_notifications_config = (
2390 self._maint_notifications_pool_handler.config
2391 if self._maint_notifications_pool_handler
2392 else None
2393 )
2395 return maint_notifications_config and maint_notifications_config.enabled
2397 def update_maint_notifications_config(
2398 self,
2399 maint_notifications_config: MaintNotificationsConfig,
2400 oss_cluster_maint_notifications_handler: Optional[
2401 OSSMaintNotificationsHandler
2402 ] = None,
2403 ):
2404 """
2405 Updates the maintenance notifications configuration.
2406 This method should be called only if the pool was created
2407 without enabling the maintenance notifications and
2408 in a later point in time maintenance notifications
2409 are requested to be enabled.
2410 """
2411 if (
2412 self.maint_notifications_enabled()
2413 and not maint_notifications_config.enabled
2414 ):
2415 raise ValueError(
2416 "Cannot disable maintenance notifications after enabling them"
2417 )
2418 if oss_cluster_maint_notifications_handler:
2419 self._oss_cluster_maint_notifications_handler = (
2420 oss_cluster_maint_notifications_handler
2421 )
2422 else:
2423 # first update pool settings
2424 if not self._maint_notifications_pool_handler:
2425 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2426 self, maint_notifications_config
2427 )
2428 else:
2429 self._maint_notifications_pool_handler.config = (
2430 maint_notifications_config
2431 )
2433 # then update connection kwargs and existing connections
2434 self._update_connection_kwargs_for_maint_notifications(
2435 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2436 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2437 )
2438 self._update_maint_notifications_configs_for_connections(
2439 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2440 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2441 )
2443 def _update_connection_kwargs_for_maint_notifications(
2444 self,
2445 maint_notifications_pool_handler: Optional[
2446 MaintNotificationsPoolHandler
2447 ] = None,
2448 oss_cluster_maint_notifications_handler: Optional[
2449 OSSMaintNotificationsHandler
2450 ] = None,
2451 ):
2452 """
2453 Update the connection kwargs for all future connections.
2454 """
2455 if not self.maint_notifications_enabled():
2456 return
2457 if maint_notifications_pool_handler:
2458 self.connection_kwargs.update(
2459 {
2460 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2461 "maint_notifications_config": maint_notifications_pool_handler.config,
2462 }
2463 )
2464 if oss_cluster_maint_notifications_handler:
2465 self.connection_kwargs.update(
2466 {
2467 "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler,
2468 "maint_notifications_config": oss_cluster_maint_notifications_handler.config,
2469 }
2470 )
2472 # Store original connection parameters for maintenance notifications.
2473 if self.connection_kwargs.get("orig_host_address", None) is None:
2474 # If orig_host_address is None it means we haven't
2475 # configured the original values yet
2476 self.connection_kwargs.update(
2477 {
2478 "orig_host_address": self.connection_kwargs.get("host"),
2479 "orig_socket_timeout": self.connection_kwargs.get(
2480 "socket_timeout", None
2481 ),
2482 "orig_socket_connect_timeout": self.connection_kwargs.get(
2483 "socket_connect_timeout", None
2484 ),
2485 }
2486 )
2488 def _update_maint_notifications_configs_for_connections(
2489 self,
2490 maint_notifications_pool_handler: Optional[
2491 MaintNotificationsPoolHandler
2492 ] = None,
2493 oss_cluster_maint_notifications_handler: Optional[
2494 OSSMaintNotificationsHandler
2495 ] = None,
2496 ):
2497 """Update the maintenance notifications config for all connections in the pool."""
2498 with self._get_pool_lock():
2499 for conn in self._get_free_connections():
2500 if oss_cluster_maint_notifications_handler:
2501 # set cluster handler for conn
2502 conn.set_maint_notifications_cluster_handler_for_connection(
2503 oss_cluster_maint_notifications_handler
2504 )
2505 conn.maint_notifications_config = (
2506 oss_cluster_maint_notifications_handler.config
2507 )
2508 elif maint_notifications_pool_handler:
2509 conn.set_maint_notifications_pool_handler_for_connection(
2510 maint_notifications_pool_handler
2511 )
2512 conn.maint_notifications_config = (
2513 maint_notifications_pool_handler.config
2514 )
2515 else:
2516 raise ValueError(
2517 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2518 )
2519 conn.disconnect()
2520 for conn in self._get_in_use_connections():
2521 if oss_cluster_maint_notifications_handler:
2522 conn.maint_notifications_config = (
2523 oss_cluster_maint_notifications_handler.config
2524 )
2525 conn._configure_maintenance_notifications(
2526 oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler
2527 )
2528 elif maint_notifications_pool_handler:
2529 conn.set_maint_notifications_pool_handler_for_connection(
2530 maint_notifications_pool_handler
2531 )
2532 conn.maint_notifications_config = (
2533 maint_notifications_pool_handler.config
2534 )
2535 else:
2536 raise ValueError(
2537 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2538 )
2539 conn.mark_for_reconnect()
2541 def _should_update_connection(
2542 self,
2543 conn: "MaintNotificationsAbstractConnection",
2544 matching_pattern: Literal[
2545 "connected_address", "configured_address", "notification_hash"
2546 ] = "connected_address",
2547 matching_address: Optional[str] = None,
2548 matching_notification_hash: Optional[int] = None,
2549 ) -> bool:
2550 """
2551 Check if the connection should be updated based on the matching criteria.
2552 """
2553 if matching_pattern == "connected_address":
2554 if matching_address and conn.getpeername() != matching_address:
2555 return False
2556 elif matching_pattern == "configured_address":
2557 if matching_address and conn.host != matching_address:
2558 return False
2559 elif matching_pattern == "notification_hash":
2560 if (
2561 matching_notification_hash
2562 and conn.maintenance_notification_hash != matching_notification_hash
2563 ):
2564 return False
2565 return True
2567 def update_connection_settings(
2568 self,
2569 conn: "MaintNotificationsAbstractConnection",
2570 state: Optional["MaintenanceState"] = None,
2571 maintenance_notification_hash: Optional[int] = None,
2572 host_address: Optional[str] = None,
2573 relaxed_timeout: Optional[float] = None,
2574 update_notification_hash: bool = False,
2575 reset_host_address: bool = False,
2576 reset_relaxed_timeout: bool = False,
2577 ):
2578 """
2579 Update the settings for a single connection.
2580 """
2581 if state:
2582 conn.maintenance_state = state
2584 if update_notification_hash:
2585 # update the notification hash only if requested
2586 conn.maintenance_notification_hash = maintenance_notification_hash
2588 if host_address is not None:
2589 conn.set_tmp_settings(tmp_host_address=host_address)
2591 if relaxed_timeout is not None:
2592 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2594 if reset_relaxed_timeout or reset_host_address:
2595 conn.reset_tmp_settings(
2596 reset_host_address=reset_host_address,
2597 reset_relaxed_timeout=reset_relaxed_timeout,
2598 )
2600 conn.update_current_socket_timeout(relaxed_timeout)
2602 def update_connections_settings(
2603 self,
2604 state: Optional["MaintenanceState"] = None,
2605 maintenance_notification_hash: Optional[int] = None,
2606 host_address: Optional[str] = None,
2607 relaxed_timeout: Optional[float] = None,
2608 matching_address: Optional[str] = None,
2609 matching_notification_hash: Optional[int] = None,
2610 matching_pattern: Literal[
2611 "connected_address", "configured_address", "notification_hash"
2612 ] = "connected_address",
2613 update_notification_hash: bool = False,
2614 reset_host_address: bool = False,
2615 reset_relaxed_timeout: bool = False,
2616 include_free_connections: bool = True,
2617 ):
2618 """
2619 Update the settings for all matching connections in the pool.
2621 This method does not create new connections.
2622 This method does not affect the connection kwargs.
2624 :param state: The maintenance state to set for the connection.
2625 :param maintenance_notification_hash: The hash of the maintenance notification
2626 to set for the connection.
2627 :param host_address: The host address to set for the connection.
2628 :param relaxed_timeout: The relaxed timeout to set for the connection.
2629 :param matching_address: The address to match for the connection.
2630 :param matching_notification_hash: The notification hash to match for the connection.
2631 :param matching_pattern: The pattern to match for the connection.
2632 :param update_notification_hash: Whether to update the notification hash for the connection.
2633 :param reset_host_address: Whether to reset the host address to the original address.
2634 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2635 :param include_free_connections: Whether to include free/available connections.
2636 """
2637 with self._get_pool_lock():
2638 for conn in self._get_in_use_connections():
2639 if self._should_update_connection(
2640 conn,
2641 matching_pattern,
2642 matching_address,
2643 matching_notification_hash,
2644 ):
2645 self.update_connection_settings(
2646 conn,
2647 state=state,
2648 maintenance_notification_hash=maintenance_notification_hash,
2649 host_address=host_address,
2650 relaxed_timeout=relaxed_timeout,
2651 update_notification_hash=update_notification_hash,
2652 reset_host_address=reset_host_address,
2653 reset_relaxed_timeout=reset_relaxed_timeout,
2654 )
2656 if include_free_connections:
2657 for conn in self._get_free_connections():
2658 if self._should_update_connection(
2659 conn,
2660 matching_pattern,
2661 matching_address,
2662 matching_notification_hash,
2663 ):
2664 self.update_connection_settings(
2665 conn,
2666 state=state,
2667 maintenance_notification_hash=maintenance_notification_hash,
2668 host_address=host_address,
2669 relaxed_timeout=relaxed_timeout,
2670 update_notification_hash=update_notification_hash,
2671 reset_host_address=reset_host_address,
2672 reset_relaxed_timeout=reset_relaxed_timeout,
2673 )
2675 def update_connection_kwargs(
2676 self,
2677 **kwargs,
2678 ):
2679 """
2680 Update the connection kwargs for all future connections.
2682 This method updates the connection kwargs for all future connections created by the pool.
2683 Existing connections are not affected.
2684 """
2685 self.connection_kwargs.update(kwargs)
2687 def update_active_connections_for_reconnect(
2688 self,
2689 moving_address_src: Optional[str] = None,
2690 ):
2691 """
2692 Mark all active connections for reconnect.
2693 This is used when a cluster node is migrated to a different address.
2695 :param moving_address_src: The address of the node that is being moved.
2696 """
2697 with self._get_pool_lock():
2698 for conn in self._get_in_use_connections():
2699 if self._should_update_connection(
2700 conn, "connected_address", moving_address_src
2701 ):
2702 conn.mark_for_reconnect()
2704 def disconnect_free_connections(
2705 self,
2706 moving_address_src: Optional[str] = None,
2707 ):
2708 """
2709 Disconnect all free/available connections.
2710 This is used when a cluster node is migrated to a different address.
2712 :param moving_address_src: The address of the node that is being moved.
2713 """
2714 with self._get_pool_lock():
2715 for conn in self._get_free_connections():
2716 if self._should_update_connection(
2717 conn, "connected_address", moving_address_src
2718 ):
2719 conn.disconnect()
2722class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2723 """
2724 Create a connection pool. ``If max_connections`` is set, then this
2725 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2726 limit is reached.
2728 By default, TCP connections are created unless ``connection_class``
2729 is specified. Use class:`.UnixDomainSocketConnection` for
2730 unix sockets.
2731 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2733 If ``maint_notifications_config`` is provided, the connection pool will support
2734 maintenance notifications.
2735 Maintenance notifications are supported only with RESP3.
2736 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2737 the maintenance notifications will be enabled by default.
2739 Any additional keyword arguments are passed to the constructor of
2740 ``connection_class``.
2741 """
2743 @classmethod
2744 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2745 """
2746 Return a connection pool configured from the given URL.
2748 For example::
2750 redis://[[username]:[password]]@localhost:6379/0
2751 rediss://[[username]:[password]]@localhost:6379/0
2752 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2754 Three URL schemes are supported:
2756 - `redis://` creates a TCP socket connection. See more at:
2757 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2758 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2759 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2760 - ``unix://``: creates a Unix Domain Socket connection.
2762 The username, password, hostname, path and all querystring values
2763 are passed through urllib.parse.unquote in order to replace any
2764 percent-encoded values with their corresponding characters.
2766 There are several ways to specify a database number. The first value
2767 found will be used:
2769 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2770 2. If using the redis:// or rediss:// schemes, the path argument
2771 of the url, e.g. redis://localhost/0
2772 3. A ``db`` keyword argument to this function.
2774 If none of these options are specified, the default db=0 is used.
2776 All querystring options are cast to their appropriate Python types.
2777 Boolean arguments can be specified with string values "True"/"False"
2778 or "Yes"/"No". Values that cannot be properly cast cause a
2779 ``ValueError`` to be raised. Once parsed, the querystring arguments
2780 and keyword arguments are passed to the ``ConnectionPool``'s
2781 class initializer. In the case of conflicting arguments, querystring
2782 arguments always win.
2783 """
2784 url_options = parse_url(url)
2786 if "connection_class" in kwargs:
2787 url_options["connection_class"] = kwargs["connection_class"]
2789 kwargs.update(url_options)
2790 return cls(**kwargs)
2792 def __init__(
2793 self,
2794 connection_class=Connection,
2795 max_connections: Optional[int] = None,
2796 cache_factory: Optional[CacheFactoryInterface] = None,
2797 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2798 **connection_kwargs,
2799 ):
2800 max_connections = max_connections or 2**31
2801 if not isinstance(max_connections, int) or max_connections < 0:
2802 raise ValueError('"max_connections" must be a positive integer')
2804 self.connection_class = connection_class
2805 self._connection_kwargs = connection_kwargs
2806 self.max_connections = max_connections
2807 self.cache = None
2808 self._cache_factory = cache_factory
2810 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2811 if self._event_dispatcher is None:
2812 self._event_dispatcher = EventDispatcher()
2814 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2815 if not check_protocol_version(self._connection_kwargs.get("protocol"), 3):
2816 raise RedisError("Client caching is only supported with RESP version 3")
2818 cache = self._connection_kwargs.get("cache")
2820 if cache is not None:
2821 if not isinstance(cache, CacheInterface):
2822 raise ValueError("Cache must implement CacheInterface")
2824 self.cache = cache
2825 else:
2826 if self._cache_factory is not None:
2827 self.cache = CacheProxy(self._cache_factory.get_cache())
2828 else:
2829 self.cache = CacheFactory(
2830 self._connection_kwargs.get("cache_config")
2831 ).get_cache()
2833 init_csc_items()
2834 register_csc_items_callback(
2835 callback=lambda: self.cache.size,
2836 pool_name=get_pool_name(self),
2837 )
2839 connection_kwargs.pop("cache", None)
2840 connection_kwargs.pop("cache_config", None)
2842 # a lock to protect the critical section in _checkpid().
2843 # this lock is acquired when the process id changes, such as
2844 # after a fork. during this time, multiple threads in the child
2845 # process could attempt to acquire this lock. the first thread
2846 # to acquire the lock will reset the data structures and lock
2847 # object of this pool. subsequent threads acquiring this lock
2848 # will notice the first thread already did the work and simply
2849 # release the lock.
2851 self._fork_lock = threading.RLock()
2852 self._lock = threading.RLock()
2854 # Generate unique pool ID for observability (matches go-redis behavior)
2855 import secrets
2857 self._pool_id = secrets.token_hex(4)
2859 MaintNotificationsAbstractConnectionPool.__init__(
2860 self,
2861 maint_notifications_config=maint_notifications_config,
2862 **connection_kwargs,
2863 )
2865 self.reset()
2867 def __repr__(self) -> str:
2868 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
2869 return (
2870 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2871 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2872 f"({conn_kwargs})>)>"
2873 )
2875 @property
2876 def connection_kwargs(self) -> Dict[str, Any]:
2877 return self._connection_kwargs
2879 @connection_kwargs.setter
2880 def connection_kwargs(self, value: Dict[str, Any]):
2881 self._connection_kwargs = value
2883 def get_protocol(self):
2884 """
2885 Returns:
2886 The RESP protocol version, or ``None`` if the protocol is not specified,
2887 in which case the server default will be used.
2888 """
2889 return self.connection_kwargs.get("protocol", None)
2891 def reset(self) -> None:
2892 self._created_connections = 0
2893 self._available_connections = []
2894 self._in_use_connections = set()
2896 # this must be the last operation in this method. while reset() is
2897 # called when holding _fork_lock, other threads in this process
2898 # can call _checkpid() which compares self.pid and os.getpid() without
2899 # holding any lock (for performance reasons). keeping this assignment
2900 # as the last operation ensures that those other threads will also
2901 # notice a pid difference and block waiting for the first thread to
2902 # release _fork_lock. when each of these threads eventually acquire
2903 # _fork_lock, they will notice that another thread already called
2904 # reset() and they will immediately release _fork_lock and continue on.
2905 self.pid = os.getpid()
2907 def _checkpid(self) -> None:
2908 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
2909 # systems. this is called by all ConnectionPool methods that
2910 # manipulate the pool's state such as get_connection() and release().
2911 #
2912 # _checkpid() determines whether the process has forked by comparing
2913 # the current process id to the process id saved on the ConnectionPool
2914 # instance. if these values are the same, _checkpid() simply returns.
2915 #
2916 # when the process ids differ, _checkpid() assumes that the process
2917 # has forked and that we're now running in the child process. the child
2918 # process cannot use the parent's file descriptors (e.g., sockets).
2919 # therefore, when _checkpid() sees the process id change, it calls
2920 # reset() in order to reinitialize the child's ConnectionPool. this
2921 # will cause the child to make all new connection objects.
2922 #
2923 # _checkpid() is protected by self._fork_lock to ensure that multiple
2924 # threads in the child process do not call reset() multiple times.
2925 #
2926 # there is an extremely small chance this could fail in the following
2927 # scenario:
2928 # 1. process A calls _checkpid() for the first time and acquires
2929 # self._fork_lock.
2930 # 2. while holding self._fork_lock, process A forks (the fork()
2931 # could happen in a different thread owned by process A)
2932 # 3. process B (the forked child process) inherits the
2933 # ConnectionPool's state from the parent. that state includes
2934 # a locked _fork_lock. process B will not be notified when
2935 # process A releases the _fork_lock and will thus never be
2936 # able to acquire the _fork_lock.
2937 #
2938 # to mitigate this possible deadlock, _checkpid() will only wait 5
2939 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
2940 # that time it is assumed that the child is deadlocked and a
2941 # redis.ChildDeadlockedError error is raised.
2942 if self.pid != os.getpid():
2943 acquired = self._fork_lock.acquire(timeout=5)
2944 if not acquired:
2945 raise ChildDeadlockedError
2946 # reset() the instance for the new process if another thread
2947 # hasn't already done so
2948 try:
2949 if self.pid != os.getpid():
2950 self.reset()
2951 finally:
2952 self._fork_lock.release()
2954 @deprecated_args(
2955 args_to_warn=["*"],
2956 reason="Use get_connection() without args instead",
2957 version="5.3.0",
2958 )
2959 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
2960 "Get a connection from the pool"
2962 # Start timing for observability
2963 self._checkpid()
2964 is_created = False
2966 with self._lock:
2967 try:
2968 connection = self._available_connections.pop()
2969 except IndexError:
2970 # Start timing for observability
2971 start_time_created = time.monotonic()
2973 connection = self.make_connection()
2974 is_created = True
2975 self._in_use_connections.add(connection)
2977 try:
2978 # ensure this connection is connected to Redis
2979 connection.connect()
2980 # connections that the pool provides should be ready to send
2981 # a command. if not, the connection was either returned to the
2982 # pool before all data has been read or the socket has been
2983 # closed. either way, reconnect and verify everything is good.
2984 try:
2985 if (
2986 connection.can_read()
2987 and self.cache is None
2988 and not self.maint_notifications_enabled()
2989 ):
2990 raise ConnectionError("Connection has data")
2991 except (ConnectionError, TimeoutError, OSError):
2992 connection.disconnect()
2993 connection.connect()
2994 if connection.can_read():
2995 raise ConnectionError("Connection not ready")
2996 except BaseException:
2997 # release the connection back to the pool so that we don't
2998 # leak it
2999 self.release(connection)
3000 raise
3002 if is_created:
3003 record_connection_create_time(
3004 connection_pool=self,
3005 duration_seconds=time.monotonic() - start_time_created,
3006 )
3007 return connection
3009 def get_encoder(self) -> Encoder:
3010 "Return an encoder based on encoding settings"
3011 kwargs = self.connection_kwargs
3012 return Encoder(
3013 encoding=kwargs.get("encoding", "utf-8"),
3014 encoding_errors=kwargs.get("encoding_errors", "strict"),
3015 decode_responses=kwargs.get("decode_responses", False),
3016 )
3018 def make_connection(self) -> "ConnectionInterface":
3019 "Create a new connection"
3020 if self._created_connections >= self.max_connections:
3021 raise MaxConnectionsError("Too many connections")
3022 self._created_connections += 1
3024 kwargs = dict(self.connection_kwargs)
3026 if self.cache is not None:
3027 return CacheProxyConnection(
3028 self.connection_class(**kwargs), self.cache, self._lock
3029 )
3030 return self.connection_class(**kwargs)
3032 def release(self, connection: "Connection") -> None:
3033 "Releases the connection back to the pool"
3034 self._checkpid()
3035 with self._lock:
3036 try:
3037 self._in_use_connections.remove(connection)
3038 except KeyError:
3039 # Gracefully fail when a connection is returned to this pool
3040 # that the pool doesn't actually own
3041 return
3043 if self.owns_connection(connection):
3044 if connection.should_reconnect():
3045 connection.disconnect()
3046 self._available_connections.append(connection)
3047 self._event_dispatcher.dispatch(
3048 AfterConnectionReleasedEvent(connection)
3049 )
3050 else:
3051 # Pool doesn't own this connection, do not add it back
3052 # to the pool.
3053 # The created connections count should not be changed,
3054 # because the connection was not created by the pool.
3055 connection.disconnect()
3056 return
3058 def owns_connection(self, connection: "Connection") -> int:
3059 return connection.pid == self.pid
3061 def disconnect(self, inuse_connections: bool = True) -> None:
3062 """
3063 Disconnects connections in the pool
3065 If ``inuse_connections`` is True, disconnect connections that are
3066 currently in use, potentially by other threads. Otherwise only disconnect
3067 connections that are idle in the pool.
3068 """
3069 self._checkpid()
3070 with self._lock:
3071 if inuse_connections:
3072 connections = chain(
3073 self._available_connections, self._in_use_connections
3074 )
3075 else:
3076 connections = self._available_connections
3078 for connection in connections:
3079 connection.disconnect()
3081 def close(self) -> None:
3082 """Close the pool, disconnecting all connections"""
3083 self.disconnect()
3085 def set_retry(self, retry: Retry) -> None:
3086 self.connection_kwargs.update({"retry": retry})
3087 for conn in self._available_connections:
3088 conn.retry = retry
3089 for conn in self._in_use_connections:
3090 conn.retry = retry
3092 def re_auth_callback(self, token: TokenInterface):
3093 with self._lock:
3094 for conn in self._available_connections:
3095 conn.retry.call_with_retry(
3096 lambda: conn.send_command(
3097 "AUTH", token.try_get("oid"), token.get_value()
3098 ),
3099 lambda error: self._mock(error),
3100 )
3101 conn.retry.call_with_retry(
3102 lambda: conn.read_response(), lambda error: self._mock(error)
3103 )
3104 for conn in self._in_use_connections:
3105 conn.set_re_auth_token(token)
3107 def _get_pool_lock(self):
3108 return self._lock
3110 def _get_free_connections(self):
3111 with self._lock:
3112 return list(self._available_connections)
3114 def _get_in_use_connections(self):
3115 with self._lock:
3116 return set(self._in_use_connections)
3118 def _mock(self, error: RedisError):
3119 """
3120 Dummy functions, needs to be passed as error callback to retry object.
3121 :param error:
3122 :return:
3123 """
3124 pass
3126 def get_connection_count(self) -> List[tuple[int, dict]]:
3127 from redis.observability.attributes import get_pool_name
3129 attributes = AttributeBuilder.build_base_attributes()
3130 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
3131 free_connections_attributes = attributes.copy()
3132 in_use_connections_attributes = attributes.copy()
3134 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3135 ConnectionState.IDLE.value
3136 )
3137 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
3138 ConnectionState.USED.value
3139 )
3141 return [
3142 (len(self._get_free_connections()), free_connections_attributes),
3143 (len(self._get_in_use_connections()), in_use_connections_attributes),
3144 ]
3147class BlockingConnectionPool(ConnectionPool):
3148 """
3149 Thread-safe blocking connection pool::
3151 >>> from redis.client import Redis
3152 >>> client = Redis(connection_pool=BlockingConnectionPool())
3154 It performs the same function as the default
3155 :py:class:`~redis.ConnectionPool` implementation, in that,
3156 it maintains a pool of reusable connections that can be shared by
3157 multiple redis clients (safely across threads if required).
3159 The difference is that, in the event that a client tries to get a
3160 connection from the pool when all of connections are in use, rather than
3161 raising a :py:class:`~redis.ConnectionError` (as the default
3162 :py:class:`~redis.ConnectionPool` implementation does), it
3163 makes the client wait ("blocks") for a specified number of seconds until
3164 a connection becomes available.
3166 Use ``max_connections`` to increase / decrease the pool size::
3168 >>> pool = BlockingConnectionPool(max_connections=10)
3170 Use ``timeout`` to tell it either how many seconds to wait for a connection
3171 to become available, or to block forever:
3173 >>> # Block forever.
3174 >>> pool = BlockingConnectionPool(timeout=None)
3176 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
3177 >>> # not available.
3178 >>> pool = BlockingConnectionPool(timeout=5)
3179 """
3181 def __init__(
3182 self,
3183 max_connections=50,
3184 timeout=20,
3185 connection_class=Connection,
3186 queue_class=LifoQueue,
3187 **connection_kwargs,
3188 ):
3189 self.queue_class = queue_class
3190 self.timeout = timeout
3191 self._in_maintenance = False
3192 self._locked = False
3193 super().__init__(
3194 connection_class=connection_class,
3195 max_connections=max_connections,
3196 **connection_kwargs,
3197 )
3199 def reset(self):
3200 # Create and fill up a thread safe queue with ``None`` values.
3201 try:
3202 if self._in_maintenance:
3203 self._lock.acquire()
3204 self._locked = True
3205 self.pool = self.queue_class(self.max_connections)
3206 while True:
3207 try:
3208 self.pool.put_nowait(None)
3209 except Full:
3210 break
3212 # Keep a list of actual connection instances so that we can
3213 # disconnect them later.
3214 self._connections = []
3215 finally:
3216 if self._locked:
3217 try:
3218 self._lock.release()
3219 except Exception:
3220 pass
3221 self._locked = False
3223 # this must be the last operation in this method. while reset() is
3224 # called when holding _fork_lock, other threads in this process
3225 # can call _checkpid() which compares self.pid and os.getpid() without
3226 # holding any lock (for performance reasons). keeping this assignment
3227 # as the last operation ensures that those other threads will also
3228 # notice a pid difference and block waiting for the first thread to
3229 # release _fork_lock. when each of these threads eventually acquire
3230 # _fork_lock, they will notice that another thread already called
3231 # reset() and they will immediately release _fork_lock and continue on.
3232 self.pid = os.getpid()
3234 def make_connection(self):
3235 "Make a fresh connection."
3236 try:
3237 if self._in_maintenance:
3238 self._lock.acquire()
3239 self._locked = True
3241 if self.cache is not None:
3242 connection = CacheProxyConnection(
3243 self.connection_class(**self.connection_kwargs),
3244 self.cache,
3245 self._lock,
3246 )
3247 else:
3248 connection = self.connection_class(**self.connection_kwargs)
3249 self._connections.append(connection)
3250 return connection
3251 finally:
3252 if self._locked:
3253 try:
3254 self._lock.release()
3255 except Exception:
3256 pass
3257 self._locked = False
3259 @deprecated_args(
3260 args_to_warn=["*"],
3261 reason="Use get_connection() without args instead",
3262 version="5.3.0",
3263 )
3264 def get_connection(self, command_name=None, *keys, **options):
3265 """
3266 Get a connection, blocking for ``self.timeout`` until a connection
3267 is available from the pool.
3269 If the connection returned is ``None`` then creates a new connection.
3270 Because we use a last-in first-out queue, the existing connections
3271 (having been returned to the pool after the initial ``None`` values
3272 were added) will be returned before ``None`` values. This means we only
3273 create new connections when we need to, i.e.: the actual number of
3274 connections will only increase in response to demand.
3275 """
3276 start_time_acquired = time.monotonic()
3277 # Make sure we haven't changed process.
3278 self._checkpid()
3279 is_created = False
3281 # Try and get a connection from the pool. If one isn't available within
3282 # self.timeout then raise a ``ConnectionError``.
3283 connection = None
3284 try:
3285 if self._in_maintenance:
3286 self._lock.acquire()
3287 self._locked = True
3288 try:
3289 connection = self.pool.get(block=True, timeout=self.timeout)
3290 except Empty:
3291 # Note that this is not caught by the redis client and will be
3292 # raised unless handled by application code. If you want never to
3293 raise ConnectionError("No connection available.")
3295 # If the ``connection`` is actually ``None`` then that's a cue to make
3296 # a new connection to add to the pool.
3297 if connection is None:
3298 # Start timing for observability
3299 start_time_created = time.monotonic()
3300 connection = self.make_connection()
3301 is_created = True
3302 finally:
3303 if self._locked:
3304 try:
3305 self._lock.release()
3306 except Exception:
3307 pass
3308 self._locked = False
3310 try:
3311 # ensure this connection is connected to Redis
3312 connection.connect()
3313 # connections that the pool provides should be ready to send
3314 # a command. if not, the connection was either returned to the
3315 # pool before all data has been read or the socket has been
3316 # closed. either way, reconnect and verify everything is good.
3317 try:
3318 if connection.can_read():
3319 raise ConnectionError("Connection has data")
3320 except (ConnectionError, TimeoutError, OSError):
3321 connection.disconnect()
3322 connection.connect()
3323 if connection.can_read():
3324 raise ConnectionError("Connection not ready")
3325 except BaseException:
3326 # release the connection back to the pool so that we don't leak it
3327 self.release(connection)
3328 raise
3330 if is_created:
3331 record_connection_create_time(
3332 connection_pool=self,
3333 duration_seconds=time.monotonic() - start_time_created,
3334 )
3336 record_connection_wait_time(
3337 pool_name=get_pool_name(self),
3338 duration_seconds=time.monotonic() - start_time_acquired,
3339 )
3341 return connection
3343 def release(self, connection):
3344 "Releases the connection back to the pool."
3345 # Make sure we haven't changed process.
3346 self._checkpid()
3348 try:
3349 if self._in_maintenance:
3350 self._lock.acquire()
3351 self._locked = True
3352 if not self.owns_connection(connection):
3353 # pool doesn't own this connection. do not add it back
3354 # to the pool. instead add a None value which is a placeholder
3355 # that will cause the pool to recreate the connection if
3356 # its needed.
3357 connection.disconnect()
3358 self.pool.put_nowait(None)
3359 return
3360 if connection.should_reconnect():
3361 connection.disconnect()
3362 # Put the connection back into the pool.
3363 try:
3364 self.pool.put_nowait(connection)
3365 except Full:
3366 # perhaps the pool has been reset() after a fork? regardless,
3367 # we don't want this connection
3368 pass
3369 finally:
3370 if self._locked:
3371 try:
3372 self._lock.release()
3373 except Exception:
3374 pass
3375 self._locked = False
3377 def disconnect(self, inuse_connections: bool = True):
3378 "Disconnects either all connections in the pool or just the free connections."
3379 self._checkpid()
3380 try:
3381 if self._in_maintenance:
3382 self._lock.acquire()
3383 self._locked = True
3384 if inuse_connections:
3385 connections = self._connections
3386 else:
3387 connections = self._get_free_connections()
3388 for connection in connections:
3389 connection.disconnect()
3390 finally:
3391 if self._locked:
3392 try:
3393 self._lock.release()
3394 except Exception:
3395 pass
3396 self._locked = False
3398 def _get_free_connections(self):
3399 with self._lock:
3400 return {conn for conn in self.pool.queue if conn}
3402 def _get_in_use_connections(self):
3403 with self._lock:
3404 # free connections
3405 connections_in_queue = {conn for conn in self.pool.queue if conn}
3406 # in self._connections we keep all created connections
3407 # so the ones that are not in the queue are the in use ones
3408 return {
3409 conn for conn in self._connections if conn not in connections_in_queue
3410 }
3412 def set_in_maintenance(self, in_maintenance: bool):
3413 """
3414 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
3416 This is used to prevent new connections from being created while we are in maintenance mode.
3417 The pool will be in maintenance mode only when we are processing a MOVING notification.
3418 """
3419 self._in_maintenance = in_maintenance