Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/asyncio/connection.py: 22%
665 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 06:16 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 06:16 +0000
1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import ssl
7import sys
8import warnings
9import weakref
10from abc import abstractmethod
11from itertools import chain
12from types import MappingProxyType
13from typing import (
14 Any,
15 Callable,
16 Iterable,
17 List,
18 Mapping,
19 Optional,
20 Protocol,
21 Set,
22 Tuple,
23 Type,
24 TypedDict,
25 TypeVar,
26 Union,
27)
28from urllib.parse import ParseResult, parse_qs, unquote, urlparse
30# the functionality is available in 3.11.x but has a major issue before
31# 3.11.3. See https://github.com/redis/redis-py/issues/2633
32if sys.version_info >= (3, 11, 3):
33 from asyncio import timeout as async_timeout
34else:
35 from async_timeout import timeout as async_timeout
37from redis.asyncio.retry import Retry
38from redis.backoff import NoBackoff
39from redis.connection import DEFAULT_RESP_VERSION
40from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
41from redis.exceptions import (
42 AuthenticationError,
43 AuthenticationWrongNumberOfArgsError,
44 ConnectionError,
45 DataError,
46 RedisError,
47 ResponseError,
48 TimeoutError,
49)
50from redis.typing import EncodableT, KeysT, ResponseT
51from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
53from .._cache import (
54 DEFAULT_BLACKLIST,
55 DEFAULT_EVICTION_POLICY,
56 DEFAULT_WHITELIST,
57 AbstractCache,
58 _LocalCache,
59)
60from .._parsers import (
61 BaseParser,
62 Encoder,
63 _AsyncHiredisParser,
64 _AsyncRESP2Parser,
65 _AsyncRESP3Parser,
66)
68SYM_STAR = b"*"
69SYM_DOLLAR = b"$"
70SYM_CRLF = b"\r\n"
71SYM_LF = b"\n"
72SYM_EMPTY = b""
75class _Sentinel(enum.Enum):
76 sentinel = object()
79SENTINEL = _Sentinel.sentinel
82DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
83if HIREDIS_AVAILABLE:
84 DefaultParser = _AsyncHiredisParser
85else:
86 DefaultParser = _AsyncRESP2Parser
89class ConnectCallbackProtocol(Protocol):
90 def __call__(self, connection: "AbstractConnection"): ...
93class AsyncConnectCallbackProtocol(Protocol):
94 async def __call__(self, connection: "AbstractConnection"): ...
97ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
100class AbstractConnection:
101 """Manages communication to and from a Redis server"""
103 __slots__ = (
104 "db",
105 "username",
106 "client_name",
107 "lib_name",
108 "lib_version",
109 "credential_provider",
110 "password",
111 "socket_timeout",
112 "socket_connect_timeout",
113 "redis_connect_func",
114 "retry_on_timeout",
115 "retry_on_error",
116 "health_check_interval",
117 "next_health_check",
118 "last_active_at",
119 "encoder",
120 "ssl_context",
121 "protocol",
122 "client_cache",
123 "cache_blacklist",
124 "cache_whitelist",
125 "_reader",
126 "_writer",
127 "_parser",
128 "_connect_callbacks",
129 "_buffer_cutoff",
130 "_lock",
131 "_socket_read_size",
132 "__dict__",
133 )
135 def __init__(
136 self,
137 *,
138 db: Union[str, int] = 0,
139 password: Optional[str] = None,
140 socket_timeout: Optional[float] = None,
141 socket_connect_timeout: Optional[float] = None,
142 retry_on_timeout: bool = False,
143 retry_on_error: Union[list, _Sentinel] = SENTINEL,
144 encoding: str = "utf-8",
145 encoding_errors: str = "strict",
146 decode_responses: bool = False,
147 parser_class: Type[BaseParser] = DefaultParser,
148 socket_read_size: int = 65536,
149 health_check_interval: float = 0,
150 client_name: Optional[str] = None,
151 lib_name: Optional[str] = "redis-py",
152 lib_version: Optional[str] = get_lib_version(),
153 username: Optional[str] = None,
154 retry: Optional[Retry] = None,
155 redis_connect_func: Optional[ConnectCallbackT] = None,
156 encoder_class: Type[Encoder] = Encoder,
157 credential_provider: Optional[CredentialProvider] = None,
158 protocol: Optional[int] = 2,
159 cache_enabled: bool = False,
160 client_cache: Optional[AbstractCache] = None,
161 cache_max_size: int = 10000,
162 cache_ttl: int = 0,
163 cache_policy: str = DEFAULT_EVICTION_POLICY,
164 cache_blacklist: List[str] = DEFAULT_BLACKLIST,
165 cache_whitelist: List[str] = DEFAULT_WHITELIST,
166 ):
167 if (username or password) and credential_provider is not None:
168 raise DataError(
169 "'username' and 'password' cannot be passed along with 'credential_"
170 "provider'. Please provide only one of the following arguments: \n"
171 "1. 'password' and (optional) 'username'\n"
172 "2. 'credential_provider'"
173 )
174 self.db = db
175 self.client_name = client_name
176 self.lib_name = lib_name
177 self.lib_version = lib_version
178 self.credential_provider = credential_provider
179 self.password = password
180 self.username = username
181 self.socket_timeout = socket_timeout
182 if socket_connect_timeout is None:
183 socket_connect_timeout = socket_timeout
184 self.socket_connect_timeout = socket_connect_timeout
185 self.retry_on_timeout = retry_on_timeout
186 if retry_on_error is SENTINEL:
187 retry_on_error = []
188 if retry_on_timeout:
189 retry_on_error.append(TimeoutError)
190 retry_on_error.append(socket.timeout)
191 retry_on_error.append(asyncio.TimeoutError)
192 self.retry_on_error = retry_on_error
193 if retry or retry_on_error:
194 if not retry:
195 self.retry = Retry(NoBackoff(), 1)
196 else:
197 # deep-copy the Retry object as it is mutable
198 self.retry = copy.deepcopy(retry)
199 # Update the retry's supported errors with the specified errors
200 self.retry.update_supported_errors(retry_on_error)
201 else:
202 self.retry = Retry(NoBackoff(), 0)
203 self.health_check_interval = health_check_interval
204 self.next_health_check: float = -1
205 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
206 self.redis_connect_func = redis_connect_func
207 self._reader: Optional[asyncio.StreamReader] = None
208 self._writer: Optional[asyncio.StreamWriter] = None
209 self._socket_read_size = socket_read_size
210 self.set_parser(parser_class)
211 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
212 self._buffer_cutoff = 6000
213 try:
214 p = int(protocol)
215 except TypeError:
216 p = DEFAULT_RESP_VERSION
217 except ValueError:
218 raise ConnectionError("protocol must be an integer")
219 finally:
220 if p < 2 or p > 3:
221 raise ConnectionError("protocol must be either 2 or 3")
222 self.protocol = protocol
223 if cache_enabled:
224 _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy)
225 else:
226 _cache = None
227 self.client_cache = client_cache if client_cache is not None else _cache
228 if self.client_cache is not None:
229 if self.protocol not in [3, "3"]:
230 raise RedisError(
231 "client caching is only supported with protocol version 3 or higher"
232 )
233 self.cache_blacklist = cache_blacklist
234 self.cache_whitelist = cache_whitelist
236 def __del__(self, _warnings: Any = warnings):
237 # For some reason, the individual streams don't get properly garbage
238 # collected and therefore produce no resource warnings. We add one
239 # here, in the same style as those from the stdlib.
240 if getattr(self, "_writer", None):
241 _warnings.warn(
242 f"unclosed Connection {self!r}", ResourceWarning, source=self
243 )
244 self._close()
246 def _close(self):
247 """
248 Internal method to silently close the connection without waiting
249 """
250 if self._writer:
251 self._writer.close()
252 self._writer = self._reader = None
254 def __repr__(self):
255 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
256 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
258 @abstractmethod
259 def repr_pieces(self):
260 pass
262 @property
263 def is_connected(self):
264 return self._reader is not None and self._writer is not None
266 def register_connect_callback(self, callback):
267 """
268 Register a callback to be called when the connection is established either
269 initially or reconnected. This allows listeners to issue commands that
270 are ephemeral to the connection, for example pub/sub subscription or
271 key tracking. The callback must be a _method_ and will be kept as
272 a weak reference.
273 """
274 wm = weakref.WeakMethod(callback)
275 if wm not in self._connect_callbacks:
276 self._connect_callbacks.append(wm)
278 def deregister_connect_callback(self, callback):
279 """
280 De-register a previously registered callback. It will no-longer receive
281 notifications on connection events. Calling this is not required when the
282 listener goes away, since the callbacks are kept as weak methods.
283 """
284 try:
285 self._connect_callbacks.remove(weakref.WeakMethod(callback))
286 except ValueError:
287 pass
289 def set_parser(self, parser_class: Type[BaseParser]) -> None:
290 """
291 Creates a new instance of parser_class with socket size:
292 _socket_read_size and assigns it to the parser for the connection
293 :param parser_class: The required parser class
294 """
295 self._parser = parser_class(socket_read_size=self._socket_read_size)
297 async def connect(self):
298 """Connects to the Redis server if not already connected"""
299 if self.is_connected:
300 return
301 try:
302 await self.retry.call_with_retry(
303 lambda: self._connect(), lambda error: self.disconnect()
304 )
305 except asyncio.CancelledError:
306 raise # in 3.7 and earlier, this is an Exception, not BaseException
307 except (socket.timeout, asyncio.TimeoutError):
308 raise TimeoutError("Timeout connecting to server")
309 except OSError as e:
310 raise ConnectionError(self._error_message(e))
311 except Exception as exc:
312 raise ConnectionError(exc) from exc
314 try:
315 if not self.redis_connect_func:
316 # Use the default on_connect function
317 await self.on_connect()
318 else:
319 # Use the passed function redis_connect_func
320 (
321 await self.redis_connect_func(self)
322 if asyncio.iscoroutinefunction(self.redis_connect_func)
323 else self.redis_connect_func(self)
324 )
325 except RedisError:
326 # clean up after any error in on_connect
327 await self.disconnect()
328 raise
330 # run any user callbacks. right now the only internal callback
331 # is for pubsub channel/pattern resubscription
332 # first, remove any dead weakrefs
333 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
334 for ref in self._connect_callbacks:
335 callback = ref()
336 task = callback(self)
337 if task and inspect.isawaitable(task):
338 await task
340 @abstractmethod
341 async def _connect(self):
342 pass
344 @abstractmethod
345 def _host_error(self) -> str:
346 pass
348 @abstractmethod
349 def _error_message(self, exception: BaseException) -> str:
350 pass
352 async def on_connect(self) -> None:
353 """Initialize the connection, authenticate and select a database"""
354 self._parser.on_connect(self)
355 parser = self._parser
357 auth_args = None
358 # if credential provider or username and/or password are set, authenticate
359 if self.credential_provider or (self.username or self.password):
360 cred_provider = (
361 self.credential_provider
362 or UsernamePasswordCredentialProvider(self.username, self.password)
363 )
364 auth_args = cred_provider.get_credentials()
365 # if resp version is specified and we have auth args,
366 # we need to send them via HELLO
367 if auth_args and self.protocol not in [2, "2"]:
368 if isinstance(self._parser, _AsyncRESP2Parser):
369 self.set_parser(_AsyncRESP3Parser)
370 # update cluster exception classes
371 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
372 self._parser.on_connect(self)
373 if len(auth_args) == 1:
374 auth_args = ["default", auth_args[0]]
375 await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
376 response = await self.read_response()
377 if response.get(b"proto") != int(self.protocol) and response.get(
378 "proto"
379 ) != int(self.protocol):
380 raise ConnectionError("Invalid RESP version")
381 # avoid checking health here -- PING will fail if we try
382 # to check the health prior to the AUTH
383 elif auth_args:
384 await self.send_command("AUTH", *auth_args, check_health=False)
386 try:
387 auth_response = await self.read_response()
388 except AuthenticationWrongNumberOfArgsError:
389 # a username and password were specified but the Redis
390 # server seems to be < 6.0.0 which expects a single password
391 # arg. retry auth with just the password.
392 # https://github.com/andymccurdy/redis-py/issues/1274
393 await self.send_command("AUTH", auth_args[-1], check_health=False)
394 auth_response = await self.read_response()
396 if str_if_bytes(auth_response) != "OK":
397 raise AuthenticationError("Invalid Username or Password")
399 # if resp version is specified, switch to it
400 elif self.protocol not in [2, "2"]:
401 if isinstance(self._parser, _AsyncRESP2Parser):
402 self.set_parser(_AsyncRESP3Parser)
403 # update cluster exception classes
404 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
405 self._parser.on_connect(self)
406 await self.send_command("HELLO", self.protocol)
407 response = await self.read_response()
408 # if response.get(b"proto") != self.protocol and response.get(
409 # "proto"
410 # ) != self.protocol:
411 # raise ConnectionError("Invalid RESP version")
413 # if a client_name is given, set it
414 if self.client_name:
415 await self.send_command("CLIENT", "SETNAME", self.client_name)
416 if str_if_bytes(await self.read_response()) != "OK":
417 raise ConnectionError("Error setting client name")
419 # set the library name and version, pipeline for lower startup latency
420 if self.lib_name:
421 await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
422 if self.lib_version:
423 await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
424 # if a database is specified, switch to it. Also pipeline this
425 if self.db:
426 await self.send_command("SELECT", self.db)
427 # if client caching is enabled, start tracking
428 if self.client_cache:
429 await self.send_command("CLIENT", "TRACKING", "ON")
430 await self.read_response()
431 self._parser.set_invalidation_push_handler(self._cache_invalidation_process)
433 # read responses from pipeline
434 for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
435 try:
436 await self.read_response()
437 except ResponseError:
438 pass
440 if self.db:
441 if str_if_bytes(await self.read_response()) != "OK":
442 raise ConnectionError("Invalid Database")
444 async def disconnect(self, nowait: bool = False) -> None:
445 """Disconnects from the Redis server"""
446 try:
447 async with async_timeout(self.socket_connect_timeout):
448 self._parser.on_disconnect()
449 if not self.is_connected:
450 return
451 try:
452 self._writer.close() # type: ignore[union-attr]
453 # wait for close to finish, except when handling errors and
454 # forcefully disconnecting.
455 if not nowait:
456 await self._writer.wait_closed() # type: ignore[union-attr]
457 except OSError:
458 pass
459 finally:
460 self._reader = None
461 self._writer = None
462 except asyncio.TimeoutError:
463 raise TimeoutError(
464 f"Timed out closing connection after {self.socket_connect_timeout}"
465 ) from None
466 finally:
467 if self.client_cache:
468 self.client_cache.flush()
470 async def _send_ping(self):
471 """Send PING, expect PONG in return"""
472 await self.send_command("PING", check_health=False)
473 if str_if_bytes(await self.read_response()) != "PONG":
474 raise ConnectionError("Bad response from PING health check")
476 async def _ping_failed(self, error):
477 """Function to call when PING fails"""
478 await self.disconnect()
480 async def check_health(self):
481 """Check the health of the connection with a PING/PONG"""
482 if (
483 self.health_check_interval
484 and asyncio.get_running_loop().time() > self.next_health_check
485 ):
486 await self.retry.call_with_retry(self._send_ping, self._ping_failed)
488 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
489 self._writer.writelines(command)
490 await self._writer.drain()
492 async def send_packed_command(
493 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
494 ) -> None:
495 if not self.is_connected:
496 await self.connect()
497 elif check_health:
498 await self.check_health()
500 try:
501 if isinstance(command, str):
502 command = command.encode()
503 if isinstance(command, bytes):
504 command = [command]
505 if self.socket_timeout:
506 await asyncio.wait_for(
507 self._send_packed_command(command), self.socket_timeout
508 )
509 else:
510 self._writer.writelines(command)
511 await self._writer.drain()
512 except asyncio.TimeoutError:
513 await self.disconnect(nowait=True)
514 raise TimeoutError("Timeout writing to socket") from None
515 except OSError as e:
516 await self.disconnect(nowait=True)
517 if len(e.args) == 1:
518 err_no, errmsg = "UNKNOWN", e.args[0]
519 else:
520 err_no = e.args[0]
521 errmsg = e.args[1]
522 raise ConnectionError(
523 f"Error {err_no} while writing to socket. {errmsg}."
524 ) from e
525 except BaseException:
526 # BaseExceptions can be raised when a socket send operation is not
527 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
528 # to send un-sent data. However, the send_packed_command() API
529 # does not support it so there is no point in keeping the connection open.
530 await self.disconnect(nowait=True)
531 raise
533 async def send_command(self, *args: Any, **kwargs: Any) -> None:
534 """Pack and send a command to the Redis server"""
535 await self.send_packed_command(
536 self.pack_command(*args), check_health=kwargs.get("check_health", True)
537 )
539 async def can_read_destructive(self):
540 """Poll the socket to see if there's data that can be read."""
541 try:
542 return await self._parser.can_read_destructive()
543 except OSError as e:
544 await self.disconnect(nowait=True)
545 host_error = self._host_error()
546 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
548 async def read_response(
549 self,
550 disable_decoding: bool = False,
551 timeout: Optional[float] = None,
552 *,
553 disconnect_on_error: bool = True,
554 push_request: Optional[bool] = False,
555 ):
556 """Read the response from a previously sent command"""
557 read_timeout = timeout if timeout is not None else self.socket_timeout
558 host_error = self._host_error()
559 try:
560 if (
561 read_timeout is not None
562 and self.protocol in ["3", 3]
563 and not HIREDIS_AVAILABLE
564 ):
565 async with async_timeout(read_timeout):
566 response = await self._parser.read_response(
567 disable_decoding=disable_decoding, push_request=push_request
568 )
569 elif read_timeout is not None:
570 async with async_timeout(read_timeout):
571 response = await self._parser.read_response(
572 disable_decoding=disable_decoding
573 )
574 elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
575 response = await self._parser.read_response(
576 disable_decoding=disable_decoding, push_request=push_request
577 )
578 else:
579 response = await self._parser.read_response(
580 disable_decoding=disable_decoding
581 )
582 except asyncio.TimeoutError:
583 if timeout is not None:
584 # user requested timeout, return None. Operation can be retried
585 return None
586 # it was a self.socket_timeout error.
587 if disconnect_on_error:
588 await self.disconnect(nowait=True)
589 raise TimeoutError(f"Timeout reading from {host_error}")
590 except OSError as e:
591 if disconnect_on_error:
592 await self.disconnect(nowait=True)
593 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
594 except BaseException:
595 # Also by default close in case of BaseException. A lot of code
596 # relies on this behaviour when doing Command/Response pairs.
597 # See #1128.
598 if disconnect_on_error:
599 await self.disconnect(nowait=True)
600 raise
602 if self.health_check_interval:
603 next_time = asyncio.get_running_loop().time() + self.health_check_interval
604 self.next_health_check = next_time
606 if isinstance(response, ResponseError):
607 raise response from None
608 return response
610 def pack_command(self, *args: EncodableT) -> List[bytes]:
611 """Pack a series of arguments into the Redis protocol"""
612 output = []
613 # the client might have included 1 or more literal arguments in
614 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
615 # arguments to be sent separately, so split the first argument
616 # manually. These arguments should be bytestrings so that they are
617 # not encoded.
618 assert not isinstance(args[0], float)
619 if isinstance(args[0], str):
620 args = tuple(args[0].encode().split()) + args[1:]
621 elif b" " in args[0]:
622 args = tuple(args[0].split()) + args[1:]
624 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
626 buffer_cutoff = self._buffer_cutoff
627 for arg in map(self.encoder.encode, args):
628 # to avoid large string mallocs, chunk the command into the
629 # output list if we're sending large values or memoryviews
630 arg_length = len(arg)
631 if (
632 len(buff) > buffer_cutoff
633 or arg_length > buffer_cutoff
634 or isinstance(arg, memoryview)
635 ):
636 buff = SYM_EMPTY.join(
637 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
638 )
639 output.append(buff)
640 output.append(arg)
641 buff = SYM_CRLF
642 else:
643 buff = SYM_EMPTY.join(
644 (
645 buff,
646 SYM_DOLLAR,
647 str(arg_length).encode(),
648 SYM_CRLF,
649 arg,
650 SYM_CRLF,
651 )
652 )
653 output.append(buff)
654 return output
656 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
657 """Pack multiple commands into the Redis protocol"""
658 output: List[bytes] = []
659 pieces: List[bytes] = []
660 buffer_length = 0
661 buffer_cutoff = self._buffer_cutoff
663 for cmd in commands:
664 for chunk in self.pack_command(*cmd):
665 chunklen = len(chunk)
666 if (
667 buffer_length > buffer_cutoff
668 or chunklen > buffer_cutoff
669 or isinstance(chunk, memoryview)
670 ):
671 if pieces:
672 output.append(SYM_EMPTY.join(pieces))
673 buffer_length = 0
674 pieces = []
676 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
677 output.append(chunk)
678 else:
679 pieces.append(chunk)
680 buffer_length += chunklen
682 if pieces:
683 output.append(SYM_EMPTY.join(pieces))
684 return output
686 def _socket_is_empty(self):
687 """Check if the socket is empty"""
688 return len(self._reader._buffer) == 0
690 def _cache_invalidation_process(
691 self, data: List[Union[str, Optional[List[str]]]]
692 ) -> None:
693 """
694 Invalidate (delete) all redis commands associated with a specific key.
695 `data` is a list of strings, where the first string is the invalidation message
696 and the second string is the list of keys to invalidate.
697 (if the list of keys is None, then all keys are invalidated)
698 """
699 if data[1] is not None:
700 self.client_cache.flush()
701 else:
702 for key in data[1]:
703 self.client_cache.invalidate_key(str_if_bytes(key))
705 async def _get_from_local_cache(self, command: str):
706 """
707 If the command is in the local cache, return the response
708 """
709 if (
710 self.client_cache is None
711 or command[0] in self.cache_blacklist
712 or command[0] not in self.cache_whitelist
713 ):
714 return None
715 while not self._socket_is_empty():
716 await self.read_response(push_request=True)
717 return self.client_cache.get(command)
719 def _add_to_local_cache(
720 self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
721 ):
722 """
723 Add the command and response to the local cache if the command
724 is allowed to be cached
725 """
726 if (
727 self.client_cache is not None
728 and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
729 and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
730 ):
731 self.client_cache.set(command, response, keys)
734class Connection(AbstractConnection):
735 "Manages TCP communication to and from a Redis server"
737 def __init__(
738 self,
739 *,
740 host: str = "localhost",
741 port: Union[str, int] = 6379,
742 socket_keepalive: bool = False,
743 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
744 socket_type: int = 0,
745 **kwargs,
746 ):
747 self.host = host
748 self.port = int(port)
749 self.socket_keepalive = socket_keepalive
750 self.socket_keepalive_options = socket_keepalive_options or {}
751 self.socket_type = socket_type
752 super().__init__(**kwargs)
754 def repr_pieces(self):
755 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
756 if self.client_name:
757 pieces.append(("client_name", self.client_name))
758 return pieces
760 def _connection_arguments(self) -> Mapping:
761 return {"host": self.host, "port": self.port}
763 async def _connect(self):
764 """Create a TCP socket connection"""
765 async with async_timeout(self.socket_connect_timeout):
766 reader, writer = await asyncio.open_connection(
767 **self._connection_arguments()
768 )
769 self._reader = reader
770 self._writer = writer
771 sock = writer.transport.get_extra_info("socket")
772 if sock:
773 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
774 try:
775 # TCP_KEEPALIVE
776 if self.socket_keepalive:
777 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
778 for k, v in self.socket_keepalive_options.items():
779 sock.setsockopt(socket.SOL_TCP, k, v)
781 except (OSError, TypeError):
782 # `socket_keepalive_options` might contain invalid options
783 # causing an error. Do not leave the connection open.
784 writer.close()
785 raise
787 def _host_error(self) -> str:
788 return f"{self.host}:{self.port}"
790 def _error_message(self, exception: BaseException) -> str:
791 # args for socket.error can either be (errno, "message")
792 # or just "message"
794 host_error = self._host_error()
796 if not exception.args:
797 # asyncio has a bug where on Connection reset by peer, the
798 # exception is not instanciated, so args is empty. This is the
799 # workaround.
800 # See: https://github.com/redis/redis-py/issues/2237
801 # See: https://github.com/python/cpython/issues/94061
802 return f"Error connecting to {host_error}. Connection reset by peer"
803 elif len(exception.args) == 1:
804 return f"Error connecting to {host_error}. {exception.args[0]}."
805 else:
806 return (
807 f"Error {exception.args[0]} connecting to {host_error}. "
808 f"{exception.args[0]}."
809 )
812class SSLConnection(Connection):
813 """Manages SSL connections to and from the Redis server(s).
814 This class extends the Connection class, adding SSL functionality, and making
815 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
816 """
818 def __init__(
819 self,
820 ssl_keyfile: Optional[str] = None,
821 ssl_certfile: Optional[str] = None,
822 ssl_cert_reqs: str = "required",
823 ssl_ca_certs: Optional[str] = None,
824 ssl_ca_data: Optional[str] = None,
825 ssl_check_hostname: bool = False,
826 ssl_min_version: Optional[ssl.TLSVersion] = None,
827 **kwargs,
828 ):
829 self.ssl_context: RedisSSLContext = RedisSSLContext(
830 keyfile=ssl_keyfile,
831 certfile=ssl_certfile,
832 cert_reqs=ssl_cert_reqs,
833 ca_certs=ssl_ca_certs,
834 ca_data=ssl_ca_data,
835 check_hostname=ssl_check_hostname,
836 min_version=ssl_min_version,
837 )
838 super().__init__(**kwargs)
840 def _connection_arguments(self) -> Mapping:
841 kwargs = super()._connection_arguments()
842 kwargs["ssl"] = self.ssl_context.get()
843 return kwargs
845 @property
846 def keyfile(self):
847 return self.ssl_context.keyfile
849 @property
850 def certfile(self):
851 return self.ssl_context.certfile
853 @property
854 def cert_reqs(self):
855 return self.ssl_context.cert_reqs
857 @property
858 def ca_certs(self):
859 return self.ssl_context.ca_certs
861 @property
862 def ca_data(self):
863 return self.ssl_context.ca_data
865 @property
866 def check_hostname(self):
867 return self.ssl_context.check_hostname
869 @property
870 def min_version(self):
871 return self.ssl_context.min_version
874class RedisSSLContext:
875 __slots__ = (
876 "keyfile",
877 "certfile",
878 "cert_reqs",
879 "ca_certs",
880 "ca_data",
881 "context",
882 "check_hostname",
883 "min_version",
884 )
886 def __init__(
887 self,
888 keyfile: Optional[str] = None,
889 certfile: Optional[str] = None,
890 cert_reqs: Optional[str] = None,
891 ca_certs: Optional[str] = None,
892 ca_data: Optional[str] = None,
893 check_hostname: bool = False,
894 min_version: Optional[ssl.TLSVersion] = None,
895 ):
896 self.keyfile = keyfile
897 self.certfile = certfile
898 if cert_reqs is None:
899 self.cert_reqs = ssl.CERT_NONE
900 elif isinstance(cert_reqs, str):
901 CERT_REQS = {
902 "none": ssl.CERT_NONE,
903 "optional": ssl.CERT_OPTIONAL,
904 "required": ssl.CERT_REQUIRED,
905 }
906 if cert_reqs not in CERT_REQS:
907 raise RedisError(
908 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
909 )
910 self.cert_reqs = CERT_REQS[cert_reqs]
911 self.ca_certs = ca_certs
912 self.ca_data = ca_data
913 self.check_hostname = check_hostname
914 self.min_version = min_version
915 self.context: Optional[ssl.SSLContext] = None
917 def get(self) -> ssl.SSLContext:
918 if not self.context:
919 context = ssl.create_default_context()
920 context.check_hostname = self.check_hostname
921 context.verify_mode = self.cert_reqs
922 if self.certfile and self.keyfile:
923 context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
924 if self.ca_certs or self.ca_data:
925 context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
926 if self.min_version is not None:
927 context.minimum_version = self.min_version
928 self.context = context
929 return self.context
932class UnixDomainSocketConnection(AbstractConnection):
933 "Manages UDS communication to and from a Redis server"
935 def __init__(self, *, path: str = "", **kwargs):
936 self.path = path
937 super().__init__(**kwargs)
939 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
940 pieces = [("path", self.path), ("db", self.db)]
941 if self.client_name:
942 pieces.append(("client_name", self.client_name))
943 return pieces
945 async def _connect(self):
946 async with async_timeout(self.socket_connect_timeout):
947 reader, writer = await asyncio.open_unix_connection(path=self.path)
948 self._reader = reader
949 self._writer = writer
950 await self.on_connect()
952 def _host_error(self) -> str:
953 return self.path
955 def _error_message(self, exception: BaseException) -> str:
956 # args for socket.error can either be (errno, "message")
957 # or just "message"
958 host_error = self._host_error()
959 if len(exception.args) == 1:
960 return (
961 f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
962 )
963 else:
964 return (
965 f"Error {exception.args[0]} connecting to unix socket: "
966 f"{host_error}. {exception.args[1]}."
967 )
970FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
973def to_bool(value) -> Optional[bool]:
974 if value is None or value == "":
975 return None
976 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
977 return False
978 return bool(value)
981URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
982 {
983 "db": int,
984 "socket_timeout": float,
985 "socket_connect_timeout": float,
986 "socket_keepalive": to_bool,
987 "retry_on_timeout": to_bool,
988 "max_connections": int,
989 "health_check_interval": int,
990 "ssl_check_hostname": to_bool,
991 "timeout": float,
992 }
993)
996class ConnectKwargs(TypedDict, total=False):
997 username: str
998 password: str
999 connection_class: Type[AbstractConnection]
1000 host: str
1001 port: int
1002 db: int
1003 path: str
1006def parse_url(url: str) -> ConnectKwargs:
1007 parsed: ParseResult = urlparse(url)
1008 kwargs: ConnectKwargs = {}
1010 for name, value_list in parse_qs(parsed.query).items():
1011 if value_list and len(value_list) > 0:
1012 value = unquote(value_list[0])
1013 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1014 if parser:
1015 try:
1016 kwargs[name] = parser(value)
1017 except (TypeError, ValueError):
1018 raise ValueError(f"Invalid value for `{name}` in connection URL.")
1019 else:
1020 kwargs[name] = value
1022 if parsed.username:
1023 kwargs["username"] = unquote(parsed.username)
1024 if parsed.password:
1025 kwargs["password"] = unquote(parsed.password)
1027 # We only support redis://, rediss:// and unix:// schemes.
1028 if parsed.scheme == "unix":
1029 if parsed.path:
1030 kwargs["path"] = unquote(parsed.path)
1031 kwargs["connection_class"] = UnixDomainSocketConnection
1033 elif parsed.scheme in ("redis", "rediss"):
1034 if parsed.hostname:
1035 kwargs["host"] = unquote(parsed.hostname)
1036 if parsed.port:
1037 kwargs["port"] = int(parsed.port)
1039 # If there's a path argument, use it as the db argument if a
1040 # querystring value wasn't specified
1041 if parsed.path and "db" not in kwargs:
1042 try:
1043 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1044 except (AttributeError, ValueError):
1045 pass
1047 if parsed.scheme == "rediss":
1048 kwargs["connection_class"] = SSLConnection
1049 else:
1050 valid_schemes = "redis://, rediss://, unix://"
1051 raise ValueError(
1052 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1053 )
1055 return kwargs
1058_CP = TypeVar("_CP", bound="ConnectionPool")
1061class ConnectionPool:
1062 """
1063 Create a connection pool. ``If max_connections`` is set, then this
1064 object raises :py:class:`~redis.ConnectionError` when the pool's
1065 limit is reached.
1067 By default, TCP connections are created unless ``connection_class``
1068 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1069 unix sockets.
1071 Any additional keyword arguments are passed to the constructor of
1072 ``connection_class``.
1073 """
1075 @classmethod
1076 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1077 """
1078 Return a connection pool configured from the given URL.
1080 For example::
1082 redis://[[username]:[password]]@localhost:6379/0
1083 rediss://[[username]:[password]]@localhost:6379/0
1084 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1086 Three URL schemes are supported:
1088 - `redis://` creates a TCP socket connection. See more at:
1089 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1090 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1091 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1092 - ``unix://``: creates a Unix Domain Socket connection.
1094 The username, password, hostname, path and all querystring values
1095 are passed through urllib.parse.unquote in order to replace any
1096 percent-encoded values with their corresponding characters.
1098 There are several ways to specify a database number. The first value
1099 found will be used:
1101 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1103 2. If using the redis:// or rediss:// schemes, the path argument
1104 of the url, e.g. redis://localhost/0
1106 3. A ``db`` keyword argument to this function.
1108 If none of these options are specified, the default db=0 is used.
1110 All querystring options are cast to their appropriate Python types.
1111 Boolean arguments can be specified with string values "True"/"False"
1112 or "Yes"/"No". Values that cannot be properly cast cause a
1113 ``ValueError`` to be raised. Once parsed, the querystring arguments
1114 and keyword arguments are passed to the ``ConnectionPool``'s
1115 class initializer. In the case of conflicting arguments, querystring
1116 arguments always win.
1117 """
1118 url_options = parse_url(url)
1119 kwargs.update(url_options)
1120 return cls(**kwargs)
1122 def __init__(
1123 self,
1124 connection_class: Type[AbstractConnection] = Connection,
1125 max_connections: Optional[int] = None,
1126 **connection_kwargs,
1127 ):
1128 max_connections = max_connections or 2**31
1129 if not isinstance(max_connections, int) or max_connections < 0:
1130 raise ValueError('"max_connections" must be a positive integer')
1132 self.connection_class = connection_class
1133 self.connection_kwargs = connection_kwargs
1134 self.max_connections = max_connections
1136 self._available_connections: List[AbstractConnection] = []
1137 self._in_use_connections: Set[AbstractConnection] = set()
1138 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1140 def __repr__(self):
1141 return (
1142 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1143 f"({self.connection_class(**self.connection_kwargs)!r})>"
1144 )
1146 def reset(self):
1147 self._available_connections = []
1148 self._in_use_connections = weakref.WeakSet()
1150 def can_get_connection(self) -> bool:
1151 """Return True if a connection can be retrieved from the pool."""
1152 return (
1153 self._available_connections
1154 or len(self._in_use_connections) < self.max_connections
1155 )
1157 async def get_connection(self, command_name, *keys, **options):
1158 """Get a connected connection from the pool"""
1159 connection = self.get_available_connection()
1160 try:
1161 await self.ensure_connection(connection)
1162 except BaseException:
1163 await self.release(connection)
1164 raise
1166 return connection
1168 def get_available_connection(self):
1169 """Get a connection from the pool, without making sure it is connected"""
1170 try:
1171 connection = self._available_connections.pop()
1172 except IndexError:
1173 if len(self._in_use_connections) >= self.max_connections:
1174 raise ConnectionError("Too many connections") from None
1175 connection = self.make_connection()
1176 self._in_use_connections.add(connection)
1177 return connection
1179 def get_encoder(self):
1180 """Return an encoder based on encoding settings"""
1181 kwargs = self.connection_kwargs
1182 return self.encoder_class(
1183 encoding=kwargs.get("encoding", "utf-8"),
1184 encoding_errors=kwargs.get("encoding_errors", "strict"),
1185 decode_responses=kwargs.get("decode_responses", False),
1186 )
1188 def make_connection(self):
1189 """Create a new connection. Can be overridden by child classes."""
1190 return self.connection_class(**self.connection_kwargs)
1192 async def ensure_connection(self, connection: AbstractConnection):
1193 """Ensure that the connection object is connected and valid"""
1194 await connection.connect()
1195 # if client caching is not enabled connections that the pool
1196 # provides should be ready to send a command.
1197 # if not, the connection was either returned to the
1198 # pool before all data has been read or the socket has been
1199 # closed. either way, reconnect and verify everything is good.
1200 # (if caching enabled the connection will not always be ready
1201 # to send a command because it may contain invalidation messages)
1202 try:
1203 if (
1204 await connection.can_read_destructive()
1205 and connection.client_cache is None
1206 ):
1207 raise ConnectionError("Connection has data") from None
1208 except (ConnectionError, OSError):
1209 await connection.disconnect()
1210 await connection.connect()
1211 if await connection.can_read_destructive():
1212 raise ConnectionError("Connection not ready") from None
1214 async def release(self, connection: AbstractConnection):
1215 """Releases the connection back to the pool"""
1216 # Connections should always be returned to the correct pool,
1217 # not doing so is an error that will cause an exception here.
1218 self._in_use_connections.remove(connection)
1219 self._available_connections.append(connection)
1221 async def disconnect(self, inuse_connections: bool = True):
1222 """
1223 Disconnects connections in the pool
1225 If ``inuse_connections`` is True, disconnect connections that are
1226 current in use, potentially by other tasks. Otherwise only disconnect
1227 connections that are idle in the pool.
1228 """
1229 if inuse_connections:
1230 connections: Iterable[AbstractConnection] = chain(
1231 self._available_connections, self._in_use_connections
1232 )
1233 else:
1234 connections = self._available_connections
1235 resp = await asyncio.gather(
1236 *(connection.disconnect() for connection in connections),
1237 return_exceptions=True,
1238 )
1239 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1240 if exc:
1241 raise exc
1243 async def aclose(self) -> None:
1244 """Close the pool, disconnecting all connections"""
1245 await self.disconnect()
1247 def set_retry(self, retry: "Retry") -> None:
1248 for conn in self._available_connections:
1249 conn.retry = retry
1250 for conn in self._in_use_connections:
1251 conn.retry = retry
1253 def flush_cache(self):
1254 connections = chain(self._available_connections, self._in_use_connections)
1256 for connection in connections:
1257 try:
1258 connection.client_cache.flush()
1259 except AttributeError:
1260 # cache is not enabled
1261 pass
1263 def delete_command_from_cache(self, command: str):
1264 connections = chain(self._available_connections, self._in_use_connections)
1266 for connection in connections:
1267 try:
1268 connection.client_cache.delete_command(command)
1269 except AttributeError:
1270 # cache is not enabled
1271 pass
1273 def invalidate_key_from_cache(self, key: str):
1274 connections = chain(self._available_connections, self._in_use_connections)
1276 for connection in connections:
1277 try:
1278 connection.client_cache.invalidate_key(key)
1279 except AttributeError:
1280 # cache is not enabled
1281 pass
1284class BlockingConnectionPool(ConnectionPool):
1285 """
1286 A blocking connection pool::
1288 >>> from redis.asyncio import Redis, BlockingConnectionPool
1289 >>> client = Redis.from_pool(BlockingConnectionPool())
1291 It performs the same function as the default
1292 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1293 it maintains a pool of reusable connections that can be shared by
1294 multiple async redis clients.
1296 The difference is that, in the event that a client tries to get a
1297 connection from the pool when all of connections are in use, rather than
1298 raising a :py:class:`~redis.ConnectionError` (as the default
1299 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1300 blocks the current `Task` for a specified number of seconds until
1301 a connection becomes available.
1303 Use ``max_connections`` to increase / decrease the pool size::
1305 >>> pool = BlockingConnectionPool(max_connections=10)
1307 Use ``timeout`` to tell it either how many seconds to wait for a connection
1308 to become available, or to block forever:
1310 >>> # Block forever.
1311 >>> pool = BlockingConnectionPool(timeout=None)
1313 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1314 >>> # not available.
1315 >>> pool = BlockingConnectionPool(timeout=5)
1316 """
1318 def __init__(
1319 self,
1320 max_connections: int = 50,
1321 timeout: Optional[int] = 20,
1322 connection_class: Type[AbstractConnection] = Connection,
1323 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1324 **connection_kwargs,
1325 ):
1326 super().__init__(
1327 connection_class=connection_class,
1328 max_connections=max_connections,
1329 **connection_kwargs,
1330 )
1331 self._condition = asyncio.Condition()
1332 self.timeout = timeout
1334 async def get_connection(self, command_name, *keys, **options):
1335 """Gets a connection from the pool, blocking until one is available"""
1336 try:
1337 async with self._condition:
1338 async with async_timeout(self.timeout):
1339 await self._condition.wait_for(self.can_get_connection)
1340 connection = super().get_available_connection()
1341 except asyncio.TimeoutError as err:
1342 raise ConnectionError("No connection available.") from err
1344 # We now perform the connection check outside of the lock.
1345 try:
1346 await self.ensure_connection(connection)
1347 return connection
1348 except BaseException:
1349 await self.release(connection)
1350 raise
1352 async def release(self, connection: AbstractConnection):
1353 """Releases the connection back to the pool."""
1354 async with self._condition:
1355 await super().release(connection)
1356 self._condition.notify()