Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/connection.py: 20%
864 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 07:16 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 07:16 +0000
1import copy
2import errno
3import io
4import os
5import socket
6import sys
7import threading
8import weakref
9from abc import abstractmethod
10from io import SEEK_END
11from itertools import chain
12from queue import Empty, Full, LifoQueue
13from time import time
14from typing import Optional, Union
15from urllib.parse import parse_qs, unquote, urlparse
17from redis.backoff import NoBackoff
18from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
19from redis.exceptions import (
20 AuthenticationError,
21 AuthenticationWrongNumberOfArgsError,
22 BusyLoadingError,
23 ChildDeadlockedError,
24 ConnectionError,
25 DataError,
26 ExecAbortError,
27 InvalidResponse,
28 ModuleError,
29 NoPermissionError,
30 NoScriptError,
31 ReadOnlyError,
32 RedisError,
33 ResponseError,
34 TimeoutError,
35)
36from redis.retry import Retry
37from redis.utils import (
38 CRYPTOGRAPHY_AVAILABLE,
39 HIREDIS_AVAILABLE,
40 HIREDIS_PACK_AVAILABLE,
41 str_if_bytes,
42)
44try:
45 import ssl
47 ssl_available = True
48except ImportError:
49 ssl_available = False
51NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK}
53if ssl_available:
54 if hasattr(ssl, "SSLWantReadError"):
55 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2
56 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2
57 else:
58 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2
60NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
62if HIREDIS_AVAILABLE:
63 import hiredis
65SYM_STAR = b"*"
66SYM_DOLLAR = b"$"
67SYM_CRLF = b"\r\n"
68SYM_EMPTY = b""
70SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
72SENTINEL = object()
73MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
74NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
75MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
76MODULE_EXPORTS_DATA_TYPES_ERROR = (
77 "Error unloading module: the module "
78 "exports one or more module-side data "
79 "types, can't unload"
80)
81# user send an AUTH cmd to a server without authorization configured
82NO_AUTH_SET_ERROR = {
83 # Redis >= 6.0
84 "AUTH <password> called without any password "
85 "configured for the default user. Are you sure "
86 "your configuration is correct?": AuthenticationError,
87 # Redis < 6.0
88 "Client sent AUTH, but no password is set": AuthenticationError,
89}
92class Encoder:
93 "Encode strings to bytes-like and decode bytes-like to strings"
95 def __init__(self, encoding, encoding_errors, decode_responses):
96 self.encoding = encoding
97 self.encoding_errors = encoding_errors
98 self.decode_responses = decode_responses
100 def encode(self, value):
101 "Return a bytestring or bytes-like representation of the value"
102 if isinstance(value, (bytes, memoryview)):
103 return value
104 elif isinstance(value, bool):
105 # special case bool since it is a subclass of int
106 raise DataError(
107 "Invalid input of type: 'bool'. Convert to a "
108 "bytes, string, int or float first."
109 )
110 elif isinstance(value, (int, float)):
111 value = repr(value).encode()
112 elif not isinstance(value, str):
113 # a value we don't know how to deal with. throw an error
114 typename = type(value).__name__
115 raise DataError(
116 f"Invalid input of type: '{typename}'. "
117 f"Convert to a bytes, string, int or float first."
118 )
119 if isinstance(value, str):
120 value = value.encode(self.encoding, self.encoding_errors)
121 return value
123 def decode(self, value, force=False):
124 "Return a unicode string from the bytes-like representation"
125 if self.decode_responses or force:
126 if isinstance(value, memoryview):
127 value = value.tobytes()
128 if isinstance(value, bytes):
129 value = value.decode(self.encoding, self.encoding_errors)
130 return value
133class BaseParser:
134 EXCEPTION_CLASSES = {
135 "ERR": {
136 "max number of clients reached": ConnectionError,
137 "invalid password": AuthenticationError,
138 # some Redis server versions report invalid command syntax
139 # in lowercase
140 "wrong number of arguments "
141 "for 'auth' command": AuthenticationWrongNumberOfArgsError,
142 # some Redis server versions report invalid command syntax
143 # in uppercase
144 "wrong number of arguments "
145 "for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
146 MODULE_LOAD_ERROR: ModuleError,
147 MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
148 NO_SUCH_MODULE_ERROR: ModuleError,
149 MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
150 **NO_AUTH_SET_ERROR,
151 },
152 "WRONGPASS": AuthenticationError,
153 "EXECABORT": ExecAbortError,
154 "LOADING": BusyLoadingError,
155 "NOSCRIPT": NoScriptError,
156 "READONLY": ReadOnlyError,
157 "NOAUTH": AuthenticationError,
158 "NOPERM": NoPermissionError,
159 }
161 @classmethod
162 def parse_error(cls, response):
163 "Parse an error response"
164 error_code = response.split(" ")[0]
165 if error_code in cls.EXCEPTION_CLASSES:
166 response = response[len(error_code) + 1 :]
167 exception_class = cls.EXCEPTION_CLASSES[error_code]
168 if isinstance(exception_class, dict):
169 exception_class = exception_class.get(response, ResponseError)
170 return exception_class(response)
171 return ResponseError(response)
174class SocketBuffer:
175 def __init__(
176 self, socket: socket.socket, socket_read_size: int, socket_timeout: float
177 ):
178 self._sock = socket
179 self.socket_read_size = socket_read_size
180 self.socket_timeout = socket_timeout
181 self._buffer = io.BytesIO()
183 def unread_bytes(self) -> int:
184 """
185 Remaining unread length of buffer
186 """
187 pos = self._buffer.tell()
188 end = self._buffer.seek(0, SEEK_END)
189 self._buffer.seek(pos)
190 return end - pos
192 def _read_from_socket(
193 self,
194 length: Optional[int] = None,
195 timeout: Union[float, object] = SENTINEL,
196 raise_on_timeout: Optional[bool] = True,
197 ) -> bool:
198 sock = self._sock
199 socket_read_size = self.socket_read_size
200 marker = 0
201 custom_timeout = timeout is not SENTINEL
203 buf = self._buffer
204 current_pos = buf.tell()
205 buf.seek(0, SEEK_END)
206 if custom_timeout:
207 sock.settimeout(timeout)
208 try:
209 while True:
210 data = self._sock.recv(socket_read_size)
211 # an empty string indicates the server shutdown the socket
212 if isinstance(data, bytes) and len(data) == 0:
213 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
214 buf.write(data)
215 data_length = len(data)
216 marker += data_length
218 if length is not None and length > marker:
219 continue
220 return True
221 except socket.timeout:
222 if raise_on_timeout:
223 raise TimeoutError("Timeout reading from socket")
224 return False
225 except NONBLOCKING_EXCEPTIONS as ex:
226 # if we're in nonblocking mode and the recv raises a
227 # blocking error, simply return False indicating that
228 # there's no data to be read. otherwise raise the
229 # original exception.
230 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
231 if not raise_on_timeout and ex.errno == allowed:
232 return False
233 raise ConnectionError(f"Error while reading from socket: {ex.args}")
234 finally:
235 buf.seek(current_pos)
236 if custom_timeout:
237 sock.settimeout(self.socket_timeout)
239 def can_read(self, timeout: float) -> bool:
240 return bool(self.unread_bytes()) or self._read_from_socket(
241 timeout=timeout, raise_on_timeout=False
242 )
244 def read(self, length: int) -> bytes:
245 length = length + 2 # make sure to read the \r\n terminator
246 # BufferIO will return less than requested if buffer is short
247 data = self._buffer.read(length)
248 missing = length - len(data)
249 if missing:
250 # fill up the buffer and read the remainder
251 self._read_from_socket(missing)
252 data += self._buffer.read(missing)
253 return data[:-2]
255 def readline(self) -> bytes:
256 buf = self._buffer
257 data = buf.readline()
258 while not data.endswith(SYM_CRLF):
259 # there's more data in the socket that we need
260 self._read_from_socket()
261 data += buf.readline()
263 return data[:-2]
265 def get_pos(self) -> int:
266 """
267 Get current read position
268 """
269 return self._buffer.tell()
271 def rewind(self, pos: int) -> None:
272 """
273 Rewind the buffer to a specific position, to re-start reading
274 """
275 self._buffer.seek(pos)
277 def purge(self) -> None:
278 """
279 After a successful read, purge the read part of buffer
280 """
281 unread = self.unread_bytes()
283 # Only if we have read all of the buffer do we truncate, to
284 # reduce the amount of memory thrashing. This heuristic
285 # can be changed or removed later.
286 if unread > 0:
287 return
289 if unread > 0:
290 # move unread data to the front
291 view = self._buffer.getbuffer()
292 view[:unread] = view[-unread:]
293 self._buffer.truncate(unread)
294 self._buffer.seek(0)
296 def close(self) -> None:
297 try:
298 self._buffer.close()
299 except Exception:
300 # issue #633 suggests the purge/close somehow raised a
301 # BadFileDescriptor error. Perhaps the client ran out of
302 # memory or something else? It's probably OK to ignore
303 # any error being raised from purge/close since we're
304 # removing the reference to the instance below.
305 pass
306 self._buffer = None
307 self._sock = None
310class PythonParser(BaseParser):
311 "Plain Python parsing class"
313 def __init__(self, socket_read_size):
314 self.socket_read_size = socket_read_size
315 self.encoder = None
316 self._sock = None
317 self._buffer = None
319 def __del__(self):
320 try:
321 self.on_disconnect()
322 except Exception:
323 pass
325 def on_connect(self, connection):
326 "Called when the socket connects"
327 self._sock = connection._sock
328 self._buffer = SocketBuffer(
329 self._sock, self.socket_read_size, connection.socket_timeout
330 )
331 self.encoder = connection.encoder
333 def on_disconnect(self):
334 "Called when the socket disconnects"
335 self._sock = None
336 if self._buffer is not None:
337 self._buffer.close()
338 self._buffer = None
339 self.encoder = None
341 def can_read(self, timeout):
342 return self._buffer and self._buffer.can_read(timeout)
344 def read_response(self, disable_decoding=False):
345 pos = self._buffer.get_pos() if self._buffer else None
346 try:
347 result = self._read_response(disable_decoding=disable_decoding)
348 except BaseException:
349 if self._buffer:
350 self._buffer.rewind(pos)
351 raise
352 else:
353 self._buffer.purge()
354 return result
356 def _read_response(self, disable_decoding=False):
357 raw = self._buffer.readline()
358 if not raw:
359 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
361 byte, response = raw[:1], raw[1:]
363 # server returned an error
364 if byte == b"-":
365 response = response.decode("utf-8", errors="replace")
366 error = self.parse_error(response)
367 # if the error is a ConnectionError, raise immediately so the user
368 # is notified
369 if isinstance(error, ConnectionError):
370 raise error
371 # otherwise, we're dealing with a ResponseError that might belong
372 # inside a pipeline response. the connection's read_response()
373 # and/or the pipeline's execute() will raise this error if
374 # necessary, so just return the exception instance here.
375 return error
376 # single value
377 elif byte == b"+":
378 pass
379 # int value
380 elif byte == b":":
381 return int(response)
382 # bulk response
383 elif byte == b"$" and response == b"-1":
384 return None
385 elif byte == b"$":
386 response = self._buffer.read(int(response))
387 # multi-bulk response
388 elif byte == b"*" and response == b"-1":
389 return None
390 elif byte == b"*":
391 response = [
392 self._read_response(disable_decoding=disable_decoding)
393 for i in range(int(response))
394 ]
395 else:
396 raise InvalidResponse(f"Protocol Error: {raw!r}")
398 if disable_decoding is False:
399 response = self.encoder.decode(response)
400 return response
403class HiredisParser(BaseParser):
404 "Parser class for connections using Hiredis"
406 def __init__(self, socket_read_size):
407 if not HIREDIS_AVAILABLE:
408 raise RedisError("Hiredis is not installed")
409 self.socket_read_size = socket_read_size
410 self._buffer = bytearray(socket_read_size)
412 def __del__(self):
413 try:
414 self.on_disconnect()
415 except Exception:
416 pass
418 def on_connect(self, connection, **kwargs):
419 self._sock = connection._sock
420 self._socket_timeout = connection.socket_timeout
421 kwargs = {
422 "protocolError": InvalidResponse,
423 "replyError": self.parse_error,
424 "errors": connection.encoder.encoding_errors,
425 }
427 if connection.encoder.decode_responses:
428 kwargs["encoding"] = connection.encoder.encoding
429 self._reader = hiredis.Reader(**kwargs)
430 self._next_response = False
432 def on_disconnect(self):
433 self._sock = None
434 self._reader = None
435 self._next_response = False
437 def can_read(self, timeout):
438 if not self._reader:
439 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
441 if self._next_response is False:
442 self._next_response = self._reader.gets()
443 if self._next_response is False:
444 return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
445 return True
447 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
448 sock = self._sock
449 custom_timeout = timeout is not SENTINEL
450 try:
451 if custom_timeout:
452 sock.settimeout(timeout)
453 bufflen = self._sock.recv_into(self._buffer)
454 if bufflen == 0:
455 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
456 self._reader.feed(self._buffer, 0, bufflen)
457 # data was read from the socket and added to the buffer.
458 # return True to indicate that data was read.
459 return True
460 except socket.timeout:
461 if raise_on_timeout:
462 raise TimeoutError("Timeout reading from socket")
463 return False
464 except NONBLOCKING_EXCEPTIONS as ex:
465 # if we're in nonblocking mode and the recv raises a
466 # blocking error, simply return False indicating that
467 # there's no data to be read. otherwise raise the
468 # original exception.
469 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
470 if not raise_on_timeout and ex.errno == allowed:
471 return False
472 raise ConnectionError(f"Error while reading from socket: {ex.args}")
473 finally:
474 if custom_timeout:
475 sock.settimeout(self._socket_timeout)
477 def read_response(self, disable_decoding=False):
478 if not self._reader:
479 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
481 # _next_response might be cached from a can_read() call
482 if self._next_response is not False:
483 response = self._next_response
484 self._next_response = False
485 return response
487 if disable_decoding:
488 response = self._reader.gets(False)
489 else:
490 response = self._reader.gets()
492 while response is False:
493 self.read_from_socket()
494 if disable_decoding:
495 response = self._reader.gets(False)
496 else:
497 response = self._reader.gets()
498 # if the response is a ConnectionError or the response is a list and
499 # the first item is a ConnectionError, raise it as something bad
500 # happened
501 if isinstance(response, ConnectionError):
502 raise response
503 elif (
504 isinstance(response, list)
505 and response
506 and isinstance(response[0], ConnectionError)
507 ):
508 raise response[0]
509 return response
512DefaultParser: BaseParser
513if HIREDIS_AVAILABLE:
514 DefaultParser = HiredisParser
515else:
516 DefaultParser = PythonParser
519class HiredisRespSerializer:
520 def pack(self, *args):
521 """Pack a series of arguments into the Redis protocol"""
522 output = []
524 if isinstance(args[0], str):
525 args = tuple(args[0].encode().split()) + args[1:]
526 elif b" " in args[0]:
527 args = tuple(args[0].split()) + args[1:]
528 try:
529 output.append(hiredis.pack_command(args))
530 except TypeError:
531 _, value, traceback = sys.exc_info()
532 raise DataError(value).with_traceback(traceback)
534 return output
537class PythonRespSerializer:
538 def __init__(self, buffer_cutoff, encode) -> None:
539 self._buffer_cutoff = buffer_cutoff
540 self.encode = encode
542 def pack(self, *args):
543 """Pack a series of arguments into the Redis protocol"""
544 output = []
545 # the client might have included 1 or more literal arguments in
546 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
547 # arguments to be sent separately, so split the first argument
548 # manually. These arguments should be bytestrings so that they are
549 # not encoded.
550 if isinstance(args[0], str):
551 args = tuple(args[0].encode().split()) + args[1:]
552 elif b" " in args[0]:
553 args = tuple(args[0].split()) + args[1:]
555 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
557 buffer_cutoff = self._buffer_cutoff
558 for arg in map(self.encode, args):
559 # to avoid large string mallocs, chunk the command into the
560 # output list if we're sending large values or memoryviews
561 arg_length = len(arg)
562 if (
563 len(buff) > buffer_cutoff
564 or arg_length > buffer_cutoff
565 or isinstance(arg, memoryview)
566 ):
567 buff = SYM_EMPTY.join(
568 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
569 )
570 output.append(buff)
571 output.append(arg)
572 buff = SYM_CRLF
573 else:
574 buff = SYM_EMPTY.join(
575 (
576 buff,
577 SYM_DOLLAR,
578 str(arg_length).encode(),
579 SYM_CRLF,
580 arg,
581 SYM_CRLF,
582 )
583 )
584 output.append(buff)
585 return output
588class AbstractConnection:
589 "Manages communication to and from a Redis server"
591 def __init__(
592 self,
593 db=0,
594 password=None,
595 retry_on_timeout=False,
596 retry_on_error=SENTINEL,
597 encoding="utf-8",
598 encoding_errors="strict",
599 decode_responses=False,
600 parser_class=DefaultParser,
601 socket_read_size=65536,
602 health_check_interval=0,
603 client_name=None,
604 username=None,
605 retry=None,
606 redis_connect_func=None,
607 credential_provider: Optional[CredentialProvider] = None,
608 command_packer=None,
609 ):
610 """
611 Initialize a new Connection.
612 To specify a retry policy for specific errors, first set
613 `retry_on_error` to a list of the error/s to retry on, then set
614 `retry` to a valid `Retry` object.
615 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
616 """
617 if (username or password) and credential_provider is not None:
618 raise DataError(
619 "'username' and 'password' cannot be passed along with 'credential_"
620 "provider'. Please provide only one of the following arguments: \n"
621 "1. 'password' and (optional) 'username'\n"
622 "2. 'credential_provider'"
623 )
624 self.pid = os.getpid()
625 self.db = db
626 self.client_name = client_name
627 self.credential_provider = credential_provider
628 self.password = password
629 self.username = username
630 self.retry_on_timeout = retry_on_timeout
631 if retry_on_error is SENTINEL:
632 retry_on_error = []
633 if retry_on_timeout:
634 # Add TimeoutError to the errors list to retry on
635 retry_on_error.append(TimeoutError)
636 self.retry_on_error = retry_on_error
637 if retry or retry_on_error:
638 if retry is None:
639 self.retry = Retry(NoBackoff(), 1)
640 else:
641 # deep-copy the Retry object as it is mutable
642 self.retry = copy.deepcopy(retry)
643 # Update the retry's supported errors with the specified errors
644 self.retry.update_supported_errors(retry_on_error)
645 else:
646 self.retry = Retry(NoBackoff(), 0)
647 self.health_check_interval = health_check_interval
648 self.next_health_check = 0
649 self.redis_connect_func = redis_connect_func
650 self.encoder = Encoder(encoding, encoding_errors, decode_responses)
651 self._sock = None
652 self._socket_read_size = socket_read_size
653 self.set_parser(parser_class)
654 self._connect_callbacks = []
655 self._buffer_cutoff = 6000
656 self._command_packer = self._construct_command_packer(command_packer)
658 def __repr__(self):
659 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
660 return f"{self.__class__.__name__}<{repr_args}>"
662 @abstractmethod
663 def repr_pieces(self):
664 pass
666 def __del__(self):
667 try:
668 self.disconnect()
669 except Exception:
670 pass
672 def _construct_command_packer(self, packer):
673 if packer is not None:
674 return packer
675 elif HIREDIS_PACK_AVAILABLE:
676 return HiredisRespSerializer()
677 else:
678 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
680 def register_connect_callback(self, callback):
681 self._connect_callbacks.append(weakref.WeakMethod(callback))
683 def clear_connect_callbacks(self):
684 self._connect_callbacks = []
686 def set_parser(self, parser_class):
687 """
688 Creates a new instance of parser_class with socket size:
689 _socket_read_size and assigns it to the parser for the connection
690 :param parser_class: The required parser class
691 """
692 self._parser = parser_class(socket_read_size=self._socket_read_size)
694 def connect(self):
695 "Connects to the Redis server if not already connected"
696 if self._sock:
697 return
698 try:
699 sock = self.retry.call_with_retry(
700 lambda: self._connect(), lambda error: self.disconnect(error)
701 )
702 except socket.timeout:
703 raise TimeoutError("Timeout connecting to server")
704 except OSError as e:
705 raise ConnectionError(self._error_message(e))
707 self._sock = sock
708 try:
709 if self.redis_connect_func is None:
710 # Use the default on_connect function
711 self.on_connect()
712 else:
713 # Use the passed function redis_connect_func
714 self.redis_connect_func(self)
715 except RedisError:
716 # clean up after any error in on_connect
717 self.disconnect()
718 raise
720 # run any user callbacks. right now the only internal callback
721 # is for pubsub channel/pattern resubscription
722 for ref in self._connect_callbacks:
723 callback = ref()
724 if callback:
725 callback(self)
727 @abstractmethod
728 def _connect(self):
729 pass
731 @abstractmethod
732 def _host_error(self):
733 pass
735 @abstractmethod
736 def _error_message(self, exception):
737 pass
739 def on_connect(self):
740 "Initialize the connection, authenticate and select a database"
741 self._parser.on_connect(self)
743 # if credential provider or username and/or password are set, authenticate
744 if self.credential_provider or (self.username or self.password):
745 cred_provider = (
746 self.credential_provider
747 or UsernamePasswordCredentialProvider(self.username, self.password)
748 )
749 auth_args = cred_provider.get_credentials()
750 # avoid checking health here -- PING will fail if we try
751 # to check the health prior to the AUTH
752 self.send_command("AUTH", *auth_args, check_health=False)
754 try:
755 auth_response = self.read_response()
756 except AuthenticationWrongNumberOfArgsError:
757 # a username and password were specified but the Redis
758 # server seems to be < 6.0.0 which expects a single password
759 # arg. retry auth with just the password.
760 # https://github.com/andymccurdy/redis-py/issues/1274
761 self.send_command("AUTH", auth_args[-1], check_health=False)
762 auth_response = self.read_response()
764 if str_if_bytes(auth_response) != "OK":
765 raise AuthenticationError("Invalid Username or Password")
767 # if a client_name is given, set it
768 if self.client_name:
769 self.send_command("CLIENT", "SETNAME", self.client_name)
770 if str_if_bytes(self.read_response()) != "OK":
771 raise ConnectionError("Error setting client name")
773 # if a database is specified, switch to it
774 if self.db:
775 self.send_command("SELECT", self.db)
776 if str_if_bytes(self.read_response()) != "OK":
777 raise ConnectionError("Invalid Database")
779 def disconnect(self, *args):
780 "Disconnects from the Redis server"
781 self._parser.on_disconnect()
783 conn_sock = self._sock
784 self._sock = None
785 if conn_sock is None:
786 return
788 if os.getpid() == self.pid:
789 try:
790 conn_sock.shutdown(socket.SHUT_RDWR)
791 except OSError:
792 pass
794 try:
795 conn_sock.close()
796 except OSError:
797 pass
799 def _send_ping(self):
800 """Send PING, expect PONG in return"""
801 self.send_command("PING", check_health=False)
802 if str_if_bytes(self.read_response()) != "PONG":
803 raise ConnectionError("Bad response from PING health check")
805 def _ping_failed(self, error):
806 """Function to call when PING fails"""
807 self.disconnect()
809 def check_health(self):
810 """Check the health of the connection with a PING/PONG"""
811 if self.health_check_interval and time() > self.next_health_check:
812 self.retry.call_with_retry(self._send_ping, self._ping_failed)
814 def send_packed_command(self, command, check_health=True):
815 """Send an already packed command to the Redis server"""
816 if not self._sock:
817 self.connect()
818 # guard against health check recursion
819 if check_health:
820 self.check_health()
821 try:
822 if isinstance(command, str):
823 command = [command]
824 for item in command:
825 self._sock.sendall(item)
826 except socket.timeout:
827 self.disconnect()
828 raise TimeoutError("Timeout writing to socket")
829 except OSError as e:
830 self.disconnect()
831 if len(e.args) == 1:
832 errno, errmsg = "UNKNOWN", e.args[0]
833 else:
834 errno = e.args[0]
835 errmsg = e.args[1]
836 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
837 except BaseException:
838 # BaseExceptions can be raised when a socket send operation is not
839 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
840 # to send un-sent data. However, the send_packed_command() API
841 # does not support it so there is no point in keeping the connection open.
842 self.disconnect()
843 raise
845 def send_command(self, *args, **kwargs):
846 """Pack and send a command to the Redis server"""
847 self.send_packed_command(
848 self._command_packer.pack(*args),
849 check_health=kwargs.get("check_health", True),
850 )
852 def can_read(self, timeout=0):
853 """Poll the socket to see if there's data that can be read."""
854 sock = self._sock
855 if not sock:
856 self.connect()
858 host_error = self._host_error()
860 try:
861 return self._parser.can_read(timeout)
862 except OSError as e:
863 self.disconnect()
864 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
866 def read_response(
867 self, disable_decoding=False, *, disconnect_on_error: bool = True
868 ):
869 """Read the response from a previously sent command"""
871 host_error = self._host_error()
873 try:
874 response = self._parser.read_response(disable_decoding=disable_decoding)
875 except socket.timeout:
876 if disconnect_on_error:
877 self.disconnect()
878 raise TimeoutError(f"Timeout reading from {host_error}")
879 except OSError as e:
880 if disconnect_on_error:
881 self.disconnect()
882 raise ConnectionError(
883 f"Error while reading from {host_error}" f" : {e.args}"
884 )
885 except BaseException:
886 # Also by default close in case of BaseException. A lot of code
887 # relies on this behaviour when doing Command/Response pairs.
888 # See #1128.
889 if disconnect_on_error:
890 self.disconnect()
891 raise
893 if self.health_check_interval:
894 self.next_health_check = time() + self.health_check_interval
896 if isinstance(response, ResponseError):
897 raise response
898 return response
900 def pack_command(self, *args):
901 """Pack a series of arguments into the Redis protocol"""
902 return self._command_packer.pack(*args)
904 def pack_commands(self, commands):
905 """Pack multiple commands into the Redis protocol"""
906 output = []
907 pieces = []
908 buffer_length = 0
909 buffer_cutoff = self._buffer_cutoff
911 for cmd in commands:
912 for chunk in self._command_packer.pack(*cmd):
913 chunklen = len(chunk)
914 if (
915 buffer_length > buffer_cutoff
916 or chunklen > buffer_cutoff
917 or isinstance(chunk, memoryview)
918 ):
919 if pieces:
920 output.append(SYM_EMPTY.join(pieces))
921 buffer_length = 0
922 pieces = []
924 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
925 output.append(chunk)
926 else:
927 pieces.append(chunk)
928 buffer_length += chunklen
930 if pieces:
931 output.append(SYM_EMPTY.join(pieces))
932 return output
935class Connection(AbstractConnection):
936 "Manages TCP communication to and from a Redis server"
938 def __init__(
939 self,
940 host="localhost",
941 port=6379,
942 socket_timeout=None,
943 socket_connect_timeout=None,
944 socket_keepalive=False,
945 socket_keepalive_options=None,
946 socket_type=0,
947 **kwargs,
948 ):
949 self.host = host
950 self.port = int(port)
951 self.socket_timeout = socket_timeout
952 self.socket_connect_timeout = socket_connect_timeout or socket_timeout
953 self.socket_keepalive = socket_keepalive
954 self.socket_keepalive_options = socket_keepalive_options or {}
955 self.socket_type = socket_type
956 super().__init__(**kwargs)
958 def repr_pieces(self):
959 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
960 if self.client_name:
961 pieces.append(("client_name", self.client_name))
962 return pieces
964 def _connect(self):
965 "Create a TCP socket connection"
966 # we want to mimic what socket.create_connection does to support
967 # ipv4/ipv6, but we want to set options prior to calling
968 # socket.connect()
969 err = None
970 for res in socket.getaddrinfo(
971 self.host, self.port, self.socket_type, socket.SOCK_STREAM
972 ):
973 family, socktype, proto, canonname, socket_address = res
974 sock = None
975 try:
976 sock = socket.socket(family, socktype, proto)
977 # TCP_NODELAY
978 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
980 # TCP_KEEPALIVE
981 if self.socket_keepalive:
982 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
983 for k, v in self.socket_keepalive_options.items():
984 sock.setsockopt(socket.IPPROTO_TCP, k, v)
986 # set the socket_connect_timeout before we connect
987 sock.settimeout(self.socket_connect_timeout)
989 # connect
990 sock.connect(socket_address)
992 # set the socket_timeout now that we're connected
993 sock.settimeout(self.socket_timeout)
994 return sock
996 except OSError as _:
997 err = _
998 if sock is not None:
999 sock.close()
1001 if err is not None:
1002 raise err
1003 raise OSError("socket.getaddrinfo returned an empty list")
1005 def _host_error(self):
1006 return f"{self.host}:{self.port}"
1008 def _error_message(self, exception):
1009 # args for socket.error can either be (errno, "message")
1010 # or just "message"
1012 host_error = self._host_error()
1014 if len(exception.args) == 1:
1015 try:
1016 return f"Error connecting to {host_error}. \
1017 {exception.args[0]}."
1018 except AttributeError:
1019 return f"Connection Error: {exception.args[0]}"
1020 else:
1021 try:
1022 return (
1023 f"Error {exception.args[0]} connecting to "
1024 f"{host_error}. {exception.args[1]}."
1025 )
1026 except AttributeError:
1027 return f"Connection Error: {exception.args[0]}"
1030class SSLConnection(Connection):
1031 """Manages SSL connections to and from the Redis server(s).
1032 This class extends the Connection class, adding SSL functionality, and making
1033 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
1034 """ # noqa
1036 def __init__(
1037 self,
1038 ssl_keyfile=None,
1039 ssl_certfile=None,
1040 ssl_cert_reqs="required",
1041 ssl_ca_certs=None,
1042 ssl_ca_data=None,
1043 ssl_check_hostname=False,
1044 ssl_ca_path=None,
1045 ssl_password=None,
1046 ssl_validate_ocsp=False,
1047 ssl_validate_ocsp_stapled=False,
1048 ssl_ocsp_context=None,
1049 ssl_ocsp_expected_cert=None,
1050 **kwargs,
1051 ):
1052 """Constructor
1054 Args:
1055 ssl_keyfile: Path to an ssl private key. Defaults to None.
1056 ssl_certfile: Path to an ssl certificate. Defaults to None.
1057 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required). Defaults to "required".
1058 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1059 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1060 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1061 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1062 ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1064 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification)
1065 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
1066 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
1067 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
1069 Raises:
1070 RedisError
1071 """ # noqa
1072 if not ssl_available:
1073 raise RedisError("Python wasn't built with SSL support")
1075 self.keyfile = ssl_keyfile
1076 self.certfile = ssl_certfile
1077 if ssl_cert_reqs is None:
1078 ssl_cert_reqs = ssl.CERT_NONE
1079 elif isinstance(ssl_cert_reqs, str):
1080 CERT_REQS = {
1081 "none": ssl.CERT_NONE,
1082 "optional": ssl.CERT_OPTIONAL,
1083 "required": ssl.CERT_REQUIRED,
1084 }
1085 if ssl_cert_reqs not in CERT_REQS:
1086 raise RedisError(
1087 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}"
1088 )
1089 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1090 self.cert_reqs = ssl_cert_reqs
1091 self.ca_certs = ssl_ca_certs
1092 self.ca_data = ssl_ca_data
1093 self.ca_path = ssl_ca_path
1094 self.check_hostname = ssl_check_hostname
1095 self.certificate_password = ssl_password
1096 self.ssl_validate_ocsp = ssl_validate_ocsp
1097 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
1098 self.ssl_ocsp_context = ssl_ocsp_context
1099 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1100 super().__init__(**kwargs)
1102 def _connect(self):
1103 "Wrap the socket with SSL support"
1104 sock = super()._connect()
1105 context = ssl.create_default_context()
1106 context.check_hostname = self.check_hostname
1107 context.verify_mode = self.cert_reqs
1108 if self.certfile or self.keyfile:
1109 context.load_cert_chain(
1110 certfile=self.certfile,
1111 keyfile=self.keyfile,
1112 password=self.certificate_password,
1113 )
1114 if (
1115 self.ca_certs is not None
1116 or self.ca_path is not None
1117 or self.ca_data is not None
1118 ):
1119 context.load_verify_locations(
1120 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1121 )
1122 sslsock = context.wrap_socket(sock, server_hostname=self.host)
1123 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1124 raise RedisError("cryptography is not installed.")
1126 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1127 raise RedisError(
1128 "Either an OCSP staple or pure OCSP connection must be validated "
1129 "- not both."
1130 )
1132 # validation for the stapled case
1133 if self.ssl_validate_ocsp_stapled:
1134 import OpenSSL
1136 from .ocsp import ocsp_staple_verifier
1138 # if a context is provided use it - otherwise, a basic context
1139 if self.ssl_ocsp_context is None:
1140 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1141 staple_ctx.use_certificate_file(self.certfile)
1142 staple_ctx.use_privatekey_file(self.keyfile)
1143 else:
1144 staple_ctx = self.ssl_ocsp_context
1146 staple_ctx.set_ocsp_client_callback(
1147 ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1148 )
1150 # need another socket
1151 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1152 con.request_ocsp()
1153 con.connect((self.host, self.port))
1154 con.do_handshake()
1155 con.shutdown()
1156 return sslsock
1158 # pure ocsp validation
1159 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1160 from .ocsp import OCSPVerifier
1162 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1163 if o.is_valid():
1164 return sslsock
1165 else:
1166 raise ConnectionError("ocsp validation error")
1167 return sslsock
1170class UnixDomainSocketConnection(AbstractConnection):
1171 "Manages UDS communication to and from a Redis server"
1173 def __init__(self, path="", socket_timeout=None, **kwargs):
1174 self.path = path
1175 self.socket_timeout = socket_timeout
1176 super().__init__(**kwargs)
1178 def repr_pieces(self):
1179 pieces = [("path", self.path), ("db", self.db)]
1180 if self.client_name:
1181 pieces.append(("client_name", self.client_name))
1182 return pieces
1184 def _connect(self):
1185 "Create a Unix domain socket connection"
1186 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1187 sock.settimeout(self.socket_timeout)
1188 sock.connect(self.path)
1189 return sock
1191 def _host_error(self):
1192 return self.path
1194 def _error_message(self, exception):
1195 # args for socket.error can either be (errno, "message")
1196 # or just "message"
1197 host_error = self._host_error()
1198 if len(exception.args) == 1:
1199 return (
1200 f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
1201 )
1202 else:
1203 return (
1204 f"Error {exception.args[0]} connecting to unix socket: "
1205 f"{host_error}. {exception.args[1]}."
1206 )
1209FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1212def to_bool(value):
1213 if value is None or value == "":
1214 return None
1215 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1216 return False
1217 return bool(value)
1220URL_QUERY_ARGUMENT_PARSERS = {
1221 "db": int,
1222 "socket_timeout": float,
1223 "socket_connect_timeout": float,
1224 "socket_keepalive": to_bool,
1225 "retry_on_timeout": to_bool,
1226 "retry_on_error": list,
1227 "max_connections": int,
1228 "health_check_interval": int,
1229 "ssl_check_hostname": to_bool,
1230}
1233def parse_url(url):
1234 if not (
1235 url.startswith("redis://")
1236 or url.startswith("rediss://")
1237 or url.startswith("unix://")
1238 ):
1239 raise ValueError(
1240 "Redis URL must specify one of the following "
1241 "schemes (redis://, rediss://, unix://)"
1242 )
1244 url = urlparse(url)
1245 kwargs = {}
1247 for name, value in parse_qs(url.query).items():
1248 if value and len(value) > 0:
1249 value = unquote(value[0])
1250 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1251 if parser:
1252 try:
1253 kwargs[name] = parser(value)
1254 except (TypeError, ValueError):
1255 raise ValueError(f"Invalid value for `{name}` in connection URL.")
1256 else:
1257 kwargs[name] = value
1259 if url.username:
1260 kwargs["username"] = unquote(url.username)
1261 if url.password:
1262 kwargs["password"] = unquote(url.password)
1264 # We only support redis://, rediss:// and unix:// schemes.
1265 if url.scheme == "unix":
1266 if url.path:
1267 kwargs["path"] = unquote(url.path)
1268 kwargs["connection_class"] = UnixDomainSocketConnection
1270 else: # implied: url.scheme in ("redis", "rediss"):
1271 if url.hostname:
1272 kwargs["host"] = unquote(url.hostname)
1273 if url.port:
1274 kwargs["port"] = int(url.port)
1276 # If there's a path argument, use it as the db argument if a
1277 # querystring value wasn't specified
1278 if url.path and "db" not in kwargs:
1279 try:
1280 kwargs["db"] = int(unquote(url.path).replace("/", ""))
1281 except (AttributeError, ValueError):
1282 pass
1284 if url.scheme == "rediss":
1285 kwargs["connection_class"] = SSLConnection
1287 return kwargs
1290class ConnectionPool:
1291 """
1292 Create a connection pool. ``If max_connections`` is set, then this
1293 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
1294 limit is reached.
1296 By default, TCP connections are created unless ``connection_class``
1297 is specified. Use class:`.UnixDomainSocketConnection` for
1298 unix sockets.
1300 Any additional keyword arguments are passed to the constructor of
1301 ``connection_class``.
1302 """
1304 @classmethod
1305 def from_url(cls, url, **kwargs):
1306 """
1307 Return a connection pool configured from the given URL.
1309 For example::
1311 redis://[[username]:[password]]@localhost:6379/0
1312 rediss://[[username]:[password]]@localhost:6379/0
1313 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1315 Three URL schemes are supported:
1317 - `redis://` creates a TCP socket connection. See more at:
1318 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1319 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1320 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1321 - ``unix://``: creates a Unix Domain Socket connection.
1323 The username, password, hostname, path and all querystring values
1324 are passed through urllib.parse.unquote in order to replace any
1325 percent-encoded values with their corresponding characters.
1327 There are several ways to specify a database number. The first value
1328 found will be used:
1330 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1331 2. If using the redis:// or rediss:// schemes, the path argument
1332 of the url, e.g. redis://localhost/0
1333 3. A ``db`` keyword argument to this function.
1335 If none of these options are specified, the default db=0 is used.
1337 All querystring options are cast to their appropriate Python types.
1338 Boolean arguments can be specified with string values "True"/"False"
1339 or "Yes"/"No". Values that cannot be properly cast cause a
1340 ``ValueError`` to be raised. Once parsed, the querystring arguments
1341 and keyword arguments are passed to the ``ConnectionPool``'s
1342 class initializer. In the case of conflicting arguments, querystring
1343 arguments always win.
1344 """
1345 url_options = parse_url(url)
1347 if "connection_class" in kwargs:
1348 url_options["connection_class"] = kwargs["connection_class"]
1350 kwargs.update(url_options)
1351 return cls(**kwargs)
1353 def __init__(
1354 self, connection_class=Connection, max_connections=None, **connection_kwargs
1355 ):
1356 max_connections = max_connections or 2**31
1357 if not isinstance(max_connections, int) or max_connections < 0:
1358 raise ValueError('"max_connections" must be a positive integer')
1360 self.connection_class = connection_class
1361 self.connection_kwargs = connection_kwargs
1362 self.max_connections = max_connections
1364 # a lock to protect the critical section in _checkpid().
1365 # this lock is acquired when the process id changes, such as
1366 # after a fork. during this time, multiple threads in the child
1367 # process could attempt to acquire this lock. the first thread
1368 # to acquire the lock will reset the data structures and lock
1369 # object of this pool. subsequent threads acquiring this lock
1370 # will notice the first thread already did the work and simply
1371 # release the lock.
1372 self._fork_lock = threading.Lock()
1373 self.reset()
1375 def __repr__(self):
1376 return (
1377 f"{type(self).__name__}"
1378 f"<{repr(self.connection_class(**self.connection_kwargs))}>"
1379 )
1381 def reset(self):
1382 self._lock = threading.Lock()
1383 self._created_connections = 0
1384 self._available_connections = []
1385 self._in_use_connections = set()
1387 # this must be the last operation in this method. while reset() is
1388 # called when holding _fork_lock, other threads in this process
1389 # can call _checkpid() which compares self.pid and os.getpid() without
1390 # holding any lock (for performance reasons). keeping this assignment
1391 # as the last operation ensures that those other threads will also
1392 # notice a pid difference and block waiting for the first thread to
1393 # release _fork_lock. when each of these threads eventually acquire
1394 # _fork_lock, they will notice that another thread already called
1395 # reset() and they will immediately release _fork_lock and continue on.
1396 self.pid = os.getpid()
1398 def _checkpid(self):
1399 # _checkpid() attempts to keep ConnectionPool fork-safe on modern
1400 # systems. this is called by all ConnectionPool methods that
1401 # manipulate the pool's state such as get_connection() and release().
1402 #
1403 # _checkpid() determines whether the process has forked by comparing
1404 # the current process id to the process id saved on the ConnectionPool
1405 # instance. if these values are the same, _checkpid() simply returns.
1406 #
1407 # when the process ids differ, _checkpid() assumes that the process
1408 # has forked and that we're now running in the child process. the child
1409 # process cannot use the parent's file descriptors (e.g., sockets).
1410 # therefore, when _checkpid() sees the process id change, it calls
1411 # reset() in order to reinitialize the child's ConnectionPool. this
1412 # will cause the child to make all new connection objects.
1413 #
1414 # _checkpid() is protected by self._fork_lock to ensure that multiple
1415 # threads in the child process do not call reset() multiple times.
1416 #
1417 # there is an extremely small chance this could fail in the following
1418 # scenario:
1419 # 1. process A calls _checkpid() for the first time and acquires
1420 # self._fork_lock.
1421 # 2. while holding self._fork_lock, process A forks (the fork()
1422 # could happen in a different thread owned by process A)
1423 # 3. process B (the forked child process) inherits the
1424 # ConnectionPool's state from the parent. that state includes
1425 # a locked _fork_lock. process B will not be notified when
1426 # process A releases the _fork_lock and will thus never be
1427 # able to acquire the _fork_lock.
1428 #
1429 # to mitigate this possible deadlock, _checkpid() will only wait 5
1430 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1431 # that time it is assumed that the child is deadlocked and a
1432 # redis.ChildDeadlockedError error is raised.
1433 if self.pid != os.getpid():
1434 acquired = self._fork_lock.acquire(timeout=5)
1435 if not acquired:
1436 raise ChildDeadlockedError
1437 # reset() the instance for the new process if another thread
1438 # hasn't already done so
1439 try:
1440 if self.pid != os.getpid():
1441 self.reset()
1442 finally:
1443 self._fork_lock.release()
1445 def get_connection(self, command_name, *keys, **options):
1446 "Get a connection from the pool"
1447 self._checkpid()
1448 with self._lock:
1449 try:
1450 connection = self._available_connections.pop()
1451 except IndexError:
1452 connection = self.make_connection()
1453 self._in_use_connections.add(connection)
1455 try:
1456 # ensure this connection is connected to Redis
1457 connection.connect()
1458 # connections that the pool provides should be ready to send
1459 # a command. if not, the connection was either returned to the
1460 # pool before all data has been read or the socket has been
1461 # closed. either way, reconnect and verify everything is good.
1462 try:
1463 if connection.can_read():
1464 raise ConnectionError("Connection has data")
1465 except (ConnectionError, OSError):
1466 connection.disconnect()
1467 connection.connect()
1468 if connection.can_read():
1469 raise ConnectionError("Connection not ready")
1470 except BaseException:
1471 # release the connection back to the pool so that we don't
1472 # leak it
1473 self.release(connection)
1474 raise
1476 return connection
1478 def get_encoder(self):
1479 "Return an encoder based on encoding settings"
1480 kwargs = self.connection_kwargs
1481 return Encoder(
1482 encoding=kwargs.get("encoding", "utf-8"),
1483 encoding_errors=kwargs.get("encoding_errors", "strict"),
1484 decode_responses=kwargs.get("decode_responses", False),
1485 )
1487 def make_connection(self):
1488 "Create a new connection"
1489 if self._created_connections >= self.max_connections:
1490 raise ConnectionError("Too many connections")
1491 self._created_connections += 1
1492 return self.connection_class(**self.connection_kwargs)
1494 def release(self, connection):
1495 "Releases the connection back to the pool"
1496 self._checkpid()
1497 with self._lock:
1498 try:
1499 self._in_use_connections.remove(connection)
1500 except KeyError:
1501 # Gracefully fail when a connection is returned to this pool
1502 # that the pool doesn't actually own
1503 pass
1505 if self.owns_connection(connection):
1506 self._available_connections.append(connection)
1507 else:
1508 # pool doesn't own this connection. do not add it back
1509 # to the pool and decrement the count so that another
1510 # connection can take its place if needed
1511 self._created_connections -= 1
1512 connection.disconnect()
1513 return
1515 def owns_connection(self, connection):
1516 return connection.pid == self.pid
1518 def disconnect(self, inuse_connections=True):
1519 """
1520 Disconnects connections in the pool
1522 If ``inuse_connections`` is True, disconnect connections that are
1523 current in use, potentially by other threads. Otherwise only disconnect
1524 connections that are idle in the pool.
1525 """
1526 self._checkpid()
1527 with self._lock:
1528 if inuse_connections:
1529 connections = chain(
1530 self._available_connections, self._in_use_connections
1531 )
1532 else:
1533 connections = self._available_connections
1535 for connection in connections:
1536 connection.disconnect()
1538 def set_retry(self, retry: "Retry") -> None:
1539 self.connection_kwargs.update({"retry": retry})
1540 for conn in self._available_connections:
1541 conn.retry = retry
1542 for conn in self._in_use_connections:
1543 conn.retry = retry
1546class BlockingConnectionPool(ConnectionPool):
1547 """
1548 Thread-safe blocking connection pool::
1550 >>> from redis.client import Redis
1551 >>> client = Redis(connection_pool=BlockingConnectionPool())
1553 It performs the same function as the default
1554 :py:class:`~redis.ConnectionPool` implementation, in that,
1555 it maintains a pool of reusable connections that can be shared by
1556 multiple redis clients (safely across threads if required).
1558 The difference is that, in the event that a client tries to get a
1559 connection from the pool when all of connections are in use, rather than
1560 raising a :py:class:`~redis.ConnectionError` (as the default
1561 :py:class:`~redis.ConnectionPool` implementation does), it
1562 makes the client wait ("blocks") for a specified number of seconds until
1563 a connection becomes available.
1565 Use ``max_connections`` to increase / decrease the pool size::
1567 >>> pool = BlockingConnectionPool(max_connections=10)
1569 Use ``timeout`` to tell it either how many seconds to wait for a connection
1570 to become available, or to block forever:
1572 >>> # Block forever.
1573 >>> pool = BlockingConnectionPool(timeout=None)
1575 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1576 >>> # not available.
1577 >>> pool = BlockingConnectionPool(timeout=5)
1578 """
1580 def __init__(
1581 self,
1582 max_connections=50,
1583 timeout=20,
1584 connection_class=Connection,
1585 queue_class=LifoQueue,
1586 **connection_kwargs,
1587 ):
1589 self.queue_class = queue_class
1590 self.timeout = timeout
1591 super().__init__(
1592 connection_class=connection_class,
1593 max_connections=max_connections,
1594 **connection_kwargs,
1595 )
1597 def reset(self):
1598 # Create and fill up a thread safe queue with ``None`` values.
1599 self.pool = self.queue_class(self.max_connections)
1600 while True:
1601 try:
1602 self.pool.put_nowait(None)
1603 except Full:
1604 break
1606 # Keep a list of actual connection instances so that we can
1607 # disconnect them later.
1608 self._connections = []
1610 # this must be the last operation in this method. while reset() is
1611 # called when holding _fork_lock, other threads in this process
1612 # can call _checkpid() which compares self.pid and os.getpid() without
1613 # holding any lock (for performance reasons). keeping this assignment
1614 # as the last operation ensures that those other threads will also
1615 # notice a pid difference and block waiting for the first thread to
1616 # release _fork_lock. when each of these threads eventually acquire
1617 # _fork_lock, they will notice that another thread already called
1618 # reset() and they will immediately release _fork_lock and continue on.
1619 self.pid = os.getpid()
1621 def make_connection(self):
1622 "Make a fresh connection."
1623 connection = self.connection_class(**self.connection_kwargs)
1624 self._connections.append(connection)
1625 return connection
1627 def get_connection(self, command_name, *keys, **options):
1628 """
1629 Get a connection, blocking for ``self.timeout`` until a connection
1630 is available from the pool.
1632 If the connection returned is ``None`` then creates a new connection.
1633 Because we use a last-in first-out queue, the existing connections
1634 (having been returned to the pool after the initial ``None`` values
1635 were added) will be returned before ``None`` values. This means we only
1636 create new connections when we need to, i.e.: the actual number of
1637 connections will only increase in response to demand.
1638 """
1639 # Make sure we haven't changed process.
1640 self._checkpid()
1642 # Try and get a connection from the pool. If one isn't available within
1643 # self.timeout then raise a ``ConnectionError``.
1644 connection = None
1645 try:
1646 connection = self.pool.get(block=True, timeout=self.timeout)
1647 except Empty:
1648 # Note that this is not caught by the redis client and will be
1649 # raised unless handled by application code. If you want never to
1650 raise ConnectionError("No connection available.")
1652 # If the ``connection`` is actually ``None`` then that's a cue to make
1653 # a new connection to add to the pool.
1654 if connection is None:
1655 connection = self.make_connection()
1657 try:
1658 # ensure this connection is connected to Redis
1659 connection.connect()
1660 # connections that the pool provides should be ready to send
1661 # a command. if not, the connection was either returned to the
1662 # pool before all data has been read or the socket has been
1663 # closed. either way, reconnect and verify everything is good.
1664 try:
1665 if connection.can_read():
1666 raise ConnectionError("Connection has data")
1667 except (ConnectionError, OSError):
1668 connection.disconnect()
1669 connection.connect()
1670 if connection.can_read():
1671 raise ConnectionError("Connection not ready")
1672 except BaseException:
1673 # release the connection back to the pool so that we don't leak it
1674 self.release(connection)
1675 raise
1677 return connection
1679 def release(self, connection):
1680 "Releases the connection back to the pool."
1681 # Make sure we haven't changed process.
1682 self._checkpid()
1683 if not self.owns_connection(connection):
1684 # pool doesn't own this connection. do not add it back
1685 # to the pool. instead add a None value which is a placeholder
1686 # that will cause the pool to recreate the connection if
1687 # its needed.
1688 connection.disconnect()
1689 self.pool.put_nowait(None)
1690 return
1692 # Put the connection back into the pool.
1693 try:
1694 self.pool.put_nowait(connection)
1695 except Full:
1696 # perhaps the pool has been reset() after a fork? regardless,
1697 # we don't want this connection
1698 pass
1700 def disconnect(self):
1701 "Disconnects all connections in the pool."
1702 self._checkpid()
1703 for connection in self._connections:
1704 connection.disconnect()