Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/connection.py: 20%
858 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 07:09 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 07:09 +0000
1import copy
2import errno
3import io
4import os
5import socket
6import sys
7import threading
8import weakref
9from abc import abstractmethod
10from io import SEEK_END
11from itertools import chain
12from queue import Empty, Full, LifoQueue
13from time import time
14from typing import Optional, Union
15from urllib.parse import parse_qs, unquote, urlparse
17from redis.backoff import NoBackoff
18from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
19from redis.exceptions import (
20 AuthenticationError,
21 AuthenticationWrongNumberOfArgsError,
22 BusyLoadingError,
23 ChildDeadlockedError,
24 ConnectionError,
25 DataError,
26 ExecAbortError,
27 InvalidResponse,
28 ModuleError,
29 NoPermissionError,
30 NoScriptError,
31 ReadOnlyError,
32 RedisError,
33 ResponseError,
34 TimeoutError,
35)
36from redis.retry import Retry
37from redis.utils import (
38 CRYPTOGRAPHY_AVAILABLE,
39 HIREDIS_AVAILABLE,
40 HIREDIS_PACK_AVAILABLE,
41 str_if_bytes,
42)
44try:
45 import ssl
47 ssl_available = True
48except ImportError:
49 ssl_available = False
51NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK}
53if ssl_available:
54 if hasattr(ssl, "SSLWantReadError"):
55 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2
56 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2
57 else:
58 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2
60NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
62if HIREDIS_AVAILABLE:
63 import hiredis
65SYM_STAR = b"*"
66SYM_DOLLAR = b"$"
67SYM_CRLF = b"\r\n"
68SYM_EMPTY = b""
70SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
72SENTINEL = object()
73MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
74NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
75MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
76MODULE_EXPORTS_DATA_TYPES_ERROR = (
77 "Error unloading module: the module "
78 "exports one or more module-side data "
79 "types, can't unload"
80)
81# user send an AUTH cmd to a server without authorization configured
82NO_AUTH_SET_ERROR = {
83 # Redis >= 6.0
84 "AUTH <password> called without any password "
85 "configured for the default user. Are you sure "
86 "your configuration is correct?": AuthenticationError,
87 # Redis < 6.0
88 "Client sent AUTH, but no password is set": AuthenticationError,
89}
92class Encoder:
93 "Encode strings to bytes-like and decode bytes-like to strings"
95 def __init__(self, encoding, encoding_errors, decode_responses):
96 self.encoding = encoding
97 self.encoding_errors = encoding_errors
98 self.decode_responses = decode_responses
100 def encode(self, value):
101 "Return a bytestring or bytes-like representation of the value"
102 if isinstance(value, (bytes, memoryview)):
103 return value
104 elif isinstance(value, bool):
105 # special case bool since it is a subclass of int
106 raise DataError(
107 "Invalid input of type: 'bool'. Convert to a "
108 "bytes, string, int or float first."
109 )
110 elif isinstance(value, (int, float)):
111 value = repr(value).encode()
112 elif not isinstance(value, str):
113 # a value we don't know how to deal with. throw an error
114 typename = type(value).__name__
115 raise DataError(
116 f"Invalid input of type: '{typename}'. "
117 f"Convert to a bytes, string, int or float first."
118 )
119 if isinstance(value, str):
120 value = value.encode(self.encoding, self.encoding_errors)
121 return value
123 def decode(self, value, force=False):
124 "Return a unicode string from the bytes-like representation"
125 if self.decode_responses or force:
126 if isinstance(value, memoryview):
127 value = value.tobytes()
128 if isinstance(value, bytes):
129 value = value.decode(self.encoding, self.encoding_errors)
130 return value
133class BaseParser:
134 EXCEPTION_CLASSES = {
135 "ERR": {
136 "max number of clients reached": ConnectionError,
137 "invalid password": AuthenticationError,
138 # some Redis server versions report invalid command syntax
139 # in lowercase
140 "wrong number of arguments "
141 "for 'auth' command": AuthenticationWrongNumberOfArgsError,
142 # some Redis server versions report invalid command syntax
143 # in uppercase
144 "wrong number of arguments "
145 "for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
146 MODULE_LOAD_ERROR: ModuleError,
147 MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
148 NO_SUCH_MODULE_ERROR: ModuleError,
149 MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
150 **NO_AUTH_SET_ERROR,
151 },
152 "WRONGPASS": AuthenticationError,
153 "EXECABORT": ExecAbortError,
154 "LOADING": BusyLoadingError,
155 "NOSCRIPT": NoScriptError,
156 "READONLY": ReadOnlyError,
157 "NOAUTH": AuthenticationError,
158 "NOPERM": NoPermissionError,
159 }
161 def parse_error(self, response):
162 "Parse an error response"
163 error_code = response.split(" ")[0]
164 if error_code in self.EXCEPTION_CLASSES:
165 response = response[len(error_code) + 1 :]
166 exception_class = self.EXCEPTION_CLASSES[error_code]
167 if isinstance(exception_class, dict):
168 exception_class = exception_class.get(response, ResponseError)
169 return exception_class(response)
170 return ResponseError(response)
173class SocketBuffer:
174 def __init__(
175 self, socket: socket.socket, socket_read_size: int, socket_timeout: float
176 ):
177 self._sock = socket
178 self.socket_read_size = socket_read_size
179 self.socket_timeout = socket_timeout
180 self._buffer = io.BytesIO()
182 def unread_bytes(self) -> int:
183 """
184 Remaining unread length of buffer
185 """
186 pos = self._buffer.tell()
187 end = self._buffer.seek(0, SEEK_END)
188 self._buffer.seek(pos)
189 return end - pos
191 def _read_from_socket(
192 self,
193 length: Optional[int] = None,
194 timeout: Union[float, object] = SENTINEL,
195 raise_on_timeout: Optional[bool] = True,
196 ) -> bool:
197 sock = self._sock
198 socket_read_size = self.socket_read_size
199 marker = 0
200 custom_timeout = timeout is not SENTINEL
202 buf = self._buffer
203 current_pos = buf.tell()
204 buf.seek(0, SEEK_END)
205 if custom_timeout:
206 sock.settimeout(timeout)
207 try:
208 while True:
209 data = self._sock.recv(socket_read_size)
210 # an empty string indicates the server shutdown the socket
211 if isinstance(data, bytes) and len(data) == 0:
212 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
213 buf.write(data)
214 data_length = len(data)
215 marker += data_length
217 if length is not None and length > marker:
218 continue
219 return True
220 except socket.timeout:
221 if raise_on_timeout:
222 raise TimeoutError("Timeout reading from socket")
223 return False
224 except NONBLOCKING_EXCEPTIONS as ex:
225 # if we're in nonblocking mode and the recv raises a
226 # blocking error, simply return False indicating that
227 # there's no data to be read. otherwise raise the
228 # original exception.
229 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
230 if not raise_on_timeout and ex.errno == allowed:
231 return False
232 raise ConnectionError(f"Error while reading from socket: {ex.args}")
233 finally:
234 buf.seek(current_pos)
235 if custom_timeout:
236 sock.settimeout(self.socket_timeout)
238 def can_read(self, timeout: float) -> bool:
239 return bool(self.unread_bytes()) or self._read_from_socket(
240 timeout=timeout, raise_on_timeout=False
241 )
243 def read(self, length: int) -> bytes:
244 length = length + 2 # make sure to read the \r\n terminator
245 # BufferIO will return less than requested if buffer is short
246 data = self._buffer.read(length)
247 missing = length - len(data)
248 if missing:
249 # fill up the buffer and read the remainder
250 self._read_from_socket(missing)
251 data += self._buffer.read(missing)
252 return data[:-2]
254 def readline(self) -> bytes:
255 buf = self._buffer
256 data = buf.readline()
257 while not data.endswith(SYM_CRLF):
258 # there's more data in the socket that we need
259 self._read_from_socket()
260 data += buf.readline()
262 return data[:-2]
264 def get_pos(self) -> int:
265 """
266 Get current read position
267 """
268 return self._buffer.tell()
270 def rewind(self, pos: int) -> None:
271 """
272 Rewind the buffer to a specific position, to re-start reading
273 """
274 self._buffer.seek(pos)
276 def purge(self) -> None:
277 """
278 After a successful read, purge the read part of buffer
279 """
280 unread = self.unread_bytes()
282 # Only if we have read all of the buffer do we truncate, to
283 # reduce the amount of memory thrashing. This heuristic
284 # can be changed or removed later.
285 if unread > 0:
286 return
288 if unread > 0:
289 # move unread data to the front
290 view = self._buffer.getbuffer()
291 view[:unread] = view[-unread:]
292 self._buffer.truncate(unread)
293 self._buffer.seek(0)
295 def close(self) -> None:
296 try:
297 self._buffer.close()
298 except Exception:
299 # issue #633 suggests the purge/close somehow raised a
300 # BadFileDescriptor error. Perhaps the client ran out of
301 # memory or something else? It's probably OK to ignore
302 # any error being raised from purge/close since we're
303 # removing the reference to the instance below.
304 pass
305 self._buffer = None
306 self._sock = None
309class PythonParser(BaseParser):
310 "Plain Python parsing class"
312 def __init__(self, socket_read_size):
313 self.socket_read_size = socket_read_size
314 self.encoder = None
315 self._sock = None
316 self._buffer = None
318 def __del__(self):
319 try:
320 self.on_disconnect()
321 except Exception:
322 pass
324 def on_connect(self, connection):
325 "Called when the socket connects"
326 self._sock = connection._sock
327 self._buffer = SocketBuffer(
328 self._sock, self.socket_read_size, connection.socket_timeout
329 )
330 self.encoder = connection.encoder
332 def on_disconnect(self):
333 "Called when the socket disconnects"
334 self._sock = None
335 if self._buffer is not None:
336 self._buffer.close()
337 self._buffer = None
338 self.encoder = None
340 def can_read(self, timeout):
341 return self._buffer and self._buffer.can_read(timeout)
343 def read_response(self, disable_decoding=False):
344 pos = self._buffer.get_pos() if self._buffer else None
345 try:
346 result = self._read_response(disable_decoding=disable_decoding)
347 except BaseException:
348 if self._buffer:
349 self._buffer.rewind(pos)
350 raise
351 else:
352 self._buffer.purge()
353 return result
355 def _read_response(self, disable_decoding=False):
356 raw = self._buffer.readline()
357 if not raw:
358 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
360 byte, response = raw[:1], raw[1:]
362 # server returned an error
363 if byte == b"-":
364 response = response.decode("utf-8", errors="replace")
365 error = self.parse_error(response)
366 # if the error is a ConnectionError, raise immediately so the user
367 # is notified
368 if isinstance(error, ConnectionError):
369 raise error
370 # otherwise, we're dealing with a ResponseError that might belong
371 # inside a pipeline response. the connection's read_response()
372 # and/or the pipeline's execute() will raise this error if
373 # necessary, so just return the exception instance here.
374 return error
375 # single value
376 elif byte == b"+":
377 pass
378 # int value
379 elif byte == b":":
380 return int(response)
381 # bulk response
382 elif byte == b"$" and response == b"-1":
383 return None
384 elif byte == b"$":
385 response = self._buffer.read(int(response))
386 # multi-bulk response
387 elif byte == b"*" and response == b"-1":
388 return None
389 elif byte == b"*":
390 response = [
391 self._read_response(disable_decoding=disable_decoding)
392 for i in range(int(response))
393 ]
394 else:
395 raise InvalidResponse(f"Protocol Error: {raw!r}")
397 if disable_decoding is False:
398 response = self.encoder.decode(response)
399 return response
402class HiredisParser(BaseParser):
403 "Parser class for connections using Hiredis"
405 def __init__(self, socket_read_size):
406 if not HIREDIS_AVAILABLE:
407 raise RedisError("Hiredis is not installed")
408 self.socket_read_size = socket_read_size
409 self._buffer = bytearray(socket_read_size)
411 def __del__(self):
412 try:
413 self.on_disconnect()
414 except Exception:
415 pass
417 def on_connect(self, connection, **kwargs):
418 self._sock = connection._sock
419 self._socket_timeout = connection.socket_timeout
420 kwargs = {
421 "protocolError": InvalidResponse,
422 "replyError": self.parse_error,
423 "errors": connection.encoder.encoding_errors,
424 }
426 if connection.encoder.decode_responses:
427 kwargs["encoding"] = connection.encoder.encoding
428 self._reader = hiredis.Reader(**kwargs)
429 self._next_response = False
431 def on_disconnect(self):
432 self._sock = None
433 self._reader = None
434 self._next_response = False
436 def can_read(self, timeout):
437 if not self._reader:
438 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
440 if self._next_response is False:
441 self._next_response = self._reader.gets()
442 if self._next_response is False:
443 return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
444 return True
446 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
447 sock = self._sock
448 custom_timeout = timeout is not SENTINEL
449 try:
450 if custom_timeout:
451 sock.settimeout(timeout)
452 bufflen = self._sock.recv_into(self._buffer)
453 if bufflen == 0:
454 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
455 self._reader.feed(self._buffer, 0, bufflen)
456 # data was read from the socket and added to the buffer.
457 # return True to indicate that data was read.
458 return True
459 except socket.timeout:
460 if raise_on_timeout:
461 raise TimeoutError("Timeout reading from socket")
462 return False
463 except NONBLOCKING_EXCEPTIONS as ex:
464 # if we're in nonblocking mode and the recv raises a
465 # blocking error, simply return False indicating that
466 # there's no data to be read. otherwise raise the
467 # original exception.
468 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
469 if not raise_on_timeout and ex.errno == allowed:
470 return False
471 raise ConnectionError(f"Error while reading from socket: {ex.args}")
472 finally:
473 if custom_timeout:
474 sock.settimeout(self._socket_timeout)
476 def read_response(self, disable_decoding=False):
477 if not self._reader:
478 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
480 # _next_response might be cached from a can_read() call
481 if self._next_response is not False:
482 response = self._next_response
483 self._next_response = False
484 return response
486 if disable_decoding:
487 response = self._reader.gets(False)
488 else:
489 response = self._reader.gets()
491 while response is False:
492 self.read_from_socket()
493 if disable_decoding:
494 response = self._reader.gets(False)
495 else:
496 response = self._reader.gets()
497 # if the response is a ConnectionError or the response is a list and
498 # the first item is a ConnectionError, raise it as something bad
499 # happened
500 if isinstance(response, ConnectionError):
501 raise response
502 elif (
503 isinstance(response, list)
504 and response
505 and isinstance(response[0], ConnectionError)
506 ):
507 raise response[0]
508 return response
511DefaultParser: BaseParser
512if HIREDIS_AVAILABLE:
513 DefaultParser = HiredisParser
514else:
515 DefaultParser = PythonParser
518class HiredisRespSerializer:
519 def pack(self, *args):
520 """Pack a series of arguments into the Redis protocol"""
521 output = []
523 if isinstance(args[0], str):
524 args = tuple(args[0].encode().split()) + args[1:]
525 elif b" " in args[0]:
526 args = tuple(args[0].split()) + args[1:]
527 try:
528 output.append(hiredis.pack_command(args))
529 except TypeError:
530 _, value, traceback = sys.exc_info()
531 raise DataError(value).with_traceback(traceback)
533 return output
536class PythonRespSerializer:
537 def __init__(self, buffer_cutoff, encode) -> None:
538 self._buffer_cutoff = buffer_cutoff
539 self.encode = encode
541 def pack(self, *args):
542 """Pack a series of arguments into the Redis protocol"""
543 output = []
544 # the client might have included 1 or more literal arguments in
545 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
546 # arguments to be sent separately, so split the first argument
547 # manually. These arguments should be bytestrings so that they are
548 # not encoded.
549 if isinstance(args[0], str):
550 args = tuple(args[0].encode().split()) + args[1:]
551 elif b" " in args[0]:
552 args = tuple(args[0].split()) + args[1:]
554 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
556 buffer_cutoff = self._buffer_cutoff
557 for arg in map(self.encode, args):
558 # to avoid large string mallocs, chunk the command into the
559 # output list if we're sending large values or memoryviews
560 arg_length = len(arg)
561 if (
562 len(buff) > buffer_cutoff
563 or arg_length > buffer_cutoff
564 or isinstance(arg, memoryview)
565 ):
566 buff = SYM_EMPTY.join(
567 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
568 )
569 output.append(buff)
570 output.append(arg)
571 buff = SYM_CRLF
572 else:
573 buff = SYM_EMPTY.join(
574 (
575 buff,
576 SYM_DOLLAR,
577 str(arg_length).encode(),
578 SYM_CRLF,
579 arg,
580 SYM_CRLF,
581 )
582 )
583 output.append(buff)
584 return output
587class AbstractConnection:
588 "Manages communication to and from a Redis server"
590 def __init__(
591 self,
592 db=0,
593 password=None,
594 retry_on_timeout=False,
595 retry_on_error=SENTINEL,
596 encoding="utf-8",
597 encoding_errors="strict",
598 decode_responses=False,
599 parser_class=DefaultParser,
600 socket_read_size=65536,
601 health_check_interval=0,
602 client_name=None,
603 username=None,
604 retry=None,
605 redis_connect_func=None,
606 credential_provider: Optional[CredentialProvider] = None,
607 command_packer=None,
608 ):
609 """
610 Initialize a new Connection.
611 To specify a retry policy for specific errors, first set
612 `retry_on_error` to a list of the error/s to retry on, then set
613 `retry` to a valid `Retry` object.
614 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
615 """
616 if (username or password) and credential_provider is not None:
617 raise DataError(
618 "'username' and 'password' cannot be passed along with 'credential_"
619 "provider'. Please provide only one of the following arguments: \n"
620 "1. 'password' and (optional) 'username'\n"
621 "2. 'credential_provider'"
622 )
623 self.pid = os.getpid()
624 self.db = db
625 self.client_name = client_name
626 self.credential_provider = credential_provider
627 self.password = password
628 self.username = username
629 self.retry_on_timeout = retry_on_timeout
630 if retry_on_error is SENTINEL:
631 retry_on_error = []
632 if retry_on_timeout:
633 # Add TimeoutError to the errors list to retry on
634 retry_on_error.append(TimeoutError)
635 self.retry_on_error = retry_on_error
636 if retry or retry_on_error:
637 if retry is None:
638 self.retry = Retry(NoBackoff(), 1)
639 else:
640 # deep-copy the Retry object as it is mutable
641 self.retry = copy.deepcopy(retry)
642 # Update the retry's supported errors with the specified errors
643 self.retry.update_supported_errors(retry_on_error)
644 else:
645 self.retry = Retry(NoBackoff(), 0)
646 self.health_check_interval = health_check_interval
647 self.next_health_check = 0
648 self.redis_connect_func = redis_connect_func
649 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
650 self._sock = None
651 self._socket_read_size = socket_read_size
652 self.set_parser(parser_class)
653 self._connect_callbacks = []
654 self._buffer_cutoff = 6000
655 self._command_packer = self._construct_command_packer(command_packer)
657 def __repr__(self):
658 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
659 return f"{self.__class__.__name__}<{repr_args}>"
661 @abstractmethod
662 def repr_pieces(self):
663 pass
665 def __del__(self):
666 try:
667 self.disconnect()
668 except Exception:
669 pass
671 def _construct_command_packer(self, packer):
672 if packer is not None:
673 return packer
674 elif HIREDIS_PACK_AVAILABLE:
675 return HiredisRespSerializer()
676 else:
677 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
679 def register_connect_callback(self, callback):
680 self._connect_callbacks.append(weakref.WeakMethod(callback))
682 def clear_connect_callbacks(self):
683 self._connect_callbacks = []
685 def set_parser(self, parser_class):
686 """
687 Creates a new instance of parser_class with socket size:
688 _socket_read_size and assigns it to the parser for the connection
689 :param parser_class: The required parser class
690 """
691 self._parser = parser_class(socket_read_size=self._socket_read_size)
693 def connect(self):
694 "Connects to the Redis server if not already connected"
695 if self._sock:
696 return
697 try:
698 sock = self.retry.call_with_retry(
699 lambda: self._connect(), lambda error: self.disconnect(error)
700 )
701 except socket.timeout:
702 raise TimeoutError("Timeout connecting to server")
703 except OSError as e:
704 raise ConnectionError(self._error_message(e))
706 self._sock = sock
707 try:
708 if self.redis_connect_func is None:
709 # Use the default on_connect function
710 self.on_connect()
711 else:
712 # Use the passed function redis_connect_func
713 self.redis_connect_func(self)
714 except RedisError:
715 # clean up after any error in on_connect
716 self.disconnect()
717 raise
719 # run any user callbacks. right now the only internal callback
720 # is for pubsub channel/pattern resubscription
721 for ref in self._connect_callbacks:
722 callback = ref()
723 if callback:
724 callback(self)
726 @abstractmethod
727 def _connect(self):
728 pass
730 @abstractmethod
731 def _host_error(self):
732 pass
734 @abstractmethod
735 def _error_message(self, exception):
736 pass
738 def on_connect(self):
739 "Initialize the connection, authenticate and select a database"
740 self._parser.on_connect(self)
742 # if credential provider or username and/or password are set, authenticate
743 if self.credential_provider or (self.username or self.password):
744 cred_provider = (
745 self.credential_provider
746 or UsernamePasswordCredentialProvider(self.username, self.password)
747 )
748 auth_args = cred_provider.get_credentials()
749 # avoid checking health here -- PING will fail if we try
750 # to check the health prior to the AUTH
751 self.send_command("AUTH", *auth_args, check_health=False)
753 try:
754 auth_response = self.read_response()
755 except AuthenticationWrongNumberOfArgsError:
756 # a username and password were specified but the Redis
757 # server seems to be < 6.0.0 which expects a single password
758 # arg. retry auth with just the password.
759 # https://github.com/andymccurdy/redis-py/issues/1274
760 self.send_command("AUTH", auth_args[-1], check_health=False)
761 auth_response = self.read_response()
763 if str_if_bytes(auth_response) != "OK":
764 raise AuthenticationError("Invalid Username or Password")
766 # if a client_name is given, set it
767 if self.client_name:
768 self.send_command("CLIENT", "SETNAME", self.client_name)
769 if str_if_bytes(self.read_response()) != "OK":
770 raise ConnectionError("Error setting client name")
772 # if a database is specified, switch to it
773 if self.db:
774 self.send_command("SELECT", self.db)
775 if str_if_bytes(self.read_response()) != "OK":
776 raise ConnectionError("Invalid Database")
778 def disconnect(self, *args):
779 "Disconnects from the Redis server"
780 self._parser.on_disconnect()
781 if self._sock is None:
782 return
784 if os.getpid() == self.pid:
785 try:
786 self._sock.shutdown(socket.SHUT_RDWR)
787 except OSError:
788 pass
790 try:
791 self._sock.close()
792 except OSError:
793 pass
794 self._sock = None
796 def _send_ping(self):
797 """Send PING, expect PONG in return"""
798 self.send_command("PING", check_health=False)
799 if str_if_bytes(self.read_response()) != "PONG":
800 raise ConnectionError("Bad response from PING health check")
802 def _ping_failed(self, error):
803 """Function to call when PING fails"""
804 self.disconnect()
806 def check_health(self):
807 """Check the health of the connection with a PING/PONG"""
808 if self.health_check_interval and time() > self.next_health_check:
809 self.retry.call_with_retry(self._send_ping, self._ping_failed)
811 def send_packed_command(self, command, check_health=True):
812 """Send an already packed command to the Redis server"""
813 if not self._sock:
814 self.connect()
815 # guard against health check recursion
816 if check_health:
817 self.check_health()
818 try:
819 if isinstance(command, str):
820 command = [command]
821 for item in command:
822 self._sock.sendall(item)
823 except socket.timeout:
824 self.disconnect()
825 raise TimeoutError("Timeout writing to socket")
826 except OSError as e:
827 self.disconnect()
828 if len(e.args) == 1:
829 errno, errmsg = "UNKNOWN", e.args[0]
830 else:
831 errno = e.args[0]
832 errmsg = e.args[1]
833 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
834 except Exception:
835 self.disconnect()
836 raise
838 def send_command(self, *args, **kwargs):
839 """Pack and send a command to the Redis server"""
840 self.send_packed_command(
841 self._command_packer.pack(*args),
842 check_health=kwargs.get("check_health", True),
843 )
845 def can_read(self, timeout=0):
846 """Poll the socket to see if there's data that can be read."""
847 sock = self._sock
848 if not sock:
849 self.connect()
851 host_error = self._host_error()
853 try:
854 return self._parser.can_read(timeout)
855 except OSError as e:
856 self.disconnect()
857 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
859 def read_response(self, disable_decoding=False):
860 """Read the response from a previously sent command"""
862 host_error = self._host_error()
864 try:
865 response = self._parser.read_response(disable_decoding=disable_decoding)
866 except socket.timeout:
867 self.disconnect()
868 raise TimeoutError(f"Timeout reading from {host_error}")
869 except OSError as e:
870 self.disconnect()
871 raise ConnectionError(
872 f"Error while reading from {host_error}" f" : {e.args}"
873 )
874 except Exception:
875 self.disconnect()
876 raise
878 if self.health_check_interval:
879 self.next_health_check = time() + self.health_check_interval
881 if isinstance(response, ResponseError):
882 raise response
883 return response
885 def pack_command(self, *args):
886 """Pack a series of arguments into the Redis protocol"""
887 return self._command_packer.pack(*args)
889 def pack_commands(self, commands):
890 """Pack multiple commands into the Redis protocol"""
891 output = []
892 pieces = []
893 buffer_length = 0
894 buffer_cutoff = self._buffer_cutoff
896 for cmd in commands:
897 for chunk in self._command_packer.pack(*cmd):
898 chunklen = len(chunk)
899 if (
900 buffer_length > buffer_cutoff
901 or chunklen > buffer_cutoff
902 or isinstance(chunk, memoryview)
903 ):
904 if pieces:
905 output.append(SYM_EMPTY.join(pieces))
906 buffer_length = 0
907 pieces = []
909 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
910 output.append(chunk)
911 else:
912 pieces.append(chunk)
913 buffer_length += chunklen
915 if pieces:
916 output.append(SYM_EMPTY.join(pieces))
917 return output
920class Connection(AbstractConnection):
921 "Manages TCP communication to and from a Redis server"
923 def __init__(
924 self,
925 host="localhost",
926 port=6379,
927 socket_timeout=None,
928 socket_connect_timeout=None,
929 socket_keepalive=False,
930 socket_keepalive_options=None,
931 socket_type=0,
932 **kwargs,
933 ):
934 self.host = host
935 self.port = int(port)
936 self.socket_timeout = socket_timeout
937 self.socket_connect_timeout = socket_connect_timeout or socket_timeout
938 self.socket_keepalive = socket_keepalive
939 self.socket_keepalive_options = socket_keepalive_options or {}
940 self.socket_type = socket_type
941 super().__init__(**kwargs)
943 def repr_pieces(self):
944 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
945 if self.client_name:
946 pieces.append(("client_name", self.client_name))
947 return pieces
949 def _connect(self):
950 "Create a TCP socket connection"
951 # we want to mimic what socket.create_connection does to support
952 # ipv4/ipv6, but we want to set options prior to calling
953 # socket.connect()
954 err = None
955 for res in socket.getaddrinfo(
956 self.host, self.port, self.socket_type, socket.SOCK_STREAM
957 ):
958 family, socktype, proto, canonname, socket_address = res
959 sock = None
960 try:
961 sock = socket.socket(family, socktype, proto)
962 # TCP_NODELAY
963 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
965 # TCP_KEEPALIVE
966 if self.socket_keepalive:
967 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
968 for k, v in self.socket_keepalive_options.items():
969 sock.setsockopt(socket.IPPROTO_TCP, k, v)
971 # set the socket_connect_timeout before we connect
972 sock.settimeout(self.socket_connect_timeout)
974 # connect
975 sock.connect(socket_address)
977 # set the socket_timeout now that we're connected
978 sock.settimeout(self.socket_timeout)
979 return sock
981 except OSError as _:
982 err = _
983 if sock is not None:
984 sock.close()
986 if err is not None:
987 raise err
988 raise OSError("socket.getaddrinfo returned an empty list")
990 def _host_error(self):
991 return f"{self.host}:{self.port}"
993 def _error_message(self, exception):
994 # args for socket.error can either be (errno, "message")
995 # or just "message"
997 host_error = self._host_error()
999 if len(exception.args) == 1:
1000 try:
1001 return f"Error connecting to {host_error}. \
1002 {exception.args[0]}."
1003 except AttributeError:
1004 return f"Connection Error: {exception.args[0]}"
1005 else:
1006 try:
1007 return (
1008 f"Error {exception.args[0]} connecting to "
1009 f"{host_error}. {exception.args[1]}."
1010 )
1011 except AttributeError:
1012 return f"Connection Error: {exception.args[0]}"
1015class SSLConnection(Connection):
1016 """Manages SSL connections to and from the Redis server(s).
1017 This class extends the Connection class, adding SSL functionality, and making
1018 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1019 """ # noqa
1021 def __init__(
1022 self,
1023 ssl_keyfile=None,
1024 ssl_certfile=None,
1025 ssl_cert_reqs="required",
1026 ssl_ca_certs=None,
1027 ssl_ca_data=None,
1028 ssl_check_hostname=False,
1029 ssl_ca_path=None,
1030 ssl_password=None,
1031 ssl_validate_ocsp=False,
1032 ssl_validate_ocsp_stapled=False,
1033 ssl_ocsp_context=None,
1034 ssl_ocsp_expected_cert=None,
1035 **kwargs,
1036 ):
1037 """Constructor
1039 Args:
1040 ssl_keyfile: Path to an ssl private key. Defaults to None.
1041 ssl_certfile: Path to an ssl certificate. Defaults to None.
1042 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required). Defaults to "required".
1043 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1044 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1045 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1046 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1047 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1049 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1050 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1051 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1052 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1054 Raises:
1055 RedisError
1056 """ # noqa
1057 if not ssl_available:
1058 raise RedisError("Python wasn't built with SSL support")
1060 self.keyfile = ssl_keyfile
1061 self.certfile = ssl_certfile
1062 if ssl_cert_reqs is None:
1063 ssl_cert_reqs = ssl.CERT_NONE
1064 elif isinstance(ssl_cert_reqs, str):
1065 CERT_REQS = {
1066 "none": ssl.CERT_NONE,
1067 "optional": ssl.CERT_OPTIONAL,
1068 "required": ssl.CERT_REQUIRED,
1069 }
1070 if ssl_cert_reqs not in CERT_REQS:
1071 raise RedisError(
1072 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1073 )
1074 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1075 self.cert_reqs = ssl_cert_reqs
1076 self.ca_certs = ssl_ca_certs
1077 self.ca_data = ssl_ca_data
1078 self.ca_path = ssl_ca_path
1079 self.check_hostname = ssl_check_hostname
1080 self.certificate_password = ssl_password
1081 self.ssl_validate_ocsp = ssl_validate_ocsp
1082 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1083 self.ssl_ocsp_context = ssl_ocsp_context
1084 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1085 super().__init__(**kwargs)
1087 def _connect(self):
1088 "Wrap the socket with SSL support"
1089 sock = super()._connect()
1090 context = ssl.create_default_context()
1091 context.check_hostname = self.check_hostname
1092 context.verify_mode = self.cert_reqs
1093 if self.certfile or self.keyfile:
1094 context.load_cert_chain(
1095 certfile=self.certfile,
1096 keyfile=self.keyfile,
1097 password=self.certificate_password,
1098 )
1099 if (
1100 self.ca_certs is not None
1101 or self.ca_path is not None
1102 or self.ca_data is not None
1103 ):
1104 context.load_verify_locations(
1105 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1106 )
1107 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1108 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1109 raise RedisError("cryptography is not installed.")
1111 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1112 raise RedisError(
1113 "Either an OCSP staple or pure OCSP connection must be validated "
1114 "- not both."
1115 )
1117 # validation for the stapled case
1118 if self.ssl_validate_ocsp_stapled:
1119 import OpenSSL
1121 from .ocsp import ocsp_staple_verifier
1123 # if a context is provided use it - otherwise, a basic context
1124 if self.ssl_ocsp_context is None:
1125 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1126 staple_ctx.use_certificate_file(self.certfile)
1127 staple_ctx.use_privatekey_file(self.keyfile)
1128 else:
1129 staple_ctx = self.ssl_ocsp_context
1131 staple_ctx.set_ocsp_client_callback(
1132 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1133 )
1135 # need another socket
1136 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1137 con.request_ocsp()
1138 con.connect((self.host, self.port))
1139 con.do_handshake()
1140 con.shutdown()
1141 return sslsock
1143 # pure ocsp validation
1144 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1145 from .ocsp import OCSPVerifier
1147 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1148 if o.is_valid():
1149 return sslsock
1150 else:
1151 raise ConnectionError("ocsp validation error")
1152 return sslsock
1155class UnixDomainSocketConnection(AbstractConnection):
1156 "Manages UDS communication to and from a Redis server"
1158 def __init__(self, path="", **kwargs):
1159 self.path = path
1160 super().__init__(**kwargs)
1162 def repr_pieces(self):
1163 pieces = [("path", self.path), ("db", self.db)]
1164 if self.client_name:
1165 pieces.append(("client_name", self.client_name))
1166 return pieces
1168 def _connect(self):
1169 "Create a Unix domain socket connection"
1170 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1171 sock.settimeout(self.socket_timeout)
1172 sock.connect(self.path)
1173 return sock
1175 def _host_error(self):
1176 return self.path
1178 def _error_message(self, exception):
1179 # args for socket.error can either be (errno, "message")
1180 # or just "message"
1181 host_error = self._host_error()
1182 if len(exception.args) == 1:
1183 return (
1184 f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
1185 )
1186 else:
1187 return (
1188 f"Error {exception.args[0]} connecting to unix socket: "
1189 f"{host_error}. {exception.args[1]}."
1190 )
1193FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1196def to_bool(value):
1197 if value is None or value == "":
1198 return None
1199 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1200 return False
1201 return bool(value)
1204URL_QUERY_ARGUMENT_PARSERS = {
1205 "db": int,
1206 "socket_timeout": float,
1207 "socket_connect_timeout": float,
1208 "socket_keepalive": to_bool,
1209 "retry_on_timeout": to_bool,
1210 "retry_on_error": list,
1211 "max_connections": int,
1212 "health_check_interval": int,
1213 "ssl_check_hostname": to_bool,
1214}
1217def parse_url(url):
1218 if not (
1219 url.startswith("redis://")
1220 or url.startswith("rediss://")
1221 or url.startswith("unix://")
1222 ):
1223 raise ValueError(
1224 "Redis URL must specify one of the following "
1225 "schemes (redis://, rediss://, unix://)"
1226 )
1228 url = urlparse(url)
1229 kwargs = {}
1231 for name, value in parse_qs(url.query).items():
1232 if value and len(value) > 0:
1233 value = unquote(value[0])
1234 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1235 if parser:
1236 try:
1237 kwargs[name] = parser(value)
1238 except (TypeError, ValueError):
1239 raise ValueError(f"Invalid value for `{name}` in connection URL.")
1240 else:
1241 kwargs[name] = value
1243 if url.username:
1244 kwargs["username"] = unquote(url.username)
1245 if url.password:
1246 kwargs["password"] = unquote(url.password)
1248 # We only support redis://, rediss:// and unix:// schemes.
1249 if url.scheme == "unix":
1250 if url.path:
1251 kwargs["path"] = unquote(url.path)
1252 kwargs["connection_class"] = UnixDomainSocketConnection
1254 else: # implied: url.scheme in ("redis", "rediss"):
1255 if url.hostname:
1256 kwargs["host"] = unquote(url.hostname)
1257 if url.port:
1258 kwargs["port"] = int(url.port)
1260 # If there's a path argument, use it as the db argument if a
1261 # querystring value wasn't specified
1262 if url.path and "db" not in kwargs:
1263 try:
1264 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1265 except (AttributeError, ValueError):
1266 pass
1268 if url.scheme == "rediss":
1269 kwargs["connection_class"] = SSLConnection
1271 return kwargs
1274class ConnectionPool:
1275 """
1276 Create a connection pool. ``If max_connections`` is set, then this
1277 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
1278 limit is reached.
1280 By default, TCP connections are created unless ``connection_class``
1281 is specified. Use class:`.UnixDomainSocketConnection` for
1282 unix sockets.
1284 Any additional keyword arguments are passed to the constructor of
1285 ``connection_class``.
1286 """
1288 @classmethod
1289 def from_url(cls, url, **kwargs):
1290 """
1291 Return a connection pool configured from the given URL.
1293 For example::
1295 redis://[[username]:[password]]@localhost:6379/0
1296 rediss://[[username]:[password]]@localhost:6379/0
1297 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1299 Three URL schemes are supported:
1301 - `redis://` creates a TCP socket connection. See more at:
1302 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1303 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1304 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1305 - ``unix://``: creates a Unix Domain Socket connection.
1307 The username, password, hostname, path and all querystring values
1308 are passed through urllib.parse.unquote in order to replace any
1309 percent-encoded values with their corresponding characters.
1311 There are several ways to specify a database number. The first value
1312 found will be used:
1314 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1315 2. If using the redis:// or rediss:// schemes, the path argument
1316 of the url, e.g. redis://localhost/0
1317 3. A ``db`` keyword argument to this function.
1319 If none of these options are specified, the default db=0 is used.
1321 All querystring options are cast to their appropriate Python types.
1322 Boolean arguments can be specified with string values "True"/"False"
1323 or "Yes"/"No". Values that cannot be properly cast cause a
1324 ``ValueError`` to be raised. Once parsed, the querystring arguments
1325 and keyword arguments are passed to the ``ConnectionPool``'s
1326 class initializer. In the case of conflicting arguments, querystring
1327 arguments always win.
1328 """
1329 url_options = parse_url(url)
1331 if "connection_class" in kwargs:
1332 url_options["connection_class"] = kwargs["connection_class"]
1334 kwargs.update(url_options)
1335 return cls(**kwargs)
1337 def __init__(
1338 self, connection_class=Connection, max_connections=None, **connection_kwargs
1339 ):
1340 max_connections = max_connections or 2**31
1341 if not isinstance(max_connections, int) or max_connections < 0:
1342 raise ValueError('"max_connections" must be a positive integer')
1344 self.connection_class = connection_class
1345 self.connection_kwargs = connection_kwargs
1346 self.max_connections = max_connections
1348 # a lock to protect the critical section in _checkpid().
1349 # this lock is acquired when the process id changes, such as
1350 # after a fork. during this time, multiple threads in the child
1351 # process could attempt to acquire this lock. the first thread
1352 # to acquire the lock will reset the data structures and lock
1353 # object of this pool. subsequent threads acquiring this lock
1354 # will notice the first thread already did the work and simply
1355 # release the lock.
1356 self._fork_lock = threading.Lock()
1357 self.reset()
1359 def __repr__(self):
1360 return (
1361 f"{type(self).__name__}"
1362 f"<{repr(self.connection_class(**self.connection_kwargs))}>"
1363 )
1365 def reset(self):
1366 self._lock = threading.Lock()
1367 self._created_connections = 0
1368 self._available_connections = []
1369 self._in_use_connections = set()
1371 # this must be the last operation in this method. while reset() is
1372 # called when holding _fork_lock, other threads in this process
1373 # can call _checkpid() which compares self.pid and os.getpid() without
1374 # holding any lock (for performance reasons). keeping this assignment
1375 # as the last operation ensures that those other threads will also
1376 # notice a pid difference and block waiting for the first thread to
1377 # release _fork_lock. when each of these threads eventually acquire
1378 # _fork_lock, they will notice that another thread already called
1379 # reset() and they will immediately release _fork_lock and continue on.
1380 self.pid = os.getpid()
1382 def _checkpid(self):
1383 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
1384 # systems. this is called by all ConnectionPool methods that
1385 # manipulate the pool's state such as get_connection() and release().
1386 #
1387 # _checkpid() determines whether the process has forked by comparing
1388 # the current process id to the process id saved on the ConnectionPool
1389 # instance. if these values are the same, _checkpid() simply returns.
1390 #
1391 # when the process ids differ, _checkpid() assumes that the process
1392 # has forked and that we're now running in the child process. the child
1393 # process cannot use the parent's file descriptors (e.g., sockets).
1394 # therefore, when _checkpid() sees the process id change, it calls
1395 # reset() in order to reinitialize the child's ConnectionPool. this
1396 # will cause the child to make all new connection objects.
1397 #
1398 # _checkpid() is protected by self._fork_lock to ensure that multiple
1399 # threads in the child process do not call reset() multiple times.
1400 #
1401 # there is an extremely small chance this could fail in the following
1402 # scenario:
1403 # 1. process A calls _checkpid() for the first time and acquires
1404 # self._fork_lock.
1405 # 2. while holding self._fork_lock, process A forks (the fork()
1406 # could happen in a different thread owned by process A)
1407 # 3. process B (the forked child process) inherits the
1408 # ConnectionPool's state from the parent. that state includes
1409 # a locked _fork_lock. process B will not be notified when
1410 # process A releases the _fork_lock and will thus never be
1411 # able to acquire the _fork_lock.
1412 #
1413 # to mitigate this possible deadlock, _checkpid() will only wait 5
1414 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1415 # that time it is assumed that the child is deadlocked and a
1416 # redis.ChildDeadlockedError error is raised.
1417 if self.pid != os.getpid():
1418 acquired = self._fork_lock.acquire(timeout=5)
1419 if not acquired:
1420 raise ChildDeadlockedError
1421 # reset() the instance for the new process if another thread
1422 # hasn't already done so
1423 try:
1424 if self.pid != os.getpid():
1425 self.reset()
1426 finally:
1427 self._fork_lock.release()
1429 def get_connection(self, command_name, *keys, **options):
1430 "Get a connection from the pool"
1431 self._checkpid()
1432 with self._lock:
1433 try:
1434 connection = self._available_connections.pop()
1435 except IndexError:
1436 connection = self.make_connection()
1437 self._in_use_connections.add(connection)
1439 try:
1440 # ensure this connection is connected to Redis
1441 connection.connect()
1442 # connections that the pool provides should be ready to send
1443 # a command. if not, the connection was either returned to the
1444 # pool before all data has been read or the socket has been
1445 # closed. either way, reconnect and verify everything is good.
1446 try:
1447 if connection.can_read():
1448 raise ConnectionError("Connection has data")
1449 except (ConnectionError, OSError):
1450 connection.disconnect()
1451 connection.connect()
1452 if connection.can_read():
1453 raise ConnectionError("Connection not ready")
1454 except BaseException:
1455 # release the connection back to the pool so that we don't
1456 # leak it
1457 self.release(connection)
1458 raise
1460 return connection
1462 def get_encoder(self):
1463 "Return an encoder based on encoding settings"
1464 kwargs = self.connection_kwargs
1465 return Encoder(
1466 encoding=kwargs.get("encoding", "utf-8"),
1467 encoding_errors=kwargs.get("encoding_errors", "strict"),
1468 decode_responses=kwargs.get("decode_responses", False),
1469 )
1471 def make_connection(self):
1472 "Create a new connection"
1473 if self._created_connections >= self.max_connections:
1474 raise ConnectionError("Too many connections")
1475 self._created_connections += 1
1476 return self.connection_class(**self.connection_kwargs)
1478 def release(self, connection):
1479 "Releases the connection back to the pool"
1480 self._checkpid()
1481 with self._lock:
1482 try:
1483 self._in_use_connections.remove(connection)
1484 except KeyError:
1485 # Gracefully fail when a connection is returned to this pool
1486 # that the pool doesn't actually own
1487 pass
1489 if self.owns_connection(connection):
1490 self._available_connections.append(connection)
1491 else:
1492 # pool doesn't own this connection. do not add it back
1493 # to the pool and decrement the count so that another
1494 # connection can take its place if needed
1495 self._created_connections -= 1
1496 connection.disconnect()
1497 return
1499 def owns_connection(self, connection):
1500 return connection.pid == self.pid
1502 def disconnect(self, inuse_connections=True):
1503 """
1504 Disconnects connections in the pool
1506 If ``inuse_connections`` is True, disconnect connections that are
1507 current in use, potentially by other threads. Otherwise only disconnect
1508 connections that are idle in the pool.
1509 """
1510 self._checkpid()
1511 with self._lock:
1512 if inuse_connections:
1513 connections = chain(
1514 self._available_connections, self._in_use_connections
1515 )
1516 else:
1517 connections = self._available_connections
1519 for connection in connections:
1520 connection.disconnect()
1522 def set_retry(self, retry: "Retry") -> None:
1523 self.connection_kwargs.update({"retry": retry})
1524 for conn in self._available_connections:
1525 conn.retry = retry
1526 for conn in self._in_use_connections:
1527 conn.retry = retry
1530class BlockingConnectionPool(ConnectionPool):
1531 """
1532 Thread-safe blocking connection pool::
1534 >>> from redis.client import Redis
1535 >>> client = Redis(connection_pool=BlockingConnectionPool())
1537 It performs the same function as the default
1538 :py:class:`~redis.ConnectionPool` implementation, in that,
1539 it maintains a pool of reusable connections that can be shared by
1540 multiple redis clients (safely across threads if required).
1542 The difference is that, in the event that a client tries to get a
1543 connection from the pool when all of connections are in use, rather than
1544 raising a :py:class:`~redis.ConnectionError` (as the default
1545 :py:class:`~redis.ConnectionPool` implementation does), it
1546 makes the client wait ("blocks") for a specified number of seconds until
1547 a connection becomes available.
1549 Use ``max_connections`` to increase / decrease the pool size::
1551 >>> pool = BlockingConnectionPool(max_connections=10)
1553 Use ``timeout`` to tell it either how many seconds to wait for a connection
1554 to become available, or to block forever:
1556 >>> # Block forever.
1557 >>> pool = BlockingConnectionPool(timeout=None)
1559 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1560 >>> # not available.
1561 >>> pool = BlockingConnectionPool(timeout=5)
1562 """
1564 def __init__(
1565 self,
1566 max_connections=50,
1567 timeout=20,
1568 connection_class=Connection,
1569 queue_class=LifoQueue,
1570 **connection_kwargs,
1571 ):
1573 self.queue_class = queue_class
1574 self.timeout = timeout
1575 super().__init__(
1576 connection_class=connection_class,
1577 max_connections=max_connections,
1578 **connection_kwargs,
1579 )
1581 def reset(self):
1582 # Create and fill up a thread safe queue with ``None`` values.
1583 self.pool = self.queue_class(self.max_connections)
1584 while True:
1585 try:
1586 self.pool.put_nowait(None)
1587 except Full:
1588 break
1590 # Keep a list of actual connection instances so that we can
1591 # disconnect them later.
1592 self._connections = []
1594 # this must be the last operation in this method. while reset() is
1595 # called when holding _fork_lock, other threads in this process
1596 # can call _checkpid() which compares self.pid and os.getpid() without
1597 # holding any lock (for performance reasons). keeping this assignment
1598 # as the last operation ensures that those other threads will also
1599 # notice a pid difference and block waiting for the first thread to
1600 # release _fork_lock. when each of these threads eventually acquire
1601 # _fork_lock, they will notice that another thread already called
1602 # reset() and they will immediately release _fork_lock and continue on.
1603 self.pid = os.getpid()
1605 def make_connection(self):
1606 "Make a fresh connection."
1607 connection = self.connection_class(**self.connection_kwargs)
1608 self._connections.append(connection)
1609 return connection
1611 def get_connection(self, command_name, *keys, **options):
1612 """
1613 Get a connection, blocking for ``self.timeout`` until a connection
1614 is available from the pool.
1616 If the connection returned is ``None`` then creates a new connection.
1617 Because we use a last-in first-out queue, the existing connections
1618 (having been returned to the pool after the initial ``None`` values
1619 were added) will be returned before ``None`` values. This means we only
1620 create new connections when we need to, i.e.: the actual number of
1621 connections will only increase in response to demand.
1622 """
1623 # Make sure we haven't changed process.
1624 self._checkpid()
1626 # Try and get a connection from the pool. If one isn't available within
1627 # self.timeout then raise a ``ConnectionError``.
1628 connection = None
1629 try:
1630 connection = self.pool.get(block=True, timeout=self.timeout)
1631 except Empty:
1632 # Note that this is not caught by the redis client and will be
1633 # raised unless handled by application code. If you want never to
1634 raise ConnectionError("No connection available.")
1636 # If the ``connection`` is actually ``None`` then that's a cue to make
1637 # a new connection to add to the pool.
1638 if connection is None:
1639 connection = self.make_connection()
1641 try:
1642 # ensure this connection is connected to Redis
1643 connection.connect()
1644 # connections that the pool provides should be ready to send
1645 # a command. if not, the connection was either returned to the
1646 # pool before all data has been read or the socket has been
1647 # closed. either way, reconnect and verify everything is good.
1648 try:
1649 if connection.can_read():
1650 raise ConnectionError("Connection has data")
1651 except (ConnectionError, OSError):
1652 connection.disconnect()
1653 connection.connect()
1654 if connection.can_read():
1655 raise ConnectionError("Connection not ready")
1656 except BaseException:
1657 # release the connection back to the pool so that we don't leak it
1658 self.release(connection)
1659 raise
1661 return connection
1663 def release(self, connection):
1664 "Releases the connection back to the pool."
1665 # Make sure we haven't changed process.
1666 self._checkpid()
1667 if not self.owns_connection(connection):
1668 # pool doesn't own this connection. do not add it back
1669 # to the pool. instead add a None value which is a placeholder
1670 # that will cause the pool to recreate the connection if
1671 # its needed.
1672 connection.disconnect()
1673 self.pool.put_nowait(None)
1674 return
1676 # Put the connection back into the pool.
1677 try:
1678 self.pool.put_nowait(connection)
1679 except Full:
1680 # perhaps the pool has been reset() after a fork? regardless,
1681 # we don't want this connection
1682 pass
1684 def disconnect(self):
1685 "Disconnects all connections in the pool."
1686 self._checkpid()
1687 for connection in self._connections:
1688 connection.disconnect()