1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
24
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32)
33
34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
35from .auth.token import TokenInterface
36from .backoff import NoBackoff
37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
38from .event import AfterConnectionReleasedEvent, EventDispatcher
39from .exceptions import (
40 AuthenticationError,
41 AuthenticationWrongNumberOfArgsError,
42 ChildDeadlockedError,
43 ConnectionError,
44 DataError,
45 MaxConnectionsError,
46 RedisError,
47 ResponseError,
48 TimeoutError,
49)
50from .maintenance_events import (
51 MaintenanceEventConnectionHandler,
52 MaintenanceEventPoolHandler,
53 MaintenanceEventsConfig,
54 MaintenanceState,
55)
56from .retry import Retry
57from .utils import (
58 CRYPTOGRAPHY_AVAILABLE,
59 HIREDIS_AVAILABLE,
60 SSL_AVAILABLE,
61 compare_versions,
62 deprecated_args,
63 ensure_string,
64 format_error_message,
65 get_lib_version,
66 str_if_bytes,
67)
68
69if SSL_AVAILABLE:
70 import ssl
71else:
72 ssl = None
73
74if HIREDIS_AVAILABLE:
75 import hiredis
76
77SYM_STAR = b"*"
78SYM_DOLLAR = b"$"
79SYM_CRLF = b"\r\n"
80SYM_EMPTY = b""
81
82DEFAULT_RESP_VERSION = 2
83
84SENTINEL = object()
85
86DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
87if HIREDIS_AVAILABLE:
88 DefaultParser = _HiredisParser
89else:
90 DefaultParser = _RESP2Parser
91
92
93class HiredisRespSerializer:
94 def pack(self, *args: List):
95 """Pack a series of arguments into the Redis protocol"""
96 output = []
97
98 if isinstance(args[0], str):
99 args = tuple(args[0].encode().split()) + args[1:]
100 elif b" " in args[0]:
101 args = tuple(args[0].split()) + args[1:]
102 try:
103 output.append(hiredis.pack_command(args))
104 except TypeError:
105 _, value, traceback = sys.exc_info()
106 raise DataError(value).with_traceback(traceback)
107
108 return output
109
110
111class PythonRespSerializer:
112 def __init__(self, buffer_cutoff, encode) -> None:
113 self._buffer_cutoff = buffer_cutoff
114 self.encode = encode
115
116 def pack(self, *args):
117 """Pack a series of arguments into the Redis protocol"""
118 output = []
119 # the client might have included 1 or more literal arguments in
120 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
121 # arguments to be sent separately, so split the first argument
122 # manually. These arguments should be bytestrings so that they are
123 # not encoded.
124 if isinstance(args[0], str):
125 args = tuple(args[0].encode().split()) + args[1:]
126 elif b" " in args[0]:
127 args = tuple(args[0].split()) + args[1:]
128
129 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
130
131 buffer_cutoff = self._buffer_cutoff
132 for arg in map(self.encode, args):
133 # to avoid large string mallocs, chunk the command into the
134 # output list if we're sending large values or memoryviews
135 arg_length = len(arg)
136 if (
137 len(buff) > buffer_cutoff
138 or arg_length > buffer_cutoff
139 or isinstance(arg, memoryview)
140 ):
141 buff = SYM_EMPTY.join(
142 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
143 )
144 output.append(buff)
145 output.append(arg)
146 buff = SYM_CRLF
147 else:
148 buff = SYM_EMPTY.join(
149 (
150 buff,
151 SYM_DOLLAR,
152 str(arg_length).encode(),
153 SYM_CRLF,
154 arg,
155 SYM_CRLF,
156 )
157 )
158 output.append(buff)
159 return output
160
161
162class ConnectionInterface:
163 @abstractmethod
164 def repr_pieces(self):
165 pass
166
167 @abstractmethod
168 def register_connect_callback(self, callback):
169 pass
170
171 @abstractmethod
172 def deregister_connect_callback(self, callback):
173 pass
174
175 @abstractmethod
176 def set_parser(self, parser_class):
177 pass
178
179 @abstractmethod
180 def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler):
181 pass
182
183 @abstractmethod
184 def get_protocol(self):
185 pass
186
187 @abstractmethod
188 def connect(self):
189 pass
190
191 @abstractmethod
192 def on_connect(self):
193 pass
194
195 @abstractmethod
196 def disconnect(self, *args):
197 pass
198
199 @abstractmethod
200 def check_health(self):
201 pass
202
203 @abstractmethod
204 def send_packed_command(self, command, check_health=True):
205 pass
206
207 @abstractmethod
208 def send_command(self, *args, **kwargs):
209 pass
210
211 @abstractmethod
212 def can_read(self, timeout=0):
213 pass
214
215 @abstractmethod
216 def read_response(
217 self,
218 disable_decoding=False,
219 *,
220 disconnect_on_error=True,
221 push_request=False,
222 ):
223 pass
224
225 @abstractmethod
226 def pack_command(self, *args):
227 pass
228
229 @abstractmethod
230 def pack_commands(self, commands):
231 pass
232
233 @property
234 @abstractmethod
235 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
236 pass
237
238 @abstractmethod
239 def set_re_auth_token(self, token: TokenInterface):
240 pass
241
242 @abstractmethod
243 def re_auth(self):
244 pass
245
246 @property
247 @abstractmethod
248 def maintenance_state(self) -> MaintenanceState:
249 """
250 Returns the current maintenance state of the connection.
251 """
252 pass
253
254 @maintenance_state.setter
255 @abstractmethod
256 def maintenance_state(self, state: "MaintenanceState"):
257 """
258 Sets the current maintenance state of the connection.
259 """
260 pass
261
262 @abstractmethod
263 def getpeername(self):
264 """
265 Returns the peer name of the connection.
266 """
267 pass
268
269 @abstractmethod
270 def mark_for_reconnect(self):
271 """
272 Mark the connection to be reconnected on the next command.
273 This is useful when a connection is moved to a different node.
274 """
275 pass
276
277 @abstractmethod
278 def should_reconnect(self):
279 """
280 Returns True if the connection should be reconnected.
281 """
282 pass
283
284 @abstractmethod
285 def get_resolved_ip(self):
286 """
287 Get resolved ip address for the connection.
288 """
289 pass
290
291 @abstractmethod
292 def update_current_socket_timeout(self, relax_timeout: Optional[float] = None):
293 """
294 Update the timeout for the current socket.
295 """
296 pass
297
298 @abstractmethod
299 def set_tmp_settings(
300 self,
301 tmp_host_address: Optional[str] = None,
302 tmp_relax_timeout: Optional[float] = None,
303 ):
304 """
305 Updates temporary host address and timeout settings for the connection.
306 """
307 pass
308
309 @abstractmethod
310 def reset_tmp_settings(
311 self,
312 reset_host_address: bool = False,
313 reset_relax_timeout: bool = False,
314 ):
315 """
316 Resets temporary host address and timeout settings for the connection.
317 """
318 pass
319
320
321class AbstractConnection(ConnectionInterface):
322 "Manages communication to and from a Redis server"
323
324 def __init__(
325 self,
326 db: int = 0,
327 password: Optional[str] = None,
328 socket_timeout: Optional[float] = None,
329 socket_connect_timeout: Optional[float] = None,
330 retry_on_timeout: bool = False,
331 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
332 encoding: str = "utf-8",
333 encoding_errors: str = "strict",
334 decode_responses: bool = False,
335 parser_class=DefaultParser,
336 socket_read_size: int = 65536,
337 health_check_interval: int = 0,
338 client_name: Optional[str] = None,
339 lib_name: Optional[str] = "redis-py",
340 lib_version: Optional[str] = get_lib_version(),
341 username: Optional[str] = None,
342 retry: Union[Any, None] = None,
343 redis_connect_func: Optional[Callable[[], None]] = None,
344 credential_provider: Optional[CredentialProvider] = None,
345 protocol: Optional[int] = 2,
346 command_packer: Optional[Callable[[], None]] = None,
347 event_dispatcher: Optional[EventDispatcher] = None,
348 maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None,
349 maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
350 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
351 maintenance_event_hash: Optional[int] = None,
352 orig_host_address: Optional[str] = None,
353 orig_socket_timeout: Optional[float] = None,
354 orig_socket_connect_timeout: Optional[float] = None,
355 ):
356 """
357 Initialize a new Connection.
358 To specify a retry policy for specific errors, first set
359 `retry_on_error` to a list of the error/s to retry on, then set
360 `retry` to a valid `Retry` object.
361 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
362 """
363 if (username or password) and credential_provider is not None:
364 raise DataError(
365 "'username' and 'password' cannot be passed along with 'credential_"
366 "provider'. Please provide only one of the following arguments: \n"
367 "1. 'password' and (optional) 'username'\n"
368 "2. 'credential_provider'"
369 )
370 if event_dispatcher is None:
371 self._event_dispatcher = EventDispatcher()
372 else:
373 self._event_dispatcher = event_dispatcher
374 self.pid = os.getpid()
375 self.db = db
376 self.client_name = client_name
377 self.lib_name = lib_name
378 self.lib_version = lib_version
379 self.credential_provider = credential_provider
380 self.password = password
381 self.username = username
382 self.socket_timeout = socket_timeout
383 if socket_connect_timeout is None:
384 socket_connect_timeout = socket_timeout
385 self.socket_connect_timeout = socket_connect_timeout
386 self.retry_on_timeout = retry_on_timeout
387 if retry_on_error is SENTINEL:
388 retry_on_errors_list = []
389 else:
390 retry_on_errors_list = list(retry_on_error)
391 if retry_on_timeout:
392 # Add TimeoutError to the errors list to retry on
393 retry_on_errors_list.append(TimeoutError)
394 self.retry_on_error = retry_on_errors_list
395 if retry or self.retry_on_error:
396 if retry is None:
397 self.retry = Retry(NoBackoff(), 1)
398 else:
399 # deep-copy the Retry object as it is mutable
400 self.retry = copy.deepcopy(retry)
401 if self.retry_on_error:
402 # Update the retry's supported errors with the specified errors
403 self.retry.update_supported_errors(self.retry_on_error)
404 else:
405 self.retry = Retry(NoBackoff(), 0)
406 self.health_check_interval = health_check_interval
407 self.next_health_check = 0
408 self.redis_connect_func = redis_connect_func
409 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
410 self.handshake_metadata = None
411 self._sock = None
412 self._socket_read_size = socket_read_size
413 self._connect_callbacks = []
414 self._buffer_cutoff = 6000
415 self._re_auth_token: Optional[TokenInterface] = None
416 try:
417 p = int(protocol)
418 except TypeError:
419 p = DEFAULT_RESP_VERSION
420 except ValueError:
421 raise ConnectionError("protocol must be an integer")
422 finally:
423 if p < 2 or p > 3:
424 raise ConnectionError("protocol must be either 2 or 3")
425 # p = DEFAULT_RESP_VERSION
426 self.protocol = p
427 if self.protocol == 3 and parser_class == DefaultParser:
428 parser_class = _RESP3Parser
429 self.set_parser(parser_class)
430
431 self.maintenance_events_config = maintenance_events_config
432
433 # Set up maintenance events if enabled
434 self._configure_maintenance_events(
435 maintenance_events_pool_handler,
436 orig_host_address,
437 orig_socket_timeout,
438 orig_socket_connect_timeout,
439 )
440
441 self._should_reconnect = False
442 self.maintenance_state = maintenance_state
443 self.maintenance_event_hash = maintenance_event_hash
444
445 self._command_packer = self._construct_command_packer(command_packer)
446
447 def __repr__(self):
448 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
449 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
450
451 @abstractmethod
452 def repr_pieces(self):
453 pass
454
455 def __del__(self):
456 try:
457 self.disconnect()
458 except Exception:
459 pass
460
461 def _construct_command_packer(self, packer):
462 if packer is not None:
463 return packer
464 elif HIREDIS_AVAILABLE:
465 return HiredisRespSerializer()
466 else:
467 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
468
469 def register_connect_callback(self, callback):
470 """
471 Register a callback to be called when the connection is established either
472 initially or reconnected. This allows listeners to issue commands that
473 are ephemeral to the connection, for example pub/sub subscription or
474 key tracking. The callback must be a _method_ and will be kept as
475 a weak reference.
476 """
477 wm = weakref.WeakMethod(callback)
478 if wm not in self._connect_callbacks:
479 self._connect_callbacks.append(wm)
480
481 def deregister_connect_callback(self, callback):
482 """
483 De-register a previously registered callback. It will no-longer receive
484 notifications on connection events. Calling this is not required when the
485 listener goes away, since the callbacks are kept as weak methods.
486 """
487 try:
488 self._connect_callbacks.remove(weakref.WeakMethod(callback))
489 except ValueError:
490 pass
491
492 def set_parser(self, parser_class):
493 """
494 Creates a new instance of parser_class with socket size:
495 _socket_read_size and assigns it to the parser for the connection
496 :param parser_class: The required parser class
497 """
498 self._parser = parser_class(socket_read_size=self._socket_read_size)
499
500 def _configure_maintenance_events(
501 self,
502 maintenance_events_pool_handler=None,
503 orig_host_address=None,
504 orig_socket_timeout=None,
505 orig_socket_connect_timeout=None,
506 ):
507 """Enable maintenance events by setting up handlers and storing original connection parameters."""
508 if (
509 not self.maintenance_events_config
510 or not self.maintenance_events_config.enabled
511 ):
512 self._maintenance_event_connection_handler = None
513 return
514
515 # Set up pool handler if available
516 if maintenance_events_pool_handler:
517 self._parser.set_node_moving_push_handler(
518 maintenance_events_pool_handler.handle_event
519 )
520
521 # Set up connection handler
522 self._maintenance_event_connection_handler = MaintenanceEventConnectionHandler(
523 self, self.maintenance_events_config
524 )
525 self._parser.set_maintenance_push_handler(
526 self._maintenance_event_connection_handler.handle_event
527 )
528
529 # Store original connection parameters
530 self.orig_host_address = orig_host_address if orig_host_address else self.host
531 self.orig_socket_timeout = (
532 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
533 )
534 self.orig_socket_connect_timeout = (
535 orig_socket_connect_timeout
536 if orig_socket_connect_timeout
537 else self.socket_connect_timeout
538 )
539
540 def set_maintenance_event_pool_handler(
541 self, maintenance_event_pool_handler: MaintenanceEventPoolHandler
542 ):
543 maintenance_event_pool_handler.set_connection(self)
544 self._parser.set_node_moving_push_handler(
545 maintenance_event_pool_handler.handle_event
546 )
547
548 # Update maintenance event connection handler if it doesn't exist
549 if not self._maintenance_event_connection_handler:
550 self._maintenance_event_connection_handler = (
551 MaintenanceEventConnectionHandler(
552 self, maintenance_event_pool_handler.config
553 )
554 )
555 self._parser.set_maintenance_push_handler(
556 self._maintenance_event_connection_handler.handle_event
557 )
558 else:
559 self._maintenance_event_connection_handler.config = (
560 maintenance_event_pool_handler.config
561 )
562
563 def connect(self):
564 "Connects to the Redis server if not already connected"
565 self.connect_check_health(check_health=True)
566
567 def connect_check_health(
568 self, check_health: bool = True, retry_socket_connect: bool = True
569 ):
570 if self._sock:
571 return
572 try:
573 if retry_socket_connect:
574 sock = self.retry.call_with_retry(
575 lambda: self._connect(), lambda error: self.disconnect(error)
576 )
577 else:
578 sock = self._connect()
579 except socket.timeout:
580 raise TimeoutError("Timeout connecting to server")
581 except OSError as e:
582 raise ConnectionError(self._error_message(e))
583
584 self._sock = sock
585 try:
586 if self.redis_connect_func is None:
587 # Use the default on_connect function
588 self.on_connect_check_health(check_health=check_health)
589 else:
590 # Use the passed function redis_connect_func
591 self.redis_connect_func(self)
592 except RedisError:
593 # clean up after any error in on_connect
594 self.disconnect()
595 raise
596
597 # run any user callbacks. right now the only internal callback
598 # is for pubsub channel/pattern resubscription
599 # first, remove any dead weakrefs
600 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
601 for ref in self._connect_callbacks:
602 callback = ref()
603 if callback:
604 callback(self)
605
606 @abstractmethod
607 def _connect(self):
608 pass
609
610 @abstractmethod
611 def _host_error(self):
612 pass
613
614 def _error_message(self, exception):
615 return format_error_message(self._host_error(), exception)
616
617 def on_connect(self):
618 self.on_connect_check_health(check_health=True)
619
620 def on_connect_check_health(self, check_health: bool = True):
621 "Initialize the connection, authenticate and select a database"
622 self._parser.on_connect(self)
623 parser = self._parser
624
625 auth_args = None
626 # if credential provider or username and/or password are set, authenticate
627 if self.credential_provider or (self.username or self.password):
628 cred_provider = (
629 self.credential_provider
630 or UsernamePasswordCredentialProvider(self.username, self.password)
631 )
632 auth_args = cred_provider.get_credentials()
633
634 # if resp version is specified and we have auth args,
635 # we need to send them via HELLO
636 if auth_args and self.protocol not in [2, "2"]:
637 if isinstance(self._parser, _RESP2Parser):
638 self.set_parser(_RESP3Parser)
639 # update cluster exception classes
640 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
641 self._parser.on_connect(self)
642 if len(auth_args) == 1:
643 auth_args = ["default", auth_args[0]]
644 # avoid checking health here -- PING will fail if we try
645 # to check the health prior to the AUTH
646 self.send_command(
647 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
648 )
649 self.handshake_metadata = self.read_response()
650 # if response.get(b"proto") != self.protocol and response.get(
651 # "proto"
652 # ) != self.protocol:
653 # raise ConnectionError("Invalid RESP version")
654 elif auth_args:
655 # avoid checking health here -- PING will fail if we try
656 # to check the health prior to the AUTH
657 self.send_command("AUTH", *auth_args, check_health=False)
658
659 try:
660 auth_response = self.read_response()
661 except AuthenticationWrongNumberOfArgsError:
662 # a username and password were specified but the Redis
663 # server seems to be < 6.0.0 which expects a single password
664 # arg. retry auth with just the password.
665 # https://github.com/andymccurdy/redis-py/issues/1274
666 self.send_command("AUTH", auth_args[-1], check_health=False)
667 auth_response = self.read_response()
668
669 if str_if_bytes(auth_response) != "OK":
670 raise AuthenticationError("Invalid Username or Password")
671
672 # if resp version is specified, switch to it
673 elif self.protocol not in [2, "2"]:
674 if isinstance(self._parser, _RESP2Parser):
675 self.set_parser(_RESP3Parser)
676 # update cluster exception classes
677 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
678 self._parser.on_connect(self)
679 self.send_command("HELLO", self.protocol, check_health=check_health)
680 self.handshake_metadata = self.read_response()
681 if (
682 self.handshake_metadata.get(b"proto") != self.protocol
683 and self.handshake_metadata.get("proto") != self.protocol
684 ):
685 raise ConnectionError("Invalid RESP version")
686
687 # Send maintenance notifications handshake if RESP3 is active and maintenance events are enabled
688 # and we have a host to determine the endpoint type from
689 if (
690 self.protocol not in [2, "2"]
691 and self.maintenance_events_config
692 and self.maintenance_events_config.enabled
693 and self._maintenance_event_connection_handler
694 and hasattr(self, "host")
695 ):
696 try:
697 endpoint_type = self.maintenance_events_config.get_endpoint_type(
698 self.host, self
699 )
700 self.send_command(
701 "CLIENT",
702 "MAINT_NOTIFICATIONS",
703 "ON",
704 "moving-endpoint-type",
705 endpoint_type.value,
706 check_health=check_health,
707 )
708 response = self.read_response()
709 if str_if_bytes(response) != "OK":
710 raise ConnectionError(
711 "The server doesn't support maintenance notifications"
712 )
713 except Exception as e:
714 # Log warning but don't fail the connection
715 import logging
716
717 logger = logging.getLogger(__name__)
718 logger.warning(f"Failed to enable maintenance notifications: {e}")
719
720 # if a client_name is given, set it
721 if self.client_name:
722 self.send_command(
723 "CLIENT",
724 "SETNAME",
725 self.client_name,
726 check_health=check_health,
727 )
728 if str_if_bytes(self.read_response()) != "OK":
729 raise ConnectionError("Error setting client name")
730
731 try:
732 # set the library name and version
733 if self.lib_name:
734 self.send_command(
735 "CLIENT",
736 "SETINFO",
737 "LIB-NAME",
738 self.lib_name,
739 check_health=check_health,
740 )
741 self.read_response()
742 except ResponseError:
743 pass
744
745 try:
746 if self.lib_version:
747 self.send_command(
748 "CLIENT",
749 "SETINFO",
750 "LIB-VER",
751 self.lib_version,
752 check_health=check_health,
753 )
754 self.read_response()
755 except ResponseError:
756 pass
757
758 # if a database is specified, switch to it
759 if self.db:
760 self.send_command("SELECT", self.db, check_health=check_health)
761 if str_if_bytes(self.read_response()) != "OK":
762 raise ConnectionError("Invalid Database")
763
764 def disconnect(self, *args):
765 "Disconnects from the Redis server"
766 self._parser.on_disconnect()
767
768 conn_sock = self._sock
769 self._sock = None
770 # reset the reconnect flag
771 self._should_reconnect = False
772 if conn_sock is None:
773 return
774
775 if os.getpid() == self.pid:
776 try:
777 conn_sock.shutdown(socket.SHUT_RDWR)
778 except (OSError, TypeError):
779 pass
780
781 try:
782 conn_sock.close()
783 except OSError:
784 pass
785
786 def _send_ping(self):
787 """Send PING, expect PONG in return"""
788 self.send_command("PING", check_health=False)
789 if str_if_bytes(self.read_response()) != "PONG":
790 raise ConnectionError("Bad response from PING health check")
791
792 def _ping_failed(self, error):
793 """Function to call when PING fails"""
794 self.disconnect()
795
796 def check_health(self):
797 """Check the health of the connection with a PING/PONG"""
798 if self.health_check_interval and time.monotonic() > self.next_health_check:
799 self.retry.call_with_retry(self._send_ping, self._ping_failed)
800
801 def send_packed_command(self, command, check_health=True):
802 """Send an already packed command to the Redis server"""
803 if not self._sock:
804 self.connect_check_health(check_health=False)
805 # guard against health check recursion
806 if check_health:
807 self.check_health()
808 try:
809 if isinstance(command, str):
810 command = [command]
811 for item in command:
812 self._sock.sendall(item)
813 except socket.timeout:
814 self.disconnect()
815 raise TimeoutError("Timeout writing to socket")
816 except OSError as e:
817 self.disconnect()
818 if len(e.args) == 1:
819 errno, errmsg = "UNKNOWN", e.args[0]
820 else:
821 errno = e.args[0]
822 errmsg = e.args[1]
823 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
824 except BaseException:
825 # BaseExceptions can be raised when a socket send operation is not
826 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
827 # to send un-sent data. However, the send_packed_command() API
828 # does not support it so there is no point in keeping the connection open.
829 self.disconnect()
830 raise
831
832 def send_command(self, *args, **kwargs):
833 """Pack and send a command to the Redis server"""
834 self.send_packed_command(
835 self._command_packer.pack(*args),
836 check_health=kwargs.get("check_health", True),
837 )
838
839 def can_read(self, timeout=0):
840 """Poll the socket to see if there's data that can be read."""
841 sock = self._sock
842 if not sock:
843 self.connect()
844
845 host_error = self._host_error()
846
847 try:
848 return self._parser.can_read(timeout)
849
850 except OSError as e:
851 self.disconnect()
852 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
853
854 def read_response(
855 self,
856 disable_decoding=False,
857 *,
858 disconnect_on_error=True,
859 push_request=False,
860 ):
861 """Read the response from a previously sent command"""
862
863 host_error = self._host_error()
864
865 try:
866 if self.protocol in ["3", 3]:
867 response = self._parser.read_response(
868 disable_decoding=disable_decoding, push_request=push_request
869 )
870 else:
871 response = self._parser.read_response(disable_decoding=disable_decoding)
872 except socket.timeout:
873 if disconnect_on_error:
874 self.disconnect()
875 raise TimeoutError(f"Timeout reading from {host_error}")
876 except OSError as e:
877 if disconnect_on_error:
878 self.disconnect()
879 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
880 except BaseException:
881 # Also by default close in case of BaseException. A lot of code
882 # relies on this behaviour when doing Command/Response pairs.
883 # See #1128.
884 if disconnect_on_error:
885 self.disconnect()
886 raise
887
888 if self.health_check_interval:
889 self.next_health_check = time.monotonic() + self.health_check_interval
890
891 if isinstance(response, ResponseError):
892 try:
893 raise response
894 finally:
895 del response # avoid creating ref cycles
896 return response
897
898 def pack_command(self, *args):
899 """Pack a series of arguments into the Redis protocol"""
900 return self._command_packer.pack(*args)
901
902 def pack_commands(self, commands):
903 """Pack multiple commands into the Redis protocol"""
904 output = []
905 pieces = []
906 buffer_length = 0
907 buffer_cutoff = self._buffer_cutoff
908
909 for cmd in commands:
910 for chunk in self._command_packer.pack(*cmd):
911 chunklen = len(chunk)
912 if (
913 buffer_length > buffer_cutoff
914 or chunklen > buffer_cutoff
915 or isinstance(chunk, memoryview)
916 ):
917 if pieces:
918 output.append(SYM_EMPTY.join(pieces))
919 buffer_length = 0
920 pieces = []
921
922 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
923 output.append(chunk)
924 else:
925 pieces.append(chunk)
926 buffer_length += chunklen
927
928 if pieces:
929 output.append(SYM_EMPTY.join(pieces))
930 return output
931
932 def get_protocol(self) -> Union[int, str]:
933 return self.protocol
934
935 @property
936 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
937 return self._handshake_metadata
938
939 @handshake_metadata.setter
940 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
941 self._handshake_metadata = value
942
943 def set_re_auth_token(self, token: TokenInterface):
944 self._re_auth_token = token
945
946 def re_auth(self):
947 if self._re_auth_token is not None:
948 self.send_command(
949 "AUTH",
950 self._re_auth_token.try_get("oid"),
951 self._re_auth_token.get_value(),
952 )
953 self.read_response()
954 self._re_auth_token = None
955
956 def get_resolved_ip(self) -> Optional[str]:
957 """
958 Extract the resolved IP address from an
959 established connection or resolve it from the host.
960
961 First tries to get the actual IP from the socket (most accurate),
962 then falls back to DNS resolution if needed.
963
964 Args:
965 connection: The connection object to extract the IP from
966
967 Returns:
968 str: The resolved IP address, or None if it cannot be determined
969 """
970
971 # Method 1: Try to get the actual IP from the established socket connection
972 # This is most accurate as it shows the exact IP being used
973 try:
974 if self._sock is not None:
975 peer_addr = self._sock.getpeername()
976 if peer_addr and len(peer_addr) >= 1:
977 # For TCP sockets, peer_addr is typically (host, port) tuple
978 # Return just the host part
979 return peer_addr[0]
980 except (AttributeError, OSError):
981 # Socket might not be connected or getpeername() might fail
982 pass
983
984 # Method 2: Fallback to DNS resolution of the host
985 # This is less accurate but works when socket is not available
986 try:
987 host = getattr(self, "host", "localhost")
988 port = getattr(self, "port", 6379)
989 if host:
990 # Use getaddrinfo to resolve the hostname to IP
991 # This mimics what the connection would do during _connect()
992 addr_info = socket.getaddrinfo(
993 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
994 )
995 if addr_info:
996 # Return the IP from the first result
997 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
998 # sockaddr[0] is the IP address
999 return addr_info[0][4][0]
1000 except (AttributeError, OSError, socket.gaierror):
1001 # DNS resolution might fail
1002 pass
1003
1004 return None
1005
1006 @property
1007 def maintenance_state(self) -> MaintenanceState:
1008 return self._maintenance_state
1009
1010 @maintenance_state.setter
1011 def maintenance_state(self, state: "MaintenanceState"):
1012 self._maintenance_state = state
1013
1014 def getpeername(self):
1015 if not self._sock:
1016 return None
1017 return self._sock.getpeername()[0]
1018
1019 def mark_for_reconnect(self):
1020 self._should_reconnect = True
1021
1022 def should_reconnect(self):
1023 return self._should_reconnect
1024
1025 def update_current_socket_timeout(self, relax_timeout: Optional[float] = None):
1026 if self._sock:
1027 timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout
1028 self._sock.settimeout(timeout)
1029 self.update_parser_buffer_timeout(timeout)
1030
1031 def update_parser_buffer_timeout(self, timeout: Optional[float] = None):
1032 if self._parser and self._parser._buffer:
1033 self._parser._buffer.socket_timeout = timeout
1034
1035 def set_tmp_settings(
1036 self,
1037 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
1038 tmp_relax_timeout: Optional[float] = None,
1039 ):
1040 """
1041 The value of SENTINEL is used to indicate that the property should not be updated.
1042 """
1043 if tmp_host_address is not SENTINEL:
1044 self.host = tmp_host_address
1045 if tmp_relax_timeout != -1:
1046 self.socket_timeout = tmp_relax_timeout
1047 self.socket_connect_timeout = tmp_relax_timeout
1048
1049 def reset_tmp_settings(
1050 self,
1051 reset_host_address: bool = False,
1052 reset_relax_timeout: bool = False,
1053 ):
1054 if reset_host_address:
1055 self.host = self.orig_host_address
1056 if reset_relax_timeout:
1057 self.socket_timeout = self.orig_socket_timeout
1058 self.socket_connect_timeout = self.orig_socket_connect_timeout
1059
1060
1061class Connection(AbstractConnection):
1062 "Manages TCP communication to and from a Redis server"
1063
1064 def __init__(
1065 self,
1066 host="localhost",
1067 port=6379,
1068 socket_keepalive=False,
1069 socket_keepalive_options=None,
1070 socket_type=0,
1071 **kwargs,
1072 ):
1073 self.host = host
1074 self.port = int(port)
1075 self.socket_keepalive = socket_keepalive
1076 self.socket_keepalive_options = socket_keepalive_options or {}
1077 self.socket_type = socket_type
1078 super().__init__(**kwargs)
1079
1080 def repr_pieces(self):
1081 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1082 if self.client_name:
1083 pieces.append(("client_name", self.client_name))
1084 return pieces
1085
1086 def _connect(self):
1087 "Create a TCP socket connection"
1088 # we want to mimic what socket.create_connection does to support
1089 # ipv4/ipv6, but we want to set options prior to calling
1090 # socket.connect()
1091 err = None
1092
1093 for res in socket.getaddrinfo(
1094 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1095 ):
1096 family, socktype, proto, canonname, socket_address = res
1097 sock = None
1098 try:
1099 sock = socket.socket(family, socktype, proto)
1100 # TCP_NODELAY
1101 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1102
1103 # TCP_KEEPALIVE
1104 if self.socket_keepalive:
1105 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1106 for k, v in self.socket_keepalive_options.items():
1107 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1108
1109 # set the socket_connect_timeout before we connect
1110 sock.settimeout(self.socket_connect_timeout)
1111
1112 # connect
1113 sock.connect(socket_address)
1114
1115 # set the socket_timeout now that we're connected
1116 sock.settimeout(self.socket_timeout)
1117 return sock
1118
1119 except OSError as _:
1120 err = _
1121 if sock is not None:
1122 try:
1123 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1124 except OSError:
1125 pass
1126 sock.close()
1127
1128 if err is not None:
1129 raise err
1130 raise OSError("socket.getaddrinfo returned an empty list")
1131
1132 def _host_error(self):
1133 return f"{self.host}:{self.port}"
1134
1135
1136class CacheProxyConnection(ConnectionInterface):
1137 DUMMY_CACHE_VALUE = b"foo"
1138 MIN_ALLOWED_VERSION = "7.4.0"
1139 DEFAULT_SERVER_NAME = "redis"
1140
1141 def __init__(
1142 self,
1143 conn: ConnectionInterface,
1144 cache: CacheInterface,
1145 pool_lock: threading.RLock,
1146 ):
1147 self.pid = os.getpid()
1148 self._conn = conn
1149 self.retry = self._conn.retry
1150 self.host = self._conn.host
1151 self.port = self._conn.port
1152 self.credential_provider = conn.credential_provider
1153 self._pool_lock = pool_lock
1154 self._cache = cache
1155 self._cache_lock = threading.RLock()
1156 self._current_command_cache_key = None
1157 self._current_options = None
1158 self.register_connect_callback(self._enable_tracking_callback)
1159
1160 def repr_pieces(self):
1161 return self._conn.repr_pieces()
1162
1163 def register_connect_callback(self, callback):
1164 self._conn.register_connect_callback(callback)
1165
1166 def deregister_connect_callback(self, callback):
1167 self._conn.deregister_connect_callback(callback)
1168
1169 def set_parser(self, parser_class):
1170 self._conn.set_parser(parser_class)
1171
1172 def connect(self):
1173 self._conn.connect()
1174
1175 server_name = self._conn.handshake_metadata.get(b"server", None)
1176 if server_name is None:
1177 server_name = self._conn.handshake_metadata.get("server", None)
1178 server_ver = self._conn.handshake_metadata.get(b"version", None)
1179 if server_ver is None:
1180 server_ver = self._conn.handshake_metadata.get("version", None)
1181 if server_ver is None or server_ver is None:
1182 raise ConnectionError("Cannot retrieve information about server version")
1183
1184 server_ver = ensure_string(server_ver)
1185 server_name = ensure_string(server_name)
1186
1187 if (
1188 server_name != self.DEFAULT_SERVER_NAME
1189 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1190 ):
1191 raise ConnectionError(
1192 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1193 )
1194
1195 def on_connect(self):
1196 self._conn.on_connect()
1197
1198 def disconnect(self, *args):
1199 with self._cache_lock:
1200 self._cache.flush()
1201 self._conn.disconnect(*args)
1202
1203 def check_health(self):
1204 self._conn.check_health()
1205
1206 def send_packed_command(self, command, check_health=True):
1207 # TODO: Investigate if it's possible to unpack command
1208 # or extract keys from packed command
1209 self._conn.send_packed_command(command)
1210
1211 def send_command(self, *args, **kwargs):
1212 self._process_pending_invalidations()
1213
1214 with self._cache_lock:
1215 # Command is write command or not allowed
1216 # to be cached.
1217 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())):
1218 self._current_command_cache_key = None
1219 self._conn.send_command(*args, **kwargs)
1220 return
1221
1222 if kwargs.get("keys") is None:
1223 raise ValueError("Cannot create cache key.")
1224
1225 # Creates cache key.
1226 self._current_command_cache_key = CacheKey(
1227 command=args[0], redis_keys=tuple(kwargs.get("keys"))
1228 )
1229
1230 with self._cache_lock:
1231 # We have to trigger invalidation processing in case if
1232 # it was cached by another connection to avoid
1233 # queueing invalidations in stale connections.
1234 if self._cache.get(self._current_command_cache_key):
1235 entry = self._cache.get(self._current_command_cache_key)
1236
1237 if entry.connection_ref != self._conn:
1238 with self._pool_lock:
1239 while entry.connection_ref.can_read():
1240 entry.connection_ref.read_response(push_request=True)
1241
1242 return
1243
1244 # Set temporary entry value to prevent
1245 # race condition from another connection.
1246 self._cache.set(
1247 CacheEntry(
1248 cache_key=self._current_command_cache_key,
1249 cache_value=self.DUMMY_CACHE_VALUE,
1250 status=CacheEntryStatus.IN_PROGRESS,
1251 connection_ref=self._conn,
1252 )
1253 )
1254
1255 # Send command over socket only if it's allowed
1256 # read-only command that not yet cached.
1257 self._conn.send_command(*args, **kwargs)
1258
1259 def can_read(self, timeout=0):
1260 return self._conn.can_read(timeout)
1261
1262 def read_response(
1263 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1264 ):
1265 with self._cache_lock:
1266 # Check if command response exists in a cache and it's not in progress.
1267 if (
1268 self._current_command_cache_key is not None
1269 and self._cache.get(self._current_command_cache_key) is not None
1270 and self._cache.get(self._current_command_cache_key).status
1271 != CacheEntryStatus.IN_PROGRESS
1272 ):
1273 res = copy.deepcopy(
1274 self._cache.get(self._current_command_cache_key).cache_value
1275 )
1276 self._current_command_cache_key = None
1277 return res
1278
1279 response = self._conn.read_response(
1280 disable_decoding=disable_decoding,
1281 disconnect_on_error=disconnect_on_error,
1282 push_request=push_request,
1283 )
1284
1285 with self._cache_lock:
1286 # Prevent not-allowed command from caching.
1287 if self._current_command_cache_key is None:
1288 return response
1289 # If response is None prevent from caching.
1290 if response is None:
1291 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1292 return response
1293
1294 cache_entry = self._cache.get(self._current_command_cache_key)
1295
1296 # Cache only responses that still valid
1297 # and wasn't invalidated by another connection in meantime.
1298 if cache_entry is not None:
1299 cache_entry.status = CacheEntryStatus.VALID
1300 cache_entry.cache_value = response
1301 self._cache.set(cache_entry)
1302
1303 self._current_command_cache_key = None
1304
1305 return response
1306
1307 def pack_command(self, *args):
1308 return self._conn.pack_command(*args)
1309
1310 def pack_commands(self, commands):
1311 return self._conn.pack_commands(commands)
1312
1313 @property
1314 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1315 return self._conn.handshake_metadata
1316
1317 def _connect(self):
1318 self._conn._connect()
1319
1320 def _host_error(self):
1321 self._conn._host_error()
1322
1323 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1324 conn.send_command("CLIENT", "TRACKING", "ON")
1325 conn.read_response()
1326 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1327
1328 def _process_pending_invalidations(self):
1329 while self.can_read():
1330 self._conn.read_response(push_request=True)
1331
1332 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1333 with self._cache_lock:
1334 # Flush cache when DB flushed on server-side
1335 if data[1] is None:
1336 self._cache.flush()
1337 else:
1338 self._cache.delete_by_redis_keys(data[1])
1339
1340 def get_protocol(self):
1341 return self._conn.get_protocol()
1342
1343 def set_re_auth_token(self, token: TokenInterface):
1344 self._conn.set_re_auth_token(token)
1345
1346 def re_auth(self):
1347 self._conn.re_auth()
1348
1349
1350class SSLConnection(Connection):
1351 """Manages SSL connections to and from the Redis server(s).
1352 This class extends the Connection class, adding SSL functionality, and making
1353 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1354 """ # noqa
1355
1356 def __init__(
1357 self,
1358 ssl_keyfile=None,
1359 ssl_certfile=None,
1360 ssl_cert_reqs="required",
1361 ssl_ca_certs=None,
1362 ssl_ca_data=None,
1363 ssl_check_hostname=True,
1364 ssl_ca_path=None,
1365 ssl_password=None,
1366 ssl_validate_ocsp=False,
1367 ssl_validate_ocsp_stapled=False,
1368 ssl_ocsp_context=None,
1369 ssl_ocsp_expected_cert=None,
1370 ssl_min_version=None,
1371 ssl_ciphers=None,
1372 **kwargs,
1373 ):
1374 """Constructor
1375
1376 Args:
1377 ssl_keyfile: Path to an ssl private key. Defaults to None.
1378 ssl_certfile: Path to an ssl certificate. Defaults to None.
1379 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1380 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1381 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1382 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1383 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1384 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1385
1386 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1387 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1388 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1389 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1390 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1391 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1392
1393 Raises:
1394 RedisError
1395 """ # noqa
1396 if not SSL_AVAILABLE:
1397 raise RedisError("Python wasn't built with SSL support")
1398
1399 self.keyfile = ssl_keyfile
1400 self.certfile = ssl_certfile
1401 if ssl_cert_reqs is None:
1402 ssl_cert_reqs = ssl.CERT_NONE
1403 elif isinstance(ssl_cert_reqs, str):
1404 CERT_REQS = { # noqa: N806
1405 "none": ssl.CERT_NONE,
1406 "optional": ssl.CERT_OPTIONAL,
1407 "required": ssl.CERT_REQUIRED,
1408 }
1409 if ssl_cert_reqs not in CERT_REQS:
1410 raise RedisError(
1411 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1412 )
1413 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1414 self.cert_reqs = ssl_cert_reqs
1415 self.ca_certs = ssl_ca_certs
1416 self.ca_data = ssl_ca_data
1417 self.ca_path = ssl_ca_path
1418 self.check_hostname = (
1419 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1420 )
1421 self.certificate_password = ssl_password
1422 self.ssl_validate_ocsp = ssl_validate_ocsp
1423 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1424 self.ssl_ocsp_context = ssl_ocsp_context
1425 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1426 self.ssl_min_version = ssl_min_version
1427 self.ssl_ciphers = ssl_ciphers
1428 super().__init__(**kwargs)
1429
1430 def _connect(self):
1431 """
1432 Wrap the socket with SSL support, handling potential errors.
1433 """
1434 sock = super()._connect()
1435 try:
1436 return self._wrap_socket_with_ssl(sock)
1437 except (OSError, RedisError):
1438 sock.close()
1439 raise
1440
1441 def _wrap_socket_with_ssl(self, sock):
1442 """
1443 Wraps the socket with SSL support.
1444
1445 Args:
1446 sock: The plain socket to wrap with SSL.
1447
1448 Returns:
1449 An SSL wrapped socket.
1450 """
1451 context = ssl.create_default_context()
1452 context.check_hostname = self.check_hostname
1453 context.verify_mode = self.cert_reqs
1454 if self.certfile or self.keyfile:
1455 context.load_cert_chain(
1456 certfile=self.certfile,
1457 keyfile=self.keyfile,
1458 password=self.certificate_password,
1459 )
1460 if (
1461 self.ca_certs is not None
1462 or self.ca_path is not None
1463 or self.ca_data is not None
1464 ):
1465 context.load_verify_locations(
1466 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1467 )
1468 if self.ssl_min_version is not None:
1469 context.minimum_version = self.ssl_min_version
1470 if self.ssl_ciphers:
1471 context.set_ciphers(self.ssl_ciphers)
1472 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1473 raise RedisError("cryptography is not installed.")
1474
1475 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1476 raise RedisError(
1477 "Either an OCSP staple or pure OCSP connection must be validated "
1478 "- not both."
1479 )
1480
1481 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1482
1483 # validation for the stapled case
1484 if self.ssl_validate_ocsp_stapled:
1485 import OpenSSL
1486
1487 from .ocsp import ocsp_staple_verifier
1488
1489 # if a context is provided use it - otherwise, a basic context
1490 if self.ssl_ocsp_context is None:
1491 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1492 staple_ctx.use_certificate_file(self.certfile)
1493 staple_ctx.use_privatekey_file(self.keyfile)
1494 else:
1495 staple_ctx = self.ssl_ocsp_context
1496
1497 staple_ctx.set_ocsp_client_callback(
1498 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1499 )
1500
1501 # need another socket
1502 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1503 con.request_ocsp()
1504 con.connect((self.host, self.port))
1505 con.do_handshake()
1506 con.shutdown()
1507 return sslsock
1508
1509 # pure ocsp validation
1510 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1511 from .ocsp import OCSPVerifier
1512
1513 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1514 if o.is_valid():
1515 return sslsock
1516 else:
1517 raise ConnectionError("ocsp validation error")
1518 return sslsock
1519
1520
1521class UnixDomainSocketConnection(AbstractConnection):
1522 "Manages UDS communication to and from a Redis server"
1523
1524 def __init__(self, path="", socket_timeout=None, **kwargs):
1525 super().__init__(**kwargs)
1526 self.path = path
1527 self.socket_timeout = socket_timeout
1528
1529 def repr_pieces(self):
1530 pieces = [("path", self.path), ("db", self.db)]
1531 if self.client_name:
1532 pieces.append(("client_name", self.client_name))
1533 return pieces
1534
1535 def _connect(self):
1536 "Create a Unix domain socket connection"
1537 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1538 sock.settimeout(self.socket_connect_timeout)
1539 try:
1540 sock.connect(self.path)
1541 except OSError:
1542 # Prevent ResourceWarnings for unclosed sockets.
1543 try:
1544 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1545 except OSError:
1546 pass
1547 sock.close()
1548 raise
1549 sock.settimeout(self.socket_timeout)
1550 return sock
1551
1552 def _host_error(self):
1553 return self.path
1554
1555
1556FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1557
1558
1559def to_bool(value):
1560 if value is None or value == "":
1561 return None
1562 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1563 return False
1564 return bool(value)
1565
1566
1567URL_QUERY_ARGUMENT_PARSERS = {
1568 "db": int,
1569 "socket_timeout": float,
1570 "socket_connect_timeout": float,
1571 "socket_keepalive": to_bool,
1572 "retry_on_timeout": to_bool,
1573 "retry_on_error": list,
1574 "max_connections": int,
1575 "health_check_interval": int,
1576 "ssl_check_hostname": to_bool,
1577 "timeout": float,
1578}
1579
1580
1581def parse_url(url):
1582 if not (
1583 url.startswith("redis://")
1584 or url.startswith("rediss://")
1585 or url.startswith("unix://")
1586 ):
1587 raise ValueError(
1588 "Redis URL must specify one of the following "
1589 "schemes (redis://, rediss://, unix://)"
1590 )
1591
1592 url = urlparse(url)
1593 kwargs = {}
1594
1595 for name, value in parse_qs(url.query).items():
1596 if value and len(value) > 0:
1597 value = unquote(value[0])
1598 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1599 if parser:
1600 try:
1601 kwargs[name] = parser(value)
1602 except (TypeError, ValueError):
1603 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1604 else:
1605 kwargs[name] = value
1606
1607 if url.username:
1608 kwargs["username"] = unquote(url.username)
1609 if url.password:
1610 kwargs["password"] = unquote(url.password)
1611
1612 # We only support redis://, rediss:// and unix:// schemes.
1613 if url.scheme == "unix":
1614 if url.path:
1615 kwargs["path"] = unquote(url.path)
1616 kwargs["connection_class"] = UnixDomainSocketConnection
1617
1618 else: # implied: url.scheme in ("redis", "rediss"):
1619 if url.hostname:
1620 kwargs["host"] = unquote(url.hostname)
1621 if url.port:
1622 kwargs["port"] = int(url.port)
1623
1624 # If there's a path argument, use it as the db argument if a
1625 # querystring value wasn't specified
1626 if url.path and "db" not in kwargs:
1627 try:
1628 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1629 except (AttributeError, ValueError):
1630 pass
1631
1632 if url.scheme == "rediss":
1633 kwargs["connection_class"] = SSLConnection
1634
1635 return kwargs
1636
1637
1638_CP = TypeVar("_CP", bound="ConnectionPool")
1639
1640
1641class ConnectionPool:
1642 """
1643 Create a connection pool. ``If max_connections`` is set, then this
1644 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
1645 limit is reached.
1646
1647 By default, TCP connections are created unless ``connection_class``
1648 is specified. Use class:`.UnixDomainSocketConnection` for
1649 unix sockets.
1650 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1651
1652 Any additional keyword arguments are passed to the constructor of
1653 ``connection_class``.
1654 """
1655
1656 @classmethod
1657 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1658 """
1659 Return a connection pool configured from the given URL.
1660
1661 For example::
1662
1663 redis://[[username]:[password]]@localhost:6379/0
1664 rediss://[[username]:[password]]@localhost:6379/0
1665 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1666
1667 Three URL schemes are supported:
1668
1669 - `redis://` creates a TCP socket connection. See more at:
1670 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1671 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1672 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1673 - ``unix://``: creates a Unix Domain Socket connection.
1674
1675 The username, password, hostname, path and all querystring values
1676 are passed through urllib.parse.unquote in order to replace any
1677 percent-encoded values with their corresponding characters.
1678
1679 There are several ways to specify a database number. The first value
1680 found will be used:
1681
1682 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1683 2. If using the redis:// or rediss:// schemes, the path argument
1684 of the url, e.g. redis://localhost/0
1685 3. A ``db`` keyword argument to this function.
1686
1687 If none of these options are specified, the default db=0 is used.
1688
1689 All querystring options are cast to their appropriate Python types.
1690 Boolean arguments can be specified with string values "True"/"False"
1691 or "Yes"/"No". Values that cannot be properly cast cause a
1692 ``ValueError`` to be raised. Once parsed, the querystring arguments
1693 and keyword arguments are passed to the ``ConnectionPool``'s
1694 class initializer. In the case of conflicting arguments, querystring
1695 arguments always win.
1696 """
1697 url_options = parse_url(url)
1698
1699 if "connection_class" in kwargs:
1700 url_options["connection_class"] = kwargs["connection_class"]
1701
1702 kwargs.update(url_options)
1703 return cls(**kwargs)
1704
1705 def __init__(
1706 self,
1707 connection_class=Connection,
1708 max_connections: Optional[int] = None,
1709 cache_factory: Optional[CacheFactoryInterface] = None,
1710 **connection_kwargs,
1711 ):
1712 max_connections = max_connections or 2**31
1713 if not isinstance(max_connections, int) or max_connections < 0:
1714 raise ValueError('"max_connections" must be a positive integer')
1715
1716 self.connection_class = connection_class
1717 self.connection_kwargs = connection_kwargs
1718 self.max_connections = max_connections
1719 self.cache = None
1720 self._cache_factory = cache_factory
1721
1722 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1723 if connection_kwargs.get("protocol") not in [3, "3"]:
1724 raise RedisError("Client caching is only supported with RESP version 3")
1725
1726 cache = self.connection_kwargs.get("cache")
1727
1728 if cache is not None:
1729 if not isinstance(cache, CacheInterface):
1730 raise ValueError("Cache must implement CacheInterface")
1731
1732 self.cache = cache
1733 else:
1734 if self._cache_factory is not None:
1735 self.cache = self._cache_factory.get_cache()
1736 else:
1737 self.cache = CacheFactory(
1738 self.connection_kwargs.get("cache_config")
1739 ).get_cache()
1740
1741 connection_kwargs.pop("cache", None)
1742 connection_kwargs.pop("cache_config", None)
1743
1744 if connection_kwargs.get(
1745 "maintenance_events_pool_handler"
1746 ) or connection_kwargs.get("maintenance_events_config"):
1747 if connection_kwargs.get("protocol") not in [3, "3"]:
1748 raise RedisError(
1749 "Push handlers on connection are only supported with RESP version 3"
1750 )
1751 config = connection_kwargs.get("maintenance_events_config", None) or (
1752 connection_kwargs.get("maintenance_events_pool_handler").config
1753 if connection_kwargs.get("maintenance_events_pool_handler")
1754 else None
1755 )
1756
1757 if config and config.enabled:
1758 connection_kwargs.update(
1759 {
1760 "orig_host_address": connection_kwargs.get("host"),
1761 "orig_socket_timeout": connection_kwargs.get(
1762 "socket_timeout", None
1763 ),
1764 "orig_socket_connect_timeout": connection_kwargs.get(
1765 "socket_connect_timeout", None
1766 ),
1767 }
1768 )
1769
1770 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1771 if self._event_dispatcher is None:
1772 self._event_dispatcher = EventDispatcher()
1773
1774 # a lock to protect the critical section in _checkpid().
1775 # this lock is acquired when the process id changes, such as
1776 # after a fork. during this time, multiple threads in the child
1777 # process could attempt to acquire this lock. the first thread
1778 # to acquire the lock will reset the data structures and lock
1779 # object of this pool. subsequent threads acquiring this lock
1780 # will notice the first thread already did the work and simply
1781 # release the lock.
1782
1783 self._fork_lock = threading.RLock()
1784 self._lock = threading.RLock()
1785
1786 self.reset()
1787
1788 def __repr__(self) -> str:
1789 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1790 return (
1791 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1792 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1793 f"({conn_kwargs})>)>"
1794 )
1795
1796 def get_protocol(self):
1797 """
1798 Returns:
1799 The RESP protocol version, or ``None`` if the protocol is not specified,
1800 in which case the server default will be used.
1801 """
1802 return self.connection_kwargs.get("protocol", None)
1803
1804 def maintenance_events_pool_handler_enabled(self):
1805 """
1806 Returns:
1807 True if the maintenance events pool handler is enabled, False otherwise.
1808 """
1809 maintenance_events_config = self.connection_kwargs.get(
1810 "maintenance_events_config", None
1811 )
1812
1813 return maintenance_events_config and maintenance_events_config.enabled
1814
1815 def set_maintenance_events_pool_handler(
1816 self, maintenance_events_pool_handler: MaintenanceEventPoolHandler
1817 ):
1818 self.connection_kwargs.update(
1819 {
1820 "maintenance_events_pool_handler": maintenance_events_pool_handler,
1821 "maintenance_events_config": maintenance_events_pool_handler.config,
1822 }
1823 )
1824
1825 self._update_maintenance_events_configs_for_connections(
1826 maintenance_events_pool_handler
1827 )
1828
1829 def _update_maintenance_events_configs_for_connections(
1830 self, maintenance_events_pool_handler
1831 ):
1832 """Update the maintenance events config for all connections in the pool."""
1833 with self._lock:
1834 for conn in self._available_connections:
1835 conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler)
1836 conn.maintenance_events_config = maintenance_events_pool_handler.config
1837 for conn in self._in_use_connections:
1838 conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler)
1839 conn.maintenance_events_config = maintenance_events_pool_handler.config
1840
1841 def reset(self) -> None:
1842 self._created_connections = 0
1843 self._available_connections = []
1844 self._in_use_connections = set()
1845
1846 # this must be the last operation in this method. while reset() is
1847 # called when holding _fork_lock, other threads in this process
1848 # can call _checkpid() which compares self.pid and os.getpid() without
1849 # holding any lock (for performance reasons). keeping this assignment
1850 # as the last operation ensures that those other threads will also
1851 # notice a pid difference and block waiting for the first thread to
1852 # release _fork_lock. when each of these threads eventually acquire
1853 # _fork_lock, they will notice that another thread already called
1854 # reset() and they will immediately release _fork_lock and continue on.
1855 self.pid = os.getpid()
1856
1857 def _checkpid(self) -> None:
1858 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
1859 # systems. this is called by all ConnectionPool methods that
1860 # manipulate the pool's state such as get_connection() and release().
1861 #
1862 # _checkpid() determines whether the process has forked by comparing
1863 # the current process id to the process id saved on the ConnectionPool
1864 # instance. if these values are the same, _checkpid() simply returns.
1865 #
1866 # when the process ids differ, _checkpid() assumes that the process
1867 # has forked and that we're now running in the child process. the child
1868 # process cannot use the parent's file descriptors (e.g., sockets).
1869 # therefore, when _checkpid() sees the process id change, it calls
1870 # reset() in order to reinitialize the child's ConnectionPool. this
1871 # will cause the child to make all new connection objects.
1872 #
1873 # _checkpid() is protected by self._fork_lock to ensure that multiple
1874 # threads in the child process do not call reset() multiple times.
1875 #
1876 # there is an extremely small chance this could fail in the following
1877 # scenario:
1878 # 1. process A calls _checkpid() for the first time and acquires
1879 # self._fork_lock.
1880 # 2. while holding self._fork_lock, process A forks (the fork()
1881 # could happen in a different thread owned by process A)
1882 # 3. process B (the forked child process) inherits the
1883 # ConnectionPool's state from the parent. that state includes
1884 # a locked _fork_lock. process B will not be notified when
1885 # process A releases the _fork_lock and will thus never be
1886 # able to acquire the _fork_lock.
1887 #
1888 # to mitigate this possible deadlock, _checkpid() will only wait 5
1889 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1890 # that time it is assumed that the child is deadlocked and a
1891 # redis.ChildDeadlockedError error is raised.
1892 if self.pid != os.getpid():
1893 acquired = self._fork_lock.acquire(timeout=5)
1894 if not acquired:
1895 raise ChildDeadlockedError
1896 # reset() the instance for the new process if another thread
1897 # hasn't already done so
1898 try:
1899 if self.pid != os.getpid():
1900 self.reset()
1901 finally:
1902 self._fork_lock.release()
1903
1904 @deprecated_args(
1905 args_to_warn=["*"],
1906 reason="Use get_connection() without args instead",
1907 version="5.3.0",
1908 )
1909 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
1910 "Get a connection from the pool"
1911
1912 self._checkpid()
1913 with self._lock:
1914 try:
1915 connection = self._available_connections.pop()
1916 except IndexError:
1917 connection = self.make_connection()
1918 self._in_use_connections.add(connection)
1919
1920 try:
1921 # ensure this connection is connected to Redis
1922 connection.connect()
1923 # connections that the pool provides should be ready to send
1924 # a command. if not, the connection was either returned to the
1925 # pool before all data has been read or the socket has been
1926 # closed. either way, reconnect and verify everything is good.
1927 try:
1928 if (
1929 connection.can_read()
1930 and self.cache is None
1931 and not self.maintenance_events_pool_handler_enabled()
1932 ):
1933 raise ConnectionError("Connection has data")
1934 except (ConnectionError, TimeoutError, OSError):
1935 connection.disconnect()
1936 connection.connect()
1937 if connection.can_read():
1938 raise ConnectionError("Connection not ready")
1939 except BaseException:
1940 # release the connection back to the pool so that we don't
1941 # leak it
1942 self.release(connection)
1943 raise
1944 return connection
1945
1946 def get_encoder(self) -> Encoder:
1947 "Return an encoder based on encoding settings"
1948 kwargs = self.connection_kwargs
1949 return Encoder(
1950 encoding=kwargs.get("encoding", "utf-8"),
1951 encoding_errors=kwargs.get("encoding_errors", "strict"),
1952 decode_responses=kwargs.get("decode_responses", False),
1953 )
1954
1955 def make_connection(self) -> "ConnectionInterface":
1956 "Create a new connection"
1957 if self._created_connections >= self.max_connections:
1958 raise MaxConnectionsError("Too many connections")
1959 self._created_connections += 1
1960
1961 kwargs = dict(self.connection_kwargs)
1962
1963 if self.cache is not None:
1964 return CacheProxyConnection(
1965 self.connection_class(**kwargs), self.cache, self._lock
1966 )
1967 return self.connection_class(**kwargs)
1968
1969 def release(self, connection: "Connection") -> None:
1970 "Releases the connection back to the pool"
1971 self._checkpid()
1972 with self._lock:
1973 try:
1974 self._in_use_connections.remove(connection)
1975 except KeyError:
1976 # Gracefully fail when a connection is returned to this pool
1977 # that the pool doesn't actually own
1978 return
1979
1980 if self.owns_connection(connection):
1981 if connection.should_reconnect():
1982 connection.disconnect()
1983 self._available_connections.append(connection)
1984 self._event_dispatcher.dispatch(
1985 AfterConnectionReleasedEvent(connection)
1986 )
1987 else:
1988 # Pool doesn't own this connection, do not add it back
1989 # to the pool.
1990 # The created connections count should not be changed,
1991 # because the connection was not created by the pool.
1992 connection.disconnect()
1993 return
1994
1995 def owns_connection(self, connection: "Connection") -> int:
1996 return connection.pid == self.pid
1997
1998 def disconnect(self, inuse_connections: bool = True) -> None:
1999 """
2000 Disconnects connections in the pool
2001
2002 If ``inuse_connections`` is True, disconnect connections that are
2003 current in use, potentially by other threads. Otherwise only disconnect
2004 connections that are idle in the pool.
2005 """
2006 self._checkpid()
2007 with self._lock:
2008 if inuse_connections:
2009 connections = chain(
2010 self._available_connections, self._in_use_connections
2011 )
2012 else:
2013 connections = self._available_connections
2014
2015 for connection in connections:
2016 connection.disconnect()
2017
2018 def close(self) -> None:
2019 """Close the pool, disconnecting all connections"""
2020 self.disconnect()
2021
2022 def set_retry(self, retry: Retry) -> None:
2023 self.connection_kwargs.update({"retry": retry})
2024 for conn in self._available_connections:
2025 conn.retry = retry
2026 for conn in self._in_use_connections:
2027 conn.retry = retry
2028
2029 def re_auth_callback(self, token: TokenInterface):
2030 with self._lock:
2031 for conn in self._available_connections:
2032 conn.retry.call_with_retry(
2033 lambda: conn.send_command(
2034 "AUTH", token.try_get("oid"), token.get_value()
2035 ),
2036 lambda error: self._mock(error),
2037 )
2038 conn.retry.call_with_retry(
2039 lambda: conn.read_response(), lambda error: self._mock(error)
2040 )
2041 for conn in self._in_use_connections:
2042 conn.set_re_auth_token(token)
2043
2044 def _should_update_connection(
2045 self,
2046 conn: "Connection",
2047 matching_pattern: Literal[
2048 "connected_address", "configured_address", "event_hash"
2049 ] = "connected_address",
2050 matching_address: Optional[str] = None,
2051 matching_event_hash: Optional[int] = None,
2052 ) -> bool:
2053 """
2054 Check if the connection should be updated based on the matching address.
2055 """
2056 if matching_pattern == "connected_address":
2057 if matching_address and conn.getpeername() != matching_address:
2058 return False
2059 elif matching_pattern == "configured_address":
2060 if matching_address and conn.host != matching_address:
2061 return False
2062 elif matching_pattern == "event_hash":
2063 if (
2064 matching_event_hash
2065 and conn.maintenance_event_hash != matching_event_hash
2066 ):
2067 return False
2068 return True
2069
2070 def update_connection_settings(
2071 self,
2072 conn: "Connection",
2073 state: Optional["MaintenanceState"] = None,
2074 maintenance_event_hash: Optional[int] = None,
2075 host_address: Optional[str] = None,
2076 relax_timeout: Optional[float] = None,
2077 update_event_hash: bool = False,
2078 reset_host_address: bool = False,
2079 reset_relax_timeout: bool = False,
2080 ):
2081 """
2082 Update the settings for a single connection.
2083 """
2084 if state:
2085 conn.maintenance_state = state
2086
2087 if update_event_hash:
2088 # update the event hash only if requested
2089 conn.maintenance_event_hash = maintenance_event_hash
2090
2091 if host_address is not None:
2092 conn.set_tmp_settings(tmp_host_address=host_address)
2093
2094 if relax_timeout is not None:
2095 conn.set_tmp_settings(tmp_relax_timeout=relax_timeout)
2096
2097 if reset_relax_timeout or reset_host_address:
2098 conn.reset_tmp_settings(
2099 reset_host_address=reset_host_address,
2100 reset_relax_timeout=reset_relax_timeout,
2101 )
2102
2103 conn.update_current_socket_timeout(relax_timeout)
2104
2105 def update_connections_settings(
2106 self,
2107 state: Optional["MaintenanceState"] = None,
2108 maintenance_event_hash: Optional[int] = None,
2109 host_address: Optional[str] = None,
2110 relax_timeout: Optional[float] = None,
2111 matching_address: Optional[str] = None,
2112 matching_event_hash: Optional[int] = None,
2113 matching_pattern: Literal[
2114 "connected_address", "configured_address", "event_hash"
2115 ] = "connected_address",
2116 update_event_hash: bool = False,
2117 reset_host_address: bool = False,
2118 reset_relax_timeout: bool = False,
2119 include_free_connections: bool = True,
2120 ):
2121 """
2122 Update the settings for all matching connections in the pool.
2123
2124 This method does not create new connections.
2125 This method does not affect the connection kwargs.
2126
2127 :param state: The maintenance state to set for the connection.
2128 :param maintenance_event_hash: The hash of the maintenance event
2129 to set for the connection.
2130 :param host_address: The host address to set for the connection.
2131 :param relax_timeout: The relax timeout to set for the connection.
2132 :param matching_address: The address to match for the connection.
2133 :param matching_event_hash: The event hash to match for the connection.
2134 :param matching_pattern: The pattern to match for the connection.
2135 :param update_event_hash: Whether to update the event hash for the connection.
2136 :param reset_host_address: Whether to reset the host address to the original address.
2137 :param reset_relax_timeout: Whether to reset the relax timeout to the original timeout.
2138 :param include_free_connections: Whether to include free/available connections.
2139 """
2140 with self._lock:
2141 for conn in self._in_use_connections:
2142 if self._should_update_connection(
2143 conn,
2144 matching_pattern,
2145 matching_address,
2146 matching_event_hash,
2147 ):
2148 self.update_connection_settings(
2149 conn,
2150 state=state,
2151 maintenance_event_hash=maintenance_event_hash,
2152 host_address=host_address,
2153 relax_timeout=relax_timeout,
2154 update_event_hash=update_event_hash,
2155 reset_host_address=reset_host_address,
2156 reset_relax_timeout=reset_relax_timeout,
2157 )
2158
2159 if include_free_connections:
2160 for conn in self._available_connections:
2161 if self._should_update_connection(
2162 conn,
2163 matching_pattern,
2164 matching_address,
2165 matching_event_hash,
2166 ):
2167 self.update_connection_settings(
2168 conn,
2169 state=state,
2170 maintenance_event_hash=maintenance_event_hash,
2171 host_address=host_address,
2172 relax_timeout=relax_timeout,
2173 update_event_hash=update_event_hash,
2174 reset_host_address=reset_host_address,
2175 reset_relax_timeout=reset_relax_timeout,
2176 )
2177
2178 def update_connection_kwargs(
2179 self,
2180 **kwargs,
2181 ):
2182 """
2183 Update the connection kwargs for all future connections.
2184
2185 This method updates the connection kwargs for all future connections created by the pool.
2186 Existing connections are not affected.
2187 """
2188 self.connection_kwargs.update(kwargs)
2189
2190 def update_active_connections_for_reconnect(
2191 self,
2192 moving_address_src: Optional[str] = None,
2193 ):
2194 """
2195 Mark all active connections for reconnect.
2196 This is used when a cluster node is migrated to a different address.
2197
2198 :param moving_address_src: The address of the node that is being moved.
2199 """
2200 with self._lock:
2201 for conn in self._in_use_connections:
2202 if self._should_update_connection(
2203 conn, "connected_address", moving_address_src
2204 ):
2205 conn.mark_for_reconnect()
2206
2207 def disconnect_free_connections(
2208 self,
2209 moving_address_src: Optional[str] = None,
2210 ):
2211 """
2212 Disconnect all free/available connections.
2213 This is used when a cluster node is migrated to a different address.
2214
2215 :param moving_address_src: The address of the node that is being moved.
2216 """
2217 with self._lock:
2218 for conn in self._available_connections:
2219 if self._should_update_connection(
2220 conn, "connected_address", moving_address_src
2221 ):
2222 conn.disconnect()
2223
2224 async def _mock(self, error: RedisError):
2225 """
2226 Dummy functions, needs to be passed as error callback to retry object.
2227 :param error:
2228 :return:
2229 """
2230 pass
2231
2232
2233class BlockingConnectionPool(ConnectionPool):
2234 """
2235 Thread-safe blocking connection pool::
2236
2237 >>> from redis.client import Redis
2238 >>> client = Redis(connection_pool=BlockingConnectionPool())
2239
2240 It performs the same function as the default
2241 :py:class:`~redis.ConnectionPool` implementation, in that,
2242 it maintains a pool of reusable connections that can be shared by
2243 multiple redis clients (safely across threads if required).
2244
2245 The difference is that, in the event that a client tries to get a
2246 connection from the pool when all of connections are in use, rather than
2247 raising a :py:class:`~redis.ConnectionError` (as the default
2248 :py:class:`~redis.ConnectionPool` implementation does), it
2249 makes the client wait ("blocks") for a specified number of seconds until
2250 a connection becomes available.
2251
2252 Use ``max_connections`` to increase / decrease the pool size::
2253
2254 >>> pool = BlockingConnectionPool(max_connections=10)
2255
2256 Use ``timeout`` to tell it either how many seconds to wait for a connection
2257 to become available, or to block forever:
2258
2259 >>> # Block forever.
2260 >>> pool = BlockingConnectionPool(timeout=None)
2261
2262 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
2263 >>> # not available.
2264 >>> pool = BlockingConnectionPool(timeout=5)
2265 """
2266
2267 def __init__(
2268 self,
2269 max_connections=50,
2270 timeout=20,
2271 connection_class=Connection,
2272 queue_class=LifoQueue,
2273 **connection_kwargs,
2274 ):
2275 self.queue_class = queue_class
2276 self.timeout = timeout
2277 self._in_maintenance = False
2278 self._locked = False
2279 super().__init__(
2280 connection_class=connection_class,
2281 max_connections=max_connections,
2282 **connection_kwargs,
2283 )
2284
2285 def reset(self):
2286 # Create and fill up a thread safe queue with ``None`` values.
2287 try:
2288 if self._in_maintenance:
2289 self._lock.acquire()
2290 self._locked = True
2291 self.pool = self.queue_class(self.max_connections)
2292 while True:
2293 try:
2294 self.pool.put_nowait(None)
2295 except Full:
2296 break
2297
2298 # Keep a list of actual connection instances so that we can
2299 # disconnect them later.
2300 self._connections = []
2301 finally:
2302 if self._locked:
2303 try:
2304 self._lock.release()
2305 except Exception:
2306 pass
2307 self._locked = False
2308
2309 # this must be the last operation in this method. while reset() is
2310 # called when holding _fork_lock, other threads in this process
2311 # can call _checkpid() which compares self.pid and os.getpid() without
2312 # holding any lock (for performance reasons). keeping this assignment
2313 # as the last operation ensures that those other threads will also
2314 # notice a pid difference and block waiting for the first thread to
2315 # release _fork_lock. when each of these threads eventually acquire
2316 # _fork_lock, they will notice that another thread already called
2317 # reset() and they will immediately release _fork_lock and continue on.
2318 self.pid = os.getpid()
2319
2320 def make_connection(self):
2321 "Make a fresh connection."
2322 try:
2323 if self._in_maintenance:
2324 self._lock.acquire()
2325 self._locked = True
2326
2327 if self.cache is not None:
2328 connection = CacheProxyConnection(
2329 self.connection_class(**self.connection_kwargs),
2330 self.cache,
2331 self._lock,
2332 )
2333 else:
2334 connection = self.connection_class(**self.connection_kwargs)
2335 self._connections.append(connection)
2336 return connection
2337 finally:
2338 if self._locked:
2339 try:
2340 self._lock.release()
2341 except Exception:
2342 pass
2343 self._locked = False
2344
2345 @deprecated_args(
2346 args_to_warn=["*"],
2347 reason="Use get_connection() without args instead",
2348 version="5.3.0",
2349 )
2350 def get_connection(self, command_name=None, *keys, **options):
2351 """
2352 Get a connection, blocking for ``self.timeout`` until a connection
2353 is available from the pool.
2354
2355 If the connection returned is ``None`` then creates a new connection.
2356 Because we use a last-in first-out queue, the existing connections
2357 (having been returned to the pool after the initial ``None`` values
2358 were added) will be returned before ``None`` values. This means we only
2359 create new connections when we need to, i.e.: the actual number of
2360 connections will only increase in response to demand.
2361 """
2362 # Make sure we haven't changed process.
2363 self._checkpid()
2364
2365 # Try and get a connection from the pool. If one isn't available within
2366 # self.timeout then raise a ``ConnectionError``.
2367 connection = None
2368 try:
2369 if self._in_maintenance:
2370 self._lock.acquire()
2371 self._locked = True
2372 try:
2373 connection = self.pool.get(block=True, timeout=self.timeout)
2374 except Empty:
2375 # Note that this is not caught by the redis client and will be
2376 # raised unless handled by application code. If you want never to
2377 raise ConnectionError("No connection available.")
2378
2379 # If the ``connection`` is actually ``None`` then that's a cue to make
2380 # a new connection to add to the pool.
2381 if connection is None:
2382 connection = self.make_connection()
2383 finally:
2384 if self._locked:
2385 try:
2386 self._lock.release()
2387 except Exception:
2388 pass
2389 self._locked = False
2390
2391 try:
2392 # ensure this connection is connected to Redis
2393 connection.connect()
2394 # connections that the pool provides should be ready to send
2395 # a command. if not, the connection was either returned to the
2396 # pool before all data has been read or the socket has been
2397 # closed. either way, reconnect and verify everything is good.
2398 try:
2399 if connection.can_read():
2400 raise ConnectionError("Connection has data")
2401 except (ConnectionError, TimeoutError, OSError):
2402 connection.disconnect()
2403 connection.connect()
2404 if connection.can_read():
2405 raise ConnectionError("Connection not ready")
2406 except BaseException:
2407 # release the connection back to the pool so that we don't leak it
2408 self.release(connection)
2409 raise
2410
2411 return connection
2412
2413 def release(self, connection):
2414 "Releases the connection back to the pool."
2415 # Make sure we haven't changed process.
2416 self._checkpid()
2417
2418 try:
2419 if self._in_maintenance:
2420 self._lock.acquire()
2421 self._locked = True
2422 if not self.owns_connection(connection):
2423 # pool doesn't own this connection. do not add it back
2424 # to the pool. instead add a None value which is a placeholder
2425 # that will cause the pool to recreate the connection if
2426 # its needed.
2427 connection.disconnect()
2428 self.pool.put_nowait(None)
2429 return
2430 if connection.should_reconnect():
2431 connection.disconnect()
2432 # Put the connection back into the pool.
2433 try:
2434 self.pool.put_nowait(connection)
2435 except Full:
2436 # perhaps the pool has been reset() after a fork? regardless,
2437 # we don't want this connection
2438 pass
2439 finally:
2440 if self._locked:
2441 try:
2442 self._lock.release()
2443 except Exception:
2444 pass
2445 self._locked = False
2446
2447 def disconnect(self):
2448 "Disconnects all connections in the pool."
2449 self._checkpid()
2450 try:
2451 if self._in_maintenance:
2452 self._lock.acquire()
2453 self._locked = True
2454 for connection in self._connections:
2455 connection.disconnect()
2456 finally:
2457 if self._locked:
2458 try:
2459 self._lock.release()
2460 except Exception:
2461 pass
2462 self._locked = False
2463
2464 def update_connections_settings(
2465 self,
2466 state: Optional["MaintenanceState"] = None,
2467 maintenance_event_hash: Optional[int] = None,
2468 relax_timeout: Optional[float] = None,
2469 host_address: Optional[str] = None,
2470 matching_address: Optional[str] = None,
2471 matching_event_hash: Optional[int] = None,
2472 matching_pattern: Literal[
2473 "connected_address", "configured_address", "event_hash"
2474 ] = "connected_address",
2475 update_event_hash: bool = False,
2476 reset_host_address: bool = False,
2477 reset_relax_timeout: bool = False,
2478 include_free_connections: bool = True,
2479 ):
2480 """
2481 Override base class method to work with BlockingConnectionPool's structure.
2482 """
2483 with self._lock:
2484 if include_free_connections:
2485 for conn in tuple(self._connections):
2486 if self._should_update_connection(
2487 conn,
2488 matching_pattern,
2489 matching_address,
2490 matching_event_hash,
2491 ):
2492 self.update_connection_settings(
2493 conn,
2494 state=state,
2495 maintenance_event_hash=maintenance_event_hash,
2496 host_address=host_address,
2497 relax_timeout=relax_timeout,
2498 update_event_hash=update_event_hash,
2499 reset_host_address=reset_host_address,
2500 reset_relax_timeout=reset_relax_timeout,
2501 )
2502 else:
2503 connections_in_queue = {conn for conn in self.pool.queue if conn}
2504 for conn in self._connections:
2505 if conn not in connections_in_queue:
2506 if self._should_update_connection(
2507 conn,
2508 matching_pattern,
2509 matching_address,
2510 matching_event_hash,
2511 ):
2512 self.update_connection_settings(
2513 conn,
2514 state=state,
2515 maintenance_event_hash=maintenance_event_hash,
2516 host_address=host_address,
2517 relax_timeout=relax_timeout,
2518 update_event_hash=update_event_hash,
2519 reset_host_address=reset_host_address,
2520 reset_relax_timeout=reset_relax_timeout,
2521 )
2522
2523 def update_active_connections_for_reconnect(
2524 self,
2525 moving_address_src: Optional[str] = None,
2526 ):
2527 """
2528 Mark all active connections for reconnect.
2529 This is used when a cluster node is migrated to a different address.
2530
2531 :param moving_address_src: The address of the node that is being moved.
2532 """
2533 with self._lock:
2534 connections_in_queue = {conn for conn in self.pool.queue if conn}
2535 for conn in self._connections:
2536 if conn not in connections_in_queue:
2537 if self._should_update_connection(
2538 conn,
2539 matching_pattern="connected_address",
2540 matching_address=moving_address_src,
2541 ):
2542 conn.mark_for_reconnect()
2543
2544 def disconnect_free_connections(
2545 self,
2546 moving_address_src: Optional[str] = None,
2547 ):
2548 """
2549 Disconnect all free/available connections.
2550 This is used when a cluster node is migrated to a different address.
2551
2552 :param moving_address_src: The address of the node that is being moved.
2553 """
2554 with self._lock:
2555 existing_connections = self.pool.queue
2556
2557 for conn in existing_connections:
2558 if conn:
2559 if self._should_update_connection(
2560 conn, "connected_address", moving_address_src
2561 ):
2562 conn.disconnect()
2563
2564 def _update_maintenance_events_config_for_connections(
2565 self, maintenance_events_config
2566 ):
2567 for conn in tuple(self._connections):
2568 conn.maintenance_events_config = maintenance_events_config
2569
2570 def _update_maintenance_events_configs_for_connections(
2571 self, maintenance_events_pool_handler
2572 ):
2573 """Update the maintenance events config for all connections in the pool."""
2574 with self._lock:
2575 for conn in tuple(self._connections):
2576 conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler)
2577 conn.maintenance_events_config = maintenance_events_pool_handler.config
2578
2579 def set_in_maintenance(self, in_maintenance: bool):
2580 """
2581 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
2582
2583 This is used to prevent new connections from being created while we are in maintenance mode.
2584 The pool will be in maintenance mode only when we are processing a MOVING event.
2585 """
2586 self._in_maintenance = in_maintenance