1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Iterable,
16 List,
17 Literal,
18 Optional,
19 Type,
20 TypeVar,
21 Union,
22)
23from urllib.parse import parse_qs, unquote, urlparse
24
25from redis.cache import (
26 CacheEntry,
27 CacheEntryStatus,
28 CacheFactory,
29 CacheFactoryInterface,
30 CacheInterface,
31 CacheKey,
32)
33
34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
35from .auth.token import TokenInterface
36from .backoff import NoBackoff
37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
38from .event import AfterConnectionReleasedEvent, EventDispatcher
39from .exceptions import (
40 AuthenticationError,
41 AuthenticationWrongNumberOfArgsError,
42 ChildDeadlockedError,
43 ConnectionError,
44 DataError,
45 MaxConnectionsError,
46 RedisError,
47 ResponseError,
48 TimeoutError,
49)
50from .maint_notifications import (
51 MaintenanceState,
52 MaintNotificationsConfig,
53 MaintNotificationsConnectionHandler,
54 MaintNotificationsPoolHandler,
55)
56from .retry import Retry
57from .utils import (
58 CRYPTOGRAPHY_AVAILABLE,
59 HIREDIS_AVAILABLE,
60 SSL_AVAILABLE,
61 compare_versions,
62 deprecated_args,
63 ensure_string,
64 format_error_message,
65 get_lib_version,
66 str_if_bytes,
67)
68
69if SSL_AVAILABLE:
70 import ssl
71 from ssl import VerifyFlags
72else:
73 ssl = None
74 VerifyFlags = None
75
76if HIREDIS_AVAILABLE:
77 import hiredis
78
79SYM_STAR = b"*"
80SYM_DOLLAR = b"$"
81SYM_CRLF = b"\r\n"
82SYM_EMPTY = b""
83
84DEFAULT_RESP_VERSION = 2
85
86SENTINEL = object()
87
88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
89if HIREDIS_AVAILABLE:
90 DefaultParser = _HiredisParser
91else:
92 DefaultParser = _RESP2Parser
93
94
95class HiredisRespSerializer:
96 def pack(self, *args: List):
97 """Pack a series of arguments into the Redis protocol"""
98 output = []
99
100 if isinstance(args[0], str):
101 args = tuple(args[0].encode().split()) + args[1:]
102 elif b" " in args[0]:
103 args = tuple(args[0].split()) + args[1:]
104 try:
105 output.append(hiredis.pack_command(args))
106 except TypeError:
107 _, value, traceback = sys.exc_info()
108 raise DataError(value).with_traceback(traceback)
109
110 return output
111
112
113class PythonRespSerializer:
114 def __init__(self, buffer_cutoff, encode) -> None:
115 self._buffer_cutoff = buffer_cutoff
116 self.encode = encode
117
118 def pack(self, *args):
119 """Pack a series of arguments into the Redis protocol"""
120 output = []
121 # the client might have included 1 or more literal arguments in
122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
123 # arguments to be sent separately, so split the first argument
124 # manually. These arguments should be bytestrings so that they are
125 # not encoded.
126 if isinstance(args[0], str):
127 args = tuple(args[0].encode().split()) + args[1:]
128 elif b" " in args[0]:
129 args = tuple(args[0].split()) + args[1:]
130
131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
132
133 buffer_cutoff = self._buffer_cutoff
134 for arg in map(self.encode, args):
135 # to avoid large string mallocs, chunk the command into the
136 # output list if we're sending large values or memoryviews
137 arg_length = len(arg)
138 if (
139 len(buff) > buffer_cutoff
140 or arg_length > buffer_cutoff
141 or isinstance(arg, memoryview)
142 ):
143 buff = SYM_EMPTY.join(
144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
145 )
146 output.append(buff)
147 output.append(arg)
148 buff = SYM_CRLF
149 else:
150 buff = SYM_EMPTY.join(
151 (
152 buff,
153 SYM_DOLLAR,
154 str(arg_length).encode(),
155 SYM_CRLF,
156 arg,
157 SYM_CRLF,
158 )
159 )
160 output.append(buff)
161 return output
162
163
164class ConnectionInterface:
165 @abstractmethod
166 def repr_pieces(self):
167 pass
168
169 @abstractmethod
170 def register_connect_callback(self, callback):
171 pass
172
173 @abstractmethod
174 def deregister_connect_callback(self, callback):
175 pass
176
177 @abstractmethod
178 def set_parser(self, parser_class):
179 pass
180
181 @abstractmethod
182 def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler):
183 pass
184
185 @abstractmethod
186 def get_protocol(self):
187 pass
188
189 @abstractmethod
190 def connect(self):
191 pass
192
193 @abstractmethod
194 def on_connect(self):
195 pass
196
197 @abstractmethod
198 def disconnect(self, *args):
199 pass
200
201 @abstractmethod
202 def check_health(self):
203 pass
204
205 @abstractmethod
206 def send_packed_command(self, command, check_health=True):
207 pass
208
209 @abstractmethod
210 def send_command(self, *args, **kwargs):
211 pass
212
213 @abstractmethod
214 def can_read(self, timeout=0):
215 pass
216
217 @abstractmethod
218 def read_response(
219 self,
220 disable_decoding=False,
221 *,
222 disconnect_on_error=True,
223 push_request=False,
224 ):
225 pass
226
227 @abstractmethod
228 def pack_command(self, *args):
229 pass
230
231 @abstractmethod
232 def pack_commands(self, commands):
233 pass
234
235 @property
236 @abstractmethod
237 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
238 pass
239
240 @abstractmethod
241 def set_re_auth_token(self, token: TokenInterface):
242 pass
243
244 @abstractmethod
245 def re_auth(self):
246 pass
247
248 @property
249 @abstractmethod
250 def maintenance_state(self) -> MaintenanceState:
251 """
252 Returns the current maintenance state of the connection.
253 """
254 pass
255
256 @maintenance_state.setter
257 @abstractmethod
258 def maintenance_state(self, state: "MaintenanceState"):
259 """
260 Sets the current maintenance state of the connection.
261 """
262 pass
263
264 @abstractmethod
265 def getpeername(self):
266 """
267 Returns the peer name of the connection.
268 """
269 pass
270
271 @abstractmethod
272 def mark_for_reconnect(self):
273 """
274 Mark the connection to be reconnected on the next command.
275 This is useful when a connection is moved to a different node.
276 """
277 pass
278
279 @abstractmethod
280 def should_reconnect(self):
281 """
282 Returns True if the connection should be reconnected.
283 """
284 pass
285
286 @abstractmethod
287 def get_resolved_ip(self):
288 """
289 Get resolved ip address for the connection.
290 """
291 pass
292
293 @abstractmethod
294 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
295 """
296 Update the timeout for the current socket.
297 """
298 pass
299
300 @abstractmethod
301 def set_tmp_settings(
302 self,
303 tmp_host_address: Optional[str] = None,
304 tmp_relaxed_timeout: Optional[float] = None,
305 ):
306 """
307 Updates temporary host address and timeout settings for the connection.
308 """
309 pass
310
311 @abstractmethod
312 def reset_tmp_settings(
313 self,
314 reset_host_address: bool = False,
315 reset_relaxed_timeout: bool = False,
316 ):
317 """
318 Resets temporary host address and timeout settings for the connection.
319 """
320 pass
321
322
323class AbstractConnection(ConnectionInterface):
324 "Manages communication to and from a Redis server"
325
326 def __init__(
327 self,
328 db: int = 0,
329 password: Optional[str] = None,
330 socket_timeout: Optional[float] = None,
331 socket_connect_timeout: Optional[float] = None,
332 retry_on_timeout: bool = False,
333 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
334 encoding: str = "utf-8",
335 encoding_errors: str = "strict",
336 decode_responses: bool = False,
337 parser_class=DefaultParser,
338 socket_read_size: int = 65536,
339 health_check_interval: int = 0,
340 client_name: Optional[str] = None,
341 lib_name: Optional[str] = "redis-py",
342 lib_version: Optional[str] = get_lib_version(),
343 username: Optional[str] = None,
344 retry: Union[Any, None] = None,
345 redis_connect_func: Optional[Callable[[], None]] = None,
346 credential_provider: Optional[CredentialProvider] = None,
347 protocol: Optional[int] = 2,
348 command_packer: Optional[Callable[[], None]] = None,
349 event_dispatcher: Optional[EventDispatcher] = None,
350 maint_notifications_pool_handler: Optional[
351 MaintNotificationsPoolHandler
352 ] = None,
353 maint_notifications_config: Optional[MaintNotificationsConfig] = None,
354 maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
355 maintenance_notification_hash: Optional[int] = None,
356 orig_host_address: Optional[str] = None,
357 orig_socket_timeout: Optional[float] = None,
358 orig_socket_connect_timeout: Optional[float] = None,
359 ):
360 """
361 Initialize a new Connection.
362 To specify a retry policy for specific errors, first set
363 `retry_on_error` to a list of the error/s to retry on, then set
364 `retry` to a valid `Retry` object.
365 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
366 """
367 if (username or password) and credential_provider is not None:
368 raise DataError(
369 "'username' and 'password' cannot be passed along with 'credential_"
370 "provider'. Please provide only one of the following arguments: \n"
371 "1. 'password' and (optional) 'username'\n"
372 "2. 'credential_provider'"
373 )
374 if event_dispatcher is None:
375 self._event_dispatcher = EventDispatcher()
376 else:
377 self._event_dispatcher = event_dispatcher
378 self.pid = os.getpid()
379 self.db = db
380 self.client_name = client_name
381 self.lib_name = lib_name
382 self.lib_version = lib_version
383 self.credential_provider = credential_provider
384 self.password = password
385 self.username = username
386 self.socket_timeout = socket_timeout
387 if socket_connect_timeout is None:
388 socket_connect_timeout = socket_timeout
389 self.socket_connect_timeout = socket_connect_timeout
390 self.retry_on_timeout = retry_on_timeout
391 if retry_on_error is SENTINEL:
392 retry_on_errors_list = []
393 else:
394 retry_on_errors_list = list(retry_on_error)
395 if retry_on_timeout:
396 # Add TimeoutError to the errors list to retry on
397 retry_on_errors_list.append(TimeoutError)
398 self.retry_on_error = retry_on_errors_list
399 if retry or self.retry_on_error:
400 if retry is None:
401 self.retry = Retry(NoBackoff(), 1)
402 else:
403 # deep-copy the Retry object as it is mutable
404 self.retry = copy.deepcopy(retry)
405 if self.retry_on_error:
406 # Update the retry's supported errors with the specified errors
407 self.retry.update_supported_errors(self.retry_on_error)
408 else:
409 self.retry = Retry(NoBackoff(), 0)
410 self.health_check_interval = health_check_interval
411 self.next_health_check = 0
412 self.redis_connect_func = redis_connect_func
413 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
414 self.handshake_metadata = None
415 self._sock = None
416 self._socket_read_size = socket_read_size
417 self._connect_callbacks = []
418 self._buffer_cutoff = 6000
419 self._re_auth_token: Optional[TokenInterface] = None
420 try:
421 p = int(protocol)
422 except TypeError:
423 p = DEFAULT_RESP_VERSION
424 except ValueError:
425 raise ConnectionError("protocol must be an integer")
426 finally:
427 if p < 2 or p > 3:
428 raise ConnectionError("protocol must be either 2 or 3")
429 # p = DEFAULT_RESP_VERSION
430 self.protocol = p
431 if self.protocol == 3 and parser_class == DefaultParser:
432 parser_class = _RESP3Parser
433 self.set_parser(parser_class)
434
435 self.maint_notifications_config = maint_notifications_config
436
437 # Set up maintenance notifications if enabled
438 self._configure_maintenance_notifications(
439 maint_notifications_pool_handler,
440 orig_host_address,
441 orig_socket_timeout,
442 orig_socket_connect_timeout,
443 )
444
445 self._should_reconnect = False
446 self.maintenance_state = maintenance_state
447 self.maintenance_notification_hash = maintenance_notification_hash
448
449 self._command_packer = self._construct_command_packer(command_packer)
450
451 def __repr__(self):
452 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
453 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
454
455 @abstractmethod
456 def repr_pieces(self):
457 pass
458
459 def __del__(self):
460 try:
461 self.disconnect()
462 except Exception:
463 pass
464
465 def _construct_command_packer(self, packer):
466 if packer is not None:
467 return packer
468 elif HIREDIS_AVAILABLE:
469 return HiredisRespSerializer()
470 else:
471 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
472
473 def register_connect_callback(self, callback):
474 """
475 Register a callback to be called when the connection is established either
476 initially or reconnected. This allows listeners to issue commands that
477 are ephemeral to the connection, for example pub/sub subscription or
478 key tracking. The callback must be a _method_ and will be kept as
479 a weak reference.
480 """
481 wm = weakref.WeakMethod(callback)
482 if wm not in self._connect_callbacks:
483 self._connect_callbacks.append(wm)
484
485 def deregister_connect_callback(self, callback):
486 """
487 De-register a previously registered callback. It will no-longer receive
488 notifications on connection events. Calling this is not required when the
489 listener goes away, since the callbacks are kept as weak methods.
490 """
491 try:
492 self._connect_callbacks.remove(weakref.WeakMethod(callback))
493 except ValueError:
494 pass
495
496 def set_parser(self, parser_class):
497 """
498 Creates a new instance of parser_class with socket size:
499 _socket_read_size and assigns it to the parser for the connection
500 :param parser_class: The required parser class
501 """
502 self._parser = parser_class(socket_read_size=self._socket_read_size)
503
504 def _configure_maintenance_notifications(
505 self,
506 maint_notifications_pool_handler=None,
507 orig_host_address=None,
508 orig_socket_timeout=None,
509 orig_socket_connect_timeout=None,
510 ):
511 """Enable maintenance notifications by setting up handlers and storing original connection parameters."""
512 if (
513 not self.maint_notifications_config
514 or not self.maint_notifications_config.enabled
515 ):
516 self._maint_notifications_connection_handler = None
517 return
518
519 # Set up pool handler if available
520 if maint_notifications_pool_handler:
521 self._parser.set_node_moving_push_handler(
522 maint_notifications_pool_handler.handle_notification
523 )
524
525 # Set up connection handler
526 self._maint_notifications_connection_handler = (
527 MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
528 )
529 self._parser.set_maintenance_push_handler(
530 self._maint_notifications_connection_handler.handle_notification
531 )
532
533 # Store original connection parameters
534 self.orig_host_address = orig_host_address if orig_host_address else self.host
535 self.orig_socket_timeout = (
536 orig_socket_timeout if orig_socket_timeout else self.socket_timeout
537 )
538 self.orig_socket_connect_timeout = (
539 orig_socket_connect_timeout
540 if orig_socket_connect_timeout
541 else self.socket_connect_timeout
542 )
543
544 def set_maint_notifications_pool_handler(
545 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
546 ):
547 maint_notifications_pool_handler.set_connection(self)
548 self._parser.set_node_moving_push_handler(
549 maint_notifications_pool_handler.handle_notification
550 )
551
552 # Update maintenance notification connection handler if it doesn't exist
553 if not self._maint_notifications_connection_handler:
554 self._maint_notifications_connection_handler = (
555 MaintNotificationsConnectionHandler(
556 self, maint_notifications_pool_handler.config
557 )
558 )
559 self._parser.set_maintenance_push_handler(
560 self._maint_notifications_connection_handler.handle_notification
561 )
562 else:
563 self._maint_notifications_connection_handler.config = (
564 maint_notifications_pool_handler.config
565 )
566
567 def connect(self):
568 "Connects to the Redis server if not already connected"
569 self.connect_check_health(check_health=True)
570
571 def connect_check_health(
572 self, check_health: bool = True, retry_socket_connect: bool = True
573 ):
574 if self._sock:
575 return
576 try:
577 if retry_socket_connect:
578 sock = self.retry.call_with_retry(
579 lambda: self._connect(), lambda error: self.disconnect(error)
580 )
581 else:
582 sock = self._connect()
583 except socket.timeout:
584 raise TimeoutError("Timeout connecting to server")
585 except OSError as e:
586 raise ConnectionError(self._error_message(e))
587
588 self._sock = sock
589 try:
590 if self.redis_connect_func is None:
591 # Use the default on_connect function
592 self.on_connect_check_health(check_health=check_health)
593 else:
594 # Use the passed function redis_connect_func
595 self.redis_connect_func(self)
596 except RedisError:
597 # clean up after any error in on_connect
598 self.disconnect()
599 raise
600
601 # run any user callbacks. right now the only internal callback
602 # is for pubsub channel/pattern resubscription
603 # first, remove any dead weakrefs
604 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
605 for ref in self._connect_callbacks:
606 callback = ref()
607 if callback:
608 callback(self)
609
610 @abstractmethod
611 def _connect(self):
612 pass
613
614 @abstractmethod
615 def _host_error(self):
616 pass
617
618 def _error_message(self, exception):
619 return format_error_message(self._host_error(), exception)
620
621 def on_connect(self):
622 self.on_connect_check_health(check_health=True)
623
624 def on_connect_check_health(self, check_health: bool = True):
625 "Initialize the connection, authenticate and select a database"
626 self._parser.on_connect(self)
627 parser = self._parser
628
629 auth_args = None
630 # if credential provider or username and/or password are set, authenticate
631 if self.credential_provider or (self.username or self.password):
632 cred_provider = (
633 self.credential_provider
634 or UsernamePasswordCredentialProvider(self.username, self.password)
635 )
636 auth_args = cred_provider.get_credentials()
637
638 # if resp version is specified and we have auth args,
639 # we need to send them via HELLO
640 if auth_args and self.protocol not in [2, "2"]:
641 if isinstance(self._parser, _RESP2Parser):
642 self.set_parser(_RESP3Parser)
643 # update cluster exception classes
644 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
645 self._parser.on_connect(self)
646 if len(auth_args) == 1:
647 auth_args = ["default", auth_args[0]]
648 # avoid checking health here -- PING will fail if we try
649 # to check the health prior to the AUTH
650 self.send_command(
651 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
652 )
653 self.handshake_metadata = self.read_response()
654 # if response.get(b"proto") != self.protocol and response.get(
655 # "proto"
656 # ) != self.protocol:
657 # raise ConnectionError("Invalid RESP version")
658 elif auth_args:
659 # avoid checking health here -- PING will fail if we try
660 # to check the health prior to the AUTH
661 self.send_command("AUTH", *auth_args, check_health=False)
662
663 try:
664 auth_response = self.read_response()
665 except AuthenticationWrongNumberOfArgsError:
666 # a username and password were specified but the Redis
667 # server seems to be < 6.0.0 which expects a single password
668 # arg. retry auth with just the password.
669 # https://github.com/andymccurdy/redis-py/issues/1274
670 self.send_command("AUTH", auth_args[-1], check_health=False)
671 auth_response = self.read_response()
672
673 if str_if_bytes(auth_response) != "OK":
674 raise AuthenticationError("Invalid Username or Password")
675
676 # if resp version is specified, switch to it
677 elif self.protocol not in [2, "2"]:
678 if isinstance(self._parser, _RESP2Parser):
679 self.set_parser(_RESP3Parser)
680 # update cluster exception classes
681 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
682 self._parser.on_connect(self)
683 self.send_command("HELLO", self.protocol, check_health=check_health)
684 self.handshake_metadata = self.read_response()
685 if (
686 self.handshake_metadata.get(b"proto") != self.protocol
687 and self.handshake_metadata.get("proto") != self.protocol
688 ):
689 raise ConnectionError("Invalid RESP version")
690
691 # Send maintenance notifications handshake if RESP3 is active
692 # and maintenance notifications are enabled
693 # and we have a host to determine the endpoint type from
694 # When the maint_notifications_config enabled mode is "auto",
695 # we just log a warning if the handshake fails
696 # When the mode is enabled=True, we raise an exception in case of failure
697 if (
698 self.protocol not in [2, "2"]
699 and self.maint_notifications_config
700 and self.maint_notifications_config.enabled
701 and self._maint_notifications_connection_handler
702 and hasattr(self, "host")
703 ):
704 try:
705 endpoint_type = self.maint_notifications_config.get_endpoint_type(
706 self.host, self
707 )
708 self.send_command(
709 "CLIENT",
710 "MAINT_NOTIFICATIONS",
711 "ON",
712 "moving-endpoint-type",
713 endpoint_type.value,
714 check_health=check_health,
715 )
716 response = self.read_response()
717 if str_if_bytes(response) != "OK":
718 raise ResponseError(
719 "The server doesn't support maintenance notifications"
720 )
721 except Exception as e:
722 if (
723 isinstance(e, ResponseError)
724 and self.maint_notifications_config.enabled == "auto"
725 ):
726 # Log warning but don't fail the connection
727 import logging
728
729 logger = logging.getLogger(__name__)
730 logger.warning(f"Failed to enable maintenance notifications: {e}")
731 else:
732 raise
733
734 # if a client_name is given, set it
735 if self.client_name:
736 self.send_command(
737 "CLIENT",
738 "SETNAME",
739 self.client_name,
740 check_health=check_health,
741 )
742 if str_if_bytes(self.read_response()) != "OK":
743 raise ConnectionError("Error setting client name")
744
745 try:
746 # set the library name and version
747 if self.lib_name:
748 self.send_command(
749 "CLIENT",
750 "SETINFO",
751 "LIB-NAME",
752 self.lib_name,
753 check_health=check_health,
754 )
755 self.read_response()
756 except ResponseError:
757 pass
758
759 try:
760 if self.lib_version:
761 self.send_command(
762 "CLIENT",
763 "SETINFO",
764 "LIB-VER",
765 self.lib_version,
766 check_health=check_health,
767 )
768 self.read_response()
769 except ResponseError:
770 pass
771
772 # if a database is specified, switch to it
773 if self.db:
774 self.send_command("SELECT", self.db, check_health=check_health)
775 if str_if_bytes(self.read_response()) != "OK":
776 raise ConnectionError("Invalid Database")
777
778 def disconnect(self, *args):
779 "Disconnects from the Redis server"
780 self._parser.on_disconnect()
781
782 conn_sock = self._sock
783 self._sock = None
784 # reset the reconnect flag
785 self._should_reconnect = False
786 if conn_sock is None:
787 return
788
789 if os.getpid() == self.pid:
790 try:
791 conn_sock.shutdown(socket.SHUT_RDWR)
792 except (OSError, TypeError):
793 pass
794
795 try:
796 conn_sock.close()
797 except OSError:
798 pass
799
800 def _send_ping(self):
801 """Send PING, expect PONG in return"""
802 self.send_command("PING", check_health=False)
803 if str_if_bytes(self.read_response()) != "PONG":
804 raise ConnectionError("Bad response from PING health check")
805
806 def _ping_failed(self, error):
807 """Function to call when PING fails"""
808 self.disconnect()
809
810 def check_health(self):
811 """Check the health of the connection with a PING/PONG"""
812 if self.health_check_interval and time.monotonic() > self.next_health_check:
813 self.retry.call_with_retry(self._send_ping, self._ping_failed)
814
815 def send_packed_command(self, command, check_health=True):
816 """Send an already packed command to the Redis server"""
817 if not self._sock:
818 self.connect_check_health(check_health=False)
819 # guard against health check recursion
820 if check_health:
821 self.check_health()
822 try:
823 if isinstance(command, str):
824 command = [command]
825 for item in command:
826 self._sock.sendall(item)
827 except socket.timeout:
828 self.disconnect()
829 raise TimeoutError("Timeout writing to socket")
830 except OSError as e:
831 self.disconnect()
832 if len(e.args) == 1:
833 errno, errmsg = "UNKNOWN", e.args[0]
834 else:
835 errno = e.args[0]
836 errmsg = e.args[1]
837 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
838 except BaseException:
839 # BaseExceptions can be raised when a socket send operation is not
840 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
841 # to send un-sent data. However, the send_packed_command() API
842 # does not support it so there is no point in keeping the connection open.
843 self.disconnect()
844 raise
845
846 def send_command(self, *args, **kwargs):
847 """Pack and send a command to the Redis server"""
848 self.send_packed_command(
849 self._command_packer.pack(*args),
850 check_health=kwargs.get("check_health", True),
851 )
852
853 def can_read(self, timeout=0):
854 """Poll the socket to see if there's data that can be read."""
855 sock = self._sock
856 if not sock:
857 self.connect()
858
859 host_error = self._host_error()
860
861 try:
862 return self._parser.can_read(timeout)
863
864 except OSError as e:
865 self.disconnect()
866 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
867
868 def read_response(
869 self,
870 disable_decoding=False,
871 *,
872 disconnect_on_error=True,
873 push_request=False,
874 ):
875 """Read the response from a previously sent command"""
876
877 host_error = self._host_error()
878
879 try:
880 if self.protocol in ["3", 3]:
881 response = self._parser.read_response(
882 disable_decoding=disable_decoding, push_request=push_request
883 )
884 else:
885 response = self._parser.read_response(disable_decoding=disable_decoding)
886 except socket.timeout:
887 if disconnect_on_error:
888 self.disconnect()
889 raise TimeoutError(f"Timeout reading from {host_error}")
890 except OSError as e:
891 if disconnect_on_error:
892 self.disconnect()
893 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
894 except BaseException:
895 # Also by default close in case of BaseException. A lot of code
896 # relies on this behaviour when doing Command/Response pairs.
897 # See #1128.
898 if disconnect_on_error:
899 self.disconnect()
900 raise
901
902 if self.health_check_interval:
903 self.next_health_check = time.monotonic() + self.health_check_interval
904
905 if isinstance(response, ResponseError):
906 try:
907 raise response
908 finally:
909 del response # avoid creating ref cycles
910 return response
911
912 def pack_command(self, *args):
913 """Pack a series of arguments into the Redis protocol"""
914 return self._command_packer.pack(*args)
915
916 def pack_commands(self, commands):
917 """Pack multiple commands into the Redis protocol"""
918 output = []
919 pieces = []
920 buffer_length = 0
921 buffer_cutoff = self._buffer_cutoff
922
923 for cmd in commands:
924 for chunk in self._command_packer.pack(*cmd):
925 chunklen = len(chunk)
926 if (
927 buffer_length > buffer_cutoff
928 or chunklen > buffer_cutoff
929 or isinstance(chunk, memoryview)
930 ):
931 if pieces:
932 output.append(SYM_EMPTY.join(pieces))
933 buffer_length = 0
934 pieces = []
935
936 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
937 output.append(chunk)
938 else:
939 pieces.append(chunk)
940 buffer_length += chunklen
941
942 if pieces:
943 output.append(SYM_EMPTY.join(pieces))
944 return output
945
946 def get_protocol(self) -> Union[int, str]:
947 return self.protocol
948
949 @property
950 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
951 return self._handshake_metadata
952
953 @handshake_metadata.setter
954 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
955 self._handshake_metadata = value
956
957 def set_re_auth_token(self, token: TokenInterface):
958 self._re_auth_token = token
959
960 def re_auth(self):
961 if self._re_auth_token is not None:
962 self.send_command(
963 "AUTH",
964 self._re_auth_token.try_get("oid"),
965 self._re_auth_token.get_value(),
966 )
967 self.read_response()
968 self._re_auth_token = None
969
970 def get_resolved_ip(self) -> Optional[str]:
971 """
972 Extract the resolved IP address from an
973 established connection or resolve it from the host.
974
975 First tries to get the actual IP from the socket (most accurate),
976 then falls back to DNS resolution if needed.
977
978 Args:
979 connection: The connection object to extract the IP from
980
981 Returns:
982 str: The resolved IP address, or None if it cannot be determined
983 """
984
985 # Method 1: Try to get the actual IP from the established socket connection
986 # This is most accurate as it shows the exact IP being used
987 try:
988 if self._sock is not None:
989 peer_addr = self._sock.getpeername()
990 if peer_addr and len(peer_addr) >= 1:
991 # For TCP sockets, peer_addr is typically (host, port) tuple
992 # Return just the host part
993 return peer_addr[0]
994 except (AttributeError, OSError):
995 # Socket might not be connected or getpeername() might fail
996 pass
997
998 # Method 2: Fallback to DNS resolution of the host
999 # This is less accurate but works when socket is not available
1000 try:
1001 host = getattr(self, "host", "localhost")
1002 port = getattr(self, "port", 6379)
1003 if host:
1004 # Use getaddrinfo to resolve the hostname to IP
1005 # This mimics what the connection would do during _connect()
1006 addr_info = socket.getaddrinfo(
1007 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
1008 )
1009 if addr_info:
1010 # Return the IP from the first result
1011 # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
1012 # sockaddr[0] is the IP address
1013 return addr_info[0][4][0]
1014 except (AttributeError, OSError, socket.gaierror):
1015 # DNS resolution might fail
1016 pass
1017
1018 return None
1019
1020 @property
1021 def maintenance_state(self) -> MaintenanceState:
1022 return self._maintenance_state
1023
1024 @maintenance_state.setter
1025 def maintenance_state(self, state: "MaintenanceState"):
1026 self._maintenance_state = state
1027
1028 def getpeername(self):
1029 if not self._sock:
1030 return None
1031 return self._sock.getpeername()[0]
1032
1033 def mark_for_reconnect(self):
1034 self._should_reconnect = True
1035
1036 def should_reconnect(self):
1037 return self._should_reconnect
1038
1039 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1040 if self._sock:
1041 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
1042 self._sock.settimeout(timeout)
1043 self.update_parser_buffer_timeout(timeout)
1044
1045 def update_parser_buffer_timeout(self, timeout: Optional[float] = None):
1046 if self._parser and self._parser._buffer:
1047 self._parser._buffer.socket_timeout = timeout
1048
1049 def set_tmp_settings(
1050 self,
1051 tmp_host_address: Optional[Union[str, object]] = SENTINEL,
1052 tmp_relaxed_timeout: Optional[float] = None,
1053 ):
1054 """
1055 The value of SENTINEL is used to indicate that the property should not be updated.
1056 """
1057 if tmp_host_address is not SENTINEL:
1058 self.host = tmp_host_address
1059 if tmp_relaxed_timeout != -1:
1060 self.socket_timeout = tmp_relaxed_timeout
1061 self.socket_connect_timeout = tmp_relaxed_timeout
1062
1063 def reset_tmp_settings(
1064 self,
1065 reset_host_address: bool = False,
1066 reset_relaxed_timeout: bool = False,
1067 ):
1068 if reset_host_address:
1069 self.host = self.orig_host_address
1070 if reset_relaxed_timeout:
1071 self.socket_timeout = self.orig_socket_timeout
1072 self.socket_connect_timeout = self.orig_socket_connect_timeout
1073
1074
1075class Connection(AbstractConnection):
1076 "Manages TCP communication to and from a Redis server"
1077
1078 def __init__(
1079 self,
1080 host="localhost",
1081 port=6379,
1082 socket_keepalive=False,
1083 socket_keepalive_options=None,
1084 socket_type=0,
1085 **kwargs,
1086 ):
1087 self.host = host
1088 self.port = int(port)
1089 self.socket_keepalive = socket_keepalive
1090 self.socket_keepalive_options = socket_keepalive_options or {}
1091 self.socket_type = socket_type
1092 super().__init__(**kwargs)
1093
1094 def repr_pieces(self):
1095 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1096 if self.client_name:
1097 pieces.append(("client_name", self.client_name))
1098 return pieces
1099
1100 def _connect(self):
1101 "Create a TCP socket connection"
1102 # we want to mimic what socket.create_connection does to support
1103 # ipv4/ipv6, but we want to set options prior to calling
1104 # socket.connect()
1105 err = None
1106
1107 for res in socket.getaddrinfo(
1108 self.host, self.port, self.socket_type, socket.SOCK_STREAM
1109 ):
1110 family, socktype, proto, canonname, socket_address = res
1111 sock = None
1112 try:
1113 sock = socket.socket(family, socktype, proto)
1114 # TCP_NODELAY
1115 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1116
1117 # TCP_KEEPALIVE
1118 if self.socket_keepalive:
1119 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1120 for k, v in self.socket_keepalive_options.items():
1121 sock.setsockopt(socket.IPPROTO_TCP, k, v)
1122
1123 # set the socket_connect_timeout before we connect
1124 sock.settimeout(self.socket_connect_timeout)
1125
1126 # connect
1127 sock.connect(socket_address)
1128
1129 # set the socket_timeout now that we're connected
1130 sock.settimeout(self.socket_timeout)
1131 return sock
1132
1133 except OSError as _:
1134 err = _
1135 if sock is not None:
1136 try:
1137 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1138 except OSError:
1139 pass
1140 sock.close()
1141
1142 if err is not None:
1143 raise err
1144 raise OSError("socket.getaddrinfo returned an empty list")
1145
1146 def _host_error(self):
1147 return f"{self.host}:{self.port}"
1148
1149
1150class CacheProxyConnection(ConnectionInterface):
1151 DUMMY_CACHE_VALUE = b"foo"
1152 MIN_ALLOWED_VERSION = "7.4.0"
1153 DEFAULT_SERVER_NAME = "redis"
1154
1155 def __init__(
1156 self,
1157 conn: ConnectionInterface,
1158 cache: CacheInterface,
1159 pool_lock: threading.RLock,
1160 ):
1161 self.pid = os.getpid()
1162 self._conn = conn
1163 self.retry = self._conn.retry
1164 self.host = self._conn.host
1165 self.port = self._conn.port
1166 self.credential_provider = conn.credential_provider
1167 self._pool_lock = pool_lock
1168 self._cache = cache
1169 self._cache_lock = threading.RLock()
1170 self._current_command_cache_key = None
1171 self._current_options = None
1172 self.register_connect_callback(self._enable_tracking_callback)
1173
1174 def repr_pieces(self):
1175 return self._conn.repr_pieces()
1176
1177 def register_connect_callback(self, callback):
1178 self._conn.register_connect_callback(callback)
1179
1180 def deregister_connect_callback(self, callback):
1181 self._conn.deregister_connect_callback(callback)
1182
1183 def set_parser(self, parser_class):
1184 self._conn.set_parser(parser_class)
1185
1186 def connect(self):
1187 self._conn.connect()
1188
1189 server_name = self._conn.handshake_metadata.get(b"server", None)
1190 if server_name is None:
1191 server_name = self._conn.handshake_metadata.get("server", None)
1192 server_ver = self._conn.handshake_metadata.get(b"version", None)
1193 if server_ver is None:
1194 server_ver = self._conn.handshake_metadata.get("version", None)
1195 if server_ver is None or server_ver is None:
1196 raise ConnectionError("Cannot retrieve information about server version")
1197
1198 server_ver = ensure_string(server_ver)
1199 server_name = ensure_string(server_name)
1200
1201 if (
1202 server_name != self.DEFAULT_SERVER_NAME
1203 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
1204 ):
1205 raise ConnectionError(
1206 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
1207 )
1208
1209 def on_connect(self):
1210 self._conn.on_connect()
1211
1212 def disconnect(self, *args):
1213 with self._cache_lock:
1214 self._cache.flush()
1215 self._conn.disconnect(*args)
1216
1217 def check_health(self):
1218 self._conn.check_health()
1219
1220 def send_packed_command(self, command, check_health=True):
1221 # TODO: Investigate if it's possible to unpack command
1222 # or extract keys from packed command
1223 self._conn.send_packed_command(command)
1224
1225 def send_command(self, *args, **kwargs):
1226 self._process_pending_invalidations()
1227
1228 with self._cache_lock:
1229 # Command is write command or not allowed
1230 # to be cached.
1231 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())):
1232 self._current_command_cache_key = None
1233 self._conn.send_command(*args, **kwargs)
1234 return
1235
1236 if kwargs.get("keys") is None:
1237 raise ValueError("Cannot create cache key.")
1238
1239 # Creates cache key.
1240 self._current_command_cache_key = CacheKey(
1241 command=args[0], redis_keys=tuple(kwargs.get("keys"))
1242 )
1243
1244 with self._cache_lock:
1245 # We have to trigger invalidation processing in case if
1246 # it was cached by another connection to avoid
1247 # queueing invalidations in stale connections.
1248 if self._cache.get(self._current_command_cache_key):
1249 entry = self._cache.get(self._current_command_cache_key)
1250
1251 if entry.connection_ref != self._conn:
1252 with self._pool_lock:
1253 while entry.connection_ref.can_read():
1254 entry.connection_ref.read_response(push_request=True)
1255
1256 return
1257
1258 # Set temporary entry value to prevent
1259 # race condition from another connection.
1260 self._cache.set(
1261 CacheEntry(
1262 cache_key=self._current_command_cache_key,
1263 cache_value=self.DUMMY_CACHE_VALUE,
1264 status=CacheEntryStatus.IN_PROGRESS,
1265 connection_ref=self._conn,
1266 )
1267 )
1268
1269 # Send command over socket only if it's allowed
1270 # read-only command that not yet cached.
1271 self._conn.send_command(*args, **kwargs)
1272
1273 def can_read(self, timeout=0):
1274 return self._conn.can_read(timeout)
1275
1276 def read_response(
1277 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
1278 ):
1279 with self._cache_lock:
1280 # Check if command response exists in a cache and it's not in progress.
1281 if (
1282 self._current_command_cache_key is not None
1283 and self._cache.get(self._current_command_cache_key) is not None
1284 and self._cache.get(self._current_command_cache_key).status
1285 != CacheEntryStatus.IN_PROGRESS
1286 ):
1287 res = copy.deepcopy(
1288 self._cache.get(self._current_command_cache_key).cache_value
1289 )
1290 self._current_command_cache_key = None
1291 return res
1292
1293 response = self._conn.read_response(
1294 disable_decoding=disable_decoding,
1295 disconnect_on_error=disconnect_on_error,
1296 push_request=push_request,
1297 )
1298
1299 with self._cache_lock:
1300 # Prevent not-allowed command from caching.
1301 if self._current_command_cache_key is None:
1302 return response
1303 # If response is None prevent from caching.
1304 if response is None:
1305 self._cache.delete_by_cache_keys([self._current_command_cache_key])
1306 return response
1307
1308 cache_entry = self._cache.get(self._current_command_cache_key)
1309
1310 # Cache only responses that still valid
1311 # and wasn't invalidated by another connection in meantime.
1312 if cache_entry is not None:
1313 cache_entry.status = CacheEntryStatus.VALID
1314 cache_entry.cache_value = response
1315 self._cache.set(cache_entry)
1316
1317 self._current_command_cache_key = None
1318
1319 return response
1320
1321 def pack_command(self, *args):
1322 return self._conn.pack_command(*args)
1323
1324 def pack_commands(self, commands):
1325 return self._conn.pack_commands(commands)
1326
1327 @property
1328 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1329 return self._conn.handshake_metadata
1330
1331 def _connect(self):
1332 self._conn._connect()
1333
1334 def _host_error(self):
1335 self._conn._host_error()
1336
1337 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
1338 conn.send_command("CLIENT", "TRACKING", "ON")
1339 conn.read_response()
1340 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1341
1342 def _process_pending_invalidations(self):
1343 while self.can_read():
1344 self._conn.read_response(push_request=True)
1345
1346 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1347 with self._cache_lock:
1348 # Flush cache when DB flushed on server-side
1349 if data[1] is None:
1350 self._cache.flush()
1351 else:
1352 self._cache.delete_by_redis_keys(data[1])
1353
1354 def get_protocol(self):
1355 return self._conn.get_protocol()
1356
1357 def set_re_auth_token(self, token: TokenInterface):
1358 self._conn.set_re_auth_token(token)
1359
1360 def re_auth(self):
1361 self._conn.re_auth()
1362
1363
1364class SSLConnection(Connection):
1365 """Manages SSL connections to and from the Redis server(s).
1366 This class extends the Connection class, adding SSL functionality, and making
1367 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1368 """ # noqa
1369
1370 def __init__(
1371 self,
1372 ssl_keyfile=None,
1373 ssl_certfile=None,
1374 ssl_cert_reqs="required",
1375 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1376 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1377 ssl_ca_certs=None,
1378 ssl_ca_data=None,
1379 ssl_check_hostname=True,
1380 ssl_ca_path=None,
1381 ssl_password=None,
1382 ssl_validate_ocsp=False,
1383 ssl_validate_ocsp_stapled=False,
1384 ssl_ocsp_context=None,
1385 ssl_ocsp_expected_cert=None,
1386 ssl_min_version=None,
1387 ssl_ciphers=None,
1388 **kwargs,
1389 ):
1390 """Constructor
1391
1392 Args:
1393 ssl_keyfile: Path to an ssl private key. Defaults to None.
1394 ssl_certfile: Path to an ssl certificate. Defaults to None.
1395 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1396 or an ssl.VerifyMode. Defaults to "required".
1397 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1398 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1399 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1400 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1401 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1402 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1403 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1404
1405 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1406 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1407 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1408 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1409 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1410 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1411
1412 Raises:
1413 RedisError
1414 """ # noqa
1415 if not SSL_AVAILABLE:
1416 raise RedisError("Python wasn't built with SSL support")
1417
1418 self.keyfile = ssl_keyfile
1419 self.certfile = ssl_certfile
1420 if ssl_cert_reqs is None:
1421 ssl_cert_reqs = ssl.CERT_NONE
1422 elif isinstance(ssl_cert_reqs, str):
1423 CERT_REQS = { # noqa: N806
1424 "none": ssl.CERT_NONE,
1425 "optional": ssl.CERT_OPTIONAL,
1426 "required": ssl.CERT_REQUIRED,
1427 }
1428 if ssl_cert_reqs not in CERT_REQS:
1429 raise RedisError(
1430 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1431 )
1432 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1433 self.cert_reqs = ssl_cert_reqs
1434 self.ssl_include_verify_flags = ssl_include_verify_flags
1435 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1436 self.ca_certs = ssl_ca_certs
1437 self.ca_data = ssl_ca_data
1438 self.ca_path = ssl_ca_path
1439 self.check_hostname = (
1440 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1441 )
1442 self.certificate_password = ssl_password
1443 self.ssl_validate_ocsp = ssl_validate_ocsp
1444 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1445 self.ssl_ocsp_context = ssl_ocsp_context
1446 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1447 self.ssl_min_version = ssl_min_version
1448 self.ssl_ciphers = ssl_ciphers
1449 super().__init__(**kwargs)
1450
1451 def _connect(self):
1452 """
1453 Wrap the socket with SSL support, handling potential errors.
1454 """
1455 sock = super()._connect()
1456 try:
1457 return self._wrap_socket_with_ssl(sock)
1458 except (OSError, RedisError):
1459 sock.close()
1460 raise
1461
1462 def _wrap_socket_with_ssl(self, sock):
1463 """
1464 Wraps the socket with SSL support.
1465
1466 Args:
1467 sock: The plain socket to wrap with SSL.
1468
1469 Returns:
1470 An SSL wrapped socket.
1471 """
1472 context = ssl.create_default_context()
1473 context.check_hostname = self.check_hostname
1474 context.verify_mode = self.cert_reqs
1475 if self.ssl_include_verify_flags:
1476 for flag in self.ssl_include_verify_flags:
1477 context.verify_flags |= flag
1478 if self.ssl_exclude_verify_flags:
1479 for flag in self.ssl_exclude_verify_flags:
1480 context.verify_flags &= ~flag
1481 if self.certfile or self.keyfile:
1482 context.load_cert_chain(
1483 certfile=self.certfile,
1484 keyfile=self.keyfile,
1485 password=self.certificate_password,
1486 )
1487 if (
1488 self.ca_certs is not None
1489 or self.ca_path is not None
1490 or self.ca_data is not None
1491 ):
1492 context.load_verify_locations(
1493 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1494 )
1495 if self.ssl_min_version is not None:
1496 context.minimum_version = self.ssl_min_version
1497 if self.ssl_ciphers:
1498 context.set_ciphers(self.ssl_ciphers)
1499 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1500 raise RedisError("cryptography is not installed.")
1501
1502 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1503 raise RedisError(
1504 "Either an OCSP staple or pure OCSP connection must be validated "
1505 "- not both."
1506 )
1507
1508 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1509
1510 # validation for the stapled case
1511 if self.ssl_validate_ocsp_stapled:
1512 import OpenSSL
1513
1514 from .ocsp import ocsp_staple_verifier
1515
1516 # if a context is provided use it - otherwise, a basic context
1517 if self.ssl_ocsp_context is None:
1518 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1519 staple_ctx.use_certificate_file(self.certfile)
1520 staple_ctx.use_privatekey_file(self.keyfile)
1521 else:
1522 staple_ctx = self.ssl_ocsp_context
1523
1524 staple_ctx.set_ocsp_client_callback(
1525 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1526 )
1527
1528 # need another socket
1529 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1530 con.request_ocsp()
1531 con.connect((self.host, self.port))
1532 con.do_handshake()
1533 con.shutdown()
1534 return sslsock
1535
1536 # pure ocsp validation
1537 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1538 from .ocsp import OCSPVerifier
1539
1540 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1541 if o.is_valid():
1542 return sslsock
1543 else:
1544 raise ConnectionError("ocsp validation error")
1545 return sslsock
1546
1547
1548class UnixDomainSocketConnection(AbstractConnection):
1549 "Manages UDS communication to and from a Redis server"
1550
1551 def __init__(self, path="", socket_timeout=None, **kwargs):
1552 super().__init__(**kwargs)
1553 self.path = path
1554 self.socket_timeout = socket_timeout
1555
1556 def repr_pieces(self):
1557 pieces = [("path", self.path), ("db", self.db)]
1558 if self.client_name:
1559 pieces.append(("client_name", self.client_name))
1560 return pieces
1561
1562 def _connect(self):
1563 "Create a Unix domain socket connection"
1564 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1565 sock.settimeout(self.socket_connect_timeout)
1566 try:
1567 sock.connect(self.path)
1568 except OSError:
1569 # Prevent ResourceWarnings for unclosed sockets.
1570 try:
1571 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1572 except OSError:
1573 pass
1574 sock.close()
1575 raise
1576 sock.settimeout(self.socket_timeout)
1577 return sock
1578
1579 def _host_error(self):
1580 return self.path
1581
1582
1583FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1584
1585
1586def to_bool(value):
1587 if value is None or value == "":
1588 return None
1589 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1590 return False
1591 return bool(value)
1592
1593
1594def parse_ssl_verify_flags(value):
1595 # flags are passed in as a string representation of a list,
1596 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1597 verify_flags_str = value.replace("[", "").replace("]", "")
1598
1599 verify_flags = []
1600 for flag in verify_flags_str.split(","):
1601 flag = flag.strip()
1602 if not hasattr(VerifyFlags, flag):
1603 raise ValueError(f"Invalid ssl verify flag: {flag}")
1604 verify_flags.append(getattr(VerifyFlags, flag))
1605 return verify_flags
1606
1607
1608URL_QUERY_ARGUMENT_PARSERS = {
1609 "db": int,
1610 "socket_timeout": float,
1611 "socket_connect_timeout": float,
1612 "socket_keepalive": to_bool,
1613 "retry_on_timeout": to_bool,
1614 "retry_on_error": list,
1615 "max_connections": int,
1616 "health_check_interval": int,
1617 "ssl_check_hostname": to_bool,
1618 "ssl_include_verify_flags": parse_ssl_verify_flags,
1619 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1620 "timeout": float,
1621}
1622
1623
1624def parse_url(url):
1625 if not (
1626 url.startswith("redis://")
1627 or url.startswith("rediss://")
1628 or url.startswith("unix://")
1629 ):
1630 raise ValueError(
1631 "Redis URL must specify one of the following "
1632 "schemes (redis://, rediss://, unix://)"
1633 )
1634
1635 url = urlparse(url)
1636 kwargs = {}
1637
1638 for name, value in parse_qs(url.query).items():
1639 if value and len(value) > 0:
1640 value = unquote(value[0])
1641 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1642 if parser:
1643 try:
1644 kwargs[name] = parser(value)
1645 except (TypeError, ValueError):
1646 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1647 else:
1648 kwargs[name] = value
1649
1650 if url.username:
1651 kwargs["username"] = unquote(url.username)
1652 if url.password:
1653 kwargs["password"] = unquote(url.password)
1654
1655 # We only support redis://, rediss:// and unix:// schemes.
1656 if url.scheme == "unix":
1657 if url.path:
1658 kwargs["path"] = unquote(url.path)
1659 kwargs["connection_class"] = UnixDomainSocketConnection
1660
1661 else: # implied: url.scheme in ("redis", "rediss"):
1662 if url.hostname:
1663 kwargs["host"] = unquote(url.hostname)
1664 if url.port:
1665 kwargs["port"] = int(url.port)
1666
1667 # If there's a path argument, use it as the db argument if a
1668 # querystring value wasn't specified
1669 if url.path and "db" not in kwargs:
1670 try:
1671 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1672 except (AttributeError, ValueError):
1673 pass
1674
1675 if url.scheme == "rediss":
1676 kwargs["connection_class"] = SSLConnection
1677
1678 return kwargs
1679
1680
1681_CP = TypeVar("_CP", bound="ConnectionPool")
1682
1683
1684class ConnectionPool:
1685 """
1686 Create a connection pool. ``If max_connections`` is set, then this
1687 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
1688 limit is reached.
1689
1690 By default, TCP connections are created unless ``connection_class``
1691 is specified. Use class:`.UnixDomainSocketConnection` for
1692 unix sockets.
1693 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1694
1695 Any additional keyword arguments are passed to the constructor of
1696 ``connection_class``.
1697 """
1698
1699 @classmethod
1700 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1701 """
1702 Return a connection pool configured from the given URL.
1703
1704 For example::
1705
1706 redis://[[username]:[password]]@localhost:6379/0
1707 rediss://[[username]:[password]]@localhost:6379/0
1708 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1709
1710 Three URL schemes are supported:
1711
1712 - `redis://` creates a TCP socket connection. See more at:
1713 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1714 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1715 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1716 - ``unix://``: creates a Unix Domain Socket connection.
1717
1718 The username, password, hostname, path and all querystring values
1719 are passed through urllib.parse.unquote in order to replace any
1720 percent-encoded values with their corresponding characters.
1721
1722 There are several ways to specify a database number. The first value
1723 found will be used:
1724
1725 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1726 2. If using the redis:// or rediss:// schemes, the path argument
1727 of the url, e.g. redis://localhost/0
1728 3. A ``db`` keyword argument to this function.
1729
1730 If none of these options are specified, the default db=0 is used.
1731
1732 All querystring options are cast to their appropriate Python types.
1733 Boolean arguments can be specified with string values "True"/"False"
1734 or "Yes"/"No". Values that cannot be properly cast cause a
1735 ``ValueError`` to be raised. Once parsed, the querystring arguments
1736 and keyword arguments are passed to the ``ConnectionPool``'s
1737 class initializer. In the case of conflicting arguments, querystring
1738 arguments always win.
1739 """
1740 url_options = parse_url(url)
1741
1742 if "connection_class" in kwargs:
1743 url_options["connection_class"] = kwargs["connection_class"]
1744
1745 kwargs.update(url_options)
1746 return cls(**kwargs)
1747
1748 def __init__(
1749 self,
1750 connection_class=Connection,
1751 max_connections: Optional[int] = None,
1752 cache_factory: Optional[CacheFactoryInterface] = None,
1753 **connection_kwargs,
1754 ):
1755 max_connections = max_connections or 2**31
1756 if not isinstance(max_connections, int) or max_connections < 0:
1757 raise ValueError('"max_connections" must be a positive integer')
1758
1759 self.connection_class = connection_class
1760 self.connection_kwargs = connection_kwargs
1761 self.max_connections = max_connections
1762 self.cache = None
1763 self._cache_factory = cache_factory
1764
1765 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1766 if self.connection_kwargs.get("protocol") not in [3, "3"]:
1767 raise RedisError("Client caching is only supported with RESP version 3")
1768
1769 cache = self.connection_kwargs.get("cache")
1770
1771 if cache is not None:
1772 if not isinstance(cache, CacheInterface):
1773 raise ValueError("Cache must implement CacheInterface")
1774
1775 self.cache = cache
1776 else:
1777 if self._cache_factory is not None:
1778 self.cache = self._cache_factory.get_cache()
1779 else:
1780 self.cache = CacheFactory(
1781 self.connection_kwargs.get("cache_config")
1782 ).get_cache()
1783
1784 connection_kwargs.pop("cache", None)
1785 connection_kwargs.pop("cache_config", None)
1786
1787 if self.connection_kwargs.get(
1788 "maint_notifications_pool_handler"
1789 ) or self.connection_kwargs.get("maint_notifications_config"):
1790 if self.connection_kwargs.get("protocol") not in [3, "3"]:
1791 raise RedisError(
1792 "Push handlers on connection are only supported with RESP version 3"
1793 )
1794 config = self.connection_kwargs.get("maint_notifications_config", None) or (
1795 self.connection_kwargs.get("maint_notifications_pool_handler").config
1796 if self.connection_kwargs.get("maint_notifications_pool_handler")
1797 else None
1798 )
1799
1800 if config and config.enabled:
1801 self._update_connection_kwargs_for_maint_notifications()
1802
1803 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1804 if self._event_dispatcher is None:
1805 self._event_dispatcher = EventDispatcher()
1806
1807 # a lock to protect the critical section in _checkpid().
1808 # this lock is acquired when the process id changes, such as
1809 # after a fork. during this time, multiple threads in the child
1810 # process could attempt to acquire this lock. the first thread
1811 # to acquire the lock will reset the data structures and lock
1812 # object of this pool. subsequent threads acquiring this lock
1813 # will notice the first thread already did the work and simply
1814 # release the lock.
1815
1816 self._fork_lock = threading.RLock()
1817 self._lock = threading.RLock()
1818
1819 self.reset()
1820
1821 def __repr__(self) -> str:
1822 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1823 return (
1824 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1825 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1826 f"({conn_kwargs})>)>"
1827 )
1828
1829 def get_protocol(self):
1830 """
1831 Returns:
1832 The RESP protocol version, or ``None`` if the protocol is not specified,
1833 in which case the server default will be used.
1834 """
1835 return self.connection_kwargs.get("protocol", None)
1836
1837 def maint_notifications_pool_handler_enabled(self):
1838 """
1839 Returns:
1840 True if the maintenance notifications pool handler is enabled, False otherwise.
1841 """
1842 maint_notifications_config = self.connection_kwargs.get(
1843 "maint_notifications_config", None
1844 )
1845
1846 return maint_notifications_config and maint_notifications_config.enabled
1847
1848 def set_maint_notifications_pool_handler(
1849 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
1850 ):
1851 self.connection_kwargs.update(
1852 {
1853 "maint_notifications_pool_handler": maint_notifications_pool_handler,
1854 "maint_notifications_config": maint_notifications_pool_handler.config,
1855 }
1856 )
1857 self._update_connection_kwargs_for_maint_notifications()
1858
1859 self._update_maint_notifications_configs_for_connections(
1860 maint_notifications_pool_handler
1861 )
1862
1863 def _update_maint_notifications_configs_for_connections(
1864 self, maint_notifications_pool_handler
1865 ):
1866 """Update the maintenance notifications config for all connections in the pool."""
1867 with self._lock:
1868 for conn in self._available_connections:
1869 conn.set_maint_notifications_pool_handler(
1870 maint_notifications_pool_handler
1871 )
1872 conn.maint_notifications_config = (
1873 maint_notifications_pool_handler.config
1874 )
1875 for conn in self._in_use_connections:
1876 conn.set_maint_notifications_pool_handler(
1877 maint_notifications_pool_handler
1878 )
1879 conn.maint_notifications_config = (
1880 maint_notifications_pool_handler.config
1881 )
1882
1883 def _update_connection_kwargs_for_maint_notifications(self):
1884 """Store original connection parameters for maintenance notifications."""
1885 if self.connection_kwargs.get("orig_host_address", None) is None:
1886 # If orig_host_address is None it means we haven't
1887 # configured the original values yet
1888 self.connection_kwargs.update(
1889 {
1890 "orig_host_address": self.connection_kwargs.get("host"),
1891 "orig_socket_timeout": self.connection_kwargs.get(
1892 "socket_timeout", None
1893 ),
1894 "orig_socket_connect_timeout": self.connection_kwargs.get(
1895 "socket_connect_timeout", None
1896 ),
1897 }
1898 )
1899
1900 def reset(self) -> None:
1901 self._created_connections = 0
1902 self._available_connections = []
1903 self._in_use_connections = set()
1904
1905 # this must be the last operation in this method. while reset() is
1906 # called when holding _fork_lock, other threads in this process
1907 # can call _checkpid() which compares self.pid and os.getpid() without
1908 # holding any lock (for performance reasons). keeping this assignment
1909 # as the last operation ensures that those other threads will also
1910 # notice a pid difference and block waiting for the first thread to
1911 # release _fork_lock. when each of these threads eventually acquire
1912 # _fork_lock, they will notice that another thread already called
1913 # reset() and they will immediately release _fork_lock and continue on.
1914 self.pid = os.getpid()
1915
1916 def _checkpid(self) -> None:
1917 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
1918 # systems. this is called by all ConnectionPool methods that
1919 # manipulate the pool's state such as get_connection() and release().
1920 #
1921 # _checkpid() determines whether the process has forked by comparing
1922 # the current process id to the process id saved on the ConnectionPool
1923 # instance. if these values are the same, _checkpid() simply returns.
1924 #
1925 # when the process ids differ, _checkpid() assumes that the process
1926 # has forked and that we're now running in the child process. the child
1927 # process cannot use the parent's file descriptors (e.g., sockets).
1928 # therefore, when _checkpid() sees the process id change, it calls
1929 # reset() in order to reinitialize the child's ConnectionPool. this
1930 # will cause the child to make all new connection objects.
1931 #
1932 # _checkpid() is protected by self._fork_lock to ensure that multiple
1933 # threads in the child process do not call reset() multiple times.
1934 #
1935 # there is an extremely small chance this could fail in the following
1936 # scenario:
1937 # 1. process A calls _checkpid() for the first time and acquires
1938 # self._fork_lock.
1939 # 2. while holding self._fork_lock, process A forks (the fork()
1940 # could happen in a different thread owned by process A)
1941 # 3. process B (the forked child process) inherits the
1942 # ConnectionPool's state from the parent. that state includes
1943 # a locked _fork_lock. process B will not be notified when
1944 # process A releases the _fork_lock and will thus never be
1945 # able to acquire the _fork_lock.
1946 #
1947 # to mitigate this possible deadlock, _checkpid() will only wait 5
1948 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1949 # that time it is assumed that the child is deadlocked and a
1950 # redis.ChildDeadlockedError error is raised.
1951 if self.pid != os.getpid():
1952 acquired = self._fork_lock.acquire(timeout=5)
1953 if not acquired:
1954 raise ChildDeadlockedError
1955 # reset() the instance for the new process if another thread
1956 # hasn't already done so
1957 try:
1958 if self.pid != os.getpid():
1959 self.reset()
1960 finally:
1961 self._fork_lock.release()
1962
1963 @deprecated_args(
1964 args_to_warn=["*"],
1965 reason="Use get_connection() without args instead",
1966 version="5.3.0",
1967 )
1968 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
1969 "Get a connection from the pool"
1970
1971 self._checkpid()
1972 with self._lock:
1973 try:
1974 connection = self._available_connections.pop()
1975 except IndexError:
1976 connection = self.make_connection()
1977 self._in_use_connections.add(connection)
1978
1979 try:
1980 # ensure this connection is connected to Redis
1981 connection.connect()
1982 # connections that the pool provides should be ready to send
1983 # a command. if not, the connection was either returned to the
1984 # pool before all data has been read or the socket has been
1985 # closed. either way, reconnect and verify everything is good.
1986 try:
1987 if (
1988 connection.can_read()
1989 and self.cache is None
1990 and not self.maint_notifications_pool_handler_enabled()
1991 ):
1992 raise ConnectionError("Connection has data")
1993 except (ConnectionError, TimeoutError, OSError):
1994 connection.disconnect()
1995 connection.connect()
1996 if connection.can_read():
1997 raise ConnectionError("Connection not ready")
1998 except BaseException:
1999 # release the connection back to the pool so that we don't
2000 # leak it
2001 self.release(connection)
2002 raise
2003 return connection
2004
2005 def get_encoder(self) -> Encoder:
2006 "Return an encoder based on encoding settings"
2007 kwargs = self.connection_kwargs
2008 return Encoder(
2009 encoding=kwargs.get("encoding", "utf-8"),
2010 encoding_errors=kwargs.get("encoding_errors", "strict"),
2011 decode_responses=kwargs.get("decode_responses", False),
2012 )
2013
2014 def make_connection(self) -> "ConnectionInterface":
2015 "Create a new connection"
2016 if self._created_connections >= self.max_connections:
2017 raise MaxConnectionsError("Too many connections")
2018 self._created_connections += 1
2019
2020 kwargs = dict(self.connection_kwargs)
2021
2022 if self.cache is not None:
2023 return CacheProxyConnection(
2024 self.connection_class(**kwargs), self.cache, self._lock
2025 )
2026 return self.connection_class(**kwargs)
2027
2028 def release(self, connection: "Connection") -> None:
2029 "Releases the connection back to the pool"
2030 self._checkpid()
2031 with self._lock:
2032 try:
2033 self._in_use_connections.remove(connection)
2034 except KeyError:
2035 # Gracefully fail when a connection is returned to this pool
2036 # that the pool doesn't actually own
2037 return
2038
2039 if self.owns_connection(connection):
2040 if connection.should_reconnect():
2041 connection.disconnect()
2042 self._available_connections.append(connection)
2043 self._event_dispatcher.dispatch(
2044 AfterConnectionReleasedEvent(connection)
2045 )
2046 else:
2047 # Pool doesn't own this connection, do not add it back
2048 # to the pool.
2049 # The created connections count should not be changed,
2050 # because the connection was not created by the pool.
2051 connection.disconnect()
2052 return
2053
2054 def owns_connection(self, connection: "Connection") -> int:
2055 return connection.pid == self.pid
2056
2057 def disconnect(self, inuse_connections: bool = True) -> None:
2058 """
2059 Disconnects connections in the pool
2060
2061 If ``inuse_connections`` is True, disconnect connections that are
2062 current in use, potentially by other threads. Otherwise only disconnect
2063 connections that are idle in the pool.
2064 """
2065 self._checkpid()
2066 with self._lock:
2067 if inuse_connections:
2068 connections = chain(
2069 self._available_connections, self._in_use_connections
2070 )
2071 else:
2072 connections = self._available_connections
2073
2074 for connection in connections:
2075 connection.disconnect()
2076
2077 def close(self) -> None:
2078 """Close the pool, disconnecting all connections"""
2079 self.disconnect()
2080
2081 def set_retry(self, retry: Retry) -> None:
2082 self.connection_kwargs.update({"retry": retry})
2083 for conn in self._available_connections:
2084 conn.retry = retry
2085 for conn in self._in_use_connections:
2086 conn.retry = retry
2087
2088 def re_auth_callback(self, token: TokenInterface):
2089 with self._lock:
2090 for conn in self._available_connections:
2091 conn.retry.call_with_retry(
2092 lambda: conn.send_command(
2093 "AUTH", token.try_get("oid"), token.get_value()
2094 ),
2095 lambda error: self._mock(error),
2096 )
2097 conn.retry.call_with_retry(
2098 lambda: conn.read_response(), lambda error: self._mock(error)
2099 )
2100 for conn in self._in_use_connections:
2101 conn.set_re_auth_token(token)
2102
2103 def _should_update_connection(
2104 self,
2105 conn: "Connection",
2106 matching_pattern: Literal[
2107 "connected_address", "configured_address", "notification_hash"
2108 ] = "connected_address",
2109 matching_address: Optional[str] = None,
2110 matching_notification_hash: Optional[int] = None,
2111 ) -> bool:
2112 """
2113 Check if the connection should be updated based on the matching criteria.
2114 """
2115 if matching_pattern == "connected_address":
2116 if matching_address and conn.getpeername() != matching_address:
2117 return False
2118 elif matching_pattern == "configured_address":
2119 if matching_address and conn.host != matching_address:
2120 return False
2121 elif matching_pattern == "notification_hash":
2122 if (
2123 matching_notification_hash
2124 and conn.maintenance_notification_hash != matching_notification_hash
2125 ):
2126 return False
2127 return True
2128
2129 def update_connection_settings(
2130 self,
2131 conn: "Connection",
2132 state: Optional["MaintenanceState"] = None,
2133 maintenance_notification_hash: Optional[int] = None,
2134 host_address: Optional[str] = None,
2135 relaxed_timeout: Optional[float] = None,
2136 update_notification_hash: bool = False,
2137 reset_host_address: bool = False,
2138 reset_relaxed_timeout: bool = False,
2139 ):
2140 """
2141 Update the settings for a single connection.
2142 """
2143 if state:
2144 conn.maintenance_state = state
2145
2146 if update_notification_hash:
2147 # update the notification hash only if requested
2148 conn.maintenance_notification_hash = maintenance_notification_hash
2149
2150 if host_address is not None:
2151 conn.set_tmp_settings(tmp_host_address=host_address)
2152
2153 if relaxed_timeout is not None:
2154 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2155
2156 if reset_relaxed_timeout or reset_host_address:
2157 conn.reset_tmp_settings(
2158 reset_host_address=reset_host_address,
2159 reset_relaxed_timeout=reset_relaxed_timeout,
2160 )
2161
2162 conn.update_current_socket_timeout(relaxed_timeout)
2163
2164 def update_connections_settings(
2165 self,
2166 state: Optional["MaintenanceState"] = None,
2167 maintenance_notification_hash: Optional[int] = None,
2168 host_address: Optional[str] = None,
2169 relaxed_timeout: Optional[float] = None,
2170 matching_address: Optional[str] = None,
2171 matching_notification_hash: Optional[int] = None,
2172 matching_pattern: Literal[
2173 "connected_address", "configured_address", "notification_hash"
2174 ] = "connected_address",
2175 update_notification_hash: bool = False,
2176 reset_host_address: bool = False,
2177 reset_relaxed_timeout: bool = False,
2178 include_free_connections: bool = True,
2179 ):
2180 """
2181 Update the settings for all matching connections in the pool.
2182
2183 This method does not create new connections.
2184 This method does not affect the connection kwargs.
2185
2186 :param state: The maintenance state to set for the connection.
2187 :param maintenance_notification_hash: The hash of the maintenance notification
2188 to set for the connection.
2189 :param host_address: The host address to set for the connection.
2190 :param relaxed_timeout: The relaxed timeout to set for the connection.
2191 :param matching_address: The address to match for the connection.
2192 :param matching_notification_hash: The notification hash to match for the connection.
2193 :param matching_pattern: The pattern to match for the connection.
2194 :param update_notification_hash: Whether to update the notification hash for the connection.
2195 :param reset_host_address: Whether to reset the host address to the original address.
2196 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2197 :param include_free_connections: Whether to include free/available connections.
2198 """
2199 with self._lock:
2200 for conn in self._in_use_connections:
2201 if self._should_update_connection(
2202 conn,
2203 matching_pattern,
2204 matching_address,
2205 matching_notification_hash,
2206 ):
2207 self.update_connection_settings(
2208 conn,
2209 state=state,
2210 maintenance_notification_hash=maintenance_notification_hash,
2211 host_address=host_address,
2212 relaxed_timeout=relaxed_timeout,
2213 update_notification_hash=update_notification_hash,
2214 reset_host_address=reset_host_address,
2215 reset_relaxed_timeout=reset_relaxed_timeout,
2216 )
2217
2218 if include_free_connections:
2219 for conn in self._available_connections:
2220 if self._should_update_connection(
2221 conn,
2222 matching_pattern,
2223 matching_address,
2224 matching_notification_hash,
2225 ):
2226 self.update_connection_settings(
2227 conn,
2228 state=state,
2229 maintenance_notification_hash=maintenance_notification_hash,
2230 host_address=host_address,
2231 relaxed_timeout=relaxed_timeout,
2232 update_notification_hash=update_notification_hash,
2233 reset_host_address=reset_host_address,
2234 reset_relaxed_timeout=reset_relaxed_timeout,
2235 )
2236
2237 def update_connection_kwargs(
2238 self,
2239 **kwargs,
2240 ):
2241 """
2242 Update the connection kwargs for all future connections.
2243
2244 This method updates the connection kwargs for all future connections created by the pool.
2245 Existing connections are not affected.
2246 """
2247 self.connection_kwargs.update(kwargs)
2248
2249 def update_active_connections_for_reconnect(
2250 self,
2251 moving_address_src: Optional[str] = None,
2252 ):
2253 """
2254 Mark all active connections for reconnect.
2255 This is used when a cluster node is migrated to a different address.
2256
2257 :param moving_address_src: The address of the node that is being moved.
2258 """
2259 with self._lock:
2260 for conn in self._in_use_connections:
2261 if self._should_update_connection(
2262 conn, "connected_address", moving_address_src
2263 ):
2264 conn.mark_for_reconnect()
2265
2266 def disconnect_free_connections(
2267 self,
2268 moving_address_src: Optional[str] = None,
2269 ):
2270 """
2271 Disconnect all free/available connections.
2272 This is used when a cluster node is migrated to a different address.
2273
2274 :param moving_address_src: The address of the node that is being moved.
2275 """
2276 with self._lock:
2277 for conn in self._available_connections:
2278 if self._should_update_connection(
2279 conn, "connected_address", moving_address_src
2280 ):
2281 conn.disconnect()
2282
2283 async def _mock(self, error: RedisError):
2284 """
2285 Dummy functions, needs to be passed as error callback to retry object.
2286 :param error:
2287 :return:
2288 """
2289 pass
2290
2291
2292class BlockingConnectionPool(ConnectionPool):
2293 """
2294 Thread-safe blocking connection pool::
2295
2296 >>> from redis.client import Redis
2297 >>> client = Redis(connection_pool=BlockingConnectionPool())
2298
2299 It performs the same function as the default
2300 :py:class:`~redis.ConnectionPool` implementation, in that,
2301 it maintains a pool of reusable connections that can be shared by
2302 multiple redis clients (safely across threads if required).
2303
2304 The difference is that, in the event that a client tries to get a
2305 connection from the pool when all of connections are in use, rather than
2306 raising a :py:class:`~redis.ConnectionError` (as the default
2307 :py:class:`~redis.ConnectionPool` implementation does), it
2308 makes the client wait ("blocks") for a specified number of seconds until
2309 a connection becomes available.
2310
2311 Use ``max_connections`` to increase / decrease the pool size::
2312
2313 >>> pool = BlockingConnectionPool(max_connections=10)
2314
2315 Use ``timeout`` to tell it either how many seconds to wait for a connection
2316 to become available, or to block forever:
2317
2318 >>> # Block forever.
2319 >>> pool = BlockingConnectionPool(timeout=None)
2320
2321 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
2322 >>> # not available.
2323 >>> pool = BlockingConnectionPool(timeout=5)
2324 """
2325
2326 def __init__(
2327 self,
2328 max_connections=50,
2329 timeout=20,
2330 connection_class=Connection,
2331 queue_class=LifoQueue,
2332 **connection_kwargs,
2333 ):
2334 self.queue_class = queue_class
2335 self.timeout = timeout
2336 self._in_maintenance = False
2337 self._locked = False
2338 super().__init__(
2339 connection_class=connection_class,
2340 max_connections=max_connections,
2341 **connection_kwargs,
2342 )
2343
2344 def reset(self):
2345 # Create and fill up a thread safe queue with ``None`` values.
2346 try:
2347 if self._in_maintenance:
2348 self._lock.acquire()
2349 self._locked = True
2350 self.pool = self.queue_class(self.max_connections)
2351 while True:
2352 try:
2353 self.pool.put_nowait(None)
2354 except Full:
2355 break
2356
2357 # Keep a list of actual connection instances so that we can
2358 # disconnect them later.
2359 self._connections = []
2360 finally:
2361 if self._locked:
2362 try:
2363 self._lock.release()
2364 except Exception:
2365 pass
2366 self._locked = False
2367
2368 # this must be the last operation in this method. while reset() is
2369 # called when holding _fork_lock, other threads in this process
2370 # can call _checkpid() which compares self.pid and os.getpid() without
2371 # holding any lock (for performance reasons). keeping this assignment
2372 # as the last operation ensures that those other threads will also
2373 # notice a pid difference and block waiting for the first thread to
2374 # release _fork_lock. when each of these threads eventually acquire
2375 # _fork_lock, they will notice that another thread already called
2376 # reset() and they will immediately release _fork_lock and continue on.
2377 self.pid = os.getpid()
2378
2379 def make_connection(self):
2380 "Make a fresh connection."
2381 try:
2382 if self._in_maintenance:
2383 self._lock.acquire()
2384 self._locked = True
2385
2386 if self.cache is not None:
2387 connection = CacheProxyConnection(
2388 self.connection_class(**self.connection_kwargs),
2389 self.cache,
2390 self._lock,
2391 )
2392 else:
2393 connection = self.connection_class(**self.connection_kwargs)
2394 self._connections.append(connection)
2395 return connection
2396 finally:
2397 if self._locked:
2398 try:
2399 self._lock.release()
2400 except Exception:
2401 pass
2402 self._locked = False
2403
2404 @deprecated_args(
2405 args_to_warn=["*"],
2406 reason="Use get_connection() without args instead",
2407 version="5.3.0",
2408 )
2409 def get_connection(self, command_name=None, *keys, **options):
2410 """
2411 Get a connection, blocking for ``self.timeout`` until a connection
2412 is available from the pool.
2413
2414 If the connection returned is ``None`` then creates a new connection.
2415 Because we use a last-in first-out queue, the existing connections
2416 (having been returned to the pool after the initial ``None`` values
2417 were added) will be returned before ``None`` values. This means we only
2418 create new connections when we need to, i.e.: the actual number of
2419 connections will only increase in response to demand.
2420 """
2421 # Make sure we haven't changed process.
2422 self._checkpid()
2423
2424 # Try and get a connection from the pool. If one isn't available within
2425 # self.timeout then raise a ``ConnectionError``.
2426 connection = None
2427 try:
2428 if self._in_maintenance:
2429 self._lock.acquire()
2430 self._locked = True
2431 try:
2432 connection = self.pool.get(block=True, timeout=self.timeout)
2433 except Empty:
2434 # Note that this is not caught by the redis client and will be
2435 # raised unless handled by application code. If you want never to
2436 raise ConnectionError("No connection available.")
2437
2438 # If the ``connection`` is actually ``None`` then that's a cue to make
2439 # a new connection to add to the pool.
2440 if connection is None:
2441 connection = self.make_connection()
2442 finally:
2443 if self._locked:
2444 try:
2445 self._lock.release()
2446 except Exception:
2447 pass
2448 self._locked = False
2449
2450 try:
2451 # ensure this connection is connected to Redis
2452 connection.connect()
2453 # connections that the pool provides should be ready to send
2454 # a command. if not, the connection was either returned to the
2455 # pool before all data has been read or the socket has been
2456 # closed. either way, reconnect and verify everything is good.
2457 try:
2458 if connection.can_read():
2459 raise ConnectionError("Connection has data")
2460 except (ConnectionError, TimeoutError, OSError):
2461 connection.disconnect()
2462 connection.connect()
2463 if connection.can_read():
2464 raise ConnectionError("Connection not ready")
2465 except BaseException:
2466 # release the connection back to the pool so that we don't leak it
2467 self.release(connection)
2468 raise
2469
2470 return connection
2471
2472 def release(self, connection):
2473 "Releases the connection back to the pool."
2474 # Make sure we haven't changed process.
2475 self._checkpid()
2476
2477 try:
2478 if self._in_maintenance:
2479 self._lock.acquire()
2480 self._locked = True
2481 if not self.owns_connection(connection):
2482 # pool doesn't own this connection. do not add it back
2483 # to the pool. instead add a None value which is a placeholder
2484 # that will cause the pool to recreate the connection if
2485 # its needed.
2486 connection.disconnect()
2487 self.pool.put_nowait(None)
2488 return
2489 if connection.should_reconnect():
2490 connection.disconnect()
2491 # Put the connection back into the pool.
2492 try:
2493 self.pool.put_nowait(connection)
2494 except Full:
2495 # perhaps the pool has been reset() after a fork? regardless,
2496 # we don't want this connection
2497 pass
2498 finally:
2499 if self._locked:
2500 try:
2501 self._lock.release()
2502 except Exception:
2503 pass
2504 self._locked = False
2505
2506 def disconnect(self):
2507 "Disconnects all connections in the pool."
2508 self._checkpid()
2509 try:
2510 if self._in_maintenance:
2511 self._lock.acquire()
2512 self._locked = True
2513 for connection in self._connections:
2514 connection.disconnect()
2515 finally:
2516 if self._locked:
2517 try:
2518 self._lock.release()
2519 except Exception:
2520 pass
2521 self._locked = False
2522
2523 def update_connections_settings(
2524 self,
2525 state: Optional["MaintenanceState"] = None,
2526 maintenance_notification_hash: Optional[int] = None,
2527 relaxed_timeout: Optional[float] = None,
2528 host_address: Optional[str] = None,
2529 matching_address: Optional[str] = None,
2530 matching_notification_hash: Optional[int] = None,
2531 matching_pattern: Literal[
2532 "connected_address", "configured_address", "notification_hash"
2533 ] = "connected_address",
2534 update_notification_hash: bool = False,
2535 reset_host_address: bool = False,
2536 reset_relaxed_timeout: bool = False,
2537 include_free_connections: bool = True,
2538 ):
2539 """
2540 Override base class method to work with BlockingConnectionPool's structure.
2541 """
2542 with self._lock:
2543 if include_free_connections:
2544 for conn in tuple(self._connections):
2545 if self._should_update_connection(
2546 conn,
2547 matching_pattern,
2548 matching_address,
2549 matching_notification_hash,
2550 ):
2551 self.update_connection_settings(
2552 conn,
2553 state=state,
2554 maintenance_notification_hash=maintenance_notification_hash,
2555 host_address=host_address,
2556 relaxed_timeout=relaxed_timeout,
2557 update_notification_hash=update_notification_hash,
2558 reset_host_address=reset_host_address,
2559 reset_relaxed_timeout=reset_relaxed_timeout,
2560 )
2561 else:
2562 connections_in_queue = {conn for conn in self.pool.queue if conn}
2563 for conn in self._connections:
2564 if conn not in connections_in_queue:
2565 if self._should_update_connection(
2566 conn,
2567 matching_pattern,
2568 matching_address,
2569 matching_notification_hash,
2570 ):
2571 self.update_connection_settings(
2572 conn,
2573 state=state,
2574 maintenance_notification_hash=maintenance_notification_hash,
2575 host_address=host_address,
2576 relaxed_timeout=relaxed_timeout,
2577 update_notification_hash=update_notification_hash,
2578 reset_host_address=reset_host_address,
2579 reset_relaxed_timeout=reset_relaxed_timeout,
2580 )
2581
2582 def update_active_connections_for_reconnect(
2583 self,
2584 moving_address_src: Optional[str] = None,
2585 ):
2586 """
2587 Mark all active connections for reconnect.
2588 This is used when a cluster node is migrated to a different address.
2589
2590 :param moving_address_src: The address of the node that is being moved.
2591 """
2592 with self._lock:
2593 connections_in_queue = {conn for conn in self.pool.queue if conn}
2594 for conn in self._connections:
2595 if conn not in connections_in_queue:
2596 if self._should_update_connection(
2597 conn,
2598 matching_pattern="connected_address",
2599 matching_address=moving_address_src,
2600 ):
2601 conn.mark_for_reconnect()
2602
2603 def disconnect_free_connections(
2604 self,
2605 moving_address_src: Optional[str] = None,
2606 ):
2607 """
2608 Disconnect all free/available connections.
2609 This is used when a cluster node is migrated to a different address.
2610
2611 :param moving_address_src: The address of the node that is being moved.
2612 """
2613 with self._lock:
2614 existing_connections = self.pool.queue
2615
2616 for conn in existing_connections:
2617 if conn:
2618 if self._should_update_connection(
2619 conn, "connected_address", moving_address_src
2620 ):
2621 conn.disconnect()
2622
2623 def _update_maint_notifications_config_for_connections(
2624 self, maint_notifications_config
2625 ):
2626 for conn in tuple(self._connections):
2627 conn.maint_notifications_config = maint_notifications_config
2628
2629 def _update_maint_notifications_configs_for_connections(
2630 self, maint_notifications_pool_handler
2631 ):
2632 """Update the maintenance notifications config for all connections in the pool."""
2633 with self._lock:
2634 for conn in tuple(self._connections):
2635 conn.set_maint_notifications_pool_handler(
2636 maint_notifications_pool_handler
2637 )
2638 conn.maint_notifications_config = (
2639 maint_notifications_pool_handler.config
2640 )
2641
2642 def set_in_maintenance(self, in_maintenance: bool):
2643 """
2644 Sets a flag that this Blocking ConnectionPool is in maintenance mode.
2645
2646 This is used to prevent new connections from being created while we are in maintenance mode.
2647 The pool will be in maintenance mode only when we are processing a MOVING notification.
2648 """
2649 self._in_maintenance = in_maintenance