Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import ABC, abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32)
34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
35from .auth.token import TokenInterface
36from .backoff import NoBackoff
37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
38from .driver_info import DriverInfo, resolve_driver_info
39from .event import AfterConnectionReleasedEvent, EventDispatcher
40from .exceptions import (
41 AuthenticationError,
42 AuthenticationWrongNumberOfArgsError,
43 ChildDeadlockedError,
44 ConnectionError,
45 DataError,
46 MaxConnectionsError,
47 RedisError,
48 ResponseError,
49 TimeoutError,
50)
51from .maint_notifications import (
52 MaintenanceState,
53 MaintNotificationsConfig,
54 MaintNotificationsConnectionHandler,
55 MaintNotificationsPoolHandler,
56)
57from .retry import Retry
58from .utils import (
59 CRYPTOGRAPHY_AVAILABLE,
60 HIREDIS_AVAILABLE,
61 SSL_AVAILABLE,
62 compare_versions,
63 deprecated_args,
64 ensure_string,
65 format_error_message,
66 str_if_bytes,
67)
69if SSL_AVAILABLE:
70 import ssl
71 from ssl import VerifyFlags
72else:
73 ssl = None
74 VerifyFlags = None
76if HIREDIS_AVAILABLE:
77 import hiredis
79SYM_STAR = b"*"
80SYM_DOLLAR = b"$"
81SYM_CRLF = b"\r\n"
82SYM_EMPTY = b""
84DEFAULT_RESP_VERSION = 2
86SENTINEL = object()
88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
89if HIREDIS_AVAILABLE:
90 DefaultParser = _HiredisParser
91else:
92 DefaultParser = _RESP2Parser
95class HiredisRespSerializer:
96 def pack(self, *args: List):
97 """Pack a series of arguments into the Redis protocol"""
98 output = []
100 if isinstance(args[0], str):
101 args = tuple(args[0].encode().split()) + args[1:]
102 elif b" " in args[0]:
103 args = tuple(args[0].split()) + args[1:]
104 try:
105 output.append(hiredis.pack_command(args))
106 except TypeError:
107 _, value, traceback = sys.exc_info()
108 raise DataError(value).with_traceback(traceback)
110 return output
113class PythonRespSerializer:
114 def __init__(self, buffer_cutoff, encode) -> None:
115 self._buffer_cutoff = buffer_cutoff
116 self.encode = encode
118 def pack(self, *args):
119 """Pack a series of arguments into the Redis protocol"""
120 output = []
121 # the client might have included 1 or more literal arguments in
122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
123 # arguments to be sent separately, so split the first argument
124 # manually. These arguments should be bytestrings so that they are
125 # not encoded.
126 if isinstance(args[0], str):
127 args = tuple(args[0].encode().split()) + args[1:]
128 elif b" " in args[0]:
129 args = tuple(args[0].split()) + args[1:]
131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
133 buffer_cutoff = self._buffer_cutoff
134 for arg in map(self.encode, args):
135 # to avoid large string mallocs, chunk the command into the
136 # output list if we're sending large values or memoryviews
137 arg_length = len(arg)
138 if (
139 len(buff) > buffer_cutoff
140 or arg_length > buffer_cutoff
141 or isinstance(arg, memoryview)
142 ):
143 buff = SYM_EMPTY.join(
144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
145 )
146 output.append(buff)
147 output.append(arg)
148 buff = SYM_CRLF
149 else:
150 buff = SYM_EMPTY.join(
151 (
152 buff,
153 SYM_DOLLAR,
154 str(arg_length).encode(),
155 SYM_CRLF,
156 arg,
157 SYM_CRLF,
158 )
159 )
160 output.append(buff)
161 return output
164class ConnectionInterface:
165 @abstractmethod
166 def repr_pieces(self):
167 pass
169 @abstractmethod
170 def register_connect_callback(self, callback):
171 pass
173 @abstractmethod
174 def deregister_connect_callback(self, callback):
175 pass
177 @abstractmethod
178 def set_parser(self, parser_class):
179 pass
181 @abstractmethod
182 def get_protocol(self):
183 pass
185 @abstractmethod
186 def connect(self):
187 pass
189 @abstractmethod
190 def on_connect(self):
191 pass
193 @abstractmethod
194 def disconnect(self, *args):
195 pass
197 @abstractmethod
198 def check_health(self):
199 pass
201 @abstractmethod
202 def send_packed_command(self, command, check_health=True):
203 pass
205 @abstractmethod
206 def send_command(self, *args, **kwargs):
207 pass
209 @abstractmethod
210 def can_read(self, timeout=0):
211 pass
213 @abstractmethod
214 def read_response(
215 self,
216 disable_decoding=False,
217 *,
218 disconnect_on_error=True,
219 push_request=False,
220 ):
221 pass
223 @abstractmethod
224 def pack_command(self, *args):
225 pass
227 @abstractmethod
228 def pack_commands(self, commands):
229 pass
231 @property
232 @abstractmethod
233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
234 pass
236 @abstractmethod
237 def set_re_auth_token(self, token: TokenInterface):
238 pass
240 @abstractmethod
241 def re_auth(self):
242 pass
244 @abstractmethod
245 def mark_for_reconnect(self):
246 """
247 Mark the connection to be reconnected on the next command.
248 This is useful when a connection is moved to a different node.
249 """
250 pass
252 @abstractmethod
253 def should_reconnect(self):
254 """
255 Returns True if the connection should be reconnected.
256 """
257 pass
259 @abstractmethod
260 def reset_should_reconnect(self):
261 """
262 Reset the internal flag to False.
263 """
264 pass
267class MaintNotificationsAbstractConnection:
268 """
269 Abstract class for handling maintenance notifications logic.
270 This class is expected to be used as base class together with ConnectionInterface.
272 This class is intended to be used with multiple inheritance!
274 All logic related to maintenance notifications is encapsulated in this class.
275 """
277 def __init__(
278 self,
279 maint_notifications_config: Optional[MaintNotificationsConfig],
280 maint_notifications_pool_handler: Optional[
281 MaintNotificationsPoolHandler
282 ] = None,
283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
284 maintenance_notification_hash: Optional[int] = None,
285 orig_host_address: Optional[str] = None,
286 orig_socket_timeout: Optional[float] = None,
287 orig_socket_connect_timeout: Optional[float] = None,
288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
289 ):
290 """
291 Initialize the maintenance notifications for the connection.
293 Args:
294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
296 maintenance_state (MaintenanceState): The current maintenance state of the connection.
297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
298 orig_host_address (Optional[str]): The original host address of the connection.
299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
302 If not provided, the parser from the connection is used.
303 This is useful when the parser is created after this object.
304 """
305 self.maint_notifications_config = maint_notifications_config
306 self.maintenance_state = maintenance_state
307 self.maintenance_notification_hash = maintenance_notification_hash
308 self._configure_maintenance_notifications(
309 maint_notifications_pool_handler,
310 orig_host_address,
311 orig_socket_timeout,
312 orig_socket_connect_timeout,
313 parser,
314 )
316 @abstractmethod
317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
318 pass
320 @abstractmethod
321 def _get_socket(self) -> Optional[socket.socket]:
322 pass
324 @abstractmethod
325 def get_protocol(self) -> Union[int, str]:
326 """
327 Returns:
328 The RESP protocol version, or ``None`` if the protocol is not specified,
329 in which case the server default will be used.
330 """
331 pass
333 @property
334 @abstractmethod
335 def host(self) -> str:
336 pass
338 @host.setter
339 @abstractmethod
340 def host(self, value: str):
341 pass
343 @property
344 @abstractmethod
345 def socket_timeout(self) -> Optional[Union[float, int]]:
346 pass
348 @socket_timeout.setter
349 @abstractmethod
350 def socket_timeout(self, value: Optional[Union[float, int]]):
351 pass
353 @property
354 @abstractmethod
355 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
356 pass
358 @socket_connect_timeout.setter
359 @abstractmethod
360 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
361 pass
363 @abstractmethod
364 def send_command(self, *args, **kwargs):
365 pass
367 @abstractmethod
368 def read_response(
369 self,
370 disable_decoding=False,
371 *,
372 disconnect_on_error=True,
373 push_request=False,
374 ):
375 pass
377 @abstractmethod
378 def disconnect(self, *args):
379 pass
381 def _configure_maintenance_notifications(
382 self,
383 maint_notifications_pool_handler: Optional[
384 MaintNotificationsPoolHandler
385 ] = None,
386 orig_host_address=None,
387 orig_socket_timeout=None,
388 orig_socket_connect_timeout=None,
389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
390 ):
391 """
392 Enable maintenance notifications by setting up
393 handlers and storing original connection parameters.
395 Should be used ONLY with parsers that support push notifications.
396 """
397 if (
398 not self.maint_notifications_config
399 or not self.maint_notifications_config.enabled
400 ):
401 self._maint_notifications_pool_handler = None
402 self._maint_notifications_connection_handler = None
403 return
405 if not parser:
406 raise RedisError(
407 "To configure maintenance notifications, a parser must be provided!"
408 )
410 if not isinstance(parser, _HiredisParser) and not isinstance(
411 parser, _RESP3Parser
412 ):
413 raise RedisError(
414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
415 )
417 if maint_notifications_pool_handler:
418 # Extract a reference to a new pool handler that copies all properties
419 # of the original one and has a different connection reference
420 # This is needed because when we attach the handler to the parser
421 # we need to make sure that the handler has a reference to the
422 # connection that the parser is attached to.
423 self._maint_notifications_pool_handler = (
424 maint_notifications_pool_handler.get_handler_for_connection()
425 )
426 self._maint_notifications_pool_handler.set_connection(self)
427 else:
428 self._maint_notifications_pool_handler = None
430 self._maint_notifications_connection_handler = (
431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
432 )
434 # Set up pool handler if available
435 if self._maint_notifications_pool_handler:
436 parser.set_node_moving_push_handler(
437 self._maint_notifications_pool_handler.handle_notification
438 )
440 # Set up connection handler
441 parser.set_maintenance_push_handler(
442 self._maint_notifications_connection_handler.handle_notification
443 )
445 # Store original connection parameters
446 self.orig_host_address = orig_host_address if orig_host_address else self.host
447 self.orig_socket_timeout = (
448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
449 )
450 self.orig_socket_connect_timeout = (
451 orig_socket_connect_timeout
452 if orig_socket_connect_timeout
453 else self.socket_connect_timeout
454 )
456 def set_maint_notifications_pool_handler_for_connection(
457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
458 ):
459 # Deep copy the pool handler to avoid sharing the same pool handler
460 # between multiple connections, because otherwise each connection will override
461 # the connection reference and the pool handler will only hold a reference
462 # to the last connection that was set.
463 maint_notifications_pool_handler_copy = (
464 maint_notifications_pool_handler.get_handler_for_connection()
465 )
467 maint_notifications_pool_handler_copy.set_connection(self)
468 self._get_parser().set_node_moving_push_handler(
469 maint_notifications_pool_handler_copy.handle_notification
470 )
472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
474 # Update maintenance notification connection handler if it doesn't exist
475 if not self._maint_notifications_connection_handler:
476 self._maint_notifications_connection_handler = (
477 MaintNotificationsConnectionHandler(
478 self, maint_notifications_pool_handler.config
479 )
480 )
481 self._get_parser().set_maintenance_push_handler(
482 self._maint_notifications_connection_handler.handle_notification
483 )
484 else:
485 self._maint_notifications_connection_handler.config = (
486 maint_notifications_pool_handler.config
487 )
489 def activate_maint_notifications_handling_if_enabled(self, check_health=True):
490 # Send maintenance notifications handshake if RESP3 is active
491 # and maintenance notifications are enabled
492 # and we have a host to determine the endpoint type from
493 # When the maint_notifications_config enabled mode is "auto",
494 # we just log a warning if the handshake fails
495 # When the mode is enabled=True, we raise an exception in case of failure
496 if (
497 self.get_protocol() not in [2, "2"]
498 and self.maint_notifications_config
499 and self.maint_notifications_config.enabled
500 and self._maint_notifications_connection_handler
501 and hasattr(self, "host")
502 ):
503 self._enable_maintenance_notifications(
504 maint_notifications_config=self.maint_notifications_config,
505 check_health=check_health,
506 )
508 def _enable_maintenance_notifications(
509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True
510 ):
511 try:
512 host = getattr(self, "host", None)
513 if host is None:
514 raise ValueError(
515 "Cannot enable maintenance notifications for connection"
516 " object that doesn't have a host attribute."
517 )
518 else:
519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
520 self.send_command(
521 "CLIENT",
522 "MAINT_NOTIFICATIONS",
523 "ON",
524 "moving-endpoint-type",
525 endpoint_type.value,
526 check_health=check_health,
527 )
528 response = self.read_response()
529 if not response or str_if_bytes(response) != "OK":
530 raise ResponseError(
531 "The server doesn't support maintenance notifications"
532 )
533 except Exception as e:
534 if (
535 isinstance(e, ResponseError)
536 and maint_notifications_config.enabled == "auto"
537 ):
538 # Log warning but don't fail the connection
539 import logging
541 logger = logging.getLogger(__name__)
542 logger.debug(f"Failed to enable maintenance notifications: {e}")
543 else:
544 raise
546 def get_resolved_ip(self) -> Optional[str]:
547 """
548 Extract the resolved IP address from an
549 established connection or resolve it from the host.
551 First tries to get the actual IP from the socket (most accurate),
552 then falls back to DNS resolution if needed.
554 Args:
555 connection: The connection object to extract the IP from
557 Returns:
558 str: The resolved IP address, or None if it cannot be determined
559 """
561 # Method 1: Try to get the actual IP from the established socket connection
562 # This is most accurate as it shows the exact IP being used
563 try:
564 conn_socket = self._get_socket()
565 if conn_socket is not None:
566 peer_addr = conn_socket.getpeername()
567 if peer_addr and len(peer_addr) >= 1:
568 # For TCP sockets, peer_addr is typically (host, port) tuple
569 # Return just the host part
570 return peer_addr[0]
571 except (AttributeError, OSError):
572 # Socket might not be connected or getpeername() might fail
573 pass
575 # Method 2: Fallback to DNS resolution of the host
576 # This is less accurate but works when socket is not available
577 try:
578 host = getattr(self, "host", "localhost")
579 port = getattr(self, "port", 6379)
580 if host:
581 # Use getaddrinfo to resolve the hostname to IP
582 # This mimics what the connection would do during _connect()
583 addr_info = socket.getaddrinfo(
584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
585 )
586 if addr_info:
587 # Return the IP from the first result
588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
589 # sockaddr[0] is the IP address
590 return str(addr_info[0][4][0])
591 except (AttributeError, OSError, socket.gaierror):
592 # DNS resolution might fail
593 pass
595 return None
597 @property
598 def maintenance_state(self) -> MaintenanceState:
599 return self._maintenance_state
601 @maintenance_state.setter
602 def maintenance_state(self, state: "MaintenanceState"):
603 self._maintenance_state = state
605 def getpeername(self):
606 """
607 Returns the peer name of the connection.
608 """
609 conn_socket = self._get_socket()
610 if conn_socket:
611 return conn_socket.getpeername()[0]
612 return None
614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
615 conn_socket = self._get_socket()
616 if conn_socket:
617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
618 conn_socket.settimeout(timeout)
619 self.update_parser_timeout(timeout)
621 def update_parser_timeout(self, timeout: Optional[float] = None):
622 parser = self._get_parser()
623 if parser and parser._buffer:
624 if isinstance(parser, _RESP3Parser) and timeout:
625 parser._buffer.socket_timeout = timeout
626 elif isinstance(parser, _HiredisParser):
627 parser._socket_timeout = timeout
629 def set_tmp_settings(
630 self,
631 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
632 tmp_relaxed_timeout: Optional[float] = None,
633 ):
634 """
635 The value of SENTINEL is used to indicate that the property should not be updated.
636 """
637 if tmp_host_address and tmp_host_address != SENTINEL:
638 self.host = str(tmp_host_address)
639 if tmp_relaxed_timeout != -1:
640 self.socket_timeout = tmp_relaxed_timeout
641 self.socket_connect_timeout = tmp_relaxed_timeout
643 def reset_tmp_settings(
644 self,
645 reset_host_address: bool = False,
646 reset_relaxed_timeout: bool = False,
647 ):
648 if reset_host_address:
649 self.host = self.orig_host_address
650 if reset_relaxed_timeout:
651 self.socket_timeout = self.orig_socket_timeout
652 self.socket_connect_timeout = self.orig_socket_connect_timeout
655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
656 "Manages communication to and from a Redis server"
658 @deprecated_args(
659 args_to_warn=["lib_name", "lib_version"],
660 reason="Use 'driver_info' parameter instead. "
661 "lib_name and lib_version will be removed in a future version.",
662 )
663 def __init__(
664 self,
665 db: int = 0,
666 password: Optional[str] = None,
667 socket_timeout: Optional[float] = None,
668 socket_connect_timeout: Optional[float] = None,
669 retry_on_timeout: bool = False,
670 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
671 encoding: str = "utf-8",
672 encoding_errors: str = "strict",
673 decode_responses: bool = False,
674 parser_class=DefaultParser,
675 socket_read_size: int = 65536,
676 health_check_interval: int = 0,
677 client_name: Optional[str] = None,
678 lib_name: Optional[str] = None,
679 lib_version: Optional[str] = None,
680 driver_info: Optional[DriverInfo] = None,
681 username: Optional[str] = None,
682 retry: Union[Any, None] = None,
683 redis_connect_func: Optional[Callable[[], None]] = None,
684 credential_provider: Optional[CredentialProvider] = None,
685 protocol: Optional[int] = 2,
686 command_packer: Optional[Callable[[], None]] = None,
687 event_dispatcher: Optional[EventDispatcher] = None,
688 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
689 maint_notifications_pool_handler: Optional[
690 MaintNotificationsPoolHandler
691 ] = None,
692 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
693 maintenance_notification_hash: Optional[int] = None,
694 orig_host_address: Optional[str] = None,
695 orig_socket_timeout: Optional[float] = None,
696 orig_socket_connect_timeout: Optional[float] = None,
697 ):
698 """
699 Initialize a new Connection.
701 To specify a retry policy for specific errors, first set
702 `retry_on_error` to a list of the error/s to retry on, then set
703 `retry` to a valid `Retry` object.
704 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
706 Parameters
707 ----------
708 driver_info : DriverInfo, optional
709 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
710 are ignored. If not provided, a DriverInfo will be created from lib_name
711 and lib_version (or defaults if those are also None).
712 lib_name : str, optional
713 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
714 lib_version : str, optional
715 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
716 """
717 if (username or password) and credential_provider is not None:
718 raise DataError(
719 "'username' and 'password' cannot be passed along with 'credential_"
720 "provider'. Please provide only one of the following arguments: \n"
721 "1. 'password' and (optional) 'username'\n"
722 "2. 'credential_provider'"
723 )
724 if event_dispatcher is None:
725 self._event_dispatcher = EventDispatcher()
726 else:
727 self._event_dispatcher = event_dispatcher
728 self.pid = os.getpid()
729 self.db = db
730 self.client_name = client_name
732 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
733 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
735 self.credential_provider = credential_provider
736 self.password = password
737 self.username = username
738 self._socket_timeout = socket_timeout
739 if socket_connect_timeout is None:
740 socket_connect_timeout = socket_timeout
741 self._socket_connect_timeout = socket_connect_timeout
742 self.retry_on_timeout = retry_on_timeout
743 if retry_on_error is SENTINEL:
744 retry_on_errors_list = []
745 else:
746 retry_on_errors_list = list(retry_on_error)
747 if retry_on_timeout:
748 # Add TimeoutError to the errors list to retry on
749 retry_on_errors_list.append(TimeoutError)
750 self.retry_on_error = retry_on_errors_list
751 if retry or self.retry_on_error:
752 if retry is None:
753 self.retry = Retry(NoBackoff(), 1)
754 else:
755 # deep-copy the Retry object as it is mutable
756 self.retry = copy.deepcopy(retry)
757 if self.retry_on_error:
758 # Update the retry's supported errors with the specified errors
759 self.retry.update_supported_errors(self.retry_on_error)
760 else:
761 self.retry = Retry(NoBackoff(), 0)
762 self.health_check_interval = health_check_interval
763 self.next_health_check = 0
764 self.redis_connect_func = redis_connect_func
765 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
766 self.handshake_metadata = None
767 self._sock = None
768 self._socket_read_size = socket_read_size
769 self._connect_callbacks = []
770 self._buffer_cutoff = 6000
771 self._re_auth_token: Optional[TokenInterface] = None
772 try:
773 p = int(protocol)
774 except TypeError:
775 p = DEFAULT_RESP_VERSION
776 except ValueError:
777 raise ConnectionError("protocol must be an integer")
778 finally:
779 if p < 2 or p > 3:
780 raise ConnectionError("protocol must be either 2 or 3")
781 # p = DEFAULT_RESP_VERSION
782 self.protocol = p
783 if self.protocol == 3 and parser_class == _RESP2Parser:
784 # If the protocol is 3 but the parser is RESP2, change it to RESP3
785 # This is needed because the parser might be set before the protocol
786 # or might be provided as a kwarg to the constructor
787 # We need to react on discrepancy only for RESP2 and RESP3
788 # as hiredis supports both
789 parser_class = _RESP3Parser
790 self.set_parser(parser_class)
792 self._command_packer = self._construct_command_packer(command_packer)
793 self._should_reconnect = False
795 # Set up maintenance notifications
796 MaintNotificationsAbstractConnection.__init__(
797 self,
798 maint_notifications_config,
799 maint_notifications_pool_handler,
800 maintenance_state,
801 maintenance_notification_hash,
802 orig_host_address,
803 orig_socket_timeout,
804 orig_socket_connect_timeout,
805 self._parser,
806 )
808 def __repr__(self):
809 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
810 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
812 @abstractmethod
813 def repr_pieces(self):
814 pass
816 def __del__(self):
817 try:
818 self.disconnect()
819 except Exception:
820 pass
822 def _construct_command_packer(self, packer):
823 if packer is not None:
824 return packer
825 elif HIREDIS_AVAILABLE:
826 return HiredisRespSerializer()
827 else:
828 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
830 def register_connect_callback(self, callback):
831 """
832 Register a callback to be called when the connection is established either
833 initially or reconnected. This allows listeners to issue commands that
834 are ephemeral to the connection, for example pub/sub subscription or
835 key tracking. The callback must be a _method_ and will be kept as
836 a weak reference.
837 """
838 wm = weakref.WeakMethod(callback)
839 if wm not in self._connect_callbacks:
840 self._connect_callbacks.append(wm)
842 def deregister_connect_callback(self, callback):
843 """
844 De-register a previously registered callback. It will no-longer receive
845 notifications on connection events. Calling this is not required when the
846 listener goes away, since the callbacks are kept as weak methods.
847 """
848 try:
849 self._connect_callbacks.remove(weakref.WeakMethod(callback))
850 except ValueError:
851 pass
853 def set_parser(self, parser_class):
854 """
855 Creates a new instance of parser_class with socket size:
856 _socket_read_size and assigns it to the parser for the connection
857 :param parser_class: The required parser class
858 """
859 self._parser = parser_class(socket_read_size=self._socket_read_size)
861 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
862 return self._parser
864 def connect(self):
865 "Connects to the Redis server if not already connected"
866 # try once the socket connect with the handshake, retry the whole
867 # connect/handshake flow based on retry policy
868 self.retry.call_with_retry(
869 lambda: self.connect_check_health(
870 check_health=True, retry_socket_connect=False
871 ),
872 lambda error: self.disconnect(error),
873 )
875 def connect_check_health(
876 self, check_health: bool = True, retry_socket_connect: bool = True
877 ):
878 if self._sock:
879 return
880 try:
881 if retry_socket_connect:
882 sock = self.retry.call_with_retry(
883 lambda: self._connect(), lambda error: self.disconnect(error)
884 )
885 else:
886 sock = self._connect()
887 except socket.timeout:
888 raise TimeoutError("Timeout connecting to server")
889 except OSError as e:
890 raise ConnectionError(self._error_message(e))
892 self._sock = sock
893 try:
894 if self.redis_connect_func is None:
895 # Use the default on_connect function
896 self.on_connect_check_health(check_health=check_health)
897 else:
898 # Use the passed function redis_connect_func
899 self.redis_connect_func(self)
900 except RedisError:
901 # clean up after any error in on_connect
902 self.disconnect()
903 raise
905 # run any user callbacks. right now the only internal callback
906 # is for pubsub channel/pattern resubscription
907 # first, remove any dead weakrefs
908 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
909 for ref in self._connect_callbacks:
910 callback = ref()
911 if callback:
912 callback(self)
914 @abstractmethod
915 def _connect(self):
916 pass
918 @abstractmethod
919 def _host_error(self):
920 pass
922 def _error_message(self, exception):
923 return format_error_message(self._host_error(), exception)
925 def on_connect(self):
926 self.on_connect_check_health(check_health=True)
928 def on_connect_check_health(self, check_health: bool = True):
929 "Initialize the connection, authenticate and select a database"
930 self._parser.on_connect(self)
931 parser = self._parser
933 auth_args = None
934 # if credential provider or username and/or password are set, authenticate
935 if self.credential_provider or (self.username or self.password):
936 cred_provider = (
937 self.credential_provider
938 or UsernamePasswordCredentialProvider(self.username, self.password)
939 )
940 auth_args = cred_provider.get_credentials()
942 # if resp version is specified and we have auth args,
943 # we need to send them via HELLO
944 if auth_args and self.protocol not in [2, "2"]:
945 if isinstance(self._parser, _RESP2Parser):
946 self.set_parser(_RESP3Parser)
947 # update cluster exception classes
948 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
949 self._parser.on_connect(self)
950 if len(auth_args) == 1:
951 auth_args = ["default", auth_args[0]]
952 # avoid checking health here -- PING will fail if we try
953 # to check the health prior to the AUTH
954 self.send_command(
955 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
956 )
957 self.handshake_metadata = self.read_response()
958 # if response.get(b"proto") != self.protocol and response.get(
959 # "proto"
960 # ) != self.protocol:
961 # raise ConnectionError("Invalid RESP version")
962 elif auth_args:
963 # avoid checking health here -- PING will fail if we try
964 # to check the health prior to the AUTH
965 self.send_command("AUTH", *auth_args, check_health=False)
967 try:
968 auth_response = self.read_response()
969 except AuthenticationWrongNumberOfArgsError:
970 # a username and password were specified but the Redis
971 # server seems to be < 6.0.0 which expects a single password
972 # arg. retry auth with just the password.
973 # https://github.com/andymccurdy/redis-py/issues/1274
974 self.send_command("AUTH", auth_args[-1], check_health=False)
975 auth_response = self.read_response()
977 if str_if_bytes(auth_response) != "OK":
978 raise AuthenticationError("Invalid Username or Password")
980 # if resp version is specified, switch to it
981 elif self.protocol not in [2, "2"]:
982 if isinstance(self._parser, _RESP2Parser):
983 self.set_parser(_RESP3Parser)
984 # update cluster exception classes
985 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
986 self._parser.on_connect(self)
987 self.send_command("HELLO", self.protocol, check_health=check_health)
988 self.handshake_metadata = self.read_response()
989 if (
990 self.handshake_metadata.get(b"proto") != self.protocol
991 and self.handshake_metadata.get("proto") != self.protocol
992 ):
993 raise ConnectionError("Invalid RESP version")
995 # Activate maintenance notifications for this connection
996 # if enabled in the configuration
997 # This is a no-op if maintenance notifications are not enabled
998 self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
1000 # if a client_name is given, set it
1001 if self.client_name:
1002 self.send_command(
1003 "CLIENT",
1004 "SETNAME",
1005 self.client_name,
1006 check_health=check_health,
1007 )
1008 if str_if_bytes(self.read_response()) != "OK":
1009 raise ConnectionError("Error setting client name")
1011 # Set the library name and version from driver_info
1012 try:
1013 if self.driver_info and self.driver_info.formatted_name:
1014 self.send_command(
1015 "CLIENT",
1016 "SETINFO",
1017 "LIB-NAME",
1018 self.driver_info.formatted_name,
1019 check_health=check_health,
1020 )
1021 self.read_response()
1022 except ResponseError:
1023 pass
1025 try:
1026 if self.driver_info and self.driver_info.lib_version:
1027 self.send_command(
1028 "CLIENT",
1029 "SETINFO",
1030 "LIB-VER",
1031 self.driver_info.lib_version,
1032 check_health=check_health,
1033 )
1034 self.read_response()
1035 except ResponseError:
1036 pass
1038 # if a database is specified, switch to it
1039 if self.db:
1040 self.send_command("SELECT", self.db, check_health=check_health)
1041 if str_if_bytes(self.read_response()) != "OK":
1042 raise ConnectionError("Invalid Database")
1044 def disconnect(self, *args):
1045 "Disconnects from the Redis server"
1046 self._parser.on_disconnect()
1048 conn_sock = self._sock
1049 self._sock = None
1050 # reset the reconnect flag
1051 self.reset_should_reconnect()
1052 if conn_sock is None:
1053 return
1055 if os.getpid() == self.pid:
1056 try:
1057 conn_sock.shutdown(socket.SHUT_RDWR)
1058 except (OSError, TypeError):
1059 pass
1061 try:
1062 conn_sock.close()
1063 except OSError:
1064 pass
1066 def mark_for_reconnect(self):
1067 self._should_reconnect = True
1069 def should_reconnect(self):
1070 return self._should_reconnect
1072 def reset_should_reconnect(self):
1073 self._should_reconnect = False
1075 def _send_ping(self):
1076 """Send PING, expect PONG in return"""
1077 self.send_command("PING", check_health=False)
1078 if str_if_bytes(self.read_response()) != "PONG":
1079 raise ConnectionError("Bad response from PING health check")
1081 def _ping_failed(self, error):
1082 """Function to call when PING fails"""
1083 self.disconnect()
1085 def check_health(self):
1086 """Check the health of the connection with a PING/PONG"""
1087 if self.health_check_interval and time.monotonic() > self.next_health_check:
1088 self.retry.call_with_retry(self._send_ping, self._ping_failed)
1090 def send_packed_command(self, command, check_health=True):
1091 """Send an already packed command to the Redis server"""
1092 if not self._sock:
1093 self.connect_check_health(check_health=False)
1094 # guard against health check recursion
1095 if check_health:
1096 self.check_health()
1097 try:
1098 if isinstance(command, str):
1099 command = [command]
1100 for item in command:
1101 self._sock.sendall(item)
1102 except socket.timeout:
1103 self.disconnect()
1104 raise TimeoutError("Timeout writing to socket")
1105 except OSError as e:
1106 self.disconnect()
1107 if len(e.args) == 1:
1108 errno, errmsg = "UNKNOWN", e.args[0]
1109 else:
1110 errno = e.args[0]
1111 errmsg = e.args[1]
1112 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1113 except BaseException:
1114 # BaseExceptions can be raised when a socket send operation is not
1115 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1116 # to send un-sent data. However, the send_packed_command() API
1117 # does not support it so there is no point in keeping the connection open.
1118 self.disconnect()
1119 raise
1121 def send_command(self, *args, **kwargs):
1122 """Pack and send a command to the Redis server"""
1123 self.send_packed_command(
1124 self._command_packer.pack(*args),
1125 check_health=kwargs.get("check_health", True),
1126 )
1128 def can_read(self, timeout=0):
1129 """Poll the socket to see if there's data that can be read."""
1130 sock = self._sock
1131 if not sock:
1132 self.connect()
1134 host_error = self._host_error()
1136 try:
1137 return self._parser.can_read(timeout)
1139 except OSError as e:
1140 self.disconnect()
1141 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1143 def read_response(
1144 self,
1145 disable_decoding=False,
1146 *,
1147 disconnect_on_error=True,
1148 push_request=False,
1149 ):
1150 """Read the response from a previously sent command"""
1152 host_error = self._host_error()
1154 try:
1155 if self.protocol in ["3", 3]:
1156 response = self._parser.read_response(
1157 disable_decoding=disable_decoding, push_request=push_request
1158 )
1159 else:
1160 response = self._parser.read_response(disable_decoding=disable_decoding)
1161 except socket.timeout:
1162 if disconnect_on_error:
1163 self.disconnect()
1164 raise TimeoutError(f"Timeout reading from {host_error}")
1165 except OSError as e:
1166 if disconnect_on_error:
1167 self.disconnect()
1168 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1169 except BaseException:
1170 # Also by default close in case of BaseException. A lot of code
1171 # relies on this behaviour when doing Command/Response pairs.
1172 # See #1128.
1173 if disconnect_on_error:
1174 self.disconnect()
1175 raise
1177 if self.health_check_interval:
1178 self.next_health_check = time.monotonic() + self.health_check_interval
1180 if isinstance(response, ResponseError):
1181 try:
1182 raise response
1183 finally:
1184 del response # avoid creating ref cycles
1185 return response
1187 def pack_command(self, *args):
1188 """Pack a series of arguments into the Redis protocol"""
1189 return self._command_packer.pack(*args)
1191 def pack_commands(self, commands):
1192 """Pack multiple commands into the Redis protocol"""
1193 output = []
1194 pieces = []
1195 buffer_length = 0
1196 buffer_cutoff = self._buffer_cutoff
1198 for cmd in commands:
1199 for chunk in self._command_packer.pack(*cmd):
1200 chunklen = len(chunk)
1201 if (
1202 buffer_length > buffer_cutoff
1203 or chunklen > buffer_cutoff
1204 or isinstance(chunk, memoryview)
1205 ):
1206 if pieces:
1207 output.append(SYM_EMPTY.join(pieces))
1208 buffer_length = 0
1209 pieces = []
1211 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1212 output.append(chunk)
1213 else:
1214 pieces.append(chunk)
1215 buffer_length += chunklen
1217 if pieces:
1218 output.append(SYM_EMPTY.join(pieces))
1219 return output
1221 def get_protocol(self) -> Union[int, str]:
1222 return self.protocol
1224 @property
1225 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1226 return self._handshake_metadata
1228 @handshake_metadata.setter
1229 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1230 self._handshake_metadata = value
1232 def set_re_auth_token(self, token: TokenInterface):
1233 self._re_auth_token = token
1235 def re_auth(self):
1236 if self._re_auth_token is not None:
1237 self.send_command(
1238 "AUTH",
1239 self._re_auth_token.try_get("oid"),
1240 self._re_auth_token.get_value(),
1241 )
1242 self.read_response()
1243 self._re_auth_token = None
1245 def _get_socket(self) -> Optional[socket.socket]:
1246 return self._sock
1248 @property
1249 def socket_timeout(self) -> Optional[Union[float, int]]:
1250 return self._socket_timeout
1252 @socket_timeout.setter
1253 def socket_timeout(self, value: Optional[Union[float, int]]):
1254 self._socket_timeout = value
1256 @property
1257 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1258 return self._socket_connect_timeout
1260 @socket_connect_timeout.setter
1261 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1262 self._socket_connect_timeout = value
1265class Connection(AbstractConnection):
1266 "Manages TCP communication to and from a Redis server"
1268 def __init__(
1269 self,
1270 host="localhost",
1271 port=6379,
1272 socket_keepalive=False,
1273 socket_keepalive_options=None,
1274 socket_type=0,
1275 **kwargs,
1276 ):
1277 self._host = host
1278 self.port = int(port)
1279 self.socket_keepalive = socket_keepalive
1280 self.socket_keepalive_options = socket_keepalive_options or {}
1281 self.socket_type = socket_type
1282 super().__init__(**kwargs)
1284 def repr_pieces(self):
1285 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1286 if self.client_name:
1287 pieces.append(("client_name", self.client_name))
1288 return pieces
1290 def _connect(self):
1291 "Create a TCP socket connection"
1292 # we want to mimic what socket.create_connection does to support
1293 # ipv4/ipv6, but we want to set options prior to calling
1294 # socket.connect()
1295 err = None
1297 for res in socket.getaddrinfo(
1298 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1299 ):
1300 family, socktype, proto, canonname, socket_address = res
1301 sock = None
1302 try:
1303 sock = socket.socket(family, socktype, proto)
1304 # TCP_NODELAY
1305 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1307 # TCP_KEEPALIVE
1308 if self.socket_keepalive:
1309 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1310 for k, v in self.socket_keepalive_options.items():
1311 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1313 # set the socket_connect_timeout before we connect
1314 sock.settimeout(self.socket_connect_timeout)
1316 # connect
1317 sock.connect(socket_address)
1319 # set the socket_timeout now that we're connected
1320 sock.settimeout(self.socket_timeout)
1321 return sock
1323 except OSError as _:
1324 err = _
1325 if sock is not None:
1326 try:
1327 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1328 except OSError:
1329 pass
1330 sock.close()
1332 if err is not None:
1333 raise err
1334 raise OSError("socket.getaddrinfo returned an empty list")
1336 def _host_error(self):
1337 return f"{self.host}:{self.port}"
1339 @property
1340 def host(self) -> str:
1341 return self._host
1343 @host.setter
1344 def host(self, value: str):
1345 self._host = value
1348class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1349 DUMMY_CACHE_VALUE = b"foo"
1350 MIN_ALLOWED_VERSION = "7.4.0"
1351 DEFAULT_SERVER_NAME = "redis"
1353 def __init__(
1354 self,
1355 conn: ConnectionInterface,
1356 cache: CacheInterface,
1357 pool_lock: threading.RLock,
1358 ):
1359 self.pid = os.getpid()
1360 self._conn = conn
1361 self.retry = self._conn.retry
1362 self.host = self._conn.host
1363 self.port = self._conn.port
1364 self.credential_provider = conn.credential_provider
1365 self._pool_lock = pool_lock
1366 self._cache = cache
1367 self._cache_lock = threading.RLock()
1368 self._current_command_cache_key = None
1369 self._current_options = None
1370 self.register_connect_callback(self._enable_tracking_callback)
1372 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1373 MaintNotificationsAbstractConnection.__init__(
1374 self,
1375 self._conn.maint_notifications_config,
1376 self._conn._maint_notifications_pool_handler,
1377 self._conn.maintenance_state,
1378 self._conn.maintenance_notification_hash,
1379 self._conn.host,
1380 self._conn.socket_timeout,
1381 self._conn.socket_connect_timeout,
1382 self._conn._get_parser(),
1383 )
1385 def repr_pieces(self):
1386 return self._conn.repr_pieces()
1388 def register_connect_callback(self, callback):
1389 self._conn.register_connect_callback(callback)
1391 def deregister_connect_callback(self, callback):
1392 self._conn.deregister_connect_callback(callback)
1394 def set_parser(self, parser_class):
1395 self._conn.set_parser(parser_class)
1397 def set_maint_notifications_pool_handler_for_connection(
1398 self, maint_notifications_pool_handler
1399 ):
1400 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1401 self._conn.set_maint_notifications_pool_handler_for_connection(
1402 maint_notifications_pool_handler
1403 )
1405 def get_protocol(self):
1406 return self._conn.get_protocol()
1408 def connect(self):
1409 self._conn.connect()
1411 server_name = self._conn.handshake_metadata.get(b"server", None)
1412 if server_name is None:
1413 server_name = self._conn.handshake_metadata.get("server", None)
1414 server_ver = self._conn.handshake_metadata.get(b"version", None)
1415 if server_ver is None:
1416 server_ver = self._conn.handshake_metadata.get("version", None)
1417 if server_ver is None or server_ver is None:
1418 raise ConnectionError("Cannot retrieve information about server version")
1420 server_ver = ensure_string(server_ver)
1421 server_name = ensure_string(server_name)
1423 if (
1424 server_name != self.DEFAULT_SERVER_NAME
1425 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1426 ):
1427 raise ConnectionError(
1428 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1429 )
1431 def on_connect(self):
1432 self._conn.on_connect()
1434 def disconnect(self, *args):
1435 with self._cache_lock:
1436 self._cache.flush()
1437 self._conn.disconnect(*args)
1439 def check_health(self):
1440 self._conn.check_health()
1442 def send_packed_command(self, command, check_health=True):
1443 # TODO: Investigate if it's possible to unpack command
1444 # or extract keys from packed command
1445 self._conn.send_packed_command(command)
1447 def send_command(self, *args, **kwargs):
1448 self._process_pending_invalidations()
1450 with self._cache_lock:
1451 # Command is write command or not allowed
1452 # to be cached.
1453 if not self._cache.is_cachable(
1454 CacheKey(command=args[0], redis_keys=(), redis_args=())
1455 ):
1456 self._current_command_cache_key = None
1457 self._conn.send_command(*args, **kwargs)
1458 return
1460 if kwargs.get("keys") is None:
1461 raise ValueError("Cannot create cache key.")
1463 # Creates cache key.
1464 self._current_command_cache_key = CacheKey(
1465 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
1466 )
1468 with self._cache_lock:
1469 # We have to trigger invalidation processing in case if
1470 # it was cached by another connection to avoid
1471 # queueing invalidations in stale connections.
1472 if self._cache.get(self._current_command_cache_key):
1473 entry = self._cache.get(self._current_command_cache_key)
1475 if entry.connection_ref != self._conn:
1476 with self._pool_lock:
1477 while entry.connection_ref.can_read():
1478 entry.connection_ref.read_response(push_request=True)
1480 return
1482 # Set temporary entry value to prevent
1483 # race condition from another connection.
1484 self._cache.set(
1485 CacheEntry(
1486 cache_key=self._current_command_cache_key,
1487 cache_value=self.DUMMY_CACHE_VALUE,
1488 status=CacheEntryStatus.IN_PROGRESS,
1489 connection_ref=self._conn,
1490 )
1491 )
1493 # Send command over socket only if it's allowed
1494 # read-only command that not yet cached.
1495 self._conn.send_command(*args, **kwargs)
1497 def can_read(self, timeout=0):
1498 return self._conn.can_read(timeout)
1500 def read_response(
1501 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1502 ):
1503 with self._cache_lock:
1504 # Check if command response exists in a cache and it's not in progress.
1505 if (
1506 self._current_command_cache_key is not None
1507 and self._cache.get(self._current_command_cache_key) is not None
1508 and self._cache.get(self._current_command_cache_key).status
1509 != CacheEntryStatus.IN_PROGRESS
1510 ):
1511 res = copy.deepcopy(
1512 self._cache.get(self._current_command_cache_key).cache_value
1513 )
1514 self._current_command_cache_key = None
1515 return res
1517 response = self._conn.read_response(
1518 disable_decoding=disable_decoding,
1519 disconnect_on_error=disconnect_on_error,
1520 push_request=push_request,
1521 )
1523 with self._cache_lock:
1524 # Prevent not-allowed command from caching.
1525 if self._current_command_cache_key is None:
1526 return response
1527 # If response is None prevent from caching.
1528 if response is None:
1529 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1530 return response
1532 cache_entry = self._cache.get(self._current_command_cache_key)
1534 # Cache only responses that still valid
1535 # and wasn't invalidated by another connection in meantime.
1536 if cache_entry is not None:
1537 cache_entry.status = CacheEntryStatus.VALID
1538 cache_entry.cache_value = response
1539 self._cache.set(cache_entry)
1541 self._current_command_cache_key = None
1543 return response
1545 def pack_command(self, *args):
1546 return self._conn.pack_command(*args)
1548 def pack_commands(self, commands):
1549 return self._conn.pack_commands(commands)
1551 @property
1552 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1553 return self._conn.handshake_metadata
1555 def set_re_auth_token(self, token: TokenInterface):
1556 self._conn.set_re_auth_token(token)
1558 def re_auth(self):
1559 self._conn.re_auth()
1561 def mark_for_reconnect(self):
1562 self._conn.mark_for_reconnect()
1564 def should_reconnect(self):
1565 return self._conn.should_reconnect()
1567 def reset_should_reconnect(self):
1568 self._conn.reset_should_reconnect()
1570 @property
1571 def host(self) -> str:
1572 return self._conn.host
1574 @host.setter
1575 def host(self, value: str):
1576 self._conn.host = value
1578 @property
1579 def socket_timeout(self) -> Optional[Union[float, int]]:
1580 return self._conn.socket_timeout
1582 @socket_timeout.setter
1583 def socket_timeout(self, value: Optional[Union[float, int]]):
1584 self._conn.socket_timeout = value
1586 @property
1587 def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1588 return self._conn.socket_connect_timeout
1590 @socket_connect_timeout.setter
1591 def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1592 self._conn.socket_connect_timeout = value
1594 def _get_socket(self) -> Optional[socket.socket]:
1595 if isinstance(self._conn, MaintNotificationsAbstractConnection):
1596 return self._conn._get_socket()
1597 else:
1598 raise NotImplementedError(
1599 "Maintenance notifications are not supported by this connection type"
1600 )
1602 def _get_maint_notifications_connection_instance(
1603 self, connection
1604 ) -> MaintNotificationsAbstractConnection:
1605 """
1606 Validate that connection instance supports maintenance notifications.
1607 With this helper method we ensure that we are working
1608 with the correct connection type.
1609 After twe validate that connection instance supports maintenance notifications
1610 we can safely return the connection instance
1611 as MaintNotificationsAbstractConnection.
1612 """
1613 if not isinstance(connection, MaintNotificationsAbstractConnection):
1614 raise NotImplementedError(
1615 "Maintenance notifications are not supported by this connection type"
1616 )
1617 else:
1618 return connection
1620 @property
1621 def maintenance_state(self) -> MaintenanceState:
1622 con = self._get_maint_notifications_connection_instance(self._conn)
1623 return con.maintenance_state
1625 @maintenance_state.setter
1626 def maintenance_state(self, state: MaintenanceState):
1627 con = self._get_maint_notifications_connection_instance(self._conn)
1628 con.maintenance_state = state
1630 def getpeername(self):
1631 con = self._get_maint_notifications_connection_instance(self._conn)
1632 return con.getpeername()
1634 def get_resolved_ip(self):
1635 con = self._get_maint_notifications_connection_instance(self._conn)
1636 return con.get_resolved_ip()
1638 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1639 con = self._get_maint_notifications_connection_instance(self._conn)
1640 con.update_current_socket_timeout(relaxed_timeout)
1642 def set_tmp_settings(
1643 self,
1644 tmp_host_address: Optional[str] = None,
1645 tmp_relaxed_timeout: Optional[float] = None,
1646 ):
1647 con = self._get_maint_notifications_connection_instance(self._conn)
1648 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1650 def reset_tmp_settings(
1651 self,
1652 reset_host_address: bool = False,
1653 reset_relaxed_timeout: bool = False,
1654 ):
1655 con = self._get_maint_notifications_connection_instance(self._conn)
1656 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1658 def _connect(self):
1659 self._conn._connect()
1661 def _host_error(self):
1662 self._conn._host_error()
1664 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1665 conn.send_command("CLIENT", "TRACKING", "ON")
1666 conn.read_response()
1667 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1669 def _process_pending_invalidations(self):
1670 while self.can_read():
1671 self._conn.read_response(push_request=True)
1673 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1674 with self._cache_lock:
1675 # Flush cache when DB flushed on server-side
1676 if data[1] is None:
1677 self._cache.flush()
1678 else:
1679 self._cache.delete_by_redis_keys(data[1])
1682class SSLConnection(Connection):
1683 """Manages SSL connections to and from the Redis server(s).
1684 This class extends the Connection class, adding SSL functionality, and making
1685 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1686 """ # noqa
1688 def __init__(
1689 self,
1690 ssl_keyfile=None,
1691 ssl_certfile=None,
1692 ssl_cert_reqs="required",
1693 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1694 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1695 ssl_ca_certs=None,
1696 ssl_ca_data=None,
1697 ssl_check_hostname=True,
1698 ssl_ca_path=None,
1699 ssl_password=None,
1700 ssl_validate_ocsp=False,
1701 ssl_validate_ocsp_stapled=False,
1702 ssl_ocsp_context=None,
1703 ssl_ocsp_expected_cert=None,
1704 ssl_min_version=None,
1705 ssl_ciphers=None,
1706 **kwargs,
1707 ):
1708 """Constructor
1710 Args:
1711 ssl_keyfile: Path to an ssl private key. Defaults to None.
1712 ssl_certfile: Path to an ssl certificate. Defaults to None.
1713 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1714 or an ssl.VerifyMode. Defaults to "required".
1715 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1716 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1717 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1718 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1719 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1720 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1721 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1723 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1724 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1725 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1726 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1727 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1728 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1730 Raises:
1731 RedisError
1732 """ # noqa
1733 if not SSL_AVAILABLE:
1734 raise RedisError("Python wasn't built with SSL support")
1736 self.keyfile = ssl_keyfile
1737 self.certfile = ssl_certfile
1738 if ssl_cert_reqs is None:
1739 ssl_cert_reqs = ssl.CERT_NONE
1740 elif isinstance(ssl_cert_reqs, str):
1741 CERT_REQS = { # noqa: N806
1742 "none": ssl.CERT_NONE,
1743 "optional": ssl.CERT_OPTIONAL,
1744 "required": ssl.CERT_REQUIRED,
1745 }
1746 if ssl_cert_reqs not in CERT_REQS:
1747 raise RedisError(
1748 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1749 )
1750 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1751 self.cert_reqs = ssl_cert_reqs
1752 self.ssl_include_verify_flags = ssl_include_verify_flags
1753 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1754 self.ca_certs = ssl_ca_certs
1755 self.ca_data = ssl_ca_data
1756 self.ca_path = ssl_ca_path
1757 self.check_hostname = (
1758 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1759 )
1760 self.certificate_password = ssl_password
1761 self.ssl_validate_ocsp = ssl_validate_ocsp
1762 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1763 self.ssl_ocsp_context = ssl_ocsp_context
1764 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1765 self.ssl_min_version = ssl_min_version
1766 self.ssl_ciphers = ssl_ciphers
1767 super().__init__(**kwargs)
1769 def _connect(self):
1770 """
1771 Wrap the socket with SSL support, handling potential errors.
1772 """
1773 sock = super()._connect()
1774 try:
1775 return self._wrap_socket_with_ssl(sock)
1776 except (OSError, RedisError):
1777 sock.close()
1778 raise
1780 def _wrap_socket_with_ssl(self, sock):
1781 """
1782 Wraps the socket with SSL support.
1784 Args:
1785 sock: The plain socket to wrap with SSL.
1787 Returns:
1788 An SSL wrapped socket.
1789 """
1790 context = ssl.create_default_context()
1791 context.check_hostname = self.check_hostname
1792 context.verify_mode = self.cert_reqs
1793 if self.ssl_include_verify_flags:
1794 for flag in self.ssl_include_verify_flags:
1795 context.verify_flags |= flag
1796 if self.ssl_exclude_verify_flags:
1797 for flag in self.ssl_exclude_verify_flags:
1798 context.verify_flags &= ~flag
1799 if self.certfile or self.keyfile:
1800 context.load_cert_chain(
1801 certfile=self.certfile,
1802 keyfile=self.keyfile,
1803 password=self.certificate_password,
1804 )
1805 if (
1806 self.ca_certs is not None
1807 or self.ca_path is not None
1808 or self.ca_data is not None
1809 ):
1810 context.load_verify_locations(
1811 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1812 )
1813 if self.ssl_min_version is not None:
1814 context.minimum_version = self.ssl_min_version
1815 if self.ssl_ciphers:
1816 context.set_ciphers(self.ssl_ciphers)
1817 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1818 raise RedisError("cryptography is not installed.")
1820 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1821 raise RedisError(
1822 "Either an OCSP staple or pure OCSP connection must be validated "
1823 "- not both."
1824 )
1826 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1828 # validation for the stapled case
1829 if self.ssl_validate_ocsp_stapled:
1830 import OpenSSL
1832 from .ocsp import ocsp_staple_verifier
1834 # if a context is provided use it - otherwise, a basic context
1835 if self.ssl_ocsp_context is None:
1836 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1837 staple_ctx.use_certificate_file(self.certfile)
1838 staple_ctx.use_privatekey_file(self.keyfile)
1839 else:
1840 staple_ctx = self.ssl_ocsp_context
1842 staple_ctx.set_ocsp_client_callback(
1843 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1844 )
1846 # need another socket
1847 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1848 con.request_ocsp()
1849 con.connect((self.host, self.port))
1850 con.do_handshake()
1851 con.shutdown()
1852 return sslsock
1854 # pure ocsp validation
1855 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1856 from .ocsp import OCSPVerifier
1858 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1859 if o.is_valid():
1860 return sslsock
1861 else:
1862 raise ConnectionError("ocsp validation error")
1863 return sslsock
1866class UnixDomainSocketConnection(AbstractConnection):
1867 "Manages UDS communication to and from a Redis server"
1869 def __init__(self, path="", socket_timeout=None, **kwargs):
1870 super().__init__(**kwargs)
1871 self.path = path
1872 self.socket_timeout = socket_timeout
1874 def repr_pieces(self):
1875 pieces = [("path", self.path), ("db", self.db)]
1876 if self.client_name:
1877 pieces.append(("client_name", self.client_name))
1878 return pieces
1880 def _connect(self):
1881 "Create a Unix domain socket connection"
1882 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1883 sock.settimeout(self.socket_connect_timeout)
1884 try:
1885 sock.connect(self.path)
1886 except OSError:
1887 # Prevent ResourceWarnings for unclosed sockets.
1888 try:
1889 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1890 except OSError:
1891 pass
1892 sock.close()
1893 raise
1894 sock.settimeout(self.socket_timeout)
1895 return sock
1897 def _host_error(self):
1898 return self.path
1901FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1904def to_bool(value):
1905 if value is None or value == "":
1906 return None
1907 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1908 return False
1909 return bool(value)
1912def parse_ssl_verify_flags(value):
1913 # flags are passed in as a string representation of a list,
1914 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1915 verify_flags_str = value.replace("[", "").replace("]", "")
1917 verify_flags = []
1918 for flag in verify_flags_str.split(","):
1919 flag = flag.strip()
1920 if not hasattr(VerifyFlags, flag):
1921 raise ValueError(f"Invalid ssl verify flag: {flag}")
1922 verify_flags.append(getattr(VerifyFlags, flag))
1923 return verify_flags
1926URL_QUERY_ARGUMENT_PARSERS = {
1927 "db": int,
1928 "socket_timeout": float,
1929 "socket_connect_timeout": float,
1930 "socket_keepalive": to_bool,
1931 "retry_on_timeout": to_bool,
1932 "retry_on_error": list,
1933 "max_connections": int,
1934 "health_check_interval": int,
1935 "ssl_check_hostname": to_bool,
1936 "ssl_include_verify_flags": parse_ssl_verify_flags,
1937 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1938 "timeout": float,
1939}
1942def parse_url(url):
1943 if not (
1944 url.startswith("redis://")
1945 or url.startswith("rediss://")
1946 or url.startswith("unix://")
1947 ):
1948 raise ValueError(
1949 "Redis URL must specify one of the following "
1950 "schemes (redis://, rediss://, unix://)"
1951 )
1953 url = urlparse(url)
1954 kwargs = {}
1956 for name, value in parse_qs(url.query).items():
1957 if value and len(value) > 0:
1958 value = unquote(value[0])
1959 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1960 if parser:
1961 try:
1962 kwargs[name] = parser(value)
1963 except (TypeError, ValueError):
1964 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1965 else:
1966 kwargs[name] = value
1968 if url.username:
1969 kwargs["username"] = unquote(url.username)
1970 if url.password:
1971 kwargs["password"] = unquote(url.password)
1973 # We only support redis://, rediss:// and unix:// schemes.
1974 if url.scheme == "unix":
1975 if url.path:
1976 kwargs["path"] = unquote(url.path)
1977 kwargs["connection_class"] = UnixDomainSocketConnection
1979 else: # implied: url.scheme in ("redis", "rediss"):
1980 if url.hostname:
1981 kwargs["host"] = unquote(url.hostname)
1982 if url.port:
1983 kwargs["port"] = int(url.port)
1985 # If there's a path argument, use it as the db argument if a
1986 # querystring value wasn't specified
1987 if url.path and "db" not in kwargs:
1988 try:
1989 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1990 except (AttributeError, ValueError):
1991 pass
1993 if url.scheme == "rediss":
1994 kwargs["connection_class"] = SSLConnection
1996 return kwargs
1999_CP = TypeVar("_CP", bound="ConnectionPool")
2002class ConnectionPoolInterface(ABC):
2003 @abstractmethod
2004 def get_protocol(self):
2005 pass
2007 @abstractmethod
2008 def reset(self):
2009 pass
2011 @abstractmethod
2012 @deprecated_args(
2013 args_to_warn=["*"],
2014 reason="Use get_connection() without args instead",
2015 version="5.3.0",
2016 )
2017 def get_connection(
2018 self, command_name: Optional[str], *keys, **options
2019 ) -> ConnectionInterface:
2020 pass
2022 @abstractmethod
2023 def get_encoder(self):
2024 pass
2026 @abstractmethod
2027 def release(self, connection: ConnectionInterface):
2028 pass
2030 @abstractmethod
2031 def disconnect(self, inuse_connections: bool = True):
2032 pass
2034 @abstractmethod
2035 def close(self):
2036 pass
2038 @abstractmethod
2039 def set_retry(self, retry: Retry):
2040 pass
2042 @abstractmethod
2043 def re_auth_callback(self, token: TokenInterface):
2044 pass
2047class MaintNotificationsAbstractConnectionPool:
2048 """
2049 Abstract class for handling maintenance notifications logic.
2050 This class is mixed into the ConnectionPool classes.
2052 This class is not intended to be used directly!
2054 All logic related to maintenance notifications and
2055 connection pool handling is encapsulated in this class.
2056 """
2058 def __init__(
2059 self,
2060 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2061 **kwargs,
2062 ):
2063 # Initialize maintenance notifications
2064 is_protocol_supported = kwargs.get("protocol") in [3, "3"]
2065 if maint_notifications_config is None and is_protocol_supported:
2066 maint_notifications_config = MaintNotificationsConfig()
2068 if maint_notifications_config and maint_notifications_config.enabled:
2069 if not is_protocol_supported:
2070 raise RedisError(
2071 "Maintenance notifications handlers on connection are only supported with RESP version 3"
2072 )
2074 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2075 self, maint_notifications_config
2076 )
2078 self._update_connection_kwargs_for_maint_notifications(
2079 self._maint_notifications_pool_handler
2080 )
2081 else:
2082 self._maint_notifications_pool_handler = None
2084 @property
2085 @abstractmethod
2086 def connection_kwargs(self) -> Dict[str, Any]:
2087 pass
2089 @connection_kwargs.setter
2090 @abstractmethod
2091 def connection_kwargs(self, value: Dict[str, Any]):
2092 pass
2094 @abstractmethod
2095 def _get_pool_lock(self) -> threading.RLock:
2096 pass
2098 @abstractmethod
2099 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2100 pass
2102 @abstractmethod
2103 def _get_in_use_connections(
2104 self,
2105 ) -> Iterable["MaintNotificationsAbstractConnection"]:
2106 pass
2108 def maint_notifications_enabled(self):
2109 """
2110 Returns:
2111 True if the maintenance notifications are enabled, False otherwise.
2112 The maintenance notifications config is stored in the pool handler.
2113 If the pool handler is not set, the maintenance notifications are not enabled.
2114 """
2115 maint_notifications_config = (
2116 self._maint_notifications_pool_handler.config
2117 if self._maint_notifications_pool_handler
2118 else None
2119 )
2121 return maint_notifications_config and maint_notifications_config.enabled
2123 def update_maint_notifications_config(
2124 self, maint_notifications_config: MaintNotificationsConfig
2125 ):
2126 """
2127 Updates the maintenance notifications configuration.
2128 This method should be called only if the pool was created
2129 without enabling the maintenance notifications and
2130 in a later point in time maintenance notifications
2131 are requested to be enabled.
2132 """
2133 if (
2134 self.maint_notifications_enabled()
2135 and not maint_notifications_config.enabled
2136 ):
2137 raise ValueError(
2138 "Cannot disable maintenance notifications after enabling them"
2139 )
2140 # first update pool settings
2141 if not self._maint_notifications_pool_handler:
2142 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2143 self, maint_notifications_config
2144 )
2145 else:
2146 self._maint_notifications_pool_handler.config = maint_notifications_config
2148 # then update connection kwargs and existing connections
2149 self._update_connection_kwargs_for_maint_notifications(
2150 self._maint_notifications_pool_handler
2151 )
2152 self._update_maint_notifications_configs_for_connections(
2153 self._maint_notifications_pool_handler
2154 )
2156 def _update_connection_kwargs_for_maint_notifications(
2157 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2158 ):
2159 """
2160 Update the connection kwargs for all future connections.
2161 """
2162 if not self.maint_notifications_enabled():
2163 return
2165 self.connection_kwargs.update(
2166 {
2167 "maint_notifications_pool_handler": maint_notifications_pool_handler,
2168 "maint_notifications_config": maint_notifications_pool_handler.config,
2169 }
2170 )
2172 # Store original connection parameters for maintenance notifications.
2173 if self.connection_kwargs.get("orig_host_address", None) is None:
2174 # If orig_host_address is None it means we haven't
2175 # configured the original values yet
2176 self.connection_kwargs.update(
2177 {
2178 "orig_host_address": self.connection_kwargs.get("host"),
2179 "orig_socket_timeout": self.connection_kwargs.get(
2180 "socket_timeout", None
2181 ),
2182 "orig_socket_connect_timeout": self.connection_kwargs.get(
2183 "socket_connect_timeout", None
2184 ),
2185 }
2186 )
2188 def _update_maint_notifications_configs_for_connections(
2189 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2190 ):
2191 """Update the maintenance notifications config for all connections in the pool."""
2192 with self._get_pool_lock():
2193 for conn in self._get_free_connections():
2194 conn.set_maint_notifications_pool_handler_for_connection(
2195 maint_notifications_pool_handler
2196 )
2197 conn.maint_notifications_config = (
2198 maint_notifications_pool_handler.config
2199 )
2200 conn.disconnect()
2201 for conn in self._get_in_use_connections():
2202 conn.set_maint_notifications_pool_handler_for_connection(
2203 maint_notifications_pool_handler
2204 )
2205 conn.maint_notifications_config = (
2206 maint_notifications_pool_handler.config
2207 )
2208 conn.mark_for_reconnect()
2210 def _should_update_connection(
2211 self,
2212 conn: "MaintNotificationsAbstractConnection",
2213 matching_pattern: Literal[
2214 "connected_address", "configured_address", "notification_hash"
2215 ] = "connected_address",
2216 matching_address: Optional[str] = None,
2217 matching_notification_hash: Optional[int] = None,
2218 ) -> bool:
2219 """
2220 Check if the connection should be updated based on the matching criteria.
2221 """
2222 if matching_pattern == "connected_address":
2223 if matching_address and conn.getpeername() != matching_address:
2224 return False
2225 elif matching_pattern == "configured_address":
2226 if matching_address and conn.host != matching_address:
2227 return False
2228 elif matching_pattern == "notification_hash":
2229 if (
2230 matching_notification_hash
2231 and conn.maintenance_notification_hash != matching_notification_hash
2232 ):
2233 return False
2234 return True
2236 def update_connection_settings(
2237 self,
2238 conn: "MaintNotificationsAbstractConnection",
2239 state: Optional["MaintenanceState"] = None,
2240 maintenance_notification_hash: Optional[int] = None,
2241 host_address: Optional[str] = None,
2242 relaxed_timeout: Optional[float] = None,
2243 update_notification_hash: bool = False,
2244 reset_host_address: bool = False,
2245 reset_relaxed_timeout: bool = False,
2246 ):
2247 """
2248 Update the settings for a single connection.
2249 """
2250 if state:
2251 conn.maintenance_state = state
2253 if update_notification_hash:
2254 # update the notification hash only if requested
2255 conn.maintenance_notification_hash = maintenance_notification_hash
2257 if host_address is not None:
2258 conn.set_tmp_settings(tmp_host_address=host_address)
2260 if relaxed_timeout is not None:
2261 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2263 if reset_relaxed_timeout or reset_host_address:
2264 conn.reset_tmp_settings(
2265 reset_host_address=reset_host_address,
2266 reset_relaxed_timeout=reset_relaxed_timeout,
2267 )
2269 conn.update_current_socket_timeout(relaxed_timeout)
2271 def update_connections_settings(
2272 self,
2273 state: Optional["MaintenanceState"] = None,
2274 maintenance_notification_hash: Optional[int] = None,
2275 host_address: Optional[str] = None,
2276 relaxed_timeout: Optional[float] = None,
2277 matching_address: Optional[str] = None,
2278 matching_notification_hash: Optional[int] = None,
2279 matching_pattern: Literal[
2280 "connected_address", "configured_address", "notification_hash"
2281 ] = "connected_address",
2282 update_notification_hash: bool = False,
2283 reset_host_address: bool = False,
2284 reset_relaxed_timeout: bool = False,
2285 include_free_connections: bool = True,
2286 ):
2287 """
2288 Update the settings for all matching connections in the pool.
2290 This method does not create new connections.
2291 This method does not affect the connection kwargs.
2293 :param state: The maintenance state to set for the connection.
2294 :param maintenance_notification_hash: The hash of the maintenance notification
2295 to set for the connection.
2296 :param host_address: The host address to set for the connection.
2297 :param relaxed_timeout: The relaxed timeout to set for the connection.
2298 :param matching_address: The address to match for the connection.
2299 :param matching_notification_hash: The notification hash to match for the connection.
2300 :param matching_pattern: The pattern to match for the connection.
2301 :param update_notification_hash: Whether to update the notification hash for the connection.
2302 :param reset_host_address: Whether to reset the host address to the original address.
2303 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2304 :param include_free_connections: Whether to include free/available connections.
2305 """
2306 with self._get_pool_lock():
2307 for conn in self._get_in_use_connections():
2308 if self._should_update_connection(
2309 conn,
2310 matching_pattern,
2311 matching_address,
2312 matching_notification_hash,
2313 ):
2314 self.update_connection_settings(
2315 conn,
2316 state=state,
2317 maintenance_notification_hash=maintenance_notification_hash,
2318 host_address=host_address,
2319 relaxed_timeout=relaxed_timeout,
2320 update_notification_hash=update_notification_hash,
2321 reset_host_address=reset_host_address,
2322 reset_relaxed_timeout=reset_relaxed_timeout,
2323 )
2325 if include_free_connections:
2326 for conn in self._get_free_connections():
2327 if self._should_update_connection(
2328 conn,
2329 matching_pattern,
2330 matching_address,
2331 matching_notification_hash,
2332 ):
2333 self.update_connection_settings(
2334 conn,
2335 state=state,
2336 maintenance_notification_hash=maintenance_notification_hash,
2337 host_address=host_address,
2338 relaxed_timeout=relaxed_timeout,
2339 update_notification_hash=update_notification_hash,
2340 reset_host_address=reset_host_address,
2341 reset_relaxed_timeout=reset_relaxed_timeout,
2342 )
2344 def update_connection_kwargs(
2345 self,
2346 **kwargs,
2347 ):
2348 """
2349 Update the connection kwargs for all future connections.
2351 This method updates the connection kwargs for all future connections created by the pool.
2352 Existing connections are not affected.
2353 """
2354 self.connection_kwargs.update(kwargs)
2356 def update_active_connections_for_reconnect(
2357 self,
2358 moving_address_src: Optional[str] = None,
2359 ):
2360 """
2361 Mark all active connections for reconnect.
2362 This is used when a cluster node is migrated to a different address.
2364 :param moving_address_src: The address of the node that is being moved.
2365 """
2366 with self._get_pool_lock():
2367 for conn in self._get_in_use_connections():
2368 if self._should_update_connection(
2369 conn, "connected_address", moving_address_src
2370 ):
2371 conn.mark_for_reconnect()
2373 def disconnect_free_connections(
2374 self,
2375 moving_address_src: Optional[str] = None,
2376 ):
2377 """
2378 Disconnect all free/available connections.
2379 This is used when a cluster node is migrated to a different address.
2381 :param moving_address_src: The address of the node that is being moved.
2382 """
2383 with self._get_pool_lock():
2384 for conn in self._get_free_connections():
2385 if self._should_update_connection(
2386 conn, "connected_address", moving_address_src
2387 ):
2388 conn.disconnect()
2391class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
2392 """
2393 Create a connection pool. ``If max_connections`` is set, then this
2394 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
2395 limit is reached.
2397 By default, TCP connections are created unless ``connection_class``
2398 is specified. Use class:`.UnixDomainSocketConnection` for
2399 unix sockets.
2400 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
2402 If ``maint_notifications_config`` is provided, the connection pool will support
2403 maintenance notifications.
2404 Maintenance notifications are supported only with RESP3.
2405 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2406 the maintenance notifications will be enabled by default.
2408 Any additional keyword arguments are passed to the constructor of
2409 ``connection_class``.
2410 """
2412 @classmethod
2413 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
2414 """
2415 Return a connection pool configured from the given URL.
2417 For example::
2419 redis://[[username]:[password]]@localhost:6379/0
2420 rediss://[[username]:[password]]@localhost:6379/0
2421 unix://[username@]/path/to/socket.sock?db=0[&password=password]
2423 Three URL schemes are supported:
2425 - `redis://` creates a TCP socket connection. See more at:
2426 <https://www.iana.org/assignments/uri-schemes/prov/redis>
2427 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
2428 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
2429 - ``unix://``: creates a Unix Domain Socket connection.
2431 The username, password, hostname, path and all querystring values
2432 are passed through urllib.parse.unquote in order to replace any
2433 percent-encoded values with their corresponding characters.
2435 There are several ways to specify a database number. The first value
2436 found will be used:
2438 1. A ``db`` querystring option, e.g. redis://localhost?db=0
2439 2. If using the redis:// or rediss:// schemes, the path argument
2440 of the url, e.g. redis://localhost/0
2441 3. A ``db`` keyword argument to this function.
2443 If none of these options are specified, the default db=0 is used.
2445 All querystring options are cast to their appropriate Python types.
2446 Boolean arguments can be specified with string values "True"/"False"
2447 or "Yes"/"No". Values that cannot be properly cast cause a
2448 ``ValueError`` to be raised. Once parsed, the querystring arguments
2449 and keyword arguments are passed to the ``ConnectionPool``'s
2450 class initializer. In the case of conflicting arguments, querystring
2451 arguments always win.
2452 """
2453 url_options = parse_url(url)
2455 if "connection_class" in kwargs:
2456 url_options["connection_class"] = kwargs["connection_class"]
2458 kwargs.update(url_options)
2459 return cls(**kwargs)
2461 def __init__(
2462 self,
2463 connection_class=Connection,
2464 max_connections: Optional[int] = None,
2465 cache_factory: Optional[CacheFactoryInterface] = None,
2466 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2467 **connection_kwargs,
2468 ):
2469 max_connections = max_connections or 2**31
2470 if not isinstance(max_connections, int) or max_connections < 0:
2471 raise ValueError('"max_connections" must be a positive integer')
2473 self.connection_class = connection_class
2474 self._connection_kwargs = connection_kwargs
2475 self.max_connections = max_connections
2476 self.cache = None
2477 self._cache_factory = cache_factory
2479 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
2480 if self._connection_kwargs.get("protocol") not in [3, "3"]:
2481 raise RedisError("Client caching is only supported with RESP version 3")
2483 cache = self._connection_kwargs.get("cache")
2485 if cache is not None:
2486 if not isinstance(cache, CacheInterface):
2487 raise ValueError("Cache must implement CacheInterface")
2489 self.cache = cache
2490 else:
2491 if self._cache_factory is not None:
2492 self.cache = self._cache_factory.get_cache()
2493 else:
2494 self.cache = CacheFactory(
2495 self._connection_kwargs.get("cache_config")
2496 ).get_cache()
2498 connection_kwargs.pop("cache", None)
2499 connection_kwargs.pop("cache_config", None)
2501 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
2502 if self._event_dispatcher is None:
2503 self._event_dispatcher = EventDispatcher()
2505 # a lock to protect the critical section in _checkpid().
2506 # this lock is acquired when the process id changes, such as
2507 # after a fork. during this time, multiple threads in the child
2508 # process could attempt to acquire this lock. the first thread
2509 # to acquire the lock will reset the data structures and lock
2510 # object of this pool. subsequent threads acquiring this lock
2511 # will notice the first thread already did the work and simply
2512 # release the lock.
2514 self._fork_lock = threading.RLock()
2515 self._lock = threading.RLock()
2517 MaintNotificationsAbstractConnectionPool.__init__(
2518 self,
2519 maint_notifications_config=maint_notifications_config,
2520 **connection_kwargs,
2521 )
2523 self.reset()
2525 def __repr__(self) -> str:
2526 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
2527 return (
2528 f"<{self.__class__.__module__}.{self.__class__.__name__}"
2529 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
2530 f"({conn_kwargs})>)>"
2531 )
2533 @property
2534 def connection_kwargs(self) -> Dict[str, Any]:
2535 return self._connection_kwargs
2537 @connection_kwargs.setter
2538 def connection_kwargs(self, value: Dict[str, Any]):
2539 self._connection_kwargs = value
2541 def get_protocol(self):
2542 """
2543 Returns:
2544 The RESP protocol version, or ``None`` if the protocol is not specified,
2545 in which case the server default will be used.
2546 """
2547 return self.connection_kwargs.get("protocol", None)
2549 def reset(self) -> None:
2550 self._created_connections = 0
2551 self._available_connections = []
2552 self._in_use_connections = set()
2554 # this must be the last operation in this method. while reset() is
2555 # called when holding _fork_lock, other threads in this process
2556 # can call _checkpid() which compares self.pid and os.getpid() without
2557 # holding any lock (for performance reasons). keeping this assignment
2558 # as the last operation ensures that those other threads will also
2559 # notice a pid difference and block waiting for the first thread to
2560 # release _fork_lock. when each of these threads eventually acquire
2561 # _fork_lock, they will notice that another thread already called
2562 # reset() and they will immediately release _fork_lock and continue on.
2563 self.pid = os.getpid()
2565 def _checkpid(self) -> None:
2566 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
2567 # systems. this is called by all ConnectionPool methods that
2568 # manipulate the pool's state such as get_connection() and release().
2569 #
2570 # _checkpid() determines whether the process has forked by comparing
2571 # the current process id to the process id saved on the ConnectionPool
2572 # instance. if these values are the same, _checkpid() simply returns.
2573 #
2574 # when the process ids differ, _checkpid() assumes that the process
2575 # has forked and that we're now running in the child process. the child
2576 # process cannot use the parent's file descriptors (e.g., sockets).
2577 # therefore, when _checkpid() sees the process id change, it calls
2578 # reset() in order to reinitialize the child's ConnectionPool. this
2579 # will cause the child to make all new connection objects.
2580 #
2581 # _checkpid() is protected by self._fork_lock to ensure that multiple
2582 # threads in the child process do not call reset() multiple times.
2583 #
2584 # there is an extremely small chance this could fail in the following
2585 # scenario:
2586 # 1. process A calls _checkpid() for the first time and acquires
2587 # self._fork_lock.
2588 # 2. while holding self._fork_lock, process A forks (the fork()
2589 # could happen in a different thread owned by process A)
2590 # 3. process B (the forked child process) inherits the
2591 # ConnectionPool's state from the parent. that state includes
2592 # a locked _fork_lock. process B will not be notified when
2593 # process A releases the _fork_lock and will thus never be
2594 # able to acquire the _fork_lock.
2595 #
2596 # to mitigate this possible deadlock, _checkpid() will only wait 5
2597 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
2598 # that time it is assumed that the child is deadlocked and a
2599 # redis.ChildDeadlockedError error is raised.
2600 if self.pid != os.getpid():
2601 acquired = self._fork_lock.acquire(timeout=5)
2602 if not acquired:
2603 raise ChildDeadlockedError
2604 # reset() the instance for the new process if another thread
2605 # hasn't already done so
2606 try:
2607 if self.pid != os.getpid():
2608 self.reset()
2609 finally:
2610 self._fork_lock.release()
2612 @deprecated_args(
2613 args_to_warn=["*"],
2614 reason="Use get_connection() without args instead",
2615 version="5.3.0",
2616 )
2617 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
2618 "Get a connection from the pool"
2620 self._checkpid()
2621 with self._lock:
2622 try:
2623 connection = self._available_connections.pop()
2624 except IndexError:
2625 connection = self.make_connection()
2626 self._in_use_connections.add(connection)
2628 try:
2629 # ensure this connection is connected to Redis
2630 connection.connect()
2631 # connections that the pool provides should be ready to send
2632 # a command. if not, the connection was either returned to the
2633 # pool before all data has been read or the socket has been
2634 # closed. either way, reconnect and verify everything is good.
2635 try:
2636 if (
2637 connection.can_read()
2638 and self.cache is None
2639 and not self.maint_notifications_enabled()
2640 ):
2641 raise ConnectionError("Connection has data")
2642 except (ConnectionError, TimeoutError, OSError):
2643 connection.disconnect()
2644 connection.connect()
2645 if connection.can_read():
2646 raise ConnectionError("Connection not ready")
2647 except BaseException:
2648 # release the connection back to the pool so that we don't
2649 # leak it
2650 self.release(connection)
2651 raise
2652 return connection
2654 def get_encoder(self) -> Encoder:
2655 "Return an encoder based on encoding settings"
2656 kwargs = self.connection_kwargs
2657 return Encoder(
2658 encoding=kwargs.get("encoding", "utf-8"),
2659 encoding_errors=kwargs.get("encoding_errors", "strict"),
2660 decode_responses=kwargs.get("decode_responses", False),
2661 )
2663 def make_connection(self) -> "ConnectionInterface":
2664 "Create a new connection"
2665 if self._created_connections >= self.max_connections:
2666 raise MaxConnectionsError("Too many connections")
2667 self._created_connections += 1
2669 kwargs = dict(self.connection_kwargs)
2671 if self.cache is not None:
2672 return CacheProxyConnection(
2673 self.connection_class(**kwargs), self.cache, self._lock
2674 )
2675 return self.connection_class(**kwargs)
2677 def release(self, connection: "Connection") -> None:
2678 "Releases the connection back to the pool"
2679 self._checkpid()
2680 with self._lock:
2681 try:
2682 self._in_use_connections.remove(connection)
2683 except KeyError:
2684 # Gracefully fail when a connection is returned to this pool
2685 # that the pool doesn't actually own
2686 return
2688 if self.owns_connection(connection):
2689 if connection.should_reconnect():
2690 connection.disconnect()
2691 self._available_connections.append(connection)
2692 self._event_dispatcher.dispatch(
2693 AfterConnectionReleasedEvent(connection)
2694 )
2695 else:
2696 # Pool doesn't own this connection, do not add it back
2697 # to the pool.
2698 # The created connections count should not be changed,
2699 # because the connection was not created by the pool.
2700 connection.disconnect()
2701 return
2703 def owns_connection(self, connection: "Connection") -> int:
2704 return connection.pid == self.pid
2706 def disconnect(self, inuse_connections: bool = True) -> None:
2707 """
2708 Disconnects connections in the pool
2710 If ``inuse_connections`` is True, disconnect connections that are
2711 currently in use, potentially by other threads. Otherwise only disconnect
2712 connections that are idle in the pool.
2713 """
2714 self._checkpid()
2715 with self._lock:
2716 if inuse_connections:
2717 connections = chain(
2718 self._available_connections, self._in_use_connections
2719 )
2720 else:
2721 connections = self._available_connections
2723 for connection in connections:
2724 connection.disconnect()
2726 def close(self) -> None:
2727 """Close the pool, disconnecting all connections"""
2728 self.disconnect()
2730 def set_retry(self, retry: Retry) -> None:
2731 self.connection_kwargs.update({"retry": retry})
2732 for conn in self._available_connections:
2733 conn.retry = retry
2734 for conn in self._in_use_connections:
2735 conn.retry = retry
2737 def re_auth_callback(self, token: TokenInterface):
2738 with self._lock:
2739 for conn in self._available_connections:
2740 conn.retry.call_with_retry(
2741 lambda: conn.send_command(
2742 "AUTH", token.try_get("oid"), token.get_value()
2743 ),
2744 lambda error: self._mock(error),
2745 )
2746 conn.retry.call_with_retry(
2747 lambda: conn.read_response(), lambda error: self._mock(error)
2748 )
2749 for conn in self._in_use_connections:
2750 conn.set_re_auth_token(token)
2752 def _get_pool_lock(self):
2753 return self._lock
2755 def _get_free_connections(self):
2756 with self._lock:
2757 return self._available_connections
2759 def _get_in_use_connections(self):
2760 with self._lock:
2761 return self._in_use_connections
2763 async def _mock(self, error: RedisError):
2764 """
2765 Dummy functions, needs to be passed as error callback to retry object.
2766 :param error:
2767 :return:
2768 """
2769 pass
2772class BlockingConnectionPool(ConnectionPool):
2773 """
2774 Thread-safe blocking connection pool::
2776 >>> from redis.client import Redis
2777 >>> client = Redis(connection_pool=BlockingConnectionPool())
2779 It performs the same function as the default
2780 :py:class:`~redis.ConnectionPool` implementation, in that,
2781 it maintains a pool of reusable connections that can be shared by
2782 multiple redis clients (safely across threads if required).
2784 The difference is that, in the event that a client tries to get a
2785 connection from the pool when all of connections are in use, rather than
2786 raising a :py:class:`~redis.ConnectionError` (as the default
2787 :py:class:`~redis.ConnectionPool` implementation does), it
2788 makes the client wait ("blocks") for a specified number of seconds until
2789 a connection becomes available.
2791 Use ``max_connections`` to increase / decrease the pool size::
2793 >>> pool = BlockingConnectionPool(max_connections=10)
2795 Use ``timeout`` to tell it either how many seconds to wait for a connection
2796 to become available, or to block forever:
2798 >>> # Block forever.
2799 >>> pool = BlockingConnectionPool(timeout=None)
2801 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
2802 >>> # not available.
2803 >>> pool = BlockingConnectionPool(timeout=5)
2804 """
2806 def __init__(
2807 self,
2808 max_connections=50,
2809 timeout=20,
2810 connection_class=Connection,
2811 queue_class=LifoQueue,
2812 **connection_kwargs,
2813 ):
2814 self.queue_class = queue_class
2815 self.timeout = timeout
2816 self._in_maintenance = False
2817 self._locked = False
2818 super().__init__(
2819 connection_class=connection_class,
2820 max_connections=max_connections,
2821 **connection_kwargs,
2822 )
2824 def reset(self):
2825 # Create and fill up a thread safe queue with ``None`` values.
2826 try:
2827 if self._in_maintenance:
2828 self._lock.acquire()
2829 self._locked = True
2830 self.pool = self.queue_class(self.max_connections)
2831 while True:
2832 try:
2833 self.pool.put_nowait(None)
2834 except Full:
2835 break
2837 # Keep a list of actual connection instances so that we can
2838 # disconnect them later.
2839 self._connections = []
2840 finally:
2841 if self._locked:
2842 try:
2843 self._lock.release()
2844 except Exception:
2845 pass
2846 self._locked = False
2848 # this must be the last operation in this method. while reset() is
2849 # called when holding _fork_lock, other threads in this process
2850 # can call _checkpid() which compares self.pid and os.getpid() without
2851 # holding any lock (for performance reasons). keeping this assignment
2852 # as the last operation ensures that those other threads will also
2853 # notice a pid difference and block waiting for the first thread to
2854 # release _fork_lock. when each of these threads eventually acquire
2855 # _fork_lock, they will notice that another thread already called
2856 # reset() and they will immediately release _fork_lock and continue on.
2857 self.pid = os.getpid()
2859 def make_connection(self):
2860 "Make a fresh connection."
2861 try:
2862 if self._in_maintenance:
2863 self._lock.acquire()
2864 self._locked = True
2866 if self.cache is not None:
2867 connection = CacheProxyConnection(
2868 self.connection_class(**self.connection_kwargs),
2869 self.cache,
2870 self._lock,
2871 )
2872 else:
2873 connection = self.connection_class(**self.connection_kwargs)
2874 self._connections.append(connection)
2875 return connection
2876 finally:
2877 if self._locked:
2878 try:
2879 self._lock.release()
2880 except Exception:
2881 pass
2882 self._locked = False
2884 @deprecated_args(
2885 args_to_warn=["*"],
2886 reason="Use get_connection() without args instead",
2887 version="5.3.0",
2888 )
2889 def get_connection(self, command_name=None, *keys, **options):
2890 """
2891 Get a connection, blocking for ``self.timeout`` until a connection
2892 is available from the pool.
2894 If the connection returned is ``None`` then creates a new connection.
2895 Because we use a last-in first-out queue, the existing connections
2896 (having been returned to the pool after the initial ``None`` values
2897 were added) will be returned before ``None`` values. This means we only
2898 create new connections when we need to, i.e.: the actual number of
2899 connections will only increase in response to demand.
2900 """
2901 # Make sure we haven't changed process.
2902 self._checkpid()
2904 # Try and get a connection from the pool. If one isn't available within
2905 # self.timeout then raise a ``ConnectionError``.
2906 connection = None
2907 try:
2908 if self._in_maintenance:
2909 self._lock.acquire()
2910 self._locked = True
2911 try:
2912 connection = self.pool.get(block=True, timeout=self.timeout)
2913 except Empty:
2914 # Note that this is not caught by the redis client and will be
2915 # raised unless handled by application code. If you want never to
2916 raise ConnectionError("No connection available.")
2918 # If the ``connection`` is actually ``None`` then that's a cue to make
2919 # a new connection to add to the pool.
2920 if connection is None:
2921 connection = self.make_connection()
2922 finally:
2923 if self._locked:
2924 try:
2925 self._lock.release()
2926 except Exception:
2927 pass
2928 self._locked = False
2930 try:
2931 # ensure this connection is connected to Redis
2932 connection.connect()
2933 # connections that the pool provides should be ready to send
2934 # a command. if not, the connection was either returned to the
2935 # pool before all data has been read or the socket has been
2936 # closed. either way, reconnect and verify everything is good.
2937 try:
2938 if connection.can_read():
2939 raise ConnectionError("Connection has data")
2940 except (ConnectionError, TimeoutError, OSError):
2941 connection.disconnect()
2942 connection.connect()
2943 if connection.can_read():
2944 raise ConnectionError("Connection not ready")
2945 except BaseException:
2946 # release the connection back to the pool so that we don't leak it
2947 self.release(connection)
2948 raise
2950 return connection
2952 def release(self, connection):
2953 "Releases the connection back to the pool."
2954 # Make sure we haven't changed process.
2955 self._checkpid()
2957 try:
2958 if self._in_maintenance:
2959 self._lock.acquire()
2960 self._locked = True
2961 if not self.owns_connection(connection):
2962 # pool doesn't own this connection. do not add it back
2963 # to the pool. instead add a None value which is a placeholder
2964 # that will cause the pool to recreate the connection if
2965 # its needed.
2966 connection.disconnect()
2967 self.pool.put_nowait(None)
2968 return
2969 if connection.should_reconnect():
2970 connection.disconnect()
2971 # Put the connection back into the pool.
2972 try:
2973 self.pool.put_nowait(connection)
2974 except Full:
2975 # perhaps the pool has been reset() after a fork? regardless,
2976 # we don't want this connection
2977 pass
2978 finally:
2979 if self._locked:
2980 try:
2981 self._lock.release()
2982 except Exception:
2983 pass
2984 self._locked = False
2986 def disconnect(self, inuse_connections: bool = True):
2987 "Disconnects either all connections in the pool or just the free connections."
2988 self._checkpid()
2989 try:
2990 if self._in_maintenance:
2991 self._lock.acquire()
2992 self._locked = True
2993 if inuse_connections:
2994 connections = self._connections
2995 else:
2996 connections = self._get_free_connections()
2997 for connection in connections:
2998 connection.disconnect()
2999 finally:
3000 if self._locked:
3001 try:
3002 self._lock.release()
3003 except Exception:
3004 pass
3005 self._locked = False
3007 def _get_free_connections(self):
3008 with self._lock:
3009 return {conn for conn in self.pool.queue if conn}
3011 def _get_in_use_connections(self):
3012 with self._lock:
3013 # free connections
3014 connections_in_queue = {conn for conn in self.pool.queue if conn}
3015 # in self._connections we keep all created connections
3016 # so the ones that are not in the queue are the in use ones
3017 return {
3018 conn for conn in self._connections if conn not in connections_in_queue
3019 }
3021 def set_in_maintenance(self, in_maintenance: bool):
3022 """
3023 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
3025 This is used to prevent new connections from being created while we are in maintenance mode.
3026 The pool will be in maintenance mode only when we are processing a MOVING notification.
3027 """
3028 self._in_maintenance = in_maintenance