1import copy
2import os
3import socket
4import sys
5import threading
6import time
7import weakref
8from abc import abstractmethod
9from itertools import chain
10from queue import Empty, Full, LifoQueue
11from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
12from urllib.parse import parse_qs, unquote, urlparse
13
14from redis.cache import (
15 CacheEntry,
16 CacheEntryStatus,
17 CacheFactory,
18 CacheFactoryInterface,
19 CacheInterface,
20 CacheKey,
21)
22
23from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
24from .auth.token import TokenInterface
25from .backoff import NoBackoff
26from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
27from .event import AfterConnectionReleasedEvent, EventDispatcher
28from .exceptions import (
29 AuthenticationError,
30 AuthenticationWrongNumberOfArgsError,
31 ChildDeadlockedError,
32 ConnectionError,
33 DataError,
34 RedisError,
35 ResponseError,
36 TimeoutError,
37)
38from .retry import Retry
39from .utils import (
40 CRYPTOGRAPHY_AVAILABLE,
41 HIREDIS_AVAILABLE,
42 SSL_AVAILABLE,
43 compare_versions,
44 deprecated_args,
45 ensure_string,
46 format_error_message,
47 get_lib_version,
48 str_if_bytes,
49)
50
51if SSL_AVAILABLE:
52 import ssl
53else:
54 ssl = None
55
56if HIREDIS_AVAILABLE:
57 import hiredis
58
59SYM_STAR = b"*"
60SYM_DOLLAR = b"$"
61SYM_CRLF = b"\r\n"
62SYM_EMPTY = b""
63
64DEFAULT_RESP_VERSION = 2
65
66SENTINEL = object()
67
68DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]]
69if HIREDIS_AVAILABLE:
70 DefaultParser = _HiredisParser
71else:
72 DefaultParser = _RESP2Parser
73
74
75class HiredisRespSerializer:
76 def pack(self, *args: List):
77 """Pack a series of arguments into the Redis protocol"""
78 output = []
79
80 if isinstance(args[0], str):
81 args = tuple(args[0].encode().split()) + args[1:]
82 elif b" " in args[0]:
83 args = tuple(args[0].split()) + args[1:]
84 try:
85 output.append(hiredis.pack_command(args))
86 except TypeError:
87 _, value, traceback = sys.exc_info()
88 raise DataError(value).with_traceback(traceback)
89
90 return output
91
92
93class PythonRespSerializer:
94 def __init__(self, buffer_cutoff, encode) -> None:
95 self._buffer_cutoff = buffer_cutoff
96 self.encode = encode
97
98 def pack(self, *args):
99 """Pack a series of arguments into the Redis protocol"""
100 output = []
101 # the client might have included 1 or more literal arguments in
102 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
103 # arguments to be sent separately, so split the first argument
104 # manually. These arguments should be bytestrings so that they are
105 # not encoded.
106 if isinstance(args[0], str):
107 args = tuple(args[0].encode().split()) + args[1:]
108 elif b" " in args[0]:
109 args = tuple(args[0].split()) + args[1:]
110
111 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
112
113 buffer_cutoff = self._buffer_cutoff
114 for arg in map(self.encode, args):
115 # to avoid large string mallocs, chunk the command into the
116 # output list if we're sending large values or memoryviews
117 arg_length = len(arg)
118 if (
119 len(buff) > buffer_cutoff
120 or arg_length > buffer_cutoff
121 or isinstance(arg, memoryview)
122 ):
123 buff = SYM_EMPTY.join(
124 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
125 )
126 output.append(buff)
127 output.append(arg)
128 buff = SYM_CRLF
129 else:
130 buff = SYM_EMPTY.join(
131 (
132 buff,
133 SYM_DOLLAR,
134 str(arg_length).encode(),
135 SYM_CRLF,
136 arg,
137 SYM_CRLF,
138 )
139 )
140 output.append(buff)
141 return output
142
143
144class ConnectionInterface:
145 @abstractmethod
146 def repr_pieces(self):
147 pass
148
149 @abstractmethod
150 def register_connect_callback(self, callback):
151 pass
152
153 @abstractmethod
154 def deregister_connect_callback(self, callback):
155 pass
156
157 @abstractmethod
158 def set_parser(self, parser_class):
159 pass
160
161 @abstractmethod
162 def get_protocol(self):
163 pass
164
165 @abstractmethod
166 def connect(self):
167 pass
168
169 @abstractmethod
170 def on_connect(self):
171 pass
172
173 @abstractmethod
174 def disconnect(self, *args):
175 pass
176
177 @abstractmethod
178 def check_health(self):
179 pass
180
181 @abstractmethod
182 def send_packed_command(self, command, check_health=True):
183 pass
184
185 @abstractmethod
186 def send_command(self, *args, **kwargs):
187 pass
188
189 @abstractmethod
190 def can_read(self, timeout=0):
191 pass
192
193 @abstractmethod
194 def read_response(
195 self,
196 disable_decoding=False,
197 *,
198 disconnect_on_error=True,
199 push_request=False,
200 ):
201 pass
202
203 @abstractmethod
204 def pack_command(self, *args):
205 pass
206
207 @abstractmethod
208 def pack_commands(self, commands):
209 pass
210
211 @property
212 @abstractmethod
213 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
214 pass
215
216 @abstractmethod
217 def set_re_auth_token(self, token: TokenInterface):
218 pass
219
220 @abstractmethod
221 def re_auth(self):
222 pass
223
224
225class AbstractConnection(ConnectionInterface):
226 "Manages communication to and from a Redis server"
227
228 def __init__(
229 self,
230 db: int = 0,
231 password: Optional[str] = None,
232 socket_timeout: Optional[float] = None,
233 socket_connect_timeout: Optional[float] = None,
234 retry_on_timeout: bool = False,
235 retry_on_error=SENTINEL,
236 encoding: str = "utf-8",
237 encoding_errors: str = "strict",
238 decode_responses: bool = False,
239 parser_class=DefaultParser,
240 socket_read_size: int = 65536,
241 health_check_interval: int = 0,
242 client_name: Optional[str] = None,
243 lib_name: Optional[str] = "redis-py",
244 lib_version: Optional[str] = get_lib_version(),
245 username: Optional[str] = None,
246 retry: Union[Any, None] = None,
247 redis_connect_func: Optional[Callable[[], None]] = None,
248 credential_provider: Optional[CredentialProvider] = None,
249 protocol: Optional[int] = 2,
250 command_packer: Optional[Callable[[], None]] = None,
251 event_dispatcher: Optional[EventDispatcher] = None,
252 ):
253 """
254 Initialize a new Connection.
255 To specify a retry policy for specific errors, first set
256 `retry_on_error` to a list of the error/s to retry on, then set
257 `retry` to a valid `Retry` object.
258 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
259 """
260 if (username or password) and credential_provider is not None:
261 raise DataError(
262 "'username' and 'password' cannot be passed along with 'credential_"
263 "provider'. Please provide only one of the following arguments: \n"
264 "1. 'password' and (optional) 'username'\n"
265 "2. 'credential_provider'"
266 )
267 if event_dispatcher is None:
268 self._event_dispatcher = EventDispatcher()
269 else:
270 self._event_dispatcher = event_dispatcher
271 self.pid = os.getpid()
272 self.db = db
273 self.client_name = client_name
274 self.lib_name = lib_name
275 self.lib_version = lib_version
276 self.credential_provider = credential_provider
277 self.password = password
278 self.username = username
279 self.socket_timeout = socket_timeout
280 if socket_connect_timeout is None:
281 socket_connect_timeout = socket_timeout
282 self.socket_connect_timeout = socket_connect_timeout
283 self.retry_on_timeout = retry_on_timeout
284 if retry_on_error is SENTINEL:
285 retry_on_error = []
286 if retry_on_timeout:
287 # Add TimeoutError to the errors list to retry on
288 retry_on_error.append(TimeoutError)
289 self.retry_on_error = retry_on_error
290 if retry or retry_on_error:
291 if retry is None:
292 self.retry = Retry(NoBackoff(), 1)
293 else:
294 # deep-copy the Retry object as it is mutable
295 self.retry = copy.deepcopy(retry)
296 # Update the retry's supported errors with the specified errors
297 self.retry.update_supported_errors(retry_on_error)
298 else:
299 self.retry = Retry(NoBackoff(), 0)
300 self.health_check_interval = health_check_interval
301 self.next_health_check = 0
302 self.redis_connect_func = redis_connect_func
303 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
304 self.handshake_metadata = None
305 self._sock = None
306 self._socket_read_size = socket_read_size
307 self.set_parser(parser_class)
308 self._connect_callbacks = []
309 self._buffer_cutoff = 6000
310 self._re_auth_token: Optional[TokenInterface] = None
311 try:
312 p = int(protocol)
313 except TypeError:
314 p = DEFAULT_RESP_VERSION
315 except ValueError:
316 raise ConnectionError("protocol must be an integer")
317 finally:
318 if p < 2 or p > 3:
319 raise ConnectionError("protocol must be either 2 or 3")
320 # p = DEFAULT_RESP_VERSION
321 self.protocol = p
322 self._command_packer = self._construct_command_packer(command_packer)
323
324 def __repr__(self):
325 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
326 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
327
328 @abstractmethod
329 def repr_pieces(self):
330 pass
331
332 def __del__(self):
333 try:
334 self.disconnect()
335 except Exception:
336 pass
337
338 def _construct_command_packer(self, packer):
339 if packer is not None:
340 return packer
341 elif HIREDIS_AVAILABLE:
342 return HiredisRespSerializer()
343 else:
344 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
345
346 def register_connect_callback(self, callback):
347 """
348 Register a callback to be called when the connection is established either
349 initially or reconnected. This allows listeners to issue commands that
350 are ephemeral to the connection, for example pub/sub subscription or
351 key tracking. The callback must be a _method_ and will be kept as
352 a weak reference.
353 """
354 wm = weakref.WeakMethod(callback)
355 if wm not in self._connect_callbacks:
356 self._connect_callbacks.append(wm)
357
358 def deregister_connect_callback(self, callback):
359 """
360 De-register a previously registered callback. It will no-longer receive
361 notifications on connection events. Calling this is not required when the
362 listener goes away, since the callbacks are kept as weak methods.
363 """
364 try:
365 self._connect_callbacks.remove(weakref.WeakMethod(callback))
366 except ValueError:
367 pass
368
369 def set_parser(self, parser_class):
370 """
371 Creates a new instance of parser_class with socket size:
372 _socket_read_size and assigns it to the parser for the connection
373 :param parser_class: The required parser class
374 """
375 self._parser = parser_class(socket_read_size=self._socket_read_size)
376
377 def connect(self):
378 "Connects to the Redis server if not already connected"
379 self.connect_check_health(check_health=True)
380
381 def connect_check_health(self, check_health: bool = True):
382 if self._sock:
383 return
384 try:
385 sock = self.retry.call_with_retry(
386 lambda: self._connect(), lambda error: self.disconnect(error)
387 )
388 except socket.timeout:
389 raise TimeoutError("Timeout connecting to server")
390 except OSError as e:
391 raise ConnectionError(self._error_message(e))
392
393 self._sock = sock
394 try:
395 if self.redis_connect_func is None:
396 # Use the default on_connect function
397 self.on_connect_check_health(check_health=check_health)
398 else:
399 # Use the passed function redis_connect_func
400 self.redis_connect_func(self)
401 except RedisError:
402 # clean up after any error in on_connect
403 self.disconnect()
404 raise
405
406 # run any user callbacks. right now the only internal callback
407 # is for pubsub channel/pattern resubscription
408 # first, remove any dead weakrefs
409 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
410 for ref in self._connect_callbacks:
411 callback = ref()
412 if callback:
413 callback(self)
414
415 @abstractmethod
416 def _connect(self):
417 pass
418
419 @abstractmethod
420 def _host_error(self):
421 pass
422
423 def _error_message(self, exception):
424 return format_error_message(self._host_error(), exception)
425
426 def on_connect(self):
427 self.on_connect_check_health(check_health=True)
428
429 def on_connect_check_health(self, check_health: bool = True):
430 "Initialize the connection, authenticate and select a database"
431 self._parser.on_connect(self)
432 parser = self._parser
433
434 auth_args = None
435 # if credential provider or username and/or password are set, authenticate
436 if self.credential_provider or (self.username or self.password):
437 cred_provider = (
438 self.credential_provider
439 or UsernamePasswordCredentialProvider(self.username, self.password)
440 )
441 auth_args = cred_provider.get_credentials()
442
443 # if resp version is specified and we have auth args,
444 # we need to send them via HELLO
445 if auth_args and self.protocol not in [2, "2"]:
446 if isinstance(self._parser, _RESP2Parser):
447 self.set_parser(_RESP3Parser)
448 # update cluster exception classes
449 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
450 self._parser.on_connect(self)
451 if len(auth_args) == 1:
452 auth_args = ["default", auth_args[0]]
453 # avoid checking health here -- PING will fail if we try
454 # to check the health prior to the AUTH
455 self.send_command(
456 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
457 )
458 self.handshake_metadata = self.read_response()
459 # if response.get(b"proto") != self.protocol and response.get(
460 # "proto"
461 # ) != self.protocol:
462 # raise ConnectionError("Invalid RESP version")
463 elif auth_args:
464 # avoid checking health here -- PING will fail if we try
465 # to check the health prior to the AUTH
466 self.send_command("AUTH", *auth_args, check_health=False)
467
468 try:
469 auth_response = self.read_response()
470 except AuthenticationWrongNumberOfArgsError:
471 # a username and password were specified but the Redis
472 # server seems to be < 6.0.0 which expects a single password
473 # arg. retry auth with just the password.
474 # https://github.com/andymccurdy/redis-py/issues/1274
475 self.send_command("AUTH", auth_args[-1], check_health=False)
476 auth_response = self.read_response()
477
478 if str_if_bytes(auth_response) != "OK":
479 raise AuthenticationError("Invalid Username or Password")
480
481 # if resp version is specified, switch to it
482 elif self.protocol not in [2, "2"]:
483 if isinstance(self._parser, _RESP2Parser):
484 self.set_parser(_RESP3Parser)
485 # update cluster exception classes
486 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
487 self._parser.on_connect(self)
488 self.send_command("HELLO", self.protocol, check_health=check_health)
489 self.handshake_metadata = self.read_response()
490 if (
491 self.handshake_metadata.get(b"proto") != self.protocol
492 and self.handshake_metadata.get("proto") != self.protocol
493 ):
494 raise ConnectionError("Invalid RESP version")
495
496 # if a client_name is given, set it
497 if self.client_name:
498 self.send_command(
499 "CLIENT",
500 "SETNAME",
501 self.client_name,
502 check_health=check_health,
503 )
504 if str_if_bytes(self.read_response()) != "OK":
505 raise ConnectionError("Error setting client name")
506
507 try:
508 # set the library name and version
509 if self.lib_name:
510 self.send_command(
511 "CLIENT",
512 "SETINFO",
513 "LIB-NAME",
514 self.lib_name,
515 check_health=check_health,
516 )
517 self.read_response()
518 except ResponseError:
519 pass
520
521 try:
522 if self.lib_version:
523 self.send_command(
524 "CLIENT",
525 "SETINFO",
526 "LIB-VER",
527 self.lib_version,
528 check_health=check_health,
529 )
530 self.read_response()
531 except ResponseError:
532 pass
533
534 # if a database is specified, switch to it
535 if self.db:
536 self.send_command("SELECT", self.db, check_health=check_health)
537 if str_if_bytes(self.read_response()) != "OK":
538 raise ConnectionError("Invalid Database")
539
540 def disconnect(self, *args):
541 "Disconnects from the Redis server"
542 self._parser.on_disconnect()
543
544 conn_sock = self._sock
545 self._sock = None
546 if conn_sock is None:
547 return
548
549 if os.getpid() == self.pid:
550 try:
551 conn_sock.shutdown(socket.SHUT_RDWR)
552 except (OSError, TypeError):
553 pass
554
555 try:
556 conn_sock.close()
557 except OSError:
558 pass
559
560 def _send_ping(self):
561 """Send PING, expect PONG in return"""
562 self.send_command("PING", check_health=False)
563 if str_if_bytes(self.read_response()) != "PONG":
564 raise ConnectionError("Bad response from PING health check")
565
566 def _ping_failed(self, error):
567 """Function to call when PING fails"""
568 self.disconnect()
569
570 def check_health(self):
571 """Check the health of the connection with a PING/PONG"""
572 if self.health_check_interval and time.monotonic() > self.next_health_check:
573 self.retry.call_with_retry(self._send_ping, self._ping_failed)
574
575 def send_packed_command(self, command, check_health=True):
576 """Send an already packed command to the Redis server"""
577 if not self._sock:
578 self.connect_check_health(check_health=False)
579 # guard against health check recursion
580 if check_health:
581 self.check_health()
582 try:
583 if isinstance(command, str):
584 command = [command]
585 for item in command:
586 self._sock.sendall(item)
587 except socket.timeout:
588 self.disconnect()
589 raise TimeoutError("Timeout writing to socket")
590 except OSError as e:
591 self.disconnect()
592 if len(e.args) == 1:
593 errno, errmsg = "UNKNOWN", e.args[0]
594 else:
595 errno = e.args[0]
596 errmsg = e.args[1]
597 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
598 except BaseException:
599 # BaseExceptions can be raised when a socket send operation is not
600 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
601 # to send un-sent data. However, the send_packed_command() API
602 # does not support it so there is no point in keeping the connection open.
603 self.disconnect()
604 raise
605
606 def send_command(self, *args, **kwargs):
607 """Pack and send a command to the Redis server"""
608 self.send_packed_command(
609 self._command_packer.pack(*args),
610 check_health=kwargs.get("check_health", True),
611 )
612
613 def can_read(self, timeout=0):
614 """Poll the socket to see if there's data that can be read."""
615 sock = self._sock
616 if not sock:
617 self.connect()
618
619 host_error = self._host_error()
620
621 try:
622 return self._parser.can_read(timeout)
623 except OSError as e:
624 self.disconnect()
625 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
626
627 def read_response(
628 self,
629 disable_decoding=False,
630 *,
631 disconnect_on_error=True,
632 push_request=False,
633 ):
634 """Read the response from a previously sent command"""
635
636 host_error = self._host_error()
637
638 try:
639 if self.protocol in ["3", 3]:
640 response = self._parser.read_response(
641 disable_decoding=disable_decoding, push_request=push_request
642 )
643 else:
644 response = self._parser.read_response(disable_decoding=disable_decoding)
645 except socket.timeout:
646 if disconnect_on_error:
647 self.disconnect()
648 raise TimeoutError(f"Timeout reading from {host_error}")
649 except OSError as e:
650 if disconnect_on_error:
651 self.disconnect()
652 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
653 except BaseException:
654 # Also by default close in case of BaseException. A lot of code
655 # relies on this behaviour when doing Command/Response pairs.
656 # See #1128.
657 if disconnect_on_error:
658 self.disconnect()
659 raise
660
661 if self.health_check_interval:
662 self.next_health_check = time.monotonic() + self.health_check_interval
663
664 if isinstance(response, ResponseError):
665 try:
666 raise response
667 finally:
668 del response # avoid creating ref cycles
669 return response
670
671 def pack_command(self, *args):
672 """Pack a series of arguments into the Redis protocol"""
673 return self._command_packer.pack(*args)
674
675 def pack_commands(self, commands):
676 """Pack multiple commands into the Redis protocol"""
677 output = []
678 pieces = []
679 buffer_length = 0
680 buffer_cutoff = self._buffer_cutoff
681
682 for cmd in commands:
683 for chunk in self._command_packer.pack(*cmd):
684 chunklen = len(chunk)
685 if (
686 buffer_length > buffer_cutoff
687 or chunklen > buffer_cutoff
688 or isinstance(chunk, memoryview)
689 ):
690 if pieces:
691 output.append(SYM_EMPTY.join(pieces))
692 buffer_length = 0
693 pieces = []
694
695 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
696 output.append(chunk)
697 else:
698 pieces.append(chunk)
699 buffer_length += chunklen
700
701 if pieces:
702 output.append(SYM_EMPTY.join(pieces))
703 return output
704
705 def get_protocol(self) -> Union[int, str]:
706 return self.protocol
707
708 @property
709 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
710 return self._handshake_metadata
711
712 @handshake_metadata.setter
713 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
714 self._handshake_metadata = value
715
716 def set_re_auth_token(self, token: TokenInterface):
717 self._re_auth_token = token
718
719 def re_auth(self):
720 if self._re_auth_token is not None:
721 self.send_command(
722 "AUTH",
723 self._re_auth_token.try_get("oid"),
724 self._re_auth_token.get_value(),
725 )
726 self.read_response()
727 self._re_auth_token = None
728
729
730class Connection(AbstractConnection):
731 "Manages TCP communication to and from a Redis server"
732
733 def __init__(
734 self,
735 host="localhost",
736 port=6379,
737 socket_keepalive=False,
738 socket_keepalive_options=None,
739 socket_type=0,
740 **kwargs,
741 ):
742 self.host = host
743 self.port = int(port)
744 self.socket_keepalive = socket_keepalive
745 self.socket_keepalive_options = socket_keepalive_options or {}
746 self.socket_type = socket_type
747 super().__init__(**kwargs)
748
749 def repr_pieces(self):
750 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
751 if self.client_name:
752 pieces.append(("client_name", self.client_name))
753 return pieces
754
755 def _connect(self):
756 "Create a TCP socket connection"
757 # we want to mimic what socket.create_connection does to support
758 # ipv4/ipv6, but we want to set options prior to calling
759 # socket.connect()
760 err = None
761 for res in socket.getaddrinfo(
762 self.host, self.port, self.socket_type, socket.SOCK_STREAM
763 ):
764 family, socktype, proto, canonname, socket_address = res
765 sock = None
766 try:
767 sock = socket.socket(family, socktype, proto)
768 # TCP_NODELAY
769 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
770
771 # TCP_KEEPALIVE
772 if self.socket_keepalive:
773 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
774 for k, v in self.socket_keepalive_options.items():
775 sock.setsockopt(socket.IPPROTO_TCP, k, v)
776
777 # set the socket_connect_timeout before we connect
778 sock.settimeout(self.socket_connect_timeout)
779
780 # connect
781 sock.connect(socket_address)
782
783 # set the socket_timeout now that we're connected
784 sock.settimeout(self.socket_timeout)
785 return sock
786
787 except OSError as _:
788 err = _
789 if sock is not None:
790 try:
791 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
792 except OSError:
793 pass
794 sock.close()
795
796 if err is not None:
797 raise err
798 raise OSError("socket.getaddrinfo returned an empty list")
799
800 def _host_error(self):
801 return f"{self.host}:{self.port}"
802
803
804class CacheProxyConnection(ConnectionInterface):
805 DUMMY_CACHE_VALUE = b"foo"
806 MIN_ALLOWED_VERSION = "7.4.0"
807 DEFAULT_SERVER_NAME = "redis"
808
809 def __init__(
810 self,
811 conn: ConnectionInterface,
812 cache: CacheInterface,
813 pool_lock: threading.RLock,
814 ):
815 self.pid = os.getpid()
816 self._conn = conn
817 self.retry = self._conn.retry
818 self.host = self._conn.host
819 self.port = self._conn.port
820 self.credential_provider = conn.credential_provider
821 self._pool_lock = pool_lock
822 self._cache = cache
823 self._cache_lock = threading.RLock()
824 self._current_command_cache_key = None
825 self._current_options = None
826 self.register_connect_callback(self._enable_tracking_callback)
827
828 def repr_pieces(self):
829 return self._conn.repr_pieces()
830
831 def register_connect_callback(self, callback):
832 self._conn.register_connect_callback(callback)
833
834 def deregister_connect_callback(self, callback):
835 self._conn.deregister_connect_callback(callback)
836
837 def set_parser(self, parser_class):
838 self._conn.set_parser(parser_class)
839
840 def connect(self):
841 self._conn.connect()
842
843 server_name = self._conn.handshake_metadata.get(b"server", None)
844 if server_name is None:
845 server_name = self._conn.handshake_metadata.get("server", None)
846 server_ver = self._conn.handshake_metadata.get(b"version", None)
847 if server_ver is None:
848 server_ver = self._conn.handshake_metadata.get("version", None)
849 if server_ver is None or server_ver is None:
850 raise ConnectionError("Cannot retrieve information about server version")
851
852 server_ver = ensure_string(server_ver)
853 server_name = ensure_string(server_name)
854
855 if (
856 server_name != self.DEFAULT_SERVER_NAME
857 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
858 ):
859 raise ConnectionError(
860 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
861 )
862
863 def on_connect(self):
864 self._conn.on_connect()
865
866 def disconnect(self, *args):
867 with self._cache_lock:
868 self._cache.flush()
869 self._conn.disconnect(*args)
870
871 def check_health(self):
872 self._conn.check_health()
873
874 def send_packed_command(self, command, check_health=True):
875 # TODO: Investigate if it's possible to unpack command
876 # or extract keys from packed command
877 self._conn.send_packed_command(command)
878
879 def send_command(self, *args, **kwargs):
880 self._process_pending_invalidations()
881
882 with self._cache_lock:
883 # Command is write command or not allowed
884 # to be cached.
885 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())):
886 self._current_command_cache_key = None
887 self._conn.send_command(*args, **kwargs)
888 return
889
890 if kwargs.get("keys") is None:
891 raise ValueError("Cannot create cache key.")
892
893 # Creates cache key.
894 self._current_command_cache_key = CacheKey(
895 command=args[0], redis_keys=tuple(kwargs.get("keys"))
896 )
897
898 with self._cache_lock:
899 # We have to trigger invalidation processing in case if
900 # it was cached by another connection to avoid
901 # queueing invalidations in stale connections.
902 if self._cache.get(self._current_command_cache_key):
903 entry = self._cache.get(self._current_command_cache_key)
904
905 if entry.connection_ref != self._conn:
906 with self._pool_lock:
907 while entry.connection_ref.can_read():
908 entry.connection_ref.read_response(push_request=True)
909
910 return
911
912 # Set temporary entry value to prevent
913 # race condition from another connection.
914 self._cache.set(
915 CacheEntry(
916 cache_key=self._current_command_cache_key,
917 cache_value=self.DUMMY_CACHE_VALUE,
918 status=CacheEntryStatus.IN_PROGRESS,
919 connection_ref=self._conn,
920 )
921 )
922
923 # Send command over socket only if it's allowed
924 # read-only command that not yet cached.
925 self._conn.send_command(*args, **kwargs)
926
927 def can_read(self, timeout=0):
928 return self._conn.can_read(timeout)
929
930 def read_response(
931 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
932 ):
933 with self._cache_lock:
934 # Check if command response exists in a cache and it's not in progress.
935 if (
936 self._current_command_cache_key is not None
937 and self._cache.get(self._current_command_cache_key) is not None
938 and self._cache.get(self._current_command_cache_key).status
939 != CacheEntryStatus.IN_PROGRESS
940 ):
941 res = copy.deepcopy(
942 self._cache.get(self._current_command_cache_key).cache_value
943 )
944 self._current_command_cache_key = None
945 return res
946
947 response = self._conn.read_response(
948 disable_decoding=disable_decoding,
949 disconnect_on_error=disconnect_on_error,
950 push_request=push_request,
951 )
952
953 with self._cache_lock:
954 # Prevent not-allowed command from caching.
955 if self._current_command_cache_key is None:
956 return response
957 # If response is None prevent from caching.
958 if response is None:
959 self._cache.delete_by_cache_keys([self._current_command_cache_key])
960 return response
961
962 cache_entry = self._cache.get(self._current_command_cache_key)
963
964 # Cache only responses that still valid
965 # and wasn't invalidated by another connection in meantime.
966 if cache_entry is not None:
967 cache_entry.status = CacheEntryStatus.VALID
968 cache_entry.cache_value = response
969 self._cache.set(cache_entry)
970
971 self._current_command_cache_key = None
972
973 return response
974
975 def pack_command(self, *args):
976 return self._conn.pack_command(*args)
977
978 def pack_commands(self, commands):
979 return self._conn.pack_commands(commands)
980
981 @property
982 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
983 return self._conn.handshake_metadata
984
985 def _connect(self):
986 self._conn._connect()
987
988 def _host_error(self):
989 self._conn._host_error()
990
991 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
992 conn.send_command("CLIENT", "TRACKING", "ON")
993 conn.read_response()
994 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
995
996 def _process_pending_invalidations(self):
997 while self.can_read():
998 self._conn.read_response(push_request=True)
999
1000 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
1001 with self._cache_lock:
1002 # Flush cache when DB flushed on server-side
1003 if data[1] is None:
1004 self._cache.flush()
1005 else:
1006 self._cache.delete_by_redis_keys(data[1])
1007
1008 def get_protocol(self):
1009 return self._conn.get_protocol()
1010
1011 def set_re_auth_token(self, token: TokenInterface):
1012 self._conn.set_re_auth_token(token)
1013
1014 def re_auth(self):
1015 self._conn.re_auth()
1016
1017
1018class SSLConnection(Connection):
1019 """Manages SSL connections to and from the Redis server(s).
1020 This class extends the Connection class, adding SSL functionality, and making
1021 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1022 """ # noqa
1023
1024 def __init__(
1025 self,
1026 ssl_keyfile=None,
1027 ssl_certfile=None,
1028 ssl_cert_reqs="required",
1029 ssl_ca_certs=None,
1030 ssl_ca_data=None,
1031 ssl_check_hostname=True,
1032 ssl_ca_path=None,
1033 ssl_password=None,
1034 ssl_validate_ocsp=False,
1035 ssl_validate_ocsp_stapled=False,
1036 ssl_ocsp_context=None,
1037 ssl_ocsp_expected_cert=None,
1038 ssl_min_version=None,
1039 ssl_ciphers=None,
1040 **kwargs,
1041 ):
1042 """Constructor
1043
1044 Args:
1045 ssl_keyfile: Path to an ssl private key. Defaults to None.
1046 ssl_certfile: Path to an ssl certificate. Defaults to None.
1047 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1048 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1049 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1050 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1051 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1052 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1053
1054 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1055 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1056 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1057 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1058 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
1059 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
1060
1061 Raises:
1062 RedisError
1063 """ # noqa
1064 if not SSL_AVAILABLE:
1065 raise RedisError("Python wasn't built with SSL support")
1066
1067 self.keyfile = ssl_keyfile
1068 self.certfile = ssl_certfile
1069 if ssl_cert_reqs is None:
1070 ssl_cert_reqs = ssl.CERT_NONE
1071 elif isinstance(ssl_cert_reqs, str):
1072 CERT_REQS = { # noqa: N806
1073 "none": ssl.CERT_NONE,
1074 "optional": ssl.CERT_OPTIONAL,
1075 "required": ssl.CERT_REQUIRED,
1076 }
1077 if ssl_cert_reqs not in CERT_REQS:
1078 raise RedisError(
1079 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1080 )
1081 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1082 self.cert_reqs = ssl_cert_reqs
1083 self.ca_certs = ssl_ca_certs
1084 self.ca_data = ssl_ca_data
1085 self.ca_path = ssl_ca_path
1086 self.check_hostname = (
1087 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1088 )
1089 self.certificate_password = ssl_password
1090 self.ssl_validate_ocsp = ssl_validate_ocsp
1091 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1092 self.ssl_ocsp_context = ssl_ocsp_context
1093 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1094 self.ssl_min_version = ssl_min_version
1095 self.ssl_ciphers = ssl_ciphers
1096 super().__init__(**kwargs)
1097
1098 def _connect(self):
1099 """
1100 Wrap the socket with SSL support, handling potential errors.
1101 """
1102 sock = super()._connect()
1103 try:
1104 return self._wrap_socket_with_ssl(sock)
1105 except (OSError, RedisError):
1106 sock.close()
1107 raise
1108
1109 def _wrap_socket_with_ssl(self, sock):
1110 """
1111 Wraps the socket with SSL support.
1112
1113 Args:
1114 sock: The plain socket to wrap with SSL.
1115
1116 Returns:
1117 An SSL wrapped socket.
1118 """
1119 context = ssl.create_default_context()
1120 context.check_hostname = self.check_hostname
1121 context.verify_mode = self.cert_reqs
1122 if self.certfile or self.keyfile:
1123 context.load_cert_chain(
1124 certfile=self.certfile,
1125 keyfile=self.keyfile,
1126 password=self.certificate_password,
1127 )
1128 if (
1129 self.ca_certs is not None
1130 or self.ca_path is not None
1131 or self.ca_data is not None
1132 ):
1133 context.load_verify_locations(
1134 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1135 )
1136 if self.ssl_min_version is not None:
1137 context.minimum_version = self.ssl_min_version
1138 if self.ssl_ciphers:
1139 context.set_ciphers(self.ssl_ciphers)
1140 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1141 raise RedisError("cryptography is not installed.")
1142
1143 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1144 raise RedisError(
1145 "Either an OCSP staple or pure OCSP connection must be validated "
1146 "- not both."
1147 )
1148
1149 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1150
1151 # validation for the stapled case
1152 if self.ssl_validate_ocsp_stapled:
1153 import OpenSSL
1154
1155 from .ocsp import ocsp_staple_verifier
1156
1157 # if a context is provided use it - otherwise, a basic context
1158 if self.ssl_ocsp_context is None:
1159 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1160 staple_ctx.use_certificate_file(self.certfile)
1161 staple_ctx.use_privatekey_file(self.keyfile)
1162 else:
1163 staple_ctx = self.ssl_ocsp_context
1164
1165 staple_ctx.set_ocsp_client_callback(
1166 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1167 )
1168
1169 # need another socket
1170 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1171 con.request_ocsp()
1172 con.connect((self.host, self.port))
1173 con.do_handshake()
1174 con.shutdown()
1175 return sslsock
1176
1177 # pure ocsp validation
1178 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1179 from .ocsp import OCSPVerifier
1180
1181 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1182 if o.is_valid():
1183 return sslsock
1184 else:
1185 raise ConnectionError("ocsp validation error")
1186 return sslsock
1187
1188
1189class UnixDomainSocketConnection(AbstractConnection):
1190 "Manages UDS communication to and from a Redis server"
1191
1192 def __init__(self, path="", socket_timeout=None, **kwargs):
1193 super().__init__(**kwargs)
1194 self.path = path
1195 self.socket_timeout = socket_timeout
1196
1197 def repr_pieces(self):
1198 pieces = [("path", self.path), ("db", self.db)]
1199 if self.client_name:
1200 pieces.append(("client_name", self.client_name))
1201 return pieces
1202
1203 def _connect(self):
1204 "Create a Unix domain socket connection"
1205 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1206 sock.settimeout(self.socket_connect_timeout)
1207 try:
1208 sock.connect(self.path)
1209 except OSError:
1210 # Prevent ResourceWarnings for unclosed sockets.
1211 try:
1212 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1213 except OSError:
1214 pass
1215 sock.close()
1216 raise
1217 sock.settimeout(self.socket_timeout)
1218 return sock
1219
1220 def _host_error(self):
1221 return self.path
1222
1223
1224FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1225
1226
1227def to_bool(value):
1228 if value is None or value == "":
1229 return None
1230 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1231 return False
1232 return bool(value)
1233
1234
1235URL_QUERY_ARGUMENT_PARSERS = {
1236 "db": int,
1237 "socket_timeout": float,
1238 "socket_connect_timeout": float,
1239 "socket_keepalive": to_bool,
1240 "retry_on_timeout": to_bool,
1241 "retry_on_error": list,
1242 "max_connections": int,
1243 "health_check_interval": int,
1244 "ssl_check_hostname": to_bool,
1245 "timeout": float,
1246}
1247
1248
1249def parse_url(url):
1250 if not (
1251 url.startswith("redis://")
1252 or url.startswith("rediss://")
1253 or url.startswith("unix://")
1254 ):
1255 raise ValueError(
1256 "Redis URL must specify one of the following "
1257 "schemes (redis://, rediss://, unix://)"
1258 )
1259
1260 url = urlparse(url)
1261 kwargs = {}
1262
1263 for name, value in parse_qs(url.query).items():
1264 if value and len(value) > 0:
1265 value = unquote(value[0])
1266 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1267 if parser:
1268 try:
1269 kwargs[name] = parser(value)
1270 except (TypeError, ValueError):
1271 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1272 else:
1273 kwargs[name] = value
1274
1275 if url.username:
1276 kwargs["username"] = unquote(url.username)
1277 if url.password:
1278 kwargs["password"] = unquote(url.password)
1279
1280 # We only support redis://, rediss:// and unix:// schemes.
1281 if url.scheme == "unix":
1282 if url.path:
1283 kwargs["path"] = unquote(url.path)
1284 kwargs["connection_class"] = UnixDomainSocketConnection
1285
1286 else: # implied: url.scheme in ("redis", "rediss"):
1287 if url.hostname:
1288 kwargs["host"] = unquote(url.hostname)
1289 if url.port:
1290 kwargs["port"] = int(url.port)
1291
1292 # If there's a path argument, use it as the db argument if a
1293 # querystring value wasn't specified
1294 if url.path and "db" not in kwargs:
1295 try:
1296 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1297 except (AttributeError, ValueError):
1298 pass
1299
1300 if url.scheme == "rediss":
1301 kwargs["connection_class"] = SSLConnection
1302
1303 return kwargs
1304
1305
1306_CP = TypeVar("_CP", bound="ConnectionPool")
1307
1308
1309class ConnectionPool:
1310 """
1311 Create a connection pool. ``If max_connections`` is set, then this
1312 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
1313 limit is reached.
1314
1315 By default, TCP connections are created unless ``connection_class``
1316 is specified. Use class:`.UnixDomainSocketConnection` for
1317 unix sockets.
1318
1319 Any additional keyword arguments are passed to the constructor of
1320 ``connection_class``.
1321 """
1322
1323 @classmethod
1324 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1325 """
1326 Return a connection pool configured from the given URL.
1327
1328 For example::
1329
1330 redis://[[username]:[password]]@localhost:6379/0
1331 rediss://[[username]:[password]]@localhost:6379/0
1332 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1333
1334 Three URL schemes are supported:
1335
1336 - `redis://` creates a TCP socket connection. See more at:
1337 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1338 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1339 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1340 - ``unix://``: creates a Unix Domain Socket connection.
1341
1342 The username, password, hostname, path and all querystring values
1343 are passed through urllib.parse.unquote in order to replace any
1344 percent-encoded values with their corresponding characters.
1345
1346 There are several ways to specify a database number. The first value
1347 found will be used:
1348
1349 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1350 2. If using the redis:// or rediss:// schemes, the path argument
1351 of the url, e.g. redis://localhost/0
1352 3. A ``db`` keyword argument to this function.
1353
1354 If none of these options are specified, the default db=0 is used.
1355
1356 All querystring options are cast to their appropriate Python types.
1357 Boolean arguments can be specified with string values "True"/"False"
1358 or "Yes"/"No". Values that cannot be properly cast cause a
1359 ``ValueError`` to be raised. Once parsed, the querystring arguments
1360 and keyword arguments are passed to the ``ConnectionPool``'s
1361 class initializer. In the case of conflicting arguments, querystring
1362 arguments always win.
1363 """
1364 url_options = parse_url(url)
1365
1366 if "connection_class" in kwargs:
1367 url_options["connection_class"] = kwargs["connection_class"]
1368
1369 kwargs.update(url_options)
1370 return cls(**kwargs)
1371
1372 def __init__(
1373 self,
1374 connection_class=Connection,
1375 max_connections: Optional[int] = None,
1376 cache_factory: Optional[CacheFactoryInterface] = None,
1377 **connection_kwargs,
1378 ):
1379 max_connections = max_connections or 2**31
1380 if not isinstance(max_connections, int) or max_connections < 0:
1381 raise ValueError('"max_connections" must be a positive integer')
1382
1383 self.connection_class = connection_class
1384 self.connection_kwargs = connection_kwargs
1385 self.max_connections = max_connections
1386 self.cache = None
1387 self._cache_factory = cache_factory
1388
1389 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1390 if connection_kwargs.get("protocol") not in [3, "3"]:
1391 raise RedisError("Client caching is only supported with RESP version 3")
1392
1393 cache = self.connection_kwargs.get("cache")
1394
1395 if cache is not None:
1396 if not isinstance(cache, CacheInterface):
1397 raise ValueError("Cache must implement CacheInterface")
1398
1399 self.cache = cache
1400 else:
1401 if self._cache_factory is not None:
1402 self.cache = self._cache_factory.get_cache()
1403 else:
1404 self.cache = CacheFactory(
1405 self.connection_kwargs.get("cache_config")
1406 ).get_cache()
1407
1408 connection_kwargs.pop("cache", None)
1409 connection_kwargs.pop("cache_config", None)
1410
1411 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1412 if self._event_dispatcher is None:
1413 self._event_dispatcher = EventDispatcher()
1414
1415 # a lock to protect the critical section in _checkpid().
1416 # this lock is acquired when the process id changes, such as
1417 # after a fork. during this time, multiple threads in the child
1418 # process could attempt to acquire this lock. the first thread
1419 # to acquire the lock will reset the data structures and lock
1420 # object of this pool. subsequent threads acquiring this lock
1421 # will notice the first thread already did the work and simply
1422 # release the lock.
1423
1424 self._fork_lock = threading.RLock()
1425 self._lock = threading.RLock()
1426
1427 self.reset()
1428
1429 def __repr__(self) -> str:
1430 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1431 return (
1432 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1433 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1434 f"({conn_kwargs})>)>"
1435 )
1436
1437 def get_protocol(self):
1438 """
1439 Returns:
1440 The RESP protocol version, or ``None`` if the protocol is not specified,
1441 in which case the server default will be used.
1442 """
1443 return self.connection_kwargs.get("protocol", None)
1444
1445 def reset(self) -> None:
1446 self._created_connections = 0
1447 self._available_connections = []
1448 self._in_use_connections = set()
1449
1450 # this must be the last operation in this method. while reset() is
1451 # called when holding _fork_lock, other threads in this process
1452 # can call _checkpid() which compares self.pid and os.getpid() without
1453 # holding any lock (for performance reasons). keeping this assignment
1454 # as the last operation ensures that those other threads will also
1455 # notice a pid difference and block waiting for the first thread to
1456 # release _fork_lock. when each of these threads eventually acquire
1457 # _fork_lock, they will notice that another thread already called
1458 # reset() and they will immediately release _fork_lock and continue on.
1459 self.pid = os.getpid()
1460
1461 def _checkpid(self) -> None:
1462 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
1463 # systems. this is called by all ConnectionPool methods that
1464 # manipulate the pool's state such as get_connection() and release().
1465 #
1466 # _checkpid() determines whether the process has forked by comparing
1467 # the current process id to the process id saved on the ConnectionPool
1468 # instance. if these values are the same, _checkpid() simply returns.
1469 #
1470 # when the process ids differ, _checkpid() assumes that the process
1471 # has forked and that we're now running in the child process. the child
1472 # process cannot use the parent's file descriptors (e.g., sockets).
1473 # therefore, when _checkpid() sees the process id change, it calls
1474 # reset() in order to reinitialize the child's ConnectionPool. this
1475 # will cause the child to make all new connection objects.
1476 #
1477 # _checkpid() is protected by self._fork_lock to ensure that multiple
1478 # threads in the child process do not call reset() multiple times.
1479 #
1480 # there is an extremely small chance this could fail in the following
1481 # scenario:
1482 # 1. process A calls _checkpid() for the first time and acquires
1483 # self._fork_lock.
1484 # 2. while holding self._fork_lock, process A forks (the fork()
1485 # could happen in a different thread owned by process A)
1486 # 3. process B (the forked child process) inherits the
1487 # ConnectionPool's state from the parent. that state includes
1488 # a locked _fork_lock. process B will not be notified when
1489 # process A releases the _fork_lock and will thus never be
1490 # able to acquire the _fork_lock.
1491 #
1492 # to mitigate this possible deadlock, _checkpid() will only wait 5
1493 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1494 # that time it is assumed that the child is deadlocked and a
1495 # redis.ChildDeadlockedError error is raised.
1496 if self.pid != os.getpid():
1497 acquired = self._fork_lock.acquire(timeout=5)
1498 if not acquired:
1499 raise ChildDeadlockedError
1500 # reset() the instance for the new process if another thread
1501 # hasn't already done so
1502 try:
1503 if self.pid != os.getpid():
1504 self.reset()
1505 finally:
1506 self._fork_lock.release()
1507
1508 @deprecated_args(
1509 args_to_warn=["*"],
1510 reason="Use get_connection() without args instead",
1511 version="5.3.0",
1512 )
1513 def get_connection(self, command_name=None, *keys, **options) -> "Connection":
1514 "Get a connection from the pool"
1515
1516 self._checkpid()
1517 with self._lock:
1518 try:
1519 connection = self._available_connections.pop()
1520 except IndexError:
1521 connection = self.make_connection()
1522 self._in_use_connections.add(connection)
1523
1524 try:
1525 # ensure this connection is connected to Redis
1526 connection.connect()
1527 # connections that the pool provides should be ready to send
1528 # a command. if not, the connection was either returned to the
1529 # pool before all data has been read or the socket has been
1530 # closed. either way, reconnect and verify everything is good.
1531 try:
1532 if connection.can_read() and self.cache is None:
1533 raise ConnectionError("Connection has data")
1534 except (ConnectionError, TimeoutError, OSError):
1535 connection.disconnect()
1536 connection.connect()
1537 if connection.can_read():
1538 raise ConnectionError("Connection not ready")
1539 except BaseException:
1540 # release the connection back to the pool so that we don't
1541 # leak it
1542 self.release(connection)
1543 raise
1544
1545 return connection
1546
1547 def get_encoder(self) -> Encoder:
1548 "Return an encoder based on encoding settings"
1549 kwargs = self.connection_kwargs
1550 return Encoder(
1551 encoding=kwargs.get("encoding", "utf-8"),
1552 encoding_errors=kwargs.get("encoding_errors", "strict"),
1553 decode_responses=kwargs.get("decode_responses", False),
1554 )
1555
1556 def make_connection(self) -> "ConnectionInterface":
1557 "Create a new connection"
1558 if self._created_connections >= self.max_connections:
1559 raise ConnectionError("Too many connections")
1560 self._created_connections += 1
1561
1562 if self.cache is not None:
1563 return CacheProxyConnection(
1564 self.connection_class(**self.connection_kwargs), self.cache, self._lock
1565 )
1566
1567 return self.connection_class(**self.connection_kwargs)
1568
1569 def release(self, connection: "Connection") -> None:
1570 "Releases the connection back to the pool"
1571 self._checkpid()
1572 with self._lock:
1573 try:
1574 self._in_use_connections.remove(connection)
1575 except KeyError:
1576 # Gracefully fail when a connection is returned to this pool
1577 # that the pool doesn't actually own
1578 return
1579
1580 if self.owns_connection(connection):
1581 self._available_connections.append(connection)
1582 self._event_dispatcher.dispatch(
1583 AfterConnectionReleasedEvent(connection)
1584 )
1585 else:
1586 # Pool doesn't own this connection, do not add it back
1587 # to the pool.
1588 # The created connections count should not be changed,
1589 # because the connection was not created by the pool.
1590 connection.disconnect()
1591 return
1592
1593 def owns_connection(self, connection: "Connection") -> int:
1594 return connection.pid == self.pid
1595
1596 def disconnect(self, inuse_connections: bool = True) -> None:
1597 """
1598 Disconnects connections in the pool
1599
1600 If ``inuse_connections`` is True, disconnect connections that are
1601 current in use, potentially by other threads. Otherwise only disconnect
1602 connections that are idle in the pool.
1603 """
1604 self._checkpid()
1605 with self._lock:
1606 if inuse_connections:
1607 connections = chain(
1608 self._available_connections, self._in_use_connections
1609 )
1610 else:
1611 connections = self._available_connections
1612
1613 for connection in connections:
1614 connection.disconnect()
1615
1616 def close(self) -> None:
1617 """Close the pool, disconnecting all connections"""
1618 self.disconnect()
1619
1620 def set_retry(self, retry: Retry) -> None:
1621 self.connection_kwargs.update({"retry": retry})
1622 for conn in self._available_connections:
1623 conn.retry = retry
1624 for conn in self._in_use_connections:
1625 conn.retry = retry
1626
1627 def re_auth_callback(self, token: TokenInterface):
1628 with self._lock:
1629 for conn in self._available_connections:
1630 conn.retry.call_with_retry(
1631 lambda: conn.send_command(
1632 "AUTH", token.try_get("oid"), token.get_value()
1633 ),
1634 lambda error: self._mock(error),
1635 )
1636 conn.retry.call_with_retry(
1637 lambda: conn.read_response(), lambda error: self._mock(error)
1638 )
1639 for conn in self._in_use_connections:
1640 conn.set_re_auth_token(token)
1641
1642 async def _mock(self, error: RedisError):
1643 """
1644 Dummy functions, needs to be passed as error callback to retry object.
1645 :param error:
1646 :return:
1647 """
1648 pass
1649
1650
1651class BlockingConnectionPool(ConnectionPool):
1652 """
1653 Thread-safe blocking connection pool::
1654
1655 >>> from redis.client import Redis
1656 >>> client = Redis(connection_pool=BlockingConnectionPool())
1657
1658 It performs the same function as the default
1659 :py:class:`~redis.ConnectionPool` implementation, in that,
1660 it maintains a pool of reusable connections that can be shared by
1661 multiple redis clients (safely across threads if required).
1662
1663 The difference is that, in the event that a client tries to get a
1664 connection from the pool when all of connections are in use, rather than
1665 raising a :py:class:`~redis.ConnectionError` (as the default
1666 :py:class:`~redis.ConnectionPool` implementation does), it
1667 makes the client wait ("blocks") for a specified number of seconds until
1668 a connection becomes available.
1669
1670 Use ``max_connections`` to increase / decrease the pool size::
1671
1672 >>> pool = BlockingConnectionPool(max_connections=10)
1673
1674 Use ``timeout`` to tell it either how many seconds to wait for a connection
1675 to become available, or to block forever:
1676
1677 >>> # Block forever.
1678 >>> pool = BlockingConnectionPool(timeout=None)
1679
1680 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1681 >>> # not available.
1682 >>> pool = BlockingConnectionPool(timeout=5)
1683 """
1684
1685 def __init__(
1686 self,
1687 max_connections=50,
1688 timeout=20,
1689 connection_class=Connection,
1690 queue_class=LifoQueue,
1691 **connection_kwargs,
1692 ):
1693 self.queue_class = queue_class
1694 self.timeout = timeout
1695 super().__init__(
1696 connection_class=connection_class,
1697 max_connections=max_connections,
1698 **connection_kwargs,
1699 )
1700
1701 def reset(self):
1702 # Create and fill up a thread safe queue with ``None`` values.
1703 self.pool = self.queue_class(self.max_connections)
1704 while True:
1705 try:
1706 self.pool.put_nowait(None)
1707 except Full:
1708 break
1709
1710 # Keep a list of actual connection instances so that we can
1711 # disconnect them later.
1712 self._connections = []
1713
1714 # this must be the last operation in this method. while reset() is
1715 # called when holding _fork_lock, other threads in this process
1716 # can call _checkpid() which compares self.pid and os.getpid() without
1717 # holding any lock (for performance reasons). keeping this assignment
1718 # as the last operation ensures that those other threads will also
1719 # notice a pid difference and block waiting for the first thread to
1720 # release _fork_lock. when each of these threads eventually acquire
1721 # _fork_lock, they will notice that another thread already called
1722 # reset() and they will immediately release _fork_lock and continue on.
1723 self.pid = os.getpid()
1724
1725 def make_connection(self):
1726 "Make a fresh connection."
1727 if self.cache is not None:
1728 connection = CacheProxyConnection(
1729 self.connection_class(**self.connection_kwargs), self.cache, self._lock
1730 )
1731 else:
1732 connection = self.connection_class(**self.connection_kwargs)
1733 self._connections.append(connection)
1734 return connection
1735
1736 @deprecated_args(
1737 args_to_warn=["*"],
1738 reason="Use get_connection() without args instead",
1739 version="5.3.0",
1740 )
1741 def get_connection(self, command_name=None, *keys, **options):
1742 """
1743 Get a connection, blocking for ``self.timeout`` until a connection
1744 is available from the pool.
1745
1746 If the connection returned is ``None`` then creates a new connection.
1747 Because we use a last-in first-out queue, the existing connections
1748 (having been returned to the pool after the initial ``None`` values
1749 were added) will be returned before ``None`` values. This means we only
1750 create new connections when we need to, i.e.: the actual number of
1751 connections will only increase in response to demand.
1752 """
1753 # Make sure we haven't changed process.
1754 self._checkpid()
1755
1756 # Try and get a connection from the pool. If one isn't available within
1757 # self.timeout then raise a ``ConnectionError``.
1758 connection = None
1759 try:
1760 connection = self.pool.get(block=True, timeout=self.timeout)
1761 except Empty:
1762 # Note that this is not caught by the redis client and will be
1763 # raised unless handled by application code. If you want never to
1764 raise ConnectionError("No connection available.")
1765
1766 # If the ``connection`` is actually ``None`` then that's a cue to make
1767 # a new connection to add to the pool.
1768 if connection is None:
1769 connection = self.make_connection()
1770
1771 try:
1772 # ensure this connection is connected to Redis
1773 connection.connect()
1774 # connections that the pool provides should be ready to send
1775 # a command. if not, the connection was either returned to the
1776 # pool before all data has been read or the socket has been
1777 # closed. either way, reconnect and verify everything is good.
1778 try:
1779 if connection.can_read():
1780 raise ConnectionError("Connection has data")
1781 except (ConnectionError, TimeoutError, OSError):
1782 connection.disconnect()
1783 connection.connect()
1784 if connection.can_read():
1785 raise ConnectionError("Connection not ready")
1786 except BaseException:
1787 # release the connection back to the pool so that we don't leak it
1788 self.release(connection)
1789 raise
1790
1791 return connection
1792
1793 def release(self, connection):
1794 "Releases the connection back to the pool."
1795 # Make sure we haven't changed process.
1796 self._checkpid()
1797 if not self.owns_connection(connection):
1798 # pool doesn't own this connection. do not add it back
1799 # to the pool. instead add a None value which is a placeholder
1800 # that will cause the pool to recreate the connection if
1801 # its needed.
1802 connection.disconnect()
1803 self.pool.put_nowait(None)
1804 return
1805
1806 # Put the connection back into the pool.
1807 try:
1808 self.pool.put_nowait(connection)
1809 except Full:
1810 # perhaps the pool has been reset() after a fork? regardless,
1811 # we don't want this connection
1812 pass
1813
1814 def disconnect(self):
1815 "Disconnects all connections in the pool."
1816 self._checkpid()
1817 for connection in self._connections:
1818 connection.disconnect()