Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32)
34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
35from .auth.token import TokenInterface
36from .backoff import NoBackoff
37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
38from .driver_info import DriverInfo, resolve_driver_info
39from .event import AfterConnectionReleasedEvent, EventDispatcher
40from .exceptions import (
41 AuthenticationError,
42 AuthenticationWrongNumberOfArgsError,
43 ChildDeadlockedError,
44 ConnectionError,
45 DataError,
46 MaxConnectionsError,
47 RedisError,
48 ResponseError,
49 TimeoutError,
50)
51from .maint_notifications import (
52 MaintenanceState,
53 MaintNotificationsConfig,
54 MaintNotificationsConnectionHandler,
55 MaintNotificationsPoolHandler,
56 OSSMaintNotificationsHandler,
57)
58from .retry import Retry
59from .utils import (
60 CRYPTOGRAPHY_AVAILABLE,
61 HIREDIS_AVAILABLE,
62 SSL_AVAILABLE,
63 check_protocol_version,
64 compare_versions,
65 deprecated_args,
66 ensure_string,
67 format_error_message,
68 str_if_bytes,
69)
71if SSL_AVAILABLE:
72 import ssl
73 from ssl import VerifyFlags
74else:
75 ssl = None
76 VerifyFlags = None
78if HIREDIS_AVAILABLE:
79 import hiredis
81SYM_STAR = b"*"
82SYM_DOLLAR = b"$"
83SYM_CRLF = b"\r\n"
84SYM_EMPTY = b""
86DEFAULT_RESP_VERSION = 2
88SENTINEL = object()
90DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
91if HIREDIS_AVAILABLE:
92 DefaultParser = _HiredisParser
93else:
94 DefaultParser = _RESP2Parser
97class HiredisRespSerializer:
98 def pack(self, *args: List):
99 """Pack a series of arguments into the Redis protocol"""
100 output = []
102 if isinstance(args[0], str):
103 args = tuple(args[0].encode().split()) + args[1:]
104 elif b" " in args[0]:
105 args = tuple(args[0].split()) + args[1:]
106 try:
107 output.append(hiredis.pack_command(args))
108 except TypeError:
109 _, value, traceback = sys.exc_info()
110 raise DataError(value).with_traceback(traceback)
112 return output
115class PythonRespSerializer:
116 def __init__(self, buffer_cutoff, encode) -> None:
117 self._buffer_cutoff = buffer_cutoff
118 self.encode = encode
120 def pack(self, *args):
121 """Pack a series of arguments into the Redis protocol"""
122 output = []
123 # the client might have included 1 or more literal arguments in
124 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
125 # arguments to be sent separately, so split the first argument
126 # manually. These arguments should be bytestrings so that they are
127 # not encoded.
128 if isinstance(args[0], str):
129 args = tuple(args[0].encode().split()) + args[1:]
130 elif b" " in args[0]:
131 args = tuple(args[0].split()) + args[1:]
133 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
135 buffer_cutoff = self._buffer_cutoff
136 for arg in map(self.encode, args):
137 # to avoid large string mallocs, chunk the command into the
138 # output list if we're sending large values or memoryviews
139 arg_length = len(arg)
140 if (
141 len(buff) > buffer_cutoff
142 or arg_length > buffer_cutoff
143 or isinstance(arg, memoryview)
144 ):
145 buff = SYM_EMPTY.join(
146 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
147 )
148 output.append(buff)
149 output.append(arg)
150 buff = SYM_CRLF
151 else:
152 buff = SYM_EMPTY.join(
153 (
154 buff,
155 SYM_DOLLAR,
156 str(arg_length).encode(),
157 SYM_CRLF,
158 arg,
159 SYM_CRLF,
160 )
161 )
162 output.append(buff)
163 return output
166class ConnectionInterface:
167 @abstractmethod
168 def repr_pieces(self):
169 pass
171 @abstractmethod
172 def register_connect_callback(self, callback):
173 pass
175 @abstractmethod
176 def deregister_connect_callback(self, callback):
177 pass
179 @abstractmethod
180 def set_parser(self, parser_class):
181 pass
183 @abstractmethod
184 def get_protocol(self):
185 pass
187 @abstractmethod
188 def connect(self):
189 pass
191 @abstractmethod
192 def on_connect(self):
193 pass
195 @abstractmethod
196 def disconnect(self, *args):
197 pass
199 @abstractmethod
200 def check_health(self):
201 pass
203 @abstractmethod
204 def send_packed_command(self, command, check_health=True):
205 pass
207 @abstractmethod
208 def send_command(self, *args, **kwargs):
209 pass
211 @abstractmethod
212 def can_read(self, timeout=0):
213 pass
215 @abstractmethod
216 def read_response(
217 self,
218 disable_decoding=False,
219 *,
220 disconnect_on_error=True,
221 push_request=False,
222 ):
223 pass
225 @abstractmethod
226 def pack_command(self, *args):
227 pass
229 @abstractmethod
230 def pack_commands(self, commands):
231 pass
233 @property
234 @abstractmethod
235 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
236 pass
238 @abstractmethod
239 def set_re_auth_token(self, token: TokenInterface):
240 pass
242 @abstractmethod
243 def re_auth(self):
244 pass
246 @abstractmethod
247 def mark_for_reconnect(self):
248 """
249 Mark the connection to be reconnected on the next command.
250 This is useful when a connection is moved to a different node.
251 """
252 pass
254 @abstractmethod
255 def should_reconnect(self):
256 """
257 Returns True if the connection should be reconnected.
258 """
259 pass
261 @abstractmethod
262 def reset_should_reconnect(self):
263 """
264 Reset the internal flag to False.
265 """
266 pass
269class MaintNotificationsAbstractConnection:
270 """
271 Abstract class for handling maintenance notifications logic.
272 This class is expected to be used as base class together with ConnectionInterface.
274 This class is intended to be used with multiple inheritance!
276 All logic related to maintenance notifications is encapsulated in this class.
277 """
279 def __init__(
280 self,
281 maint_notifications_config: Optional[MaintNotificationsConfig],
282 maint_notifications_pool_handler: Optional[
283 MaintNotificationsPoolHandler
284 ] = None,
285 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
286 maintenance_notification_hash: Optional[int] = None,
287 orig_host_address: Optional[str] = None,
288 orig_socket_timeout: Optional[float] = None,
289 orig_socket_connect_timeout: Optional[float] = None,
290 oss_cluster_maint_notifications_handler: Optional[
291 OSSMaintNotificationsHandler
292 ] = None,
293 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
294 ):
295 """
296 Initialize the maintenance notifications for the connection.
298 Args:
299 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
300 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
301 maintenance_state (MaintenanceState): The current maintenance state of the connection.
302 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
303 orig_host_address (Optional[str]): The original host address of the connection.
304 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
305 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
306 oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications.
307 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
308 If not provided, the parser from the connection is used.
309 This is useful when the parser is created after this object.
310 """
311 self.maint_notifications_config = maint_notifications_config
312 self.maintenance_state = maintenance_state
313 self.maintenance_notification_hash = maintenance_notification_hash
314 self._configure_maintenance_notifications(
315 maint_notifications_pool_handler,
316 orig_host_address,
317 orig_socket_timeout,
318 orig_socket_connect_timeout,
319 oss_cluster_maint_notifications_handler,
320 parser,
321 )
322 self._processed_start_maint_notifications = set()
323 self._skipped_end_maint_notifications = set()
325 @abstractmethod
326 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
327 pass
329 @abstractmethod
330 def _get_socket(self) -> Optional[socket.socket]:
331 pass
333 @abstractmethod
334 def get_protocol(self) -> Union[int, str]:
335 """
336 Returns:
337 The RESP protocol version, or ``None`` if the protocol is not specified,
338 in which case the server default will be used.
339 """
340 pass
342 @property
343 @abstractmethod
344 def host(self) -> str:
345 pass
347 @host.setter
348 @abstractmethod
349 def host(self, value: str):
350 pass
352 @property
353 @abstractmethod
354 def socket_timeout(self) -> Optional[Union[float, int]]:
355 pass
357 @socket_timeout.setter
358 @abstractmethod
359 def socket_timeout(self, value: Optional[Union[float, int]]):
360 pass
362 @property
363 @abstractmethod
364 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
365 pass
367 @socket_connect_timeout.setter
368 @abstractmethod
369 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
370 pass
372 @abstractmethod
373 def send_command(self, *args, **kwargs):
374 pass
376 @abstractmethod
377 def read_response(
378 self,
379 disable_decoding=False,
380 *,
381 disconnect_on_error=True,
382 push_request=False,
383 ):
384 pass
386 @abstractmethod
387 def disconnect(self, *args):
388 pass
390 def _configure_maintenance_notifications(
391 self,
392 maint_notifications_pool_handler: Optional[
393 MaintNotificationsPoolHandler
394 ] = None,
395 orig_host_address=None,
396 orig_socket_timeout=None,
397 orig_socket_connect_timeout=None,
398 oss_cluster_maint_notifications_handler: Optional[
399 OSSMaintNotificationsHandler
400 ] = None,
401 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
402 ):
403 """
404 Enable maintenance notifications by setting up
405 handlers and storing original connection parameters.
407 Should be used ONLY with parsers that support push notifications.
408 """
409 if (
410 not self.maint_notifications_config
411 or not self.maint_notifications_config.enabled
412 ):
413 self._maint_notifications_pool_handler = None
414 self._maint_notifications_connection_handler = None
415 self._oss_cluster_maint_notifications_handler = None
416 return
418 if not parser:
419 raise RedisError(
420 "To configure maintenance notifications, a parser must be provided!"
421 )
423 if not isinstance(parser, _HiredisParser) and not isinstance(
424 parser, _RESP3Parser
425 ):
426 raise RedisError(
427 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
428 )
430 if maint_notifications_pool_handler:
431 # Extract a reference to a new pool handler that copies all properties
432 # of the original one and has a different connection reference
433 # This is needed because when we attach the handler to the parser
434 # we need to make sure that the handler has a reference to the
435 # connection that the parser is attached to.
436 self._maint_notifications_pool_handler = (
437 maint_notifications_pool_handler.get_handler_for_connection()
438 )
439 self._maint_notifications_pool_handler.set_connection(self)
440 else:
441 self._maint_notifications_pool_handler = None
443 self._maint_notifications_connection_handler = (
444 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
445 )
447 if oss_cluster_maint_notifications_handler:
448 self._oss_cluster_maint_notifications_handler = (
449 oss_cluster_maint_notifications_handler
450 )
451 else:
452 self._oss_cluster_maint_notifications_handler = None
454 # Set up OSS cluster handler to parser if available
455 if self._oss_cluster_maint_notifications_handler:
456 parser.set_oss_cluster_maint_push_handler(
457 self._oss_cluster_maint_notifications_handler.handle_notification
458 )
460 # Set up pool handler to parser if available
461 if self._maint_notifications_pool_handler:
462 parser.set_node_moving_push_handler(
463 self._maint_notifications_pool_handler.handle_notification
464 )
466 # Set up connection handler
467 parser.set_maintenance_push_handler(
468 self._maint_notifications_connection_handler.handle_notification
469 )
471 # Store original connection parameters
472 self.orig_host_address = orig_host_address if orig_host_address else self.host
473 self.orig_socket_timeout = (
474 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
475 )
476 self.orig_socket_connect_timeout = (
477 orig_socket_connect_timeout
478 if orig_socket_connect_timeout
479 else self.socket_connect_timeout
480 )
482 def set_maint_notifications_pool_handler_for_connection(
483 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
484 ):
485 # Deep copy the pool handler to avoid sharing the same pool handler
486 # between multiple connections, because otherwise each connection will override
487 # the connection reference and the pool handler will only hold a reference
488 # to the last connection that was set.
489 maint_notifications_pool_handler_copy = (
490 maint_notifications_pool_handler.get_handler_for_connection()
491 )
493 maint_notifications_pool_handler_copy.set_connection(self)
494 self._get_parser().set_node_moving_push_handler(
495 maint_notifications_pool_handler_copy.handle_notification
496 )
498 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
500 # Update maintenance notification connection handler if it doesn't exist
501 if not self._maint_notifications_connection_handler:
502 self._maint_notifications_connection_handler = (
503 MaintNotificationsConnectionHandler(
504 self, maint_notifications_pool_handler.config
505 )
506 )
507 self._get_parser().set_maintenance_push_handler(
508 self._maint_notifications_connection_handler.handle_notification
509 )
510 else:
511 self._maint_notifications_connection_handler.config = (
512 maint_notifications_pool_handler.config
513 )
515 def set_maint_notifications_cluster_handler_for_connection(
516 self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
517 ):
518 self._get_parser().set_oss_cluster_maint_push_handler(
519 oss_cluster_maint_notifications_handler.handle_notification
520 )
522 self._oss_cluster_maint_notifications_handler = (
523 oss_cluster_maint_notifications_handler
524 )
526 # Update maintenance notification connection handler if it doesn't exist
527 if not self._maint_notifications_connection_handler:
528 self._maint_notifications_connection_handler = (
529 MaintNotificationsConnectionHandler(
530 self, oss_cluster_maint_notifications_handler.config
531 )
532 )
533 self._get_parser().set_maintenance_push_handler(
534 self._maint_notifications_connection_handler.handle_notification
535 )
536 else:
537 self._maint_notifications_connection_handler.config = (
538 oss_cluster_maint_notifications_handler.config
539 )
541 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
542 # Send maintenance notifications handshake if RESP3 is active
543 # and maintenance notifications are enabled
544 # and we have a host to determine the endpoint type from
545 # When the maint_notifications_config enabled mode is "auto",
546 # we just log a warning if the handshake fails
547 # When the mode is enabled=True, we raise an exception in case of failure
548 if (
549 self.get_protocol() not in [2, "2"]
550 and self.maint_notifications_config
551 and self.maint_notifications_config.enabled
552 and self._maint_notifications_connection_handler
553 and hasattr(self, "host")
554 ):
555 self._enable_maintenance_notifications(
556 maint_notifications_config=self.maint_notifications_config,
557 check_health=check_health,
558 )
560 def _enable_maintenance_notifications(
561 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
562 ):
563 try:
564 host = getattr(self, "host", None)
565 if host is None:
566 raise ValueError(
567 "Cannot enable maintenance notifications for connection"
568 " object that doesn't have a host attribute."
569 )
570 else:
571 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
572 self.send_command(
573 "CLIENT",
574 "MAINT_NOTIFICATIONS",
575 "ON",
576 "moving-endpoint-type",
577 endpoint_type.value,
578 check_health=check_health,
579 )
580 response = self.read_response()
581 if not response or str_if_bytes(response) != "OK":
582 raise ResponseError(
583 "The server doesn't support maintenance notifications"
584 )
585 except Exception as e:
586 if (
587 isinstance(e, ResponseError)
588 and maint_notifications_config.enabled == "auto"
589 ):
590 # Log warning but don't fail the connection
591 import logging
593 logger = logging.getLogger(__name__)
594 logger.debug(f"Failed to enable maintenance notifications: {e}")
595 else:
596 raise
598 def get_resolved_ip(self) -> Optional[str]:
599 """
600 Extract the resolved IP address from an
601 established connection or resolve it from the host.
603 First tries to get the actual IP from the socket (most accurate),
604 then falls back to DNS resolution if needed.
606 Args:
607 connection: The connection object to extract the IP from
609 Returns:
610 str: The resolved IP address, or None if it cannot be determined
611 """
613 # Method 1: Try to get the actual IP from the established socket connection
614 # This is most accurate as it shows the exact IP being used
615 try:
616 conn_socket = self._get_socket()
617 if conn_socket is not None:
618 peer_addr = conn_socket.getpeername()
619 if peer_addr and len(peer_addr) >= 1:
620 # For TCP sockets, peer_addr is typically (host, port) tuple
621 # Return just the host part
622 return peer_addr[0]
623 except (AttributeError, OSError):
624 # Socket might not be connected or getpeername() might fail
625 pass
627 # Method 2: Fallback to DNS resolution of the host
628 # This is less accurate but works when socket is not available
629 try:
630 host = getattr(self, "host", "localhost")
631 port = getattr(self, "port", 6379)
632 if host:
633 # Use getaddrinfo to resolve the hostname to IP
634 # This mimics what the connection would do during _connect()
635 addr_info = socket.getaddrinfo(
636 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
637 )
638 if addr_info:
639 # Return the IP from the first result
640 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
641 # sockaddr[0] is the IP address
642 return str(addr_info[0][4][0])
643 except (AttributeError, OSError, socket.gaierror):
644 # DNS resolution might fail
645 pass
647 return None
649 @property
650 def maintenance_state(self) -> MaintenanceState:
651 return self._maintenance_state
653 @maintenance_state.setter
654 def maintenance_state(self, state: "MaintenanceState"):
655 self._maintenance_state = state
657 def add_maint_start_notification(self, id: int):
658 self._processed_start_maint_notifications.add(id)
660 def get_processed_start_notifications(self) -> set:
661 return self._processed_start_maint_notifications
663 def add_skipped_end_notification(self, id: int):
664 self._skipped_end_maint_notifications.add(id)
666 def get_skipped_end_notifications(self) -> set:
667 return self._skipped_end_maint_notifications
669 def reset_received_notifications(self):
670 self._processed_start_maint_notifications.clear()
671 self._skipped_end_maint_notifications.clear()
673 def getpeername(self):
674 """
675 Returns the peer name of the connection.
676 """
677 conn_socket = self._get_socket()
678 if conn_socket:
679 return conn_socket.getpeername()[0]
680 return None
682 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
683 conn_socket = self._get_socket()
684 if conn_socket:
685 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
686 # if the current timeout is 0 it means we are in the middle of a can_read call
687 # in this case we don't want to change the timeout because the operation
688 # is non-blocking and should return immediately
689 # Changing the state from non-blocking to blocking in the middle of a read operation
690 # will lead to a deadlock
691 if conn_socket.gettimeout() != 0:
692 conn_socket.settimeout(timeout)
693 self.update_parser_timeout(timeout)
695 def update_parser_timeout(self, timeout: Optional[float] = None):
696 parser = self._get_parser()
697 if parser and parser._buffer:
698 if isinstance(parser, _RESP3Parser) and timeout:
699 parser._buffer.socket_timeout = timeout
700 elif isinstance(parser, _HiredisParser):
701 parser._socket_timeout = timeout
703 def set_tmp_settings(
704 self,
705 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
706 tmp_relaxed_timeout: Optional[float] = None,
707 ):
708 """
709 The value of SENTINEL is used to indicate that the property should not be updated.
710 """
711 if tmp_host_address and tmp_host_address != SENTINEL:
712 self.host = str(tmp_host_address)
713 if tmp_relaxed_timeout != -1:
714 self.socket_timeout = tmp_relaxed_timeout
715 self.socket_connect_timeout = tmp_relaxed_timeout
717 def reset_tmp_settings(
718 self,
719 reset_host_address: bool = False,
720 reset_relaxed_timeout: bool = False,
721 ):
722 if reset_host_address:
723 self.host = self.orig_host_address
724 if reset_relaxed_timeout:
725 self.socket_timeout = self.orig_socket_timeout
726 self.socket_connect_timeout = self.orig_socket_connect_timeout
729class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
730 "Manages communication to and from a Redis server"
732 @deprecated_args(
733 args_to_warn=["lib_name", "lib_version"],
734 reason="Use 'driver_info' parameter instead. "
735 "lib_name and lib_version will be removed in a future version.",
736 )
737 def __init__(
738 self,
739 db: int = 0,
740 password: Optional[str] = None,
741 socket_timeout: Optional[float] = None,
742 socket_connect_timeout: Optional[float] = None,
743 retry_on_timeout: bool = False,
744 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
745 encoding: str = "utf-8",
746 encoding_errors: str = "strict",
747 decode_responses: bool = False,
748 parser_class=DefaultParser,
749 socket_read_size: int = 65536,
750 health_check_interval: int = 0,
751 client_name: Optional[str] = None,
752 lib_name: Optional[str] = None,
753 lib_version: Optional[str] = None,
754 driver_info: Optional[DriverInfo] = None,
755 username: Optional[str] = None,
756 retry: Union[Any, None] = None,
757 redis_connect_func: Optional[Callable[[], None]] = None,
758 credential_provider: Optional[CredentialProvider] = None,
759 protocol: Optional[int] = 2,
760 command_packer: Optional[Callable[[], None]] = None,
761 event_dispatcher: Optional[EventDispatcher] = None,
762 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
763 maint_notifications_pool_handler: Optional[
764 MaintNotificationsPoolHandler
765 ] = None,
766 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
767 maintenance_notification_hash: Optional[int] = None,
768 orig_host_address: Optional[str] = None,
769 orig_socket_timeout: Optional[float] = None,
770 orig_socket_connect_timeout: Optional[float] = None,
771 oss_cluster_maint_notifications_handler: Optional[
772 OSSMaintNotificationsHandler
773 ] = None,
774 ):
775 """
776 Initialize a new Connection.
778 To specify a retry policy for specific errors, first set
779 `retry_on_error` to a list of the error/s to retry on, then set
780 `retry` to a valid `Retry` object.
781 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
783 Parameters
784 ----------
785 driver_info : DriverInfo, optional
786 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
787 are ignored. If not provided, a DriverInfo will be created from lib_name
788 and lib_version (or defaults if those are also None).
789 lib_name : str, optional
790 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
791 lib_version : str, optional
792 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
793 """
794 if (username or password) and credential_provider is not None:
795 raise DataError(
796 "'username' and 'password' cannot be passed along with 'credential_"
797 "provider'. Please provide only one of the following arguments: \n"
798 "1. 'password' and (optional) 'username'\n"
799 "2. 'credential_provider'"
800 )
801 if event_dispatcher is None:
802 self._event_dispatcher = EventDispatcher()
803 else:
804 self._event_dispatcher = event_dispatcher
805 self.pid = os.getpid()
806 self.db = db
807 self.client_name = client_name
809 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
810 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
812 self.credential_provider = credential_provider
813 self.password = password
814 self.username = username
815 self._socket_timeout = socket_timeout
816 if socket_connect_timeout is None:
817 socket_connect_timeout = socket_timeout
818 self._socket_connect_timeout = socket_connect_timeout
819 self.retry_on_timeout = retry_on_timeout
820 if retry_on_error is SENTINEL:
821 retry_on_errors_list = []
822 else:
823 retry_on_errors_list = list(retry_on_error)
824 if retry_on_timeout:
825 # Add TimeoutError to the errors list to retry on
826 retry_on_errors_list.append(TimeoutError)
827 self.retry_on_error = retry_on_errors_list
828 if retry or self.retry_on_error:
829 if retry is None:
830 self.retry = Retry(NoBackoff(), 1)
831 else:
832 # deep-copy the Retry object as it is mutable
833 self.retry = copy.deepcopy(retry)
834 if self.retry_on_error:
835 # Update the retry's supported errors with the specified errors
836 self.retry.update_supported_errors(self.retry_on_error)
837 else:
838 self.retry = Retry(NoBackoff(), 0)
839 self.health_check_interval = health_check_interval
840 self.next_health_check = 0
841 self.redis_connect_func = redis_connect_func
842 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
843 self.handshake_metadata = None
844 self._sock = None
845 self._socket_read_size = socket_read_size
846 self._connect_callbacks = []
847 self._buffer_cutoff = 6000
848 self._re_auth_token: Optional[TokenInterface] = None
849 try:
850 p = int(protocol)
851 except TypeError:
852 p = DEFAULT_RESP_VERSION
853 except ValueError:
854 raise ConnectionError("protocol must be an integer")
855 finally:
856 if p < 2 or p > 3:
857 raise ConnectionError("protocol must be either 2 or 3")
858 # p = DEFAULT_RESP_VERSION
859 self.protocol = p
860 if self.protocol == 3 and parser_class == _RESP2Parser:
861 # If the protocol is 3 but the parser is RESP2, change it to RESP3
862 # This is needed because the parser might be set before the protocol
863 # or might be provided as a kwarg to the constructor
864 # We need to react on discrepancy only for RESP2 and RESP3
865 # as hiredis supports both
866 parser_class = _RESP3Parser
867 self.set_parser(parser_class)
869 self._command_packer = self._construct_command_packer(command_packer)
870 self._should_reconnect = False
872 # Set up maintenance notifications
873 MaintNotificationsAbstractConnection.__init__(
874 self,
875 maint_notifications_config,
876 maint_notifications_pool_handler,
877 maintenance_state,
878 maintenance_notification_hash,
879 orig_host_address,
880 orig_socket_timeout,
881 orig_socket_connect_timeout,
882 oss_cluster_maint_notifications_handler,
883 self._parser,
884 )
886 def __repr__(self):
887 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
888 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
890 @abstractmethod
891 def repr_pieces(self):
892 pass
894 def __del__(self):
895 try:
896 self.disconnect()
897 except Exception:
898 pass
900 def _construct_command_packer(self, packer):
901 if packer is not None:
902 return packer
903 elif HIREDIS_AVAILABLE:
904 return HiredisRespSerializer()
905 else:
906 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
908 def register_connect_callback(self, callback):
909 """
910 Register a callback to be called when the connection is established either
911 initially or reconnected. This allows listeners to issue commands that
912 are ephemeral to the connection, for example pub/sub subscription or
913 key tracking. The callback must be a _method_ and will be kept as
914 a weak reference.
915 """
916 wm = weakref.WeakMethod(callback)
917 if wm not in self._connect_callbacks:
918 self._connect_callbacks.append(wm)
920 def deregister_connect_callback(self, callback):
921 """
922 De-register a previously registered callback. It will no-longer receive
923 notifications on connection events. Calling this is not required when the
924 listener goes away, since the callbacks are kept as weak methods.
925 """
926 try:
927 self._connect_callbacks.remove(weakref.WeakMethod(callback))
928 except ValueError:
929 pass
931 def set_parser(self, parser_class):
932 """
933 Creates a new instance of parser_class with socket size:
934 _socket_read_size and assigns it to the parser for the connection
935 :param parser_class: The required parser class
936 """
937 self._parser = parser_class(socket_read_size=self._socket_read_size)
939 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
940 return self._parser
942 def connect(self):
943 "Connects to the Redis server if not already connected"
944 # try once the socket connect with the handshake, retry the whole
945 # connect/handshake flow based on retry policy
946 self.retry.call_with_retry(
947 lambda: self.connect_check_health(
948 check_health=True, retry_socket_connect=False
949 ),
950 lambda error: self.disconnect(error),
951 )
953 def connect_check_health(
954 self, check_health: bool = True, retry_socket_connect: bool = True
955 ):
956 if self._sock:
957 return
958 try:
959 if retry_socket_connect:
960 sock = self.retry.call_with_retry(
961 lambda: self._connect(), lambda error: self.disconnect(error)
962 )
963 else:
964 sock = self._connect()
965 except socket.timeout:
966 raise TimeoutError("Timeout connecting to server")
967 except OSError as e:
968 raise ConnectionError(self._error_message(e))
970 self._sock = sock
971 try:
972 if self.redis_connect_func is None:
973 # Use the default on_connect function
974 self.on_connect_check_health(check_health=check_health)
975 else:
976 # Use the passed function redis_connect_func
977 self.redis_connect_func(self)
978 except RedisError:
979 # clean up after any error in on_connect
980 self.disconnect()
981 raise
983 # run any user callbacks. right now the only internal callback
984 # is for pubsub channel/pattern resubscription
985 # first, remove any dead weakrefs
986 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
987 for ref in self._connect_callbacks:
988 callback = ref()
989 if callback:
990 callback(self)
992 @abstractmethod
993 def _connect(self):
994 pass
996 @abstractmethod
997 def _host_error(self):
998 pass
1000 def _error_message(self, exception):
1001 return format_error_message(self._host_error(), exception)
1003 def on_connect(self):
1004 self.on_connect_check_health(check_health=True)
1006 def on_connect_check_health(self, check_health: bool = True):
1007 "Initialize the connection, authenticate and select a database"
1008 self._parser.on_connect(self)
1009 parser = self._parser
1011 auth_args = None
1012 # if credential provider or username and/or password are set, authenticate
1013 if self.credential_provider or (self.username or self.password):
1014 cred_provider = (
1015 self.credential_provider
1016 or UsernamePasswordCredentialProvider(self.username, self.password)
1017 )
1018 auth_args = cred_provider.get_credentials()
1020 # if resp version is specified and we have auth args,
1021 # we need to send them via HELLO
1022 if auth_args and self.protocol not in [2, "2"]:
1023 if isinstance(self._parser, _RESP2Parser):
1024 self.set_parser(_RESP3Parser)
1025 # update cluster exception classes
1026 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1027 self._parser.on_connect(self)
1028 if len(auth_args) == 1:
1029 auth_args = ["default", auth_args[0]]
1030 # avoid checking health here -- PING will fail if we try
1031 # to check the health prior to the AUTH
1032 self.send_command(
1033 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
1034 )
1035 self.handshake_metadata = self.read_response()
1036 # if response.get(b"proto") != self.protocol and response.get(
1037 # "proto"
1038 # ) != self.protocol:
1039 # raise ConnectionError("Invalid RESP version")
1040 elif auth_args:
1041 # avoid checking health here -- PING will fail if we try
1042 # to check the health prior to the AUTH
1043 self.send_command("AUTH", *auth_args, check_health=False)
1045 try:
1046 auth_response = self.read_response()
1047 except AuthenticationWrongNumberOfArgsError:
1048 # a username and password were specified but the Redis
1049 # server seems to be < 6.0.0 which expects a single password
1050 # arg. retry auth with just the password.
1051 # https://github.com/andymccurdy/redis-py/issues/1274
1052 self.send_command("AUTH", auth_args[-1], check_health=False)
1053 auth_response = self.read_response()
1055 if str_if_bytes(auth_response) != "OK":
1056 raise AuthenticationError("Invalid Username or Password")
1058 # if resp version is specified, switch to it
1059 elif self.protocol not in [2, "2"]:
1060 if isinstance(self._parser, _RESP2Parser):
1061 self.set_parser(_RESP3Parser)
1062 # update cluster exception classes
1063 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
1064 self._parser.on_connect(self)
1065 self.send_command("HELLO", self.protocol, check_health=check_health)
1066 self.handshake_metadata = self.read_response()
1067 if (
1068 self.handshake_metadata.get(b"proto") != self.protocol
1069 and self.handshake_metadata.get("proto") != self.protocol
1070 ):
1071 raise ConnectionError("Invalid RESP version")
1073 # Activate maintenance notifications for this connection
1074 # if enabled in the configuration
1075 # This is a no-op if maintenance notifications are not enabled
1076 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
1078 # if a client_name is given, set it
1079 if self.client_name:
1080 self.send_command(
1081 "CLIENT",
1082 "SETNAME",
1083 self.client_name,
1084 check_health=check_health,
1085 )
1086 if str_if_bytes(self.read_response()) != "OK":
1087 raise ConnectionError("Error setting client name")
1089 # Set the library name and version from driver_info
1090 try:
1091 if self.driver_info and self.driver_info.formatted_name:
1092 self.send_command(
1093 "CLIENT",
1094 "SETINFO",
1095 "LIB-NAME",
1096 self.driver_info.formatted_name,
1097 check_health=check_health,
1098 )
1099 self.read_response()
1100 except ResponseError:
1101 pass
1103 try:
1104 if self.driver_info and self.driver_info.lib_version:
1105 self.send_command(
1106 "CLIENT",
1107 "SETINFO",
1108 "LIB-VER",
1109 self.driver_info.lib_version,
1110 check_health=check_health,
1111 )
1112 self.read_response()
1113 except ResponseError:
1114 pass
1116 # if a database is specified, switch to it
1117 if self.db:
1118 self.send_command("SELECT", self.db, check_health=check_health)
1119 if str_if_bytes(self.read_response()) != "OK":
1120 raise ConnectionError("Invalid Database")
1122 def disconnect(self, *args):
1123 "Disconnects from the Redis server"
1124 self._parser.on_disconnect()
1126 conn_sock = self._sock
1127 self._sock = None
1128 # reset the reconnect flag
1129 self.reset_should_reconnect()
1131 if conn_sock is None:
1132 return
1134 if os.getpid() == self.pid:
1135 try:
1136 conn_sock.shutdown(socket.SHUT_RDWR)
1137 except (OSError, TypeError):
1138 pass
1140 try:
1141 conn_sock.close()
1142 except OSError:
1143 pass
1145 if self.maintenance_state == MaintenanceState.MAINTENANCE:
1146 # this block will be executed only if the connection was in maintenance state
1147 # and the connection was closed.
1148 # The state change won't be applied on connections that are in Moving state
1149 # because their state and configurations will be handled when the moving ttl expires.
1150 self.reset_tmp_settings(reset_relaxed_timeout=True)
1151 self.maintenance_state = MaintenanceState.NONE
1152 # reset the sets that keep track of received start maint
1153 # notifications and skipped end maint notifications
1154 self.reset_received_notifications()
1156 def mark_for_reconnect(self):
1157 self._should_reconnect = True
1159 def should_reconnect(self):
1160 return self._should_reconnect
1162 def reset_should_reconnect(self):
1163 self._should_reconnect = False
1165 def _send_ping(self):
1166 """Send PING, expect PONG in return"""
1167 self.send_command("PING", check_health=False)
1168 if str_if_bytes(self.read_response()) != "PONG":
1169 raise ConnectionError("Bad response from PING health check")
1171 def _ping_failed(self, error):
1172 """Function to call when PING fails"""
1173 self.disconnect()
1175 def check_health(self):
1176 """Check the health of the connection with a PING/PONG"""
1177 if self.health_check_interval and time.monotonic() > self.next_health_check:
1178 self.retry.call_with_retry(self._send_ping, self._ping_failed)
1180 def send_packed_command(self, command, check_health=True):
1181 """Send an already packed command to the Redis server"""
1182 if not self._sock:
1183 self.connect_check_health(check_health=False)
1184 # guard against health check recursion
1185 if check_health:
1186 self.check_health()
1187 try:
1188 if isinstance(command, str):
1189 command = [command]
1190 for item in command:
1191 self._sock.sendall(item)
1192 except socket.timeout:
1193 self.disconnect()
1194 raise TimeoutError("Timeout writing to socket")
1195 except OSError as e:
1196 self.disconnect()
1197 if len(e.args) == 1:
1198 errno, errmsg = "UNKNOWN", e.args[0]
1199 else:
1200 errno = e.args[0]
1201 errmsg = e.args[1]
1202 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1203 except BaseException:
1204 # BaseExceptions can be raised when a socket send operation is not
1205 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1206 # to send un-sent data. However, the send_packed_command() API
1207 # does not support it so there is no point in keeping the connection open.
1208 self.disconnect()
1209 raise
1211 def send_command(self, *args, **kwargs):
1212 """Pack and send a command to the Redis server"""
1213 self.send_packed_command(
1214 self._command_packer.pack(*args),
1215 check_health=kwargs.get("check_health", True),
1216 )
1218 def can_read(self, timeout=0):
1219 """Poll the socket to see if there's data that can be read."""
1220 sock = self._sock
1221 if not sock:
1222 self.connect()
1224 host_error = self._host_error()
1226 try:
1227 return self._parser.can_read(timeout)
1229 except OSError as e:
1230 self.disconnect()
1231 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1233 def read_response(
1234 self,
1235 disable_decoding=False,
1236 *,
1237 disconnect_on_error=True,
1238 push_request=False,
1239 ):
1240 """Read the response from a previously sent command"""
1242 host_error = self._host_error()
1244 try:
1245 if self.protocol in ["3", 3]:
1246 response = self._parser.read_response(
1247 disable_decoding=disable_decoding, push_request=push_request
1248 )
1249 else:
1250 response = self._parser.read_response(disable_decoding=disable_decoding)
1251 except socket.timeout:
1252 if disconnect_on_error:
1253 self.disconnect()
1254 raise TimeoutError(f"Timeout reading from {host_error}")
1255 except OSError as e:
1256 if disconnect_on_error:
1257 self.disconnect()
1258 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1259 except BaseException:
1260 # Also by default close in case of BaseException. A lot of code
1261 # relies on this behaviour when doing Command/Response pairs.
1262 # See #1128.
1263 if disconnect_on_error:
1264 self.disconnect()
1265 raise
1267 if self.health_check_interval:
1268 self.next_health_check = time.monotonic() + self.health_check_interval
1270 if isinstance(response, ResponseError):
1271 try:
1272 raise response
1273 finally:
1274 del response # avoid creating ref cycles
1275 return response
1277 def pack_command(self, *args):
1278 """Pack a series of arguments into the Redis protocol"""
1279 return self._command_packer.pack(*args)
1281 def pack_commands(self, commands):
1282 """Pack multiple commands into the Redis protocol"""
1283 output = []
1284 pieces = []
1285 buffer_length = 0
1286 buffer_cutoff = self._buffer_cutoff
1288 for cmd in commands:
1289 for chunk in self._command_packer.pack(*cmd):
1290 chunklen = len(chunk)
1291 if (
1292 buffer_length > buffer_cutoff
1293 or chunklen > buffer_cutoff
1294 or isinstance(chunk, memoryview)
1295 ):
1296 if pieces:
1297 output.append(SYM_EMPTY.join(pieces))
1298 buffer_length = 0
1299 pieces = []
1301 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1302 output.append(chunk)
1303 else:
1304 pieces.append(chunk)
1305 buffer_length += chunklen
1307 if pieces:
1308 output.append(SYM_EMPTY.join(pieces))
1309 return output
1311 def get_protocol(self) -> Union[int, str]:
1312 return self.protocol
1314 @property
1315 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1316 return self._handshake_metadata
1318 @handshake_metadata.setter
1319 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1320 self._handshake_metadata = value
1322 def set_re_auth_token(self, token: TokenInterface):
1323 self._re_auth_token = token
1325 def re_auth(self):
1326 if self._re_auth_token is not None:
1327 self.send_command(
1328 "AUTH",
1329 self._re_auth_token.try_get("oid"),
1330 self._re_auth_token.get_value(),
1331 )
1332 self.read_response()
1333 self._re_auth_token = None
1335 def _get_socket(self) -> Optional[socket.socket]:
1336 return self._sock
1338 @property
1339 def socket_timeout(self) -> Optional[Union[float, int]]:
1340 return self._socket_timeout
1342 @socket_timeout.setter
1343 def socket_timeout(self, value: Optional[Union[float, int]]):
1344 self._socket_timeout = value
1346 @property
1347 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1348 return self._socket_connect_timeout
1350 @socket_connect_timeout.setter
1351 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1352 self._socket_connect_timeout = value
1355class Connection(AbstractConnection):
1356 "Manages TCP communication to and from a Redis server"
1358 def __init__(
1359 self,
1360 host="localhost",
1361 port=6379,
1362 socket_keepalive=False,
1363 socket_keepalive_options=None,
1364 socket_type=0,
1365 **kwargs,
1366 ):
1367 self._host = host
1368 self.port = int(port)
1369 self.socket_keepalive = socket_keepalive
1370 self.socket_keepalive_options = socket_keepalive_options or {}
1371 self.socket_type = socket_type
1372 super().__init__(**kwargs)
1374 def repr_pieces(self):
1375 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1376 if self.client_name:
1377 pieces.append(("client_name", self.client_name))
1378 return pieces
1380 def _connect(self):
1381 "Create a TCP socket connection"
1382 # we want to mimic what socket.create_connection does to support
1383 # ipv4/ipv6, but we want to set options prior to calling
1384 # socket.connect()
1385 err = None
1387 for res in socket.getaddrinfo(
1388 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1389 ):
1390 family, socktype, proto, canonname, socket_address = res
1391 sock = None
1392 try:
1393 sock = socket.socket(family, socktype, proto)
1394 # TCP_NODELAY
1395 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1397 # TCP_KEEPALIVE
1398 if self.socket_keepalive:
1399 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1400 for k, v in self.socket_keepalive_options.items():
1401 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1403 # set the socket_connect_timeout before we connect
1404 sock.settimeout(self.socket_connect_timeout)
1406 # connect
1407 sock.connect(socket_address)
1409 # set the socket_timeout now that we're connected
1410 sock.settimeout(self.socket_timeout)
1411 return sock
1413 except OSError as _:
1414 err = _
1415 if sock is not None:
1416 try:
1417 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1418 except OSError:
1419 pass
1420 sock.close()
1422 if err is not None:
1423 raise err
1424 raise OSError("socket.getaddrinfo returned an empty list")
1426 def _host_error(self):
1427 return f"{self.host}:{self.port}"
1429 @property
1430 def host(self) -> str:
1431 return self._host
1433 @host.setter
1434 def host(self, value: str):
1435 self._host = value
1438class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1439 DUMMY_CACHE_VALUE = b"foo"
1440 MIN_ALLOWED_VERSION = "7.4.0"
1441 DEFAULT_SERVER_NAME = "redis"
1443 def __init__(
1444 self,
1445 conn: ConnectionInterface,
1446 cache: CacheInterface,
1447 pool_lock: threading.RLock,
1448 ):
1449 self.pid = os.getpid()
1450 self._conn = conn
1451 self.retry = self._conn.retry
1452 self.host = self._conn.host
1453 self.port = self._conn.port
1454 self.credential_provider = conn.credential_provider
1455 self._pool_lock = pool_lock
1456 self._cache = cache
1457 self._cache_lock = threading.RLock()
1458 self._current_command_cache_key = None
1459 self._current_options = None
1460 self.register_connect_callback(self._enable_tracking_callback)
1462 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1463 MaintNotificationsAbstractConnection.__init__(
1464 self,
1465 self._conn.maint_notifications_config,
1466 self._conn._maint_notifications_pool_handler,
1467 self._conn.maintenance_state,
1468 self._conn.maintenance_notification_hash,
1469 self._conn.host,
1470 self._conn.socket_timeout,
1471 self._conn.socket_connect_timeout,
1472 self._conn._oss_cluster_maint_notifications_handler,
1473 self._conn._get_parser(),
1474 )
1476 def repr_pieces(self):
1477 return self._conn.repr_pieces()
1479 def register_connect_callback(self, callback):
1480 self._conn.register_connect_callback(callback)
1482 def deregister_connect_callback(self, callback):
1483 self._conn.deregister_connect_callback(callback)
1485 def set_parser(self, parser_class):
1486 self._conn.set_parser(parser_class)
1488 def set_maint_notifications_pool_handler_for_connection(
1489 self, maint_notifications_pool_handler
1490 ):
1491 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1492 self._conn.set_maint_notifications_pool_handler_for_connection(
1493 maint_notifications_pool_handler
1494 )
1496 def set_maint_notifications_cluster_handler_for_connection(
1497 self, oss_cluster_maint_notifications_handler
1498 ):
1499 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1500 self._conn.set_maint_notifications_cluster_handler_for_connection(
1501 oss_cluster_maint_notifications_handler
1502 )
1504 def get_protocol(self):
1505 return self._conn.get_protocol()
1507 def connect(self):
1508 self._conn.connect()
1510 server_name = self._conn.handshake_metadata.get(b"server", None)
1511 if server_name is None:
1512 server_name = self._conn.handshake_metadata.get("server", None)
1513 server_ver = self._conn.handshake_metadata.get(b"version", None)
1514 if server_ver is None:
1515 server_ver = self._conn.handshake_metadata.get("version", None)
1516 if server_ver is None or server_name is None:
1517 raise ConnectionError("Cannot retrieve information about server version")
1519 server_ver = ensure_string(server_ver)
1520 server_name = ensure_string(server_name)
1522 if (
1523 server_name != self.DEFAULT_SERVER_NAME
1524 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1525 ):
1526 raise ConnectionError(
1527 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1528 )
1530 def on_connect(self):
1531 self._conn.on_connect()
1533 def disconnect(self, *args):
1534 with self._cache_lock:
1535 self._cache.flush()
1536 self._conn.disconnect(*args)
1538 def check_health(self):
1539 self._conn.check_health()
1541 def send_packed_command(self, command, check_health=True):
1542 # TODO: Investigate if it's possible to unpack command
1543 # or extract keys from packed command
1544 self._conn.send_packed_command(command)
1546 def send_command(self, *args, **kwargs):
1547 self._process_pending_invalidations()
1549 with self._cache_lock:
1550 # Command is write command or not allowed
1551 # to be cached.
1552 if not self._cache.is_cachable(
1553 CacheKey(command=args[0], redis_keys=(), redis_args=())
1554 ):
1555 self._current_command_cache_key = None
1556 self._conn.send_command(*args, **kwargs)
1557 return
1559 if kwargs.get("keys") is None:
1560 raise ValueError("Cannot create cache key.")
1562 # Creates cache key.
1563 self._current_command_cache_key = CacheKey(
1564 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1565 )
1567 with self._cache_lock:
1568 # We have to trigger invalidation processing in case if
1569 # it was cached by another connection to avoid
1570 # queueing invalidations in stale connections.
1571 if self._cache.get(self._current_command_cache_key):
1572 entry = self._cache.get(self._current_command_cache_key)
1574 if entry.connection_ref != self._conn:
1575 with self._pool_lock:
1576 while entry.connection_ref.can_read():
1577 entry.connection_ref.read_response(push_request=True)
1579 return
1581 # Set temporary entry value to prevent
1582 # race condition from another connection.
1583 self._cache.set(
1584 CacheEntry(
1585 cache_key=self._current_command_cache_key,
1586 cache_value=self.DUMMY_CACHE_VALUE,
1587 status=CacheEntryStatus.IN_PROGRESS,
1588 connection_ref=self._conn,
1589 )
1590 )
1592 # Send command over socket only if it's allowed
1593 # read-only command that not yet cached.
1594 self._conn.send_command(*args, **kwargs)
1596 def can_read(self, timeout=0):
1597 return self._conn.can_read(timeout)
1599 def read_response(
1600 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1601 ):
1602 with self._cache_lock:
1603 # Check if command response exists in a cache and it's not in progress.
1604 if (
1605 self._current_command_cache_key is not None
1606 and self._cache.get(self._current_command_cache_key) is not None
1607 and self._cache.get(self._current_command_cache_key).status
1608 != CacheEntryStatus.IN_PROGRESS
1609 ):
1610 res = copy.deepcopy(
1611 self._cache.get(self._current_command_cache_key).cache_value
1612 )
1613 self._current_command_cache_key = None
1614 return res
1616 response = self._conn.read_response(
1617 disable_decoding=disable_decoding,
1618 disconnect_on_error=disconnect_on_error,
1619 push_request=push_request,
1620 )
1622 with self._cache_lock:
1623 # Prevent not-allowed command from caching.
1624 if self._current_command_cache_key is None:
1625 return response
1626 # If response is None prevent from caching.
1627 if response is None:
1628 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1629 return response
1631 cache_entry = self._cache.get(self._current_command_cache_key)
1633 # Cache only responses that still valid
1634 # and wasn't invalidated by another connection in meantime.
1635 if cache_entry is not None:
1636 cache_entry.status = CacheEntryStatus.VALID
1637 cache_entry.cache_value = response
1638 self._cache.set(cache_entry)
1640 self._current_command_cache_key = None
1642 return response
1644 def pack_command(self, *args):
1645 return self._conn.pack_command(*args)
1647 def pack_commands(self, commands):
1648 return self._conn.pack_commands(commands)
1650 @property
1651 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1652 return self._conn.handshake_metadata
1654 def set_re_auth_token(self, token: TokenInterface):
1655 self._conn.set_re_auth_token(token)
1657 def re_auth(self):
1658 self._conn.re_auth()
1660 def mark_for_reconnect(self):
1661 self._conn.mark_for_reconnect()
1663 def should_reconnect(self):
1664 return self._conn.should_reconnect()
1666 def reset_should_reconnect(self):
1667 self._conn.reset_should_reconnect()
1669 @property
1670 def host(self) -> str:
1671 return self._conn.host
1673 @host.setter
1674 def host(self, value: str):
1675 self._conn.host = value
1677 @property
1678 def socket_timeout(self) -> Optional[Union[float, int]]:
1679 return self._conn.socket_timeout
1681 @socket_timeout.setter
1682 def socket_timeout(self, value: Optional[Union[float, int]]):
1683 self._conn.socket_timeout = value
1685 @property
1686 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1687 return self._conn.socket_connect_timeout
1689 @socket_connect_timeout.setter
1690 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1691 self._conn.socket_connect_timeout = value
1693 @property
1694 def _maint_notifications_connection_handler(
1695 self,
1696 ) -> Optional[MaintNotificationsConnectionHandler]:
1697 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1698 return self._conn._maint_notifications_connection_handler
1700 @_maint_notifications_connection_handler.setter
1701 def _maint_notifications_connection_handler(
1702 self, value: Optional[MaintNotificationsConnectionHandler]
1703 ):
1704 self._conn._maint_notifications_connection_handler = value
1706 def _get_socket(self) -> Optional[socket.socket]:
1707 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1708 return self._conn._get_socket()
1709 else:
1710 raise NotImplementedError(
1711 "Maintenance notifications are not supported by this connection type"
1712 )
1714 def _get_maint_notifications_connection_instance(
1715 self, connection
1716 ) -> MaintNotificationsAbstractConnection:
1717 """
1718 Validate that connection instance supports maintenance notifications.
1719 With this helper method we ensure that we are working
1720 with the correct connection type.
1721 After twe validate that connection instance supports maintenance notifications
1722 we can safely return the connection instance
1723 as MaintNotificationsAbstractConnection.
1724 """
1725 if not isinstance(connection, MaintNotificationsAbstractConnection):
1726 raise NotImplementedError(
1727 "Maintenance notifications are not supported by this connection type"
1728 )
1729 else:
1730 return connection
1732 @property
1733 def maintenance_state(self) -> MaintenanceState:
1734 con = self._get_maint_notifications_connection_instance(self._conn)
1735 return con.maintenance_state
1737 @maintenance_state.setter
1738 def maintenance_state(self, state: MaintenanceState):
1739 con = self._get_maint_notifications_connection_instance(self._conn)
1740 con.maintenance_state = state
1742 def getpeername(self):
1743 con = self._get_maint_notifications_connection_instance(self._conn)
1744 return con.getpeername()
1746 def get_resolved_ip(self):
1747 con = self._get_maint_notifications_connection_instance(self._conn)
1748 return con.get_resolved_ip()
1750 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1751 con = self._get_maint_notifications_connection_instance(self._conn)
1752 con.update_current_socket_timeout(relaxed_timeout)
1754 def set_tmp_settings(
1755 self,
1756 tmp_host_address: Optional[str] = None,
1757 tmp_relaxed_timeout: Optional[float] = None,
1758 ):
1759 con = self._get_maint_notifications_connection_instance(self._conn)
1760 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1762 def reset_tmp_settings(
1763 self,
1764 reset_host_address: bool = False,
1765 reset_relaxed_timeout: bool = False,
1766 ):
1767 con = self._get_maint_notifications_connection_instance(self._conn)
1768 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1770 def _connect(self):
1771 self._conn._connect()
1773 def _host_error(self):
1774 self._conn._host_error()
1776 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1777 conn.send_command("CLIENT", "TRACKING", "ON")
1778 conn.read_response()
1779 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1781 def _process_pending_invalidations(self):
1782 while self.can_read():
1783 self._conn.read_response(push_request=True)
1785 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1786 with self._cache_lock:
1787 # Flush cache when DB flushed on server-side
1788 if data[1] is None:
1789 self._cache.flush()
1790 else:
1791 self._cache.delete_by_redis_keys(data[1])
1794class SSLConnection(Connection):
1795 """Manages SSL connections to and from the Redis server(s).
1796 This class extends the Connection class, adding SSL functionality, and making
1797 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1798 """ # noqa
1800 def __init__(
1801 self,
1802 ssl_keyfile=None,
1803 ssl_certfile=None,
1804 ssl_cert_reqs="required",
1805 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1806 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1807 ssl_ca_certs=None,
1808 ssl_ca_data=None,
1809 ssl_check_hostname=True,
1810 ssl_ca_path=None,
1811 ssl_password=None,
1812 ssl_validate_ocsp=False,
1813 ssl_validate_ocsp_stapled=False,
1814 ssl_ocsp_context=None,
1815 ssl_ocsp_expected_cert=None,
1816 ssl_min_version=None,
1817 ssl_ciphers=None,
1818 **kwargs,
1819 ):
1820 """Constructor
1822 Args:
1823 ssl_keyfile: Path to an ssl private key. Defaults to None.
1824 ssl_certfile: Path to an ssl certificate. Defaults to None.
1825 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1826 or an ssl.VerifyMode. Defaults to "required".
1827 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1828 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1829 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1830 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1831 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1832 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1833 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1835 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1836 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1837 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1838 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1839 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1840 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1842 Raises:
1843 RedisError
1844 """ # noqa
1845 if not SSL_AVAILABLE:
1846 raise RedisError("Python wasn't built with SSL support")
1848 self.keyfile = ssl_keyfile
1849 self.certfile = ssl_certfile
1850 if ssl_cert_reqs is None:
1851 ssl_cert_reqs = ssl.CERT_NONE
1852 elif isinstance(ssl_cert_reqs, str):
1853 CERT_REQS = { # noqa: N806
1854 "none": ssl.CERT_NONE,
1855 "optional": ssl.CERT_OPTIONAL,
1856 "required": ssl.CERT_REQUIRED,
1857 }
1858 if ssl_cert_reqs not in CERT_REQS:
1859 raise RedisError(
1860 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1861 )
1862 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1863 self.cert_reqs = ssl_cert_reqs
1864 self.ssl_include_verify_flags = ssl_include_verify_flags
1865 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1866 self.ca_certs = ssl_ca_certs
1867 self.ca_data = ssl_ca_data
1868 self.ca_path = ssl_ca_path
1869 self.check_hostname = (
1870 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1871 )
1872 self.certificate_password = ssl_password
1873 self.ssl_validate_ocsp = ssl_validate_ocsp
1874 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1875 self.ssl_ocsp_context = ssl_ocsp_context
1876 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1877 self.ssl_min_version = ssl_min_version
1878 self.ssl_ciphers = ssl_ciphers
1879 super().__init__(**kwargs)
1881 def _connect(self):
1882 """
1883 Wrap the socket with SSL support, handling potential errors.
1884 """
1885 sock = super()._connect()
1886 try:
1887 return self._wrap_socket_with_ssl(sock)
1888 except (OSError, RedisError):
1889 sock.close()
1890 raise
1892 def _wrap_socket_with_ssl(self, sock):
1893 """
1894 Wraps the socket with SSL support.
1896 Args:
1897 sock: The plain socket to wrap with SSL.
1899 Returns:
1900 An SSL wrapped socket.
1901 """
1902 context = ssl.create_default_context()
1903 context.check_hostname = self.check_hostname
1904 context.verify_mode = self.cert_reqs
1905 if self.ssl_include_verify_flags:
1906 for flag in self.ssl_include_verify_flags:
1907 context.verify_flags |= flag
1908 if self.ssl_exclude_verify_flags:
1909 for flag in self.ssl_exclude_verify_flags:
1910 context.verify_flags &= ~flag
1911 if self.certfile or self.keyfile:
1912 context.load_cert_chain(
1913 certfile=self.certfile,
1914 keyfile=self.keyfile,
1915 password=self.certificate_password,
1916 )
1917 if (
1918 self.ca_certs is not None
1919 or self.ca_path is not None
1920 or self.ca_data is not None
1921 ):
1922 context.load_verify_locations(
1923 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1924 )
1925 if self.ssl_min_version is not None:
1926 context.minimum_version = self.ssl_min_version
1927 if self.ssl_ciphers:
1928 context.set_ciphers(self.ssl_ciphers)
1929 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1930 raise RedisError("cryptography is not installed.")
1932 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1933 raise RedisError(
1934 "Either an OCSP staple or pure OCSP connection must be validated "
1935 "- not both."
1936 )
1938 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1940 # validation for the stapled case
1941 if self.ssl_validate_ocsp_stapled:
1942 import OpenSSL
1944 from .ocsp import ocsp_staple_verifier
1946 # if a context is provided use it - otherwise, a basic context
1947 if self.ssl_ocsp_context is None:
1948 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1949 staple_ctx.use_certificate_file(self.certfile)
1950 staple_ctx.use_privatekey_file(self.keyfile)
1951 else:
1952 staple_ctx = self.ssl_ocsp_context
1954 staple_ctx.set_ocsp_client_callback(
1955 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1956 )
1958 # need another socket
1959 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1960 con.request_ocsp()
1961 con.connect((self.host, self.port))
1962 con.do_handshake()
1963 con.shutdown()
1964 return sslsock
1966 # pure ocsp validation
1967 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1968 from .ocsp import OCSPVerifier
1970 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1971 if o.is_valid():
1972 return sslsock
1973 else:
1974 raise ConnectionError("ocsp validation error")
1975 return sslsock
1978class UnixDomainSocketConnection(AbstractConnection):
1979 "Manages UDS communication to and from a Redis server"
1981 def __init__(self, path="", socket_timeout=None, **kwargs):
1982 super().__init__(**kwargs)
1983 self.path = path
1984 self.socket_timeout = socket_timeout
1986 def repr_pieces(self):
1987 pieces = [("path", self.path), ("db", self.db)]
1988 if self.client_name:
1989 pieces.append(("client_name", self.client_name))
1990 return pieces
1992 def _connect(self):
1993 "Create a Unix domain socket connection"
1994 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1995 sock.settimeout(self.socket_connect_timeout)
1996 try:
1997 sock.connect(self.path)
1998 except OSError:
1999 # Prevent ResourceWarnings for unclosed sockets.
2000 try:
2001 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
2002 except OSError:
2003 pass
2004 sock.close()
2005 raise
2006 sock.settimeout(self.socket_timeout)
2007 return sock
2009 def _host_error(self):
2010 return self.path
2013FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
2016def to_bool(value):
2017 if value is None or value == "":
2018 return None
2019 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
2020 return False
2021 return bool(value)
2024def parse_ssl_verify_flags(value):
2025 # flags are passed in as a string representation of a list,
2026 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
2027 verify_flags_str = value.replace("[", "").replace("]", "")
2029 verify_flags = []
2030 for flag in verify_flags_str.split(","):
2031 flag = flag.strip()
2032 if not hasattr(VerifyFlags, flag):
2033 raise ValueError(f"Invalid ssl verify flag: {flag}")
2034 verify_flags.append(getattr(VerifyFlags, flag))
2035 return verify_flags
2038URL_QUERY_ARGUMENT_PARSERS = {
2039 "db": int,
2040 "socket_timeout": float,
2041 "socket_connect_timeout": float,
2042 "socket_keepalive": to_bool,
2043 "retry_on_timeout": to_bool,
2044 "retry_on_error": list,
2045 "max_connections": int,
2046 "health_check_interval": int,
2047 "ssl_check_hostname": to_bool,
2048 "ssl_include_verify_flags": parse_ssl_verify_flags,
2049 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
2050 "timeout": float,
2051}
2054def parse_url(url):
2055 if not (
2056 url.startswith("redis://")
2057 or url.startswith("rediss://")
2058 or url.startswith("unix://")
2059 ):
2060 raise ValueError(
2061 "Redis URL must specify one of the following "
2062 "schemes (redis://, rediss://, unix://)"
2063 )
2065 url = urlparse(url)
2066 kwargs = {}
2068 for name, value in parse_qs(url.query).items():
2069 if value and len(value) > 0:
2070 value = unquote(value[0])
2071 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
2072 if parser:
2073 try:
2074 kwargs[name] = parser(value)
2075 except (TypeError, ValueError):
2076 raise ValueError(f"Invalid value for '{name}' in connection URL.")
2077 else:
2078 kwargs[name] = value
2080 if url.username:
2081 kwargs["username"] = unquote(url.username)
2082 if url.password:
2083 kwargs["password"] = unquote(url.password)
2085 # We only support redis://, rediss:// and unix:// schemes.
2086 if url.scheme == "unix":
2087 if url.path:
2088 kwargs["path"] = unquote(url.path)
2089 kwargs["connection_class"] = UnixDomainSocketConnection
2091 else: # implied: url.scheme in ("redis", "rediss"):
2092 if url.hostname:
2093 kwargs["host"] = unquote(url.hostname)
2094 if url.port:
2095 kwargs["port"] = int(url.port)
2097 # If there's a path argument, use it as the db argument if a
2098 # querystring value wasn't specified
2099 if url.path and "db" not in kwargs:
2100 try:
2101 kwargs["db"] = int(unquote(url.path).replace("/", ""))
2102 except (AttributeError, ValueError):
2103 pass
2105 if url.scheme == "rediss":
2106 kwargs["connection_class"] = SSLConnection
2108 return kwargs
2111_CP = TypeVar("_CP", bound="ConnectionPool")
2114class ConnectionPoolInterface(ABC):
2115 @abstractmethod
2116 def get_protocol(self):
2117 pass
2119 @abstractmethod
2120 def reset(self):
2121 pass
2123 @abstractmethod
2124 @deprecated_args(
2125 args_to_warn=["*"],
2126 reason="Use get_connection() without args instead",
2127 version="5.3.0",
2128 )
2129 def get_connection(
2130 self, command_name: Optional[str], *keys, **options
2131 ) -> ConnectionInterface:
2132 pass
2134 @abstractmethod
2135 def get_encoder(self):
2136 pass
2138 @abstractmethod
2139 def release(self, connection: ConnectionInterface):
2140 pass
2142 @abstractmethod
2143 def disconnect(self, inuse_connections: bool = True):
2144 pass
2146 @abstractmethod
2147 def close(self):
2148 pass
2150 @abstractmethod
2151 def set_retry(self, retry: Retry):
2152 pass
2154 @abstractmethod
2155 def re_auth_callback(self, token: TokenInterface):
2156 pass
2159class MaintNotificationsAbstractConnectionPool:
2160 """
2161 Abstract class for handling maintenance notifications logic.
2162 This class is mixed into the ConnectionPool classes.
2164 This class is not intended to be used directly!
2166 All logic related to maintenance notifications and
2167 connection pool handling is encapsulated in this class.
2168 """
2170 def __init__(
2171 self,
2172 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2173 oss_cluster_maint_notifications_handler: Optional[
2174 OSSMaintNotificationsHandler
2175 ] = None,
2176 **kwargs,
2177 ):
2178 # Initialize maintenance notifications
2179 is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3)
2181 if maint_notifications_config is None and is_protocol_supported:
2182 maint_notifications_config = MaintNotificationsConfig()
2184 if maint_notifications_config and maint_notifications_config.enabled:
2185 if not is_protocol_supported:
2186 raise RedisError(
2187 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2188 )
2189 if oss_cluster_maint_notifications_handler:
2190 self._oss_cluster_maint_notifications_handler = (
2191 oss_cluster_maint_notifications_handler
2192 )
2193 self._update_connection_kwargs_for_maint_notifications(
2194 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler
2195 )
2196 self._maint_notifications_pool_handler = None
2197 else:
2198 self._oss_cluster_maint_notifications_handler = None
2199 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2200 self, maint_notifications_config
2201 )
2203 self._update_connection_kwargs_for_maint_notifications(
2204 maint_notifications_pool_handler=self._maint_notifications_pool_handler
2205 )
2206 else:
2207 self._maint_notifications_pool_handler = None
2208 self._oss_cluster_maint_notifications_handler = None
2210 @property
2211 @abstractmethod
2212 def connection_kwargs(self) -> Dict[str, Any]:
2213 pass
2215 @connection_kwargs.setter
2216 @abstractmethod
2217 def connection_kwargs(self, value: Dict[str, Any]):
2218 pass
2220 @abstractmethod
2221 def _get_pool_lock(self) -> threading.RLock:
2222 pass
2224 @abstractmethod
2225 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2226 pass
2228 @abstractmethod
2229 def _get_in_use_connections(
2230 self,
2231 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2232 pass
2234 def maint_notifications_enabled(self):
2235 """
2236 Returns:
2237 True if the maintenance notifications are enabled, False otherwise.
2238 The maintenance notifications config is stored in the pool handler.
2239 If the pool handler is not set, the maintenance notifications are not enabled.
2240 """
2241 if self._oss_cluster_maint_notifications_handler:
2242 maint_notifications_config = (
2243 self._oss_cluster_maint_notifications_handler.config
2244 )
2245 else:
2246 maint_notifications_config = (
2247 self._maint_notifications_pool_handler.config
2248 if self._maint_notifications_pool_handler
2249 else None
2250 )
2252 return maint_notifications_config and maint_notifications_config.enabled
2254 def update_maint_notifications_config(
2255 self,
2256 maint_notifications_config: MaintNotificationsConfig,
2257 oss_cluster_maint_notifications_handler: Optional[
2258 OSSMaintNotificationsHandler
2259 ] = None,
2260 ):
2261 """
2262 Updates the maintenance notifications configuration.
2263 This method should be called only if the pool was created
2264 without enabling the maintenance notifications and
2265 in a later point in time maintenance notifications
2266 are requested to be enabled.
2267 """
2268 if (
2269 self.maint_notifications_enabled()
2270 and not maint_notifications_config.enabled
2271 ):
2272 raise ValueError(
2273 "Cannot disable maintenance notifications after enabling them"
2274 )
2275 if oss_cluster_maint_notifications_handler:
2276 self._oss_cluster_maint_notifications_handler = (
2277 oss_cluster_maint_notifications_handler
2278 )
2279 else:
2280 # first update pool settings
2281 if not self._maint_notifications_pool_handler:
2282 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2283 self, maint_notifications_config
2284 )
2285 else:
2286 self._maint_notifications_pool_handler.config = (
2287 maint_notifications_config
2288 )
2290 # then update connection kwargs and existing connections
2291 self._update_connection_kwargs_for_maint_notifications(
2292 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2293 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2294 )
2295 self._update_maint_notifications_configs_for_connections(
2296 maint_notifications_pool_handler=self._maint_notifications_pool_handler,
2297 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
2298 )
2300 def _update_connection_kwargs_for_maint_notifications(
2301 self,
2302 maint_notifications_pool_handler: Optional[
2303 MaintNotificationsPoolHandler
2304 ] = None,
2305 oss_cluster_maint_notifications_handler: Optional[
2306 OSSMaintNotificationsHandler
2307 ] = None,
2308 ):
2309 """
2310 Update the connection kwargs for all future connections.
2311 """
2312 if not self.maint_notifications_enabled():
2313 return
2314 if maint_notifications_pool_handler:
2315 self.connection_kwargs.update(
2316 {
2317 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2318 "maint_notifications_config": maint_notifications_pool_handler.config,
2319 }
2320 )
2321 if oss_cluster_maint_notifications_handler:
2322 self.connection_kwargs.update(
2323 {
2324 "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler,
2325 "maint_notifications_config": oss_cluster_maint_notifications_handler.config,
2326 }
2327 )
2329 # Store original connection parameters for maintenance notifications.
2330 if self.connection_kwargs.get("orig_host_address", None) is None:
2331 # If orig_host_address is None it means we haven't
2332 # configured the original values yet
2333 self.connection_kwargs.update(
2334 {
2335 "orig_host_address": self.connection_kwargs.get("host"),
2336 "orig_socket_timeout": self.connection_kwargs.get(
2337 "socket_timeout", None
2338 ),
2339 "orig_socket_connect_timeout": self.connection_kwargs.get(
2340 "socket_connect_timeout", None
2341 ),
2342 }
2343 )
2345 def _update_maint_notifications_configs_for_connections(
2346 self,
2347 maint_notifications_pool_handler: Optional[
2348 MaintNotificationsPoolHandler
2349 ] = None,
2350 oss_cluster_maint_notifications_handler: Optional[
2351 OSSMaintNotificationsHandler
2352 ] = None,
2353 ):
2354 """Update the maintenance notifications config for all connections in the pool."""
2355 with self._get_pool_lock():
2356 for conn in self._get_free_connections():
2357 if oss_cluster_maint_notifications_handler:
2358 # set cluster handler for conn
2359 conn.set_maint_notifications_cluster_handler_for_connection(
2360 oss_cluster_maint_notifications_handler
2361 )
2362 conn.maint_notifications_config = (
2363 oss_cluster_maint_notifications_handler.config
2364 )
2365 elif maint_notifications_pool_handler:
2366 conn.set_maint_notifications_pool_handler_for_connection(
2367 maint_notifications_pool_handler
2368 )
2369 conn.maint_notifications_config = (
2370 maint_notifications_pool_handler.config
2371 )
2372 else:
2373 raise ValueError(
2374 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2375 )
2376 conn.disconnect()
2377 for conn in self._get_in_use_connections():
2378 if oss_cluster_maint_notifications_handler:
2379 conn.maint_notifications_config = (
2380 oss_cluster_maint_notifications_handler.config
2381 )
2382 conn._configure_maintenance_notifications(
2383 oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler
2384 )
2385 elif maint_notifications_pool_handler:
2386 conn.set_maint_notifications_pool_handler_for_connection(
2387 maint_notifications_pool_handler
2388 )
2389 conn.maint_notifications_config = (
2390 maint_notifications_pool_handler.config
2391 )
2392 else:
2393 raise ValueError(
2394 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set"
2395 )
2396 conn.mark_for_reconnect()
2398 def _should_update_connection(
2399 self,
2400 conn: "MaintNotificationsAbstractConnection",
2401 matching_pattern: Literal[
2402 "connected_address", "configured_address", "notification_hash"
2403 ] = "connected_address",
2404 matching_address: Optional[str] = None,
2405 matching_notification_hash: Optional[int] = None,
2406 ) -> bool:
2407 """
2408 Check if the connection should be updated based on the matching criteria.
2409 """
2410 if matching_pattern == "connected_address":
2411 if matching_address and conn.getpeername() != matching_address:
2412 return False
2413 elif matching_pattern == "configured_address":
2414 if matching_address and conn.host != matching_address:
2415 return False
2416 elif matching_pattern == "notification_hash":
2417 if (
2418 matching_notification_hash
2419 and conn.maintenance_notification_hash != matching_notification_hash
2420 ):
2421 return False
2422 return True
2424 def update_connection_settings(
2425 self,
2426 conn: "MaintNotificationsAbstractConnection",
2427 state: Optional["MaintenanceState"] = None,
2428 maintenance_notification_hash: Optional[int] = None,
2429 host_address: Optional[str] = None,
2430 relaxed_timeout: Optional[float] = None,
2431 update_notification_hash: bool = False,
2432 reset_host_address: bool = False,
2433 reset_relaxed_timeout: bool = False,
2434 ):
2435 """
2436 Update the settings for a single connection.
2437 """
2438 if state:
2439 conn.maintenance_state = state
2441 if update_notification_hash:
2442 # update the notification hash only if requested
2443 conn.maintenance_notification_hash = maintenance_notification_hash
2445 if host_address is not None:
2446 conn.set_tmp_settings(tmp_host_address=host_address)
2448 if relaxed_timeout is not None:
2449 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2451 if reset_relaxed_timeout or reset_host_address:
2452 conn.reset_tmp_settings(
2453 reset_host_address=reset_host_address,
2454 reset_relaxed_timeout=reset_relaxed_timeout,
2455 )
2457 conn.update_current_socket_timeout(relaxed_timeout)
2459 def update_connections_settings(
2460 self,
2461 state: Optional["MaintenanceState"] = None,
2462 maintenance_notification_hash: Optional[int] = None,
2463 host_address: Optional[str] = None,
2464 relaxed_timeout: Optional[float] = None,
2465 matching_address: Optional[str] = None,
2466 matching_notification_hash: Optional[int] = None,
2467 matching_pattern: Literal[
2468 "connected_address", "configured_address", "notification_hash"
2469 ] = "connected_address",
2470 update_notification_hash: bool = False,
2471 reset_host_address: bool = False,
2472 reset_relaxed_timeout: bool = False,
2473 include_free_connections: bool = True,
2474 ):
2475 """
2476 Update the settings for all matching connections in the pool.
2478 This method does not create new connections.
2479 This method does not affect the connection kwargs.
2481 :param state: The maintenance state to set for the connection.
2482 :param maintenance_notification_hash: The hash of the maintenance notification
2483 to set for the connection.
2484 :param host_address: The host address to set for the connection.
2485 :param relaxed_timeout: The relaxed timeout to set for the connection.
2486 :param matching_address: The address to match for the connection.
2487 :param matching_notification_hash: The notification hash to match for the connection.
2488 :param matching_pattern: The pattern to match for the connection.
2489 :param update_notification_hash: Whether to update the notification hash for the connection.
2490 :param reset_host_address: Whether to reset the host address to the original address.
2491 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2492 :param include_free_connections: Whether to include free/available connections.
2493 """
2494 with self._get_pool_lock():
2495 for conn in self._get_in_use_connections():
2496 if self._should_update_connection(
2497 conn,
2498 matching_pattern,
2499 matching_address,
2500 matching_notification_hash,
2501 ):
2502 self.update_connection_settings(
2503 conn,
2504 state=state,
2505 maintenance_notification_hash=maintenance_notification_hash,
2506 host_address=host_address,
2507 relaxed_timeout=relaxed_timeout,
2508 update_notification_hash=update_notification_hash,
2509 reset_host_address=reset_host_address,
2510 reset_relaxed_timeout=reset_relaxed_timeout,
2511 )
2513 if include_free_connections:
2514 for conn in self._get_free_connections():
2515 if self._should_update_connection(
2516 conn,
2517 matching_pattern,
2518 matching_address,
2519 matching_notification_hash,
2520 ):
2521 self.update_connection_settings(
2522 conn,
2523 state=state,
2524 maintenance_notification_hash=maintenance_notification_hash,
2525 host_address=host_address,
2526 relaxed_timeout=relaxed_timeout,
2527 update_notification_hash=update_notification_hash,
2528 reset_host_address=reset_host_address,
2529 reset_relaxed_timeout=reset_relaxed_timeout,
2530 )
2532 def update_connection_kwargs(
2533 self,
2534 **kwargs,
2535 ):
2536 """
2537 Update the connection kwargs for all future connections.
2539 This method updates the connection kwargs for all future connections created by the pool.
2540 Existing connections are not affected.
2541 """
2542 self.connection_kwargs.update(kwargs)
2544 def update_active_connections_for_reconnect(
2545 self,
2546 moving_address_src: Optional[str] = None,
2547 ):
2548 """
2549 Mark all active connections for reconnect.
2550 This is used when a cluster node is migrated to a different address.
2552 :param moving_address_src: The address of the node that is being moved.
2553 """
2554 with self._get_pool_lock():
2555 for conn in self._get_in_use_connections():
2556 if self._should_update_connection(
2557 conn, "connected_address", moving_address_src
2558 ):
2559 conn.mark_for_reconnect()
2561 def disconnect_free_connections(
2562 self,
2563 moving_address_src: Optional[str] = None,
2564 ):
2565 """
2566 Disconnect all free/available connections.
2567 This is used when a cluster node is migrated to a different address.
2569 :param moving_address_src: The address of the node that is being moved.
2570 """
2571 with self._get_pool_lock():
2572 for conn in self._get_free_connections():
2573 if self._should_update_connection(
2574 conn, "connected_address", moving_address_src
2575 ):
2576 conn.disconnect()
2579class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2580 """
2581 Create a connection pool. ``If max_connections`` is set, then this
2582 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2583 limit is reached.
2585 By default, TCP connections are created unless ``connection_class``
2586 is specified. Use class:`.UnixDomainSocketConnection` for
2587 unix sockets.
2588 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2590 If ``maint_notifications_config`` is provided, the connection pool will support
2591 maintenance notifications.
2592 Maintenance notifications are supported only with RESP3.
2593 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2594 the maintenance notifications will be enabled by default.
2596 Any additional keyword arguments are passed to the constructor of
2597 ``connection_class``.
2598 """
2600 @classmethod
2601 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2602 """
2603 Return a connection pool configured from the given URL.
2605 For example::
2607 redis://[[username]:[password]]@localhost:6379/0
2608 rediss://[[username]:[password]]@localhost:6379/0
2609 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2611 Three URL schemes are supported:
2613 - `redis://` creates a TCP socket connection. See more at:
2614 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2615 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2616 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2617 - ``unix://``: creates a Unix Domain Socket connection.
2619 The username, password, hostname, path and all querystring values
2620 are passed through urllib.parse.unquote in order to replace any
2621 percent-encoded values with their corresponding characters.
2623 There are several ways to specify a database number. The first value
2624 found will be used:
2626 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2627 2. If using the redis:// or rediss:// schemes, the path argument
2628 of the url, e.g. redis://localhost/0
2629 3. A ``db`` keyword argument to this function.
2631 If none of these options are specified, the default db=0 is used.
2633 All querystring options are cast to their appropriate Python types.
2634 Boolean arguments can be specified with string values "True"/"False"
2635 or "Yes"/"No". Values that cannot be properly cast cause a
2636 ``ValueError`` to be raised. Once parsed, the querystring arguments
2637 and keyword arguments are passed to the ``ConnectionPool``'s
2638 class initializer. In the case of conflicting arguments, querystring
2639 arguments always win.
2640 """
2641 url_options = parse_url(url)
2643 if "connection_class" in kwargs:
2644 url_options["connection_class"] = kwargs["connection_class"]
2646 kwargs.update(url_options)
2647 return cls(**kwargs)
2649 def __init__(
2650 self,
2651 connection_class=Connection,
2652 max_connections: Optional[int] = None,
2653 cache_factory: Optional[CacheFactoryInterface] = None,
2654 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2655 **connection_kwargs,
2656 ):
2657 max_connections = max_connections or 2**31
2658 if not isinstance(max_connections, int) or max_connections < 0:
2659 raise ValueError('"max_connections" must be a positive integer')
2661 self.connection_class = connection_class
2662 self._connection_kwargs = connection_kwargs
2663 self.max_connections = max_connections
2664 self.cache = None
2665 self._cache_factory = cache_factory
2667 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2668 if not check_protocol_version(self._connection_kwargs.get("protocol"), 3):
2669 raise RedisError("Client caching is only supported with RESP version 3")
2671 cache = self._connection_kwargs.get("cache")
2673 if cache is not None:
2674 if not isinstance(cache, CacheInterface):
2675 raise ValueError("Cache must implement CacheInterface")
2677 self.cache = cache
2678 else:
2679 if self._cache_factory is not None:
2680 self.cache = self._cache_factory.get_cache()
2681 else:
2682 self.cache = CacheFactory(
2683 self._connection_kwargs.get("cache_config")
2684 ).get_cache()
2686 connection_kwargs.pop("cache", None)
2687 connection_kwargs.pop("cache_config", None)
2689 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2690 if self._event_dispatcher is None:
2691 self._event_dispatcher = EventDispatcher()
2693 # a lock to protect the critical section in _checkpid().
2694 # this lock is acquired when the process id changes, such as
2695 # after a fork. during this time, multiple threads in the child
2696 # process could attempt to acquire this lock. the first thread
2697 # to acquire the lock will reset the data structures and lock
2698 # object of this pool. subsequent threads acquiring this lock
2699 # will notice the first thread already did the work and simply
2700 # release the lock.
2702 self._fork_lock = threading.RLock()
2703 self._lock = threading.RLock()
2705 MaintNotificationsAbstractConnectionPool.__init__(
2706 self,
2707 maint_notifications_config=maint_notifications_config,
2708 **connection_kwargs,
2709 )
2711 self.reset()
2713 def __repr__(self) -> str:
2714 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
2715 return (
2716 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2717 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2718 f"({conn_kwargs})>)>"
2719 )
2721 @property
2722 def connection_kwargs(self) -> Dict[str, Any]:
2723 return self._connection_kwargs
2725 @connection_kwargs.setter
2726 def connection_kwargs(self, value: Dict[str, Any]):
2727 self._connection_kwargs = value
2729 def get_protocol(self):
2730 """
2731 Returns:
2732 The RESP protocol version, or ``None`` if the protocol is not specified,
2733 in which case the server default will be used.
2734 """
2735 return self.connection_kwargs.get("protocol", None)
2737 def reset(self) -> None:
2738 self._created_connections = 0
2739 self._available_connections = []
2740 self._in_use_connections = set()
2742 # this must be the last operation in this method. while reset() is
2743 # called when holding _fork_lock, other threads in this process
2744 # can call _checkpid() which compares self.pid and os.getpid() without
2745 # holding any lock (for performance reasons). keeping this assignment
2746 # as the last operation ensures that those other threads will also
2747 # notice a pid difference and block waiting for the first thread to
2748 # release _fork_lock. when each of these threads eventually acquire
2749 # _fork_lock, they will notice that another thread already called
2750 # reset() and they will immediately release _fork_lock and continue on.
2751 self.pid = os.getpid()
2753 def _checkpid(self) -> None:
2754 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
2755 # systems. this is called by all ConnectionPool methods that
2756 # manipulate the pool's state such as get_connection() and release().
2757 #
2758 # _checkpid() determines whether the process has forked by comparing
2759 # the current process id to the process id saved on the ConnectionPool
2760 # instance. if these values are the same, _checkpid() simply returns.
2761 #
2762 # when the process ids differ, _checkpid() assumes that the process
2763 # has forked and that we're now running in the child process. the child
2764 # process cannot use the parent's file descriptors (e.g., sockets).
2765 # therefore, when _checkpid() sees the process id change, it calls
2766 # reset() in order to reinitialize the child's ConnectionPool. this
2767 # will cause the child to make all new connection objects.
2768 #
2769 # _checkpid() is protected by self._fork_lock to ensure that multiple
2770 # threads in the child process do not call reset() multiple times.
2771 #
2772 # there is an extremely small chance this could fail in the following
2773 # scenario:
2774 # 1. process A calls _checkpid() for the first time and acquires
2775 # self._fork_lock.
2776 # 2. while holding self._fork_lock, process A forks (the fork()
2777 # could happen in a different thread owned by process A)
2778 # 3. process B (the forked child process) inherits the
2779 # ConnectionPool's state from the parent. that state includes
2780 # a locked _fork_lock. process B will not be notified when
2781 # process A releases the _fork_lock and will thus never be
2782 # able to acquire the _fork_lock.
2783 #
2784 # to mitigate this possible deadlock, _checkpid() will only wait 5
2785 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
2786 # that time it is assumed that the child is deadlocked and a
2787 # redis.ChildDeadlockedError error is raised.
2788 if self.pid != os.getpid():
2789 acquired = self._fork_lock.acquire(timeout=5)
2790 if not acquired:
2791 raise ChildDeadlockedError
2792 # reset() the instance for the new process if another thread
2793 # hasn't already done so
2794 try:
2795 if self.pid != os.getpid():
2796 self.reset()
2797 finally:
2798 self._fork_lock.release()
2800 @deprecated_args(
2801 args_to_warn=["*"],
2802 reason="Use get_connection() without args instead",
2803 version="5.3.0",
2804 )
2805 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
2806 "Get a connection from the pool"
2808 self._checkpid()
2809 with self._lock:
2810 try:
2811 connection = self._available_connections.pop()
2812 except IndexError:
2813 connection = self.make_connection()
2814 self._in_use_connections.add(connection)
2816 try:
2817 # ensure this connection is connected to Redis
2818 connection.connect()
2819 # connections that the pool provides should be ready to send
2820 # a command. if not, the connection was either returned to the
2821 # pool before all data has been read or the socket has been
2822 # closed. either way, reconnect and verify everything is good.
2823 try:
2824 if (
2825 connection.can_read()
2826 and self.cache is None
2827 and not self.maint_notifications_enabled()
2828 ):
2829 raise ConnectionError("Connection has data")
2830 except (ConnectionError, TimeoutError, OSError):
2831 connection.disconnect()
2832 connection.connect()
2833 if connection.can_read():
2834 raise ConnectionError("Connection not ready")
2835 except BaseException:
2836 # release the connection back to the pool so that we don't
2837 # leak it
2838 self.release(connection)
2839 raise
2840 return connection
2842 def get_encoder(self) -> Encoder:
2843 "Return an encoder based on encoding settings"
2844 kwargs = self.connection_kwargs
2845 return Encoder(
2846 encoding=kwargs.get("encoding", "utf-8"),
2847 encoding_errors=kwargs.get("encoding_errors", "strict"),
2848 decode_responses=kwargs.get("decode_responses", False),
2849 )
2851 def make_connection(self) -> "ConnectionInterface":
2852 "Create a new connection"
2853 if self._created_connections >= self.max_connections:
2854 raise MaxConnectionsError("Too many connections")
2855 self._created_connections += 1
2857 kwargs = dict(self.connection_kwargs)
2859 if self.cache is not None:
2860 return CacheProxyConnection(
2861 self.connection_class(**kwargs), self.cache, self._lock
2862 )
2863 return self.connection_class(**kwargs)
2865 def release(self, connection: "Connection") -> None:
2866 "Releases the connection back to the pool"
2867 self._checkpid()
2868 with self._lock:
2869 try:
2870 self._in_use_connections.remove(connection)
2871 except KeyError:
2872 # Gracefully fail when a connection is returned to this pool
2873 # that the pool doesn't actually own
2874 return
2876 if self.owns_connection(connection):
2877 if connection.should_reconnect():
2878 connection.disconnect()
2879 self._available_connections.append(connection)
2880 self._event_dispatcher.dispatch(
2881 AfterConnectionReleasedEvent(connection)
2882 )
2883 else:
2884 # Pool doesn't own this connection, do not add it back
2885 # to the pool.
2886 # The created connections count should not be changed,
2887 # because the connection was not created by the pool.
2888 connection.disconnect()
2889 return
2891 def owns_connection(self, connection: "Connection") -> int:
2892 return connection.pid == self.pid
2894 def disconnect(self, inuse_connections: bool = True) -> None:
2895 """
2896 Disconnects connections in the pool
2898 If ``inuse_connections`` is True, disconnect connections that are
2899 currently in use, potentially by other threads. Otherwise only disconnect
2900 connections that are idle in the pool.
2901 """
2902 self._checkpid()
2903 with self._lock:
2904 if inuse_connections:
2905 connections = chain(
2906 self._available_connections, self._in_use_connections
2907 )
2908 else:
2909 connections = self._available_connections
2911 for connection in connections:
2912 connection.disconnect()
2914 def close(self) -> None:
2915 """Close the pool, disconnecting all connections"""
2916 self.disconnect()
2918 def set_retry(self, retry: Retry) -> None:
2919 self.connection_kwargs.update({"retry": retry})
2920 for conn in self._available_connections:
2921 conn.retry = retry
2922 for conn in self._in_use_connections:
2923 conn.retry = retry
2925 def re_auth_callback(self, token: TokenInterface):
2926 with self._lock:
2927 for conn in self._available_connections:
2928 conn.retry.call_with_retry(
2929 lambda: conn.send_command(
2930 "AUTH", token.try_get("oid"), token.get_value()
2931 ),
2932 lambda error: self._mock(error),
2933 )
2934 conn.retry.call_with_retry(
2935 lambda: conn.read_response(), lambda error: self._mock(error)
2936 )
2937 for conn in self._in_use_connections:
2938 conn.set_re_auth_token(token)
2940 def _get_pool_lock(self):
2941 return self._lock
2943 def _get_free_connections(self):
2944 with self._lock:
2945 return self._available_connections
2947 def _get_in_use_connections(self):
2948 with self._lock:
2949 return self._in_use_connections
2951 async def _mock(self, error: RedisError):
2952 """
2953 Dummy functions, needs to be passed as error callback to retry object.
2954 :param error:
2955 :return:
2956 """
2957 pass
2960class BlockingConnectionPool(ConnectionPool):
2961 """
2962 Thread-safe blocking connection pool::
2964 >>> from redis.client import Redis
2965 >>> client = Redis(connection_pool=BlockingConnectionPool())
2967 It performs the same function as the default
2968 :py:class:`~redis.ConnectionPool` implementation, in that,
2969 it maintains a pool of reusable connections that can be shared by
2970 multiple redis clients (safely across threads if required).
2972 The difference is that, in the event that a client tries to get a
2973 connection from the pool when all of connections are in use, rather than
2974 raising a :py:class:`~redis.ConnectionError` (as the default
2975 :py:class:`~redis.ConnectionPool` implementation does), it
2976 makes the client wait ("blocks") for a specified number of seconds until
2977 a connection becomes available.
2979 Use ``max_connections`` to increase / decrease the pool size::
2981 >>> pool = BlockingConnectionPool(max_connections=10)
2983 Use ``timeout`` to tell it either how many seconds to wait for a connection
2984 to become available, or to block forever:
2986 >>> # Block forever.
2987 >>> pool = BlockingConnectionPool(timeout=None)
2989 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
2990 >>> # not available.
2991 >>> pool = BlockingConnectionPool(timeout=5)
2992 """
2994 def __init__(
2995 self,
2996 max_connections=50,
2997 timeout=20,
2998 connection_class=Connection,
2999 queue_class=LifoQueue,
3000 **connection_kwargs,
3001 ):
3002 self.queue_class = queue_class
3003 self.timeout = timeout
3004 self._in_maintenance = False
3005 self._locked = False
3006 super().__init__(
3007 connection_class=connection_class,
3008 max_connections=max_connections,
3009 **connection_kwargs,
3010 )
3012 def reset(self):
3013 # Create and fill up a thread safe queue with ``None`` values.
3014 try:
3015 if self._in_maintenance:
3016 self._lock.acquire()
3017 self._locked = True
3018 self.pool = self.queue_class(self.max_connections)
3019 while True:
3020 try:
3021 self.pool.put_nowait(None)
3022 except Full:
3023 break
3025 # Keep a list of actual connection instances so that we can
3026 # disconnect them later.
3027 self._connections = []
3028 finally:
3029 if self._locked:
3030 try:
3031 self._lock.release()
3032 except Exception:
3033 pass
3034 self._locked = False
3036 # this must be the last operation in this method. while reset() is
3037 # called when holding _fork_lock, other threads in this process
3038 # can call _checkpid() which compares self.pid and os.getpid() without
3039 # holding any lock (for performance reasons). keeping this assignment
3040 # as the last operation ensures that those other threads will also
3041 # notice a pid difference and block waiting for the first thread to
3042 # release _fork_lock. when each of these threads eventually acquire
3043 # _fork_lock, they will notice that another thread already called
3044 # reset() and they will immediately release _fork_lock and continue on.
3045 self.pid = os.getpid()
3047 def make_connection(self):
3048 "Make a fresh connection."
3049 try:
3050 if self._in_maintenance:
3051 self._lock.acquire()
3052 self._locked = True
3054 if self.cache is not None:
3055 connection = CacheProxyConnection(
3056 self.connection_class(**self.connection_kwargs),
3057 self.cache,
3058 self._lock,
3059 )
3060 else:
3061 connection = self.connection_class(**self.connection_kwargs)
3062 self._connections.append(connection)
3063 return connection
3064 finally:
3065 if self._locked:
3066 try:
3067 self._lock.release()
3068 except Exception:
3069 pass
3070 self._locked = False
3072 @deprecated_args(
3073 args_to_warn=["*"],
3074 reason="Use get_connection() without args instead",
3075 version="5.3.0",
3076 )
3077 def get_connection(self, command_name=None, *keys, **options):
3078 """
3079 Get a connection, blocking for ``self.timeout`` until a connection
3080 is available from the pool.
3082 If the connection returned is ``None`` then creates a new connection.
3083 Because we use a last-in first-out queue, the existing connections
3084 (having been returned to the pool after the initial ``None`` values
3085 were added) will be returned before ``None`` values. This means we only
3086 create new connections when we need to, i.e.: the actual number of
3087 connections will only increase in response to demand.
3088 """
3089 # Make sure we haven't changed process.
3090 self._checkpid()
3092 # Try and get a connection from the pool. If one isn't available within
3093 # self.timeout then raise a ``ConnectionError``.
3094 connection = None
3095 try:
3096 if self._in_maintenance:
3097 self._lock.acquire()
3098 self._locked = True
3099 try:
3100 connection = self.pool.get(block=True, timeout=self.timeout)
3101 except Empty:
3102 # Note that this is not caught by the redis client and will be
3103 # raised unless handled by application code. If you want never to
3104 raise ConnectionError("No connection available.")
3106 # If the ``connection`` is actually ``None`` then that's a cue to make
3107 # a new connection to add to the pool.
3108 if connection is None:
3109 connection = self.make_connection()
3110 finally:
3111 if self._locked:
3112 try:
3113 self._lock.release()
3114 except Exception:
3115 pass
3116 self._locked = False
3118 try:
3119 # ensure this connection is connected to Redis
3120 connection.connect()
3121 # connections that the pool provides should be ready to send
3122 # a command. if not, the connection was either returned to the
3123 # pool before all data has been read or the socket has been
3124 # closed. either way, reconnect and verify everything is good.
3125 try:
3126 if connection.can_read():
3127 raise ConnectionError("Connection has data")
3128 except (ConnectionError, TimeoutError, OSError):
3129 connection.disconnect()
3130 connection.connect()
3131 if connection.can_read():
3132 raise ConnectionError("Connection not ready")
3133 except BaseException:
3134 # release the connection back to the pool so that we don't leak it
3135 self.release(connection)
3136 raise
3138 return connection
3140 def release(self, connection):
3141 "Releases the connection back to the pool."
3142 # Make sure we haven't changed process.
3143 self._checkpid()
3145 try:
3146 if self._in_maintenance:
3147 self._lock.acquire()
3148 self._locked = True
3149 if not self.owns_connection(connection):
3150 # pool doesn't own this connection. do not add it back
3151 # to the pool. instead add a None value which is a placeholder
3152 # that will cause the pool to recreate the connection if
3153 # its needed.
3154 connection.disconnect()
3155 self.pool.put_nowait(None)
3156 return
3157 if connection.should_reconnect():
3158 connection.disconnect()
3159 # Put the connection back into the pool.
3160 try:
3161 self.pool.put_nowait(connection)
3162 except Full:
3163 # perhaps the pool has been reset() after a fork? regardless,
3164 # we don't want this connection
3165 pass
3166 finally:
3167 if self._locked:
3168 try:
3169 self._lock.release()
3170 except Exception:
3171 pass
3172 self._locked = False
3174 def disconnect(self, inuse_connections: bool = True):
3175 "Disconnects either all connections in the pool or just the free connections."
3176 self._checkpid()
3177 try:
3178 if self._in_maintenance:
3179 self._lock.acquire()
3180 self._locked = True
3181 if inuse_connections:
3182 connections = self._connections
3183 else:
3184 connections = self._get_free_connections()
3185 for connection in connections:
3186 connection.disconnect()
3187 finally:
3188 if self._locked:
3189 try:
3190 self._lock.release()
3191 except Exception:
3192 pass
3193 self._locked = False
3195 def _get_free_connections(self):
3196 with self._lock:
3197 return {conn for conn in self.pool.queue if conn}
3199 def _get_in_use_connections(self):
3200 with self._lock:
3201 # free connections
3202 connections_in_queue = {conn for conn in self.pool.queue if conn}
3203 # in self._connections we keep all created connections
3204 # so the ones that are not in the queue are the in use ones
3205 return {
3206 conn for conn in self._connections if conn not in connections_in_queue
3207 }
3209 def set_in_maintenance(self, in_maintenance: bool):
3210 """
3211 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
3213 This is used to prevent new connections from being created while we are in maintenance mode.
3214 The pool will be in maintenance mode only when we are processing a MOVING notification.
3215 """
3216 self._in_maintenance = in_maintenance