1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
12from urllib.parse import parse_qs, unquote, urlparse
13
14from redis.cache import (
15 CacheEntry,
16 CacheEntryStatus,
17 CacheFactory,
18 CacheFactoryInterface,
19 CacheInterface,
20 CacheKey,
21)
22
23from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
24from .auth.token import TokenInterface
25from .backoff import NoBackoff
26from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
27from .event import AfterConnectionReleasedEvent, EventDispatcher
28from .exceptions import (
29 AuthenticationError,
30 AuthenticationWrongNumberOfArgsError,
31 ChildDeadlockedError,
32 ConnectionError,
33 DataError,
34 MaxConnectionsError,
35 RedisError,
36 ResponseError,
37 TimeoutError,
38)
39from .retry import Retry
40from .utils import (
41 CRYPTOGRAPHY_AVAILABLE,
42 HIREDIS_AVAILABLE,
43 SSL_AVAILABLE,
44 compare_versions,
45 deprecated_args,
46 ensure_string,
47 format_error_message,
48 get_lib_version,
49 str_if_bytes,
50)
51
52if SSL_AVAILABLE:
53 import ssl
54else:
55 ssl = None
56
57if HIREDIS_AVAILABLE:
58 import hiredis
59
60SYM_STAR = b"*"
61SYM_DOLLAR = b"$"
62SYM_CRLF = b"\r\n"
63SYM_EMPTY = b""
64
65DEFAULT_RESP_VERSION = 2
66
67SENTINEL = object()
68
69DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
70if HIREDIS_AVAILABLE:
71 DefaultParser = _HiredisParser
72else:
73 DefaultParser = _RESP2Parser
74
75
76class HiredisRespSerializer:
77 def pack(self, *args: List):
78 """Pack a series of arguments into the Redis protocol"""
79 output = []
80
81 if isinstance(args[0], str):
82 args = tuple(args[0].encode().split()) + args[1:]
83 elif b" " in args[0]:
84 args = tuple(args[0].split()) + args[1:]
85 try:
86 output.append(hiredis.pack_command(args))
87 except TypeError:
88 _, value, traceback = sys.exc_info()
89 raise DataError(value).with_traceback(traceback)
90
91 return output
92
93
94class PythonRespSerializer:
95 def __init__(self, buffer_cutoff, encode) -> None:
96 self._buffer_cutoff = buffer_cutoff
97 self.encode = encode
98
99 def pack(self, *args):
100 """Pack a series of arguments into the Redis protocol"""
101 output = []
102 # the client might have included 1 or more literal arguments in
103 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
104 # arguments to be sent separately, so split the first argument
105 # manually. These arguments should be bytestrings so that they are
106 # not encoded.
107 if isinstance(args[0], str):
108 args = tuple(args[0].encode().split()) + args[1:]
109 elif b" " in args[0]:
110 args = tuple(args[0].split()) + args[1:]
111
112 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
113
114 buffer_cutoff = self._buffer_cutoff
115 for arg in map(self.encode, args):
116 # to avoid large string mallocs, chunk the command into the
117 # output list if we're sending large values or memoryviews
118 arg_length = len(arg)
119 if (
120 len(buff) > buffer_cutoff
121 or arg_length > buffer_cutoff
122 or isinstance(arg, memoryview)
123 ):
124 buff = SYM_EMPTY.join(
125 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
126 )
127 output.append(buff)
128 output.append(arg)
129 buff = SYM_CRLF
130 else:
131 buff = SYM_EMPTY.join(
132 (
133 buff,
134 SYM_DOLLAR,
135 str(arg_length).encode(),
136 SYM_CRLF,
137 arg,
138 SYM_CRLF,
139 )
140 )
141 output.append(buff)
142 return output
143
144
145class ConnectionInterface:
146 @abstractmethod
147 def repr_pieces(self):
148 pass
149
150 @abstractmethod
151 def register_connect_callback(self, callback):
152 pass
153
154 @abstractmethod
155 def deregister_connect_callback(self, callback):
156 pass
157
158 @abstractmethod
159 def set_parser(self, parser_class):
160 pass
161
162 @abstractmethod
163 def get_protocol(self):
164 pass
165
166 @abstractmethod
167 def connect(self):
168 pass
169
170 @abstractmethod
171 def on_connect(self):
172 pass
173
174 @abstractmethod
175 def disconnect(self, *args):
176 pass
177
178 @abstractmethod
179 def check_health(self):
180 pass
181
182 @abstractmethod
183 def send_packed_command(self, command, check_health=True):
184 pass
185
186 @abstractmethod
187 def send_command(self, *args, **kwargs):
188 pass
189
190 @abstractmethod
191 def can_read(self, timeout=0):
192 pass
193
194 @abstractmethod
195 def read_response(
196 self,
197 disable_decoding=False,
198 *,
199 disconnect_on_error=True,
200 push_request=False,
201 ):
202 pass
203
204 @abstractmethod
205 def pack_command(self, *args):
206 pass
207
208 @abstractmethod
209 def pack_commands(self, commands):
210 pass
211
212 @property
213 @abstractmethod
214 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
215 pass
216
217 @abstractmethod
218 def set_re_auth_token(self, token: TokenInterface):
219 pass
220
221 @abstractmethod
222 def re_auth(self):
223 pass
224
225
226class AbstractConnection(ConnectionInterface):
227 "Manages communication to and from a Redis server"
228
229 def __init__(
230 self,
231 db: int = 0,
232 password: Optional[str] = None,
233 socket_timeout: Optional[float] = None,
234 socket_connect_timeout: Optional[float] = None,
235 retry_on_timeout: bool = False,
236 retry_on_error=SENTINEL,
237 encoding: str = "utf-8",
238 encoding_errors: str = "strict",
239 decode_responses: bool = False,
240 parser_class=DefaultParser,
241 socket_read_size: int = 65536,
242 health_check_interval: int = 0,
243 client_name: Optional[str] = None,
244 lib_name: Optional[str] = "redis-py",
245 lib_version: Optional[str] = get_lib_version(),
246 username: Optional[str] = None,
247 retry: Union[Any, None] = None,
248 redis_connect_func: Optional[Callable[[], None]] = None,
249 credential_provider: Optional[CredentialProvider] = None,
250 protocol: Optional[int] = 2,
251 command_packer: Optional[Callable[[], None]] = None,
252 event_dispatcher: Optional[EventDispatcher] = None,
253 ):
254 """
255 Initialize a new Connection.
256 To specify a retry policy for specific errors, first set
257 `retry_on_error` to a list of the error/s to retry on, then set
258 `retry` to a valid `Retry` object.
259 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
260 """
261 if (username or password) and credential_provider is not None:
262 raise DataError(
263 "'username' and 'password' cannot be passed along with 'credential_"
264 "provider'. Please provide only one of the following arguments: \n"
265 "1. 'password' and (optional) 'username'\n"
266 "2. 'credential_provider'"
267 )
268 if event_dispatcher is None:
269 self._event_dispatcher = EventDispatcher()
270 else:
271 self._event_dispatcher = event_dispatcher
272 self.pid = os.getpid()
273 self.db = db
274 self.client_name = client_name
275 self.lib_name = lib_name
276 self.lib_version = lib_version
277 self.credential_provider = credential_provider
278 self.password = password
279 self.username = username
280 self.socket_timeout = socket_timeout
281 if socket_connect_timeout is None:
282 socket_connect_timeout = socket_timeout
283 self.socket_connect_timeout = socket_connect_timeout
284 self.retry_on_timeout = retry_on_timeout
285 if retry_on_error is SENTINEL:
286 retry_on_error = []
287 if retry_on_timeout:
288 # Add TimeoutError to the errors list to retry on
289 retry_on_error.append(TimeoutError)
290 self.retry_on_error = retry_on_error
291 if retry or retry_on_error:
292 if retry is None:
293 self.retry = Retry(NoBackoff(), 1)
294 else:
295 # deep-copy the Retry object as it is mutable
296 self.retry = copy.deepcopy(retry)
297 # Update the retry's supported errors with the specified errors
298 self.retry.update_supported_errors(retry_on_error)
299 else:
300 self.retry = Retry(NoBackoff(), 0)
301 self.health_check_interval = health_check_interval
302 self.next_health_check = 0
303 self.redis_connect_func = redis_connect_func
304 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
305 self.handshake_metadata = None
306 self._sock = None
307 self._socket_read_size = socket_read_size
308 self.set_parser(parser_class)
309 self._connect_callbacks = []
310 self._buffer_cutoff = 6000
311 self._re_auth_token: Optional[TokenInterface] = None
312 try:
313 p = int(protocol)
314 except TypeError:
315 p = DEFAULT_RESP_VERSION
316 except ValueError:
317 raise ConnectionError("protocol must be an integer")
318 finally:
319 if p < 2 or p > 3:
320 raise ConnectionError("protocol must be either 2 or 3")
321 # p = DEFAULT_RESP_VERSION
322 self.protocol = p
323 self._command_packer = self._construct_command_packer(command_packer)
324
325 def __repr__(self):
326 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
327 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
328
329 @abstractmethod
330 def repr_pieces(self):
331 pass
332
333 def __del__(self):
334 try:
335 self.disconnect()
336 except Exception:
337 pass
338
339 def _construct_command_packer(self, packer):
340 if packer is not None:
341 return packer
342 elif HIREDIS_AVAILABLE:
343 return HiredisRespSerializer()
344 else:
345 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
346
347 def register_connect_callback(self, callback):
348 """
349 Register a callback to be called when the connection is established either
350 initially or reconnected. This allows listeners to issue commands that
351 are ephemeral to the connection, for example pub/sub subscription or
352 key tracking. The callback must be a _method_ and will be kept as
353 a weak reference.
354 """
355 wm = weakref.WeakMethod(callback)
356 if wm not in self._connect_callbacks:
357 self._connect_callbacks.append(wm)
358
359 def deregister_connect_callback(self, callback):
360 """
361 De-register a previously registered callback. It will no-longer receive
362 notifications on connection events. Calling this is not required when the
363 listener goes away, since the callbacks are kept as weak methods.
364 """
365 try:
366 self._connect_callbacks.remove(weakref.WeakMethod(callback))
367 except ValueError:
368 pass
369
370 def set_parser(self, parser_class):
371 """
372 Creates a new instance of parser_class with socket size:
373 _socket_read_size and assigns it to the parser for the connection
374 :param parser_class: The required parser class
375 """
376 self._parser = parser_class(socket_read_size=self._socket_read_size)
377
378 def connect(self):
379 "Connects to the Redis server if not already connected"
380 self.connect_check_health(check_health=True)
381
382 def connect_check_health(
383 self, check_health: bool = True, retry_socket_connect: bool = True
384 ):
385 if self._sock:
386 return
387 try:
388 if retry_socket_connect:
389 sock = self.retry.call_with_retry(
390 lambda: self._connect(), lambda error: self.disconnect(error)
391 )
392 else:
393 sock = self._connect()
394 except socket.timeout:
395 raise TimeoutError("Timeout connecting to server")
396 except OSError as e:
397 raise ConnectionError(self._error_message(e))
398
399 self._sock = sock
400 try:
401 if self.redis_connect_func is None:
402 # Use the default on_connect function
403 self.on_connect_check_health(check_health=check_health)
404 else:
405 # Use the passed function redis_connect_func
406 self.redis_connect_func(self)
407 except RedisError:
408 # clean up after any error in on_connect
409 self.disconnect()
410 raise
411
412 # run any user callbacks. right now the only internal callback
413 # is for pubsub channel/pattern resubscription
414 # first, remove any dead weakrefs
415 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
416 for ref in self._connect_callbacks:
417 callback = ref()
418 if callback:
419 callback(self)
420
421 @abstractmethod
422 def _connect(self):
423 pass
424
425 @abstractmethod
426 def _host_error(self):
427 pass
428
429 def _error_message(self, exception):
430 return format_error_message(self._host_error(), exception)
431
432 def on_connect(self):
433 self.on_connect_check_health(check_health=True)
434
435 def on_connect_check_health(self, check_health: bool = True):
436 "Initialize the connection, authenticate and select a database"
437 self._parser.on_connect(self)
438 parser = self._parser
439
440 auth_args = None
441 # if credential provider or username and/or password are set, authenticate
442 if self.credential_provider or (self.username or self.password):
443 cred_provider = (
444 self.credential_provider
445 or UsernamePasswordCredentialProvider(self.username, self.password)
446 )
447 auth_args = cred_provider.get_credentials()
448
449 # if resp version is specified and we have auth args,
450 # we need to send them via HELLO
451 if auth_args and self.protocol not in [2, "2"]:
452 if isinstance(self._parser, _RESP2Parser):
453 self.set_parser(_RESP3Parser)
454 # update cluster exception classes
455 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
456 self._parser.on_connect(self)
457 if len(auth_args) == 1:
458 auth_args = ["default", auth_args[0]]
459 # avoid checking health here -- PING will fail if we try
460 # to check the health prior to the AUTH
461 self.send_command(
462 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
463 )
464 self.handshake_metadata = self.read_response()
465 # if response.get(b"proto") != self.protocol and response.get(
466 # "proto"
467 # ) != self.protocol:
468 # raise ConnectionError("Invalid RESP version")
469 elif auth_args:
470 # avoid checking health here -- PING will fail if we try
471 # to check the health prior to the AUTH
472 self.send_command("AUTH", *auth_args, check_health=False)
473
474 try:
475 auth_response = self.read_response()
476 except AuthenticationWrongNumberOfArgsError:
477 # a username and password were specified but the Redis
478 # server seems to be < 6.0.0 which expects a single password
479 # arg. retry auth with just the password.
480 # https://github.com/andymccurdy/redis-py/issues/1274
481 self.send_command("AUTH", auth_args[-1], check_health=False)
482 auth_response = self.read_response()
483
484 if str_if_bytes(auth_response) != "OK":
485 raise AuthenticationError("Invalid Username or Password")
486
487 # if resp version is specified, switch to it
488 elif self.protocol not in [2, "2"]:
489 if isinstance(self._parser, _RESP2Parser):
490 self.set_parser(_RESP3Parser)
491 # update cluster exception classes
492 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
493 self._parser.on_connect(self)
494 self.send_command("HELLO", self.protocol, check_health=check_health)
495 self.handshake_metadata = self.read_response()
496 if (
497 self.handshake_metadata.get(b"proto") != self.protocol
498 and self.handshake_metadata.get("proto") != self.protocol
499 ):
500 raise ConnectionError("Invalid RESP version")
501
502 # if a client_name is given, set it
503 if self.client_name:
504 self.send_command(
505 "CLIENT",
506 "SETNAME",
507 self.client_name,
508 check_health=check_health,
509 )
510 if str_if_bytes(self.read_response()) != "OK":
511 raise ConnectionError("Error setting client name")
512
513 try:
514 # set the library name and version
515 if self.lib_name:
516 self.send_command(
517 "CLIENT",
518 "SETINFO",
519 "LIB-NAME",
520 self.lib_name,
521 check_health=check_health,
522 )
523 self.read_response()
524 except ResponseError:
525 pass
526
527 try:
528 if self.lib_version:
529 self.send_command(
530 "CLIENT",
531 "SETINFO",
532 "LIB-VER",
533 self.lib_version,
534 check_health=check_health,
535 )
536 self.read_response()
537 except ResponseError:
538 pass
539
540 # if a database is specified, switch to it
541 if self.db:
542 self.send_command("SELECT", self.db, check_health=check_health)
543 if str_if_bytes(self.read_response()) != "OK":
544 raise ConnectionError("Invalid Database")
545
546 def disconnect(self, *args):
547 "Disconnects from the Redis server"
548 self._parser.on_disconnect()
549
550 conn_sock = self._sock
551 self._sock = None
552 if conn_sock is None:
553 return
554
555 if os.getpid() == self.pid:
556 try:
557 conn_sock.shutdown(socket.SHUT_RDWR)
558 except (OSError, TypeError):
559 pass
560
561 try:
562 conn_sock.close()
563 except OSError:
564 pass
565
566 def _send_ping(self):
567 """Send PING, expect PONG in return"""
568 self.send_command("PING", check_health=False)
569 if str_if_bytes(self.read_response()) != "PONG":
570 raise ConnectionError("Bad response from PING health check")
571
572 def _ping_failed(self, error):
573 """Function to call when PING fails"""
574 self.disconnect()
575
576 def check_health(self):
577 """Check the health of the connection with a PING/PONG"""
578 if self.health_check_interval and time.monotonic() > self.next_health_check:
579 self.retry.call_with_retry(self._send_ping, self._ping_failed)
580
581 def send_packed_command(self, command, check_health=True):
582 """Send an already packed command to the Redis server"""
583 if not self._sock:
584 self.connect_check_health(check_health=False)
585 # guard against health check recursion
586 if check_health:
587 self.check_health()
588 try:
589 if isinstance(command, str):
590 command = [command]
591 for item in command:
592 self._sock.sendall(item)
593 except socket.timeout:
594 self.disconnect()
595 raise TimeoutError("Timeout writing to socket")
596 except OSError as e:
597 self.disconnect()
598 if len(e.args) == 1:
599 errno, errmsg = "UNKNOWN", e.args[0]
600 else:
601 errno = e.args[0]
602 errmsg = e.args[1]
603 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
604 except BaseException:
605 # BaseExceptions can be raised when a socket send operation is not
606 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
607 # to send un-sent data. However, the send_packed_command() API
608 # does not support it so there is no point in keeping the connection open.
609 self.disconnect()
610 raise
611
612 def send_command(self, *args, **kwargs):
613 """Pack and send a command to the Redis server"""
614 self.send_packed_command(
615 self._command_packer.pack(*args),
616 check_health=kwargs.get("check_health", True),
617 )
618
619 def can_read(self, timeout=0):
620 """Poll the socket to see if there's data that can be read."""
621 sock = self._sock
622 if not sock:
623 self.connect()
624
625 host_error = self._host_error()
626
627 try:
628 return self._parser.can_read(timeout)
629 except OSError as e:
630 self.disconnect()
631 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
632
633 def read_response(
634 self,
635 disable_decoding=False,
636 *,
637 disconnect_on_error=True,
638 push_request=False,
639 ):
640 """Read the response from a previously sent command"""
641
642 host_error = self._host_error()
643
644 try:
645 if self.protocol in ["3", 3]:
646 response = self._parser.read_response(
647 disable_decoding=disable_decoding, push_request=push_request
648 )
649 else:
650 response = self._parser.read_response(disable_decoding=disable_decoding)
651 except socket.timeout:
652 if disconnect_on_error:
653 self.disconnect()
654 raise TimeoutError(f"Timeout reading from {host_error}")
655 except OSError as e:
656 if disconnect_on_error:
657 self.disconnect()
658 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
659 except BaseException:
660 # Also by default close in case of BaseException. A lot of code
661 # relies on this behaviour when doing Command/Response pairs.
662 # See #1128.
663 if disconnect_on_error:
664 self.disconnect()
665 raise
666
667 if self.health_check_interval:
668 self.next_health_check = time.monotonic() + self.health_check_interval
669
670 if isinstance(response, ResponseError):
671 try:
672 raise response
673 finally:
674 del response # avoid creating ref cycles
675 return response
676
677 def pack_command(self, *args):
678 """Pack a series of arguments into the Redis protocol"""
679 return self._command_packer.pack(*args)
680
681 def pack_commands(self, commands):
682 """Pack multiple commands into the Redis protocol"""
683 output = []
684 pieces = []
685 buffer_length = 0
686 buffer_cutoff = self._buffer_cutoff
687
688 for cmd in commands:
689 for chunk in self._command_packer.pack(*cmd):
690 chunklen = len(chunk)
691 if (
692 buffer_length > buffer_cutoff
693 or chunklen > buffer_cutoff
694 or isinstance(chunk, memoryview)
695 ):
696 if pieces:
697 output.append(SYM_EMPTY.join(pieces))
698 buffer_length = 0
699 pieces = []
700
701 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
702 output.append(chunk)
703 else:
704 pieces.append(chunk)
705 buffer_length += chunklen
706
707 if pieces:
708 output.append(SYM_EMPTY.join(pieces))
709 return output
710
711 def get_protocol(self) -> Union[int, str]:
712 return self.protocol
713
714 @property
715 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
716 return self._handshake_metadata
717
718 @handshake_metadata.setter
719 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
720 self._handshake_metadata = value
721
722 def set_re_auth_token(self, token: TokenInterface):
723 self._re_auth_token = token
724
725 def re_auth(self):
726 if self._re_auth_token is not None:
727 self.send_command(
728 "AUTH",
729 self._re_auth_token.try_get("oid"),
730 self._re_auth_token.get_value(),
731 )
732 self.read_response()
733 self._re_auth_token = None
734
735
736class Connection(AbstractConnection):
737 "Manages TCP communication to and from a Redis server"
738
739 def __init__(
740 self,
741 host="localhost",
742 port=6379,
743 socket_keepalive=False,
744 socket_keepalive_options=None,
745 socket_type=0,
746 **kwargs,
747 ):
748 self.host = host
749 self.port = int(port)
750 self.socket_keepalive = socket_keepalive
751 self.socket_keepalive_options = socket_keepalive_options or {}
752 self.socket_type = socket_type
753 super().__init__(**kwargs)
754
755 def repr_pieces(self):
756 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
757 if self.client_name:
758 pieces.append(("client_name", self.client_name))
759 return pieces
760
761 def _connect(self):
762 "Create a TCP socket connection"
763 # we want to mimic what socket.create_connection does to support
764 # ipv4/ipv6, but we want to set options prior to calling
765 # socket.connect()
766 err = None
767 for res in socket.getaddrinfo(
768 self.host, self.port, self.socket_type, socket.SOCK_STREAM
769 ):
770 family, socktype, proto, canonname, socket_address = res
771 sock = None
772 try:
773 sock = socket.socket(family, socktype, proto)
774 # TCP_NODELAY
775 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
776
777 # TCP_KEEPALIVE
778 if self.socket_keepalive:
779 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
780 for k, v in self.socket_keepalive_options.items():
781 sock.setsockopt(socket.IPPROTO_TCP, k, v)
782
783 # set the socket_connect_timeout before we connect
784 sock.settimeout(self.socket_connect_timeout)
785
786 # connect
787 sock.connect(socket_address)
788
789 # set the socket_timeout now that we're connected
790 sock.settimeout(self.socket_timeout)
791 return sock
792
793 except OSError as _:
794 err = _
795 if sock is not None:
796 try:
797 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
798 except OSError:
799 pass
800 sock.close()
801
802 if err is not None:
803 raise err
804 raise OSError("socket.getaddrinfo returned an empty list")
805
806 def _host_error(self):
807 return f"{self.host}:{self.port}"
808
809
810class CacheProxyConnection(ConnectionInterface):
811 DUMMY_CACHE_VALUE = b"foo"
812 MIN_ALLOWED_VERSION = "7.4.0"
813 DEFAULT_SERVER_NAME = "redis"
814
815 def __init__(
816 self,
817 conn: ConnectionInterface,
818 cache: CacheInterface,
819 pool_lock: threading.RLock,
820 ):
821 self.pid = os.getpid()
822 self._conn = conn
823 self.retry = self._conn.retry
824 self.host = self._conn.host
825 self.port = self._conn.port
826 self.credential_provider = conn.credential_provider
827 self._pool_lock = pool_lock
828 self._cache = cache
829 self._cache_lock = threading.RLock()
830 self._current_command_cache_key = None
831 self._current_options = None
832 self.register_connect_callback(self._enable_tracking_callback)
833
834 def repr_pieces(self):
835 return self._conn.repr_pieces()
836
837 def register_connect_callback(self, callback):
838 self._conn.register_connect_callback(callback)
839
840 def deregister_connect_callback(self, callback):
841 self._conn.deregister_connect_callback(callback)
842
843 def set_parser(self, parser_class):
844 self._conn.set_parser(parser_class)
845
846 def connect(self):
847 self._conn.connect()
848
849 server_name = self._conn.handshake_metadata.get(b"server", None)
850 if server_name is None:
851 server_name = self._conn.handshake_metadata.get("server", None)
852 server_ver = self._conn.handshake_metadata.get(b"version", None)
853 if server_ver is None:
854 server_ver = self._conn.handshake_metadata.get("version", None)
855 if server_ver is None or server_ver is None:
856 raise ConnectionError("Cannot retrieve information about server version")
857
858 server_ver = ensure_string(server_ver)
859 server_name = ensure_string(server_name)
860
861 if (
862 server_name != self.DEFAULT_SERVER_NAME
863 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
864 ):
865 raise ConnectionError(
866 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
867 )
868
869 def on_connect(self):
870 self._conn.on_connect()
871
872 def disconnect(self, *args):
873 with self._cache_lock:
874 self._cache.flush()
875 self._conn.disconnect(*args)
876
877 def check_health(self):
878 self._conn.check_health()
879
880 def send_packed_command(self, command, check_health=True):
881 # TODO: Investigate if it's possible to unpack command
882 # or extract keys from packed command
883 self._conn.send_packed_command(command)
884
885 def send_command(self, *args, **kwargs):
886 self._process_pending_invalidations()
887
888 with self._cache_lock:
889 # Command is write command or not allowed
890 # to be cached.
891 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())):
892 self._current_command_cache_key = None
893 self._conn.send_command(*args, **kwargs)
894 return
895
896 if kwargs.get("keys") is None:
897 raise ValueError("Cannot create cache key.")
898
899 # Creates cache key.
900 self._current_command_cache_key = CacheKey(
901 command=args[0], redis_keys=tuple(kwargs.get("keys"))
902 )
903
904 with self._cache_lock:
905 # We have to trigger invalidation processing in case if
906 # it was cached by another connection to avoid
907 # queueing invalidations in stale connections.
908 if self._cache.get(self._current_command_cache_key):
909 entry = self._cache.get(self._current_command_cache_key)
910
911 if entry.connection_ref != self._conn:
912 with self._pool_lock:
913 while entry.connection_ref.can_read():
914 entry.connection_ref.read_response(push_request=True)
915
916 return
917
918 # Set temporary entry value to prevent
919 # race condition from another connection.
920 self._cache.set(
921 CacheEntry(
922 cache_key=self._current_command_cache_key,
923 cache_value=self.DUMMY_CACHE_VALUE,
924 status=CacheEntryStatus.IN_PROGRESS,
925 connection_ref=self._conn,
926 )
927 )
928
929 # Send command over socket only if it's allowed
930 # read-only command that not yet cached.
931 self._conn.send_command(*args, **kwargs)
932
933 def can_read(self, timeout=0):
934 return self._conn.can_read(timeout)
935
936 def read_response(
937 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
938 ):
939 with self._cache_lock:
940 # Check if command response exists in a cache and it's not in progress.
941 if (
942 self._current_command_cache_key is not None
943 and self._cache.get(self._current_command_cache_key) is not None
944 and self._cache.get(self._current_command_cache_key).status
945 != CacheEntryStatus.IN_PROGRESS
946 ):
947 res = copy.deepcopy(
948 self._cache.get(self._current_command_cache_key).cache_value
949 )
950 self._current_command_cache_key = None
951 return res
952
953 response = self._conn.read_response(
954 disable_decoding=disable_decoding,
955 disconnect_on_error=disconnect_on_error,
956 push_request=push_request,
957 )
958
959 with self._cache_lock:
960 # Prevent not-allowed command from caching.
961 if self._current_command_cache_key is None:
962 return response
963 # If response is None prevent from caching.
964 if response is None:
965 self._cache.delete_by_cache_keys([self._current_command_cache_key])
966 return response
967
968 cache_entry = self._cache.get(self._current_command_cache_key)
969
970 # Cache only responses that still valid
971 # and wasn't invalidated by another connection in meantime.
972 if cache_entry is not None:
973 cache_entry.status = CacheEntryStatus.VALID
974 cache_entry.cache_value = response
975 self._cache.set(cache_entry)
976
977 self._current_command_cache_key = None
978
979 return response
980
981 def pack_command(self, *args):
982 return self._conn.pack_command(*args)
983
984 def pack_commands(self, commands):
985 return self._conn.pack_commands(commands)
986
987 @property
988 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
989 return self._conn.handshake_metadata
990
991 def _connect(self):
992 self._conn._connect()
993
994 def _host_error(self):
995 self._conn._host_error()
996
997 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
998 conn.send_command("CLIENT", "TRACKING", "ON")
999 conn.read_response()
1000 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
1001
1002 def _process_pending_invalidations(self):
1003 while self.can_read():
1004 self._conn.read_response(push_request=True)
1005
1006 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1007 with self._cache_lock:
1008 # Flush cache when DB flushed on server-side
1009 if data[1] is None:
1010 self._cache.flush()
1011 else:
1012 self._cache.delete_by_redis_keys(data[1])
1013
1014 def get_protocol(self):
1015 return self._conn.get_protocol()
1016
1017 def set_re_auth_token(self, token: TokenInterface):
1018 self._conn.set_re_auth_token(token)
1019
1020 def re_auth(self):
1021 self._conn.re_auth()
1022
1023
1024class SSLConnection(Connection):
1025 """Manages SSL connections to and from the Redis server(s).
1026 This class extends the Connection class, adding SSL functionality, and making
1027 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1028 """ # noqa
1029
1030 def __init__(
1031 self,
1032 ssl_keyfile=None,
1033 ssl_certfile=None,
1034 ssl_cert_reqs="required",
1035 ssl_ca_certs=None,
1036 ssl_ca_data=None,
1037 ssl_check_hostname=True,
1038 ssl_ca_path=None,
1039 ssl_password=None,
1040 ssl_validate_ocsp=False,
1041 ssl_validate_ocsp_stapled=False,
1042 ssl_ocsp_context=None,
1043 ssl_ocsp_expected_cert=None,
1044 ssl_min_version=None,
1045 ssl_ciphers=None,
1046 **kwargs,
1047 ):
1048 """Constructor
1049
1050 Args:
1051 ssl_keyfile: Path to an ssl private key. Defaults to None.
1052 ssl_certfile: Path to an ssl certificate. Defaults to None.
1053 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1054 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1055 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1056 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1057 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1058 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1059
1060 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1061 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1062 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1063 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1064 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1065 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1066
1067 Raises:
1068 RedisError
1069 """ # noqa
1070 if not SSL_AVAILABLE:
1071 raise RedisError("Python wasn't built with SSL support")
1072
1073 self.keyfile = ssl_keyfile
1074 self.certfile = ssl_certfile
1075 if ssl_cert_reqs is None:
1076 ssl_cert_reqs = ssl.CERT_NONE
1077 elif isinstance(ssl_cert_reqs, str):
1078 CERT_REQS = { # noqa: N806
1079 "none": ssl.CERT_NONE,
1080 "optional": ssl.CERT_OPTIONAL,
1081 "required": ssl.CERT_REQUIRED,
1082 }
1083 if ssl_cert_reqs not in CERT_REQS:
1084 raise RedisError(
1085 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1086 )
1087 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1088 self.cert_reqs = ssl_cert_reqs
1089 self.ca_certs = ssl_ca_certs
1090 self.ca_data = ssl_ca_data
1091 self.ca_path = ssl_ca_path
1092 self.check_hostname = (
1093 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1094 )
1095 self.certificate_password = ssl_password
1096 self.ssl_validate_ocsp = ssl_validate_ocsp
1097 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1098 self.ssl_ocsp_context = ssl_ocsp_context
1099 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1100 self.ssl_min_version = ssl_min_version
1101 self.ssl_ciphers = ssl_ciphers
1102 super().__init__(**kwargs)
1103
1104 def _connect(self):
1105 """
1106 Wrap the socket with SSL support, handling potential errors.
1107 """
1108 sock = super()._connect()
1109 try:
1110 return self._wrap_socket_with_ssl(sock)
1111 except (OSError, RedisError):
1112 sock.close()
1113 raise
1114
1115 def _wrap_socket_with_ssl(self, sock):
1116 """
1117 Wraps the socket with SSL support.
1118
1119 Args:
1120 sock: The plain socket to wrap with SSL.
1121
1122 Returns:
1123 An SSL wrapped socket.
1124 """
1125 context = ssl.create_default_context()
1126 context.check_hostname = self.check_hostname
1127 context.verify_mode = self.cert_reqs
1128 if self.certfile or self.keyfile:
1129 context.load_cert_chain(
1130 certfile=self.certfile,
1131 keyfile=self.keyfile,
1132 password=self.certificate_password,
1133 )
1134 if (
1135 self.ca_certs is not None
1136 or self.ca_path is not None
1137 or self.ca_data is not None
1138 ):
1139 context.load_verify_locations(
1140 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1141 )
1142 if self.ssl_min_version is not None:
1143 context.minimum_version = self.ssl_min_version
1144 if self.ssl_ciphers:
1145 context.set_ciphers(self.ssl_ciphers)
1146 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1147 raise RedisError("cryptography is not installed.")
1148
1149 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1150 raise RedisError(
1151 "Either an OCSP staple or pure OCSP connection must be validated "
1152 "- not both."
1153 )
1154
1155 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1156
1157 # validation for the stapled case
1158 if self.ssl_validate_ocsp_stapled:
1159 import OpenSSL
1160
1161 from .ocsp import ocsp_staple_verifier
1162
1163 # if a context is provided use it - otherwise, a basic context
1164 if self.ssl_ocsp_context is None:
1165 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1166 staple_ctx.use_certificate_file(self.certfile)
1167 staple_ctx.use_privatekey_file(self.keyfile)
1168 else:
1169 staple_ctx = self.ssl_ocsp_context
1170
1171 staple_ctx.set_ocsp_client_callback(
1172 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1173 )
1174
1175 # need another socket
1176 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1177 con.request_ocsp()
1178 con.connect((self.host, self.port))
1179 con.do_handshake()
1180 con.shutdown()
1181 return sslsock
1182
1183 # pure ocsp validation
1184 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1185 from .ocsp import OCSPVerifier
1186
1187 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1188 if o.is_valid():
1189 return sslsock
1190 else:
1191 raise ConnectionError("ocsp validation error")
1192 return sslsock
1193
1194
1195class UnixDomainSocketConnection(AbstractConnection):
1196 "Manages UDS communication to and from a Redis server"
1197
1198 def __init__(self, path="", socket_timeout=None, **kwargs):
1199 super().__init__(**kwargs)
1200 self.path = path
1201 self.socket_timeout = socket_timeout
1202
1203 def repr_pieces(self):
1204 pieces = [("path", self.path), ("db", self.db)]
1205 if self.client_name:
1206 pieces.append(("client_name", self.client_name))
1207 return pieces
1208
1209 def _connect(self):
1210 "Create a Unix domain socket connection"
1211 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1212 sock.settimeout(self.socket_connect_timeout)
1213 try:
1214 sock.connect(self.path)
1215 except OSError:
1216 # Prevent ResourceWarnings for unclosed sockets.
1217 try:
1218 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1219 except OSError:
1220 pass
1221 sock.close()
1222 raise
1223 sock.settimeout(self.socket_timeout)
1224 return sock
1225
1226 def _host_error(self):
1227 return self.path
1228
1229
1230FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1231
1232
1233def to_bool(value):
1234 if value is None or value == "":
1235 return None
1236 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1237 return False
1238 return bool(value)
1239
1240
1241URL_QUERY_ARGUMENT_PARSERS = {
1242 "db": int,
1243 "socket_timeout": float,
1244 "socket_connect_timeout": float,
1245 "socket_keepalive": to_bool,
1246 "retry_on_timeout": to_bool,
1247 "retry_on_error": list,
1248 "max_connections": int,
1249 "health_check_interval": int,
1250 "ssl_check_hostname": to_bool,
1251 "timeout": float,
1252}
1253
1254
1255def parse_url(url):
1256 if not (
1257 url.startswith("redis://")
1258 or url.startswith("rediss://")
1259 or url.startswith("unix://")
1260 ):
1261 raise ValueError(
1262 "Redis URL must specify one of the following "
1263 "schemes (redis://, rediss://, unix://)"
1264 )
1265
1266 url = urlparse(url)
1267 kwargs = {}
1268
1269 for name, value in parse_qs(url.query).items():
1270 if value and len(value) > 0:
1271 value = unquote(value[0])
1272 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1273 if parser:
1274 try:
1275 kwargs[name] = parser(value)
1276 except (TypeError, ValueError):
1277 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1278 else:
1279 kwargs[name] = value
1280
1281 if url.username:
1282 kwargs["username"] = unquote(url.username)
1283 if url.password:
1284 kwargs["password"] = unquote(url.password)
1285
1286 # We only support redis://, rediss:// and unix:// schemes.
1287 if url.scheme == "unix":
1288 if url.path:
1289 kwargs["path"] = unquote(url.path)
1290 kwargs["connection_class"] = UnixDomainSocketConnection
1291
1292 else: # implied: url.scheme in ("redis", "rediss"):
1293 if url.hostname:
1294 kwargs["host"] = unquote(url.hostname)
1295 if url.port:
1296 kwargs["port"] = int(url.port)
1297
1298 # If there's a path argument, use it as the db argument if a
1299 # querystring value wasn't specified
1300 if url.path and "db" not in kwargs:
1301 try:
1302 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1303 except (AttributeError, ValueError):
1304 pass
1305
1306 if url.scheme == "rediss":
1307 kwargs["connection_class"] = SSLConnection
1308
1309 return kwargs
1310
1311
1312_CP = TypeVar("_CP", bound="ConnectionPool")
1313
1314
1315class ConnectionPool:
1316 """
1317 Create a connection pool. ``If max_connections`` is set, then this
1318 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
1319 limit is reached.
1320
1321 By default, TCP connections are created unless ``connection_class``
1322 is specified. Use class:`.UnixDomainSocketConnection` for
1323 unix sockets.
1324 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1325
1326 Any additional keyword arguments are passed to the constructor of
1327 ``connection_class``.
1328 """
1329
1330 @classmethod
1331 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1332 """
1333 Return a connection pool configured from the given URL.
1334
1335 For example::
1336
1337 redis://[[username]:[password]]@localhost:6379/0
1338 rediss://[[username]:[password]]@localhost:6379/0
1339 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1340
1341 Three URL schemes are supported:
1342
1343 - `redis://` creates a TCP socket connection. See more at:
1344 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1345 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1346 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1347 - ``unix://``: creates a Unix Domain Socket connection.
1348
1349 The username, password, hostname, path and all querystring values
1350 are passed through urllib.parse.unquote in order to replace any
1351 percent-encoded values with their corresponding characters.
1352
1353 There are several ways to specify a database number. The first value
1354 found will be used:
1355
1356 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1357 2. If using the redis:// or rediss:// schemes, the path argument
1358 of the url, e.g. redis://localhost/0
1359 3. A ``db`` keyword argument to this function.
1360
1361 If none of these options are specified, the default db=0 is used.
1362
1363 All querystring options are cast to their appropriate Python types.
1364 Boolean arguments can be specified with string values "True"/"False"
1365 or "Yes"/"No". Values that cannot be properly cast cause a
1366 ``ValueError`` to be raised. Once parsed, the querystring arguments
1367 and keyword arguments are passed to the ``ConnectionPool``'s
1368 class initializer. In the case of conflicting arguments, querystring
1369 arguments always win.
1370 """
1371 url_options = parse_url(url)
1372
1373 if "connection_class" in kwargs:
1374 url_options["connection_class"] = kwargs["connection_class"]
1375
1376 kwargs.update(url_options)
1377 return cls(**kwargs)
1378
1379 def __init__(
1380 self,
1381 connection_class=Connection,
1382 max_connections: Optional[int] = None,
1383 cache_factory: Optional[CacheFactoryInterface] = None,
1384 **connection_kwargs,
1385 ):
1386 max_connections = max_connections or 2**31
1387 if not isinstance(max_connections, int) or max_connections < 0:
1388 raise ValueError('"max_connections" must be a positive integer')
1389
1390 self.connection_class = connection_class
1391 self.connection_kwargs = connection_kwargs
1392 self.max_connections = max_connections
1393 self.cache = None
1394 self._cache_factory = cache_factory
1395
1396 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1397 if connection_kwargs.get("protocol") not in [3, "3"]:
1398 raise RedisError("Client caching is only supported with RESP version 3")
1399
1400 cache = self.connection_kwargs.get("cache")
1401
1402 if cache is not None:
1403 if not isinstance(cache, CacheInterface):
1404 raise ValueError("Cache must implement CacheInterface")
1405
1406 self.cache = cache
1407 else:
1408 if self._cache_factory is not None:
1409 self.cache = self._cache_factory.get_cache()
1410 else:
1411 self.cache = CacheFactory(
1412 self.connection_kwargs.get("cache_config")
1413 ).get_cache()
1414
1415 connection_kwargs.pop("cache", None)
1416 connection_kwargs.pop("cache_config", None)
1417
1418 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1419 if self._event_dispatcher is None:
1420 self._event_dispatcher = EventDispatcher()
1421
1422 # a lock to protect the critical section in _checkpid().
1423 # this lock is acquired when the process id changes, such as
1424 # after a fork. during this time, multiple threads in the child
1425 # process could attempt to acquire this lock. the first thread
1426 # to acquire the lock will reset the data structures and lock
1427 # object of this pool. subsequent threads acquiring this lock
1428 # will notice the first thread already did the work and simply
1429 # release the lock.
1430
1431 self._fork_lock = threading.RLock()
1432 self._lock = threading.RLock()
1433
1434 self.reset()
1435
1436 def __repr__(self) -> str:
1437 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1438 return (
1439 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1440 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1441 f"({conn_kwargs})>)>"
1442 )
1443
1444 def get_protocol(self):
1445 """
1446 Returns:
1447 The RESP protocol version, or ``None`` if the protocol is not specified,
1448 in which case the server default will be used.
1449 """
1450 return self.connection_kwargs.get("protocol", None)
1451
1452 def reset(self) -> None:
1453 self._created_connections = 0
1454 self._available_connections = []
1455 self._in_use_connections = set()
1456
1457 # this must be the last operation in this method. while reset() is
1458 # called when holding _fork_lock, other threads in this process
1459 # can call _checkpid() which compares self.pid and os.getpid() without
1460 # holding any lock (for performance reasons). keeping this assignment
1461 # as the last operation ensures that those other threads will also
1462 # notice a pid difference and block waiting for the first thread to
1463 # release _fork_lock. when each of these threads eventually acquire
1464 # _fork_lock, they will notice that another thread already called
1465 # reset() and they will immediately release _fork_lock and continue on.
1466 self.pid = os.getpid()
1467
1468 def _checkpid(self) -> None:
1469 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
1470 # systems. this is called by all ConnectionPool methods that
1471 # manipulate the pool's state such as get_connection() and release().
1472 #
1473 # _checkpid() determines whether the process has forked by comparing
1474 # the current process id to the process id saved on the ConnectionPool
1475 # instance. if these values are the same, _checkpid() simply returns.
1476 #
1477 # when the process ids differ, _checkpid() assumes that the process
1478 # has forked and that we're now running in the child process. the child
1479 # process cannot use the parent's file descriptors (e.g., sockets).
1480 # therefore, when _checkpid() sees the process id change, it calls
1481 # reset() in order to reinitialize the child's ConnectionPool. this
1482 # will cause the child to make all new connection objects.
1483 #
1484 # _checkpid() is protected by self._fork_lock to ensure that multiple
1485 # threads in the child process do not call reset() multiple times.
1486 #
1487 # there is an extremely small chance this could fail in the following
1488 # scenario:
1489 # 1. process A calls _checkpid() for the first time and acquires
1490 # self._fork_lock.
1491 # 2. while holding self._fork_lock, process A forks (the fork()
1492 # could happen in a different thread owned by process A)
1493 # 3. process B (the forked child process) inherits the
1494 # ConnectionPool's state from the parent. that state includes
1495 # a locked _fork_lock. process B will not be notified when
1496 # process A releases the _fork_lock and will thus never be
1497 # able to acquire the _fork_lock.
1498 #
1499 # to mitigate this possible deadlock, _checkpid() will only wait 5
1500 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1501 # that time it is assumed that the child is deadlocked and a
1502 # redis.ChildDeadlockedError error is raised.
1503 if self.pid != os.getpid():
1504 acquired = self._fork_lock.acquire(timeout=5)
1505 if not acquired:
1506 raise ChildDeadlockedError
1507 # reset() the instance for the new process if another thread
1508 # hasn't already done so
1509 try:
1510 if self.pid != os.getpid():
1511 self.reset()
1512 finally:
1513 self._fork_lock.release()
1514
1515 @deprecated_args(
1516 args_to_warn=["*"],
1517 reason="Use get_connection() without args instead",
1518 version="5.3.0",
1519 )
1520 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
1521 "Get a connection from the pool"
1522
1523 self._checkpid()
1524 with self._lock:
1525 try:
1526 connection = self._available_connections.pop()
1527 except IndexError:
1528 connection = self.make_connection()
1529 self._in_use_connections.add(connection)
1530
1531 try:
1532 # ensure this connection is connected to Redis
1533 connection.connect()
1534 # connections that the pool provides should be ready to send
1535 # a command. if not, the connection was either returned to the
1536 # pool before all data has been read or the socket has been
1537 # closed. either way, reconnect and verify everything is good.
1538 try:
1539 if connection.can_read() and self.cache is None:
1540 raise ConnectionError("Connection has data")
1541 except (ConnectionError, TimeoutError, OSError):
1542 connection.disconnect()
1543 connection.connect()
1544 if connection.can_read():
1545 raise ConnectionError("Connection not ready")
1546 except BaseException:
1547 # release the connection back to the pool so that we don't
1548 # leak it
1549 self.release(connection)
1550 raise
1551
1552 return connection
1553
1554 def get_encoder(self) -> Encoder:
1555 "Return an encoder based on encoding settings"
1556 kwargs = self.connection_kwargs
1557 return Encoder(
1558 encoding=kwargs.get("encoding", "utf-8"),
1559 encoding_errors=kwargs.get("encoding_errors", "strict"),
1560 decode_responses=kwargs.get("decode_responses", False),
1561 )
1562
1563 def make_connection(self) -> "ConnectionInterface":
1564 "Create a new connection"
1565 if self._created_connections >= self.max_connections:
1566 raise MaxConnectionsError("Too many connections")
1567 self._created_connections += 1
1568
1569 if self.cache is not None:
1570 return CacheProxyConnection(
1571 self.connection_class(**self.connection_kwargs), self.cache, self._lock
1572 )
1573
1574 return self.connection_class(**self.connection_kwargs)
1575
1576 def release(self, connection: "Connection") -> None:
1577 "Releases the connection back to the pool"
1578 self._checkpid()
1579 with self._lock:
1580 try:
1581 self._in_use_connections.remove(connection)
1582 except KeyError:
1583 # Gracefully fail when a connection is returned to this pool
1584 # that the pool doesn't actually own
1585 return
1586
1587 if self.owns_connection(connection):
1588 self._available_connections.append(connection)
1589 self._event_dispatcher.dispatch(
1590 AfterConnectionReleasedEvent(connection)
1591 )
1592 else:
1593 # Pool doesn't own this connection, do not add it back
1594 # to the pool.
1595 # The created connections count should not be changed,
1596 # because the connection was not created by the pool.
1597 connection.disconnect()
1598 return
1599
1600 def owns_connection(self, connection: "Connection") -> int:
1601 return connection.pid == self.pid
1602
1603 def disconnect(self, inuse_connections: bool = True) -> None:
1604 """
1605 Disconnects connections in the pool
1606
1607 If ``inuse_connections`` is True, disconnect connections that are
1608 current in use, potentially by other threads. Otherwise only disconnect
1609 connections that are idle in the pool.
1610 """
1611 self._checkpid()
1612 with self._lock:
1613 if inuse_connections:
1614 connections = chain(
1615 self._available_connections, self._in_use_connections
1616 )
1617 else:
1618 connections = self._available_connections
1619
1620 for connection in connections:
1621 connection.disconnect()
1622
1623 def close(self) -> None:
1624 """Close the pool, disconnecting all connections"""
1625 self.disconnect()
1626
1627 def set_retry(self, retry: Retry) -> None:
1628 self.connection_kwargs.update({"retry": retry})
1629 for conn in self._available_connections:
1630 conn.retry = retry
1631 for conn in self._in_use_connections:
1632 conn.retry = retry
1633
1634 def re_auth_callback(self, token: TokenInterface):
1635 with self._lock:
1636 for conn in self._available_connections:
1637 conn.retry.call_with_retry(
1638 lambda: conn.send_command(
1639 "AUTH", token.try_get("oid"), token.get_value()
1640 ),
1641 lambda error: self._mock(error),
1642 )
1643 conn.retry.call_with_retry(
1644 lambda: conn.read_response(), lambda error: self._mock(error)
1645 )
1646 for conn in self._in_use_connections:
1647 conn.set_re_auth_token(token)
1648
1649 async def _mock(self, error: RedisError):
1650 """
1651 Dummy functions, needs to be passed as error callback to retry object.
1652 :param error:
1653 :return:
1654 """
1655 pass
1656
1657
1658class BlockingConnectionPool(ConnectionPool):
1659 """
1660 Thread-safe blocking connection pool::
1661
1662 >>> from redis.client import Redis
1663 >>> client = Redis(connection_pool=BlockingConnectionPool())
1664
1665 It performs the same function as the default
1666 :py:class:`~redis.ConnectionPool` implementation, in that,
1667 it maintains a pool of reusable connections that can be shared by
1668 multiple redis clients (safely across threads if required).
1669
1670 The difference is that, in the event that a client tries to get a
1671 connection from the pool when all of connections are in use, rather than
1672 raising a :py:class:`~redis.ConnectionError` (as the default
1673 :py:class:`~redis.ConnectionPool` implementation does), it
1674 makes the client wait ("blocks") for a specified number of seconds until
1675 a connection becomes available.
1676
1677 Use ``max_connections`` to increase / decrease the pool size::
1678
1679 >>> pool = BlockingConnectionPool(max_connections=10)
1680
1681 Use ``timeout`` to tell it either how many seconds to wait for a connection
1682 to become available, or to block forever:
1683
1684 >>> # Block forever.
1685 >>> pool = BlockingConnectionPool(timeout=None)
1686
1687 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1688 >>> # not available.
1689 >>> pool = BlockingConnectionPool(timeout=5)
1690 """
1691
1692 def __init__(
1693 self,
1694 max_connections=50,
1695 timeout=20,
1696 connection_class=Connection,
1697 queue_class=LifoQueue,
1698 **connection_kwargs,
1699 ):
1700 self.queue_class = queue_class
1701 self.timeout = timeout
1702 super().__init__(
1703 connection_class=connection_class,
1704 max_connections=max_connections,
1705 **connection_kwargs,
1706 )
1707
1708 def reset(self):
1709 # Create and fill up a thread safe queue with ``None`` values.
1710 self.pool = self.queue_class(self.max_connections)
1711 while True:
1712 try:
1713 self.pool.put_nowait(None)
1714 except Full:
1715 break
1716
1717 # Keep a list of actual connection instances so that we can
1718 # disconnect them later.
1719 self._connections = []
1720
1721 # this must be the last operation in this method. while reset() is
1722 # called when holding _fork_lock, other threads in this process
1723 # can call _checkpid() which compares self.pid and os.getpid() without
1724 # holding any lock (for performance reasons). keeping this assignment
1725 # as the last operation ensures that those other threads will also
1726 # notice a pid difference and block waiting for the first thread to
1727 # release _fork_lock. when each of these threads eventually acquire
1728 # _fork_lock, they will notice that another thread already called
1729 # reset() and they will immediately release _fork_lock and continue on.
1730 self.pid = os.getpid()
1731
1732 def make_connection(self):
1733 "Make a fresh connection."
1734 if self.cache is not None:
1735 connection = CacheProxyConnection(
1736 self.connection_class(**self.connection_kwargs), self.cache, self._lock
1737 )
1738 else:
1739 connection = self.connection_class(**self.connection_kwargs)
1740 self._connections.append(connection)
1741 return connection
1742
1743 @deprecated_args(
1744 args_to_warn=["*"],
1745 reason="Use get_connection() without args instead",
1746 version="5.3.0",
1747 )
1748 def get_connection(self, command_name=None, *keys, **options):
1749 """
1750 Get a connection, blocking for ``self.timeout`` until a connection
1751 is available from the pool.
1752
1753 If the connection returned is ``None`` then creates a new connection.
1754 Because we use a last-in first-out queue, the existing connections
1755 (having been returned to the pool after the initial ``None`` values
1756 were added) will be returned before ``None`` values. This means we only
1757 create new connections when we need to, i.e.: the actual number of
1758 connections will only increase in response to demand.
1759 """
1760 # Make sure we haven't changed process.
1761 self._checkpid()
1762
1763 # Try and get a connection from the pool. If one isn't available within
1764 # self.timeout then raise a ``ConnectionError``.
1765 connection = None
1766 try:
1767 connection = self.pool.get(block=True, timeout=self.timeout)
1768 except Empty:
1769 # Note that this is not caught by the redis client and will be
1770 # raised unless handled by application code. If you want never to
1771 raise ConnectionError("No connection available.")
1772
1773 # If the ``connection`` is actually ``None`` then that's a cue to make
1774 # a new connection to add to the pool.
1775 if connection is None:
1776 connection = self.make_connection()
1777
1778 try:
1779 # ensure this connection is connected to Redis
1780 connection.connect()
1781 # connections that the pool provides should be ready to send
1782 # a command. if not, the connection was either returned to the
1783 # pool before all data has been read or the socket has been
1784 # closed. either way, reconnect and verify everything is good.
1785 try:
1786 if connection.can_read():
1787 raise ConnectionError("Connection has data")
1788 except (ConnectionError, TimeoutError, OSError):
1789 connection.disconnect()
1790 connection.connect()
1791 if connection.can_read():
1792 raise ConnectionError("Connection not ready")
1793 except BaseException:
1794 # release the connection back to the pool so that we don't leak it
1795 self.release(connection)
1796 raise
1797
1798 return connection
1799
1800 def release(self, connection):
1801 "Releases the connection back to the pool."
1802 # Make sure we haven't changed process.
1803 self._checkpid()
1804 if not self.owns_connection(connection):
1805 # pool doesn't own this connection. do not add it back
1806 # to the pool. instead add a None value which is a placeholder
1807 # that will cause the pool to recreate the connection if
1808 # its needed.
1809 connection.disconnect()
1810 self.pool.put_nowait(None)
1811 return
1812
1813 # Put the connection back into the pool.
1814 try:
1815 self.pool.put_nowait(connection)
1816 except Full:
1817 # perhaps the pool has been reset() after a fork? regardless,
1818 # we don't want this connection
1819 pass
1820
1821 def disconnect(self):
1822 "Disconnects all connections in the pool."
1823 self._checkpid()
1824 for connection in self._connections:
1825 connection.disconnect()