1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import warnings
8import weakref
9from abc import abstractmethod
10from itertools import chain
11from types import MappingProxyType
12from typing import (
13 Any,
14 Callable,
15 Iterable,
16 List,
17 Mapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26)
27from urllib.parse import ParseResult, parse_qs, unquote, urlparse
28
29from ..utils import SSL_AVAILABLE
30
31if SSL_AVAILABLE:
32 import ssl
33 from ssl import SSLContext, TLSVersion, VerifyFlags
34else:
35 ssl = None
36 TLSVersion = None
37 SSLContext = None
38 VerifyFlags = None
39
40from ..auth.token import TokenInterface
41from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
42from ..utils import deprecated_args, format_error_message
43
44# the functionality is available in 3.11.x but has a major issue before
45# 3.11.3. See https://github.com/redis/redis-py/issues/2633
46if sys.version_info >= (3, 11, 3):
47 from asyncio import timeout as async_timeout
48else:
49 from async_timeout import timeout as async_timeout
50
51from redis.asyncio.retry import Retry
52from redis.backoff import NoBackoff
53from redis.connection import DEFAULT_RESP_VERSION
54from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
55from redis.exceptions import (
56 AuthenticationError,
57 AuthenticationWrongNumberOfArgsError,
58 ConnectionError,
59 DataError,
60 MaxConnectionsError,
61 RedisError,
62 ResponseError,
63 TimeoutError,
64)
65from redis.typing import EncodableT
66from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
67
68from .._parsers import (
69 BaseParser,
70 Encoder,
71 _AsyncHiredisParser,
72 _AsyncRESP2Parser,
73 _AsyncRESP3Parser,
74)
75
76SYM_STAR = b"*"
77SYM_DOLLAR = b"$"
78SYM_CRLF = b"\r\n"
79SYM_LF = b"\n"
80SYM_EMPTY = b""
81
82
83class _Sentinel(enum.Enum):
84 sentinel = object()
85
86
87SENTINEL = _Sentinel.sentinel
88
89
90DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
91if HIREDIS_AVAILABLE:
92 DefaultParser = _AsyncHiredisParser
93else:
94 DefaultParser = _AsyncRESP2Parser
95
96
97class ConnectCallbackProtocol(Protocol):
98 def __call__(self, connection: "AbstractConnection"): ...
99
100
101class AsyncConnectCallbackProtocol(Protocol):
102 async def __call__(self, connection: "AbstractConnection"): ...
103
104
105ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
106
107
108class AbstractConnection:
109 """Manages communication to and from a Redis server"""
110
111 __slots__ = (
112 "db",
113 "username",
114 "client_name",
115 "lib_name",
116 "lib_version",
117 "credential_provider",
118 "password",
119 "socket_timeout",
120 "socket_connect_timeout",
121 "redis_connect_func",
122 "retry_on_timeout",
123 "retry_on_error",
124 "health_check_interval",
125 "next_health_check",
126 "last_active_at",
127 "encoder",
128 "ssl_context",
129 "protocol",
130 "_reader",
131 "_writer",
132 "_parser",
133 "_connect_callbacks",
134 "_buffer_cutoff",
135 "_lock",
136 "_socket_read_size",
137 "__dict__",
138 )
139
140 def __init__(
141 self,
142 *,
143 db: Union[str, int] = 0,
144 password: Optional[str] = None,
145 socket_timeout: Optional[float] = None,
146 socket_connect_timeout: Optional[float] = None,
147 retry_on_timeout: bool = False,
148 retry_on_error: Union[list, _Sentinel] = SENTINEL,
149 encoding: str = "utf-8",
150 encoding_errors: str = "strict",
151 decode_responses: bool = False,
152 parser_class: Type[BaseParser] = DefaultParser,
153 socket_read_size: int = 65536,
154 health_check_interval: float = 0,
155 client_name: Optional[str] = None,
156 lib_name: Optional[str] = "redis-py",
157 lib_version: Optional[str] = get_lib_version(),
158 username: Optional[str] = None,
159 retry: Optional[Retry] = None,
160 redis_connect_func: Optional[ConnectCallbackT] = None,
161 encoder_class: Type[Encoder] = Encoder,
162 credential_provider: Optional[CredentialProvider] = None,
163 protocol: Optional[int] = 2,
164 event_dispatcher: Optional[EventDispatcher] = None,
165 ):
166 if (username or password) and credential_provider is not None:
167 raise DataError(
168 "'username' and 'password' cannot be passed along with 'credential_"
169 "provider'. Please provide only one of the following arguments: \n"
170 "1. 'password' and (optional) 'username'\n"
171 "2. 'credential_provider'"
172 )
173 if event_dispatcher is None:
174 self._event_dispatcher = EventDispatcher()
175 else:
176 self._event_dispatcher = event_dispatcher
177 self.db = db
178 self.client_name = client_name
179 self.lib_name = lib_name
180 self.lib_version = lib_version
181 self.credential_provider = credential_provider
182 self.password = password
183 self.username = username
184 self.socket_timeout = socket_timeout
185 if socket_connect_timeout is None:
186 socket_connect_timeout = socket_timeout
187 self.socket_connect_timeout = socket_connect_timeout
188 self.retry_on_timeout = retry_on_timeout
189 if retry_on_error is SENTINEL:
190 retry_on_error = []
191 if retry_on_timeout:
192 retry_on_error.append(TimeoutError)
193 retry_on_error.append(socket.timeout)
194 retry_on_error.append(asyncio.TimeoutError)
195 self.retry_on_error = retry_on_error
196 if retry or retry_on_error:
197 if not retry:
198 self.retry = Retry(NoBackoff(), 1)
199 else:
200 # deep-copy the Retry object as it is mutable
201 self.retry = copy.deepcopy(retry)
202 # Update the retry's supported errors with the specified errors
203 self.retry.update_supported_errors(retry_on_error)
204 else:
205 self.retry = Retry(NoBackoff(), 0)
206 self.health_check_interval = health_check_interval
207 self.next_health_check: float = -1
208 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
209 self.redis_connect_func = redis_connect_func
210 self._reader: Optional[asyncio.StreamReader] = None
211 self._writer: Optional[asyncio.StreamWriter] = None
212 self._socket_read_size = socket_read_size
213 self.set_parser(parser_class)
214 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
215 self._buffer_cutoff = 6000
216 self._re_auth_token: Optional[TokenInterface] = None
217 self._should_reconnect = False
218
219 try:
220 p = int(protocol)
221 except TypeError:
222 p = DEFAULT_RESP_VERSION
223 except ValueError:
224 raise ConnectionError("protocol must be an integer")
225 finally:
226 if p < 2 or p > 3:
227 raise ConnectionError("protocol must be either 2 or 3")
228 self.protocol = protocol
229
230 def __del__(self, _warnings: Any = warnings):
231 # For some reason, the individual streams don't get properly garbage
232 # collected and therefore produce no resource warnings. We add one
233 # here, in the same style as those from the stdlib.
234 if getattr(self, "_writer", None):
235 _warnings.warn(
236 f"unclosed Connection {self!r}", ResourceWarning, source=self
237 )
238
239 try:
240 asyncio.get_running_loop()
241 self._close()
242 except RuntimeError:
243 # No actions been taken if pool already closed.
244 pass
245
246 def _close(self):
247 """
248 Internal method to silently close the connection without waiting
249 """
250 if self._writer:
251 self._writer.close()
252 self._writer = self._reader = None
253
254 def __repr__(self):
255 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
256 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
257
258 @abstractmethod
259 def repr_pieces(self):
260 pass
261
262 @property
263 def is_connected(self):
264 return self._reader is not None and self._writer is not None
265
266 def register_connect_callback(self, callback):
267 """
268 Register a callback to be called when the connection is established either
269 initially or reconnected. This allows listeners to issue commands that
270 are ephemeral to the connection, for example pub/sub subscription or
271 key tracking. The callback must be a _method_ and will be kept as
272 a weak reference.
273 """
274 wm = weakref.WeakMethod(callback)
275 if wm not in self._connect_callbacks:
276 self._connect_callbacks.append(wm)
277
278 def deregister_connect_callback(self, callback):
279 """
280 De-register a previously registered callback. It will no-longer receive
281 notifications on connection events. Calling this is not required when the
282 listener goes away, since the callbacks are kept as weak methods.
283 """
284 try:
285 self._connect_callbacks.remove(weakref.WeakMethod(callback))
286 except ValueError:
287 pass
288
289 def set_parser(self, parser_class: Type[BaseParser]) -> None:
290 """
291 Creates a new instance of parser_class with socket size:
292 _socket_read_size and assigns it to the parser for the connection
293 :param parser_class: The required parser class
294 """
295 self._parser = parser_class(socket_read_size=self._socket_read_size)
296
297 async def connect(self):
298 """Connects to the Redis server if not already connected"""
299 # try once the socket connect with the handshake, retry the whole
300 # connect/handshake flow based on retry policy
301 await self.retry.call_with_retry(
302 lambda: self.connect_check_health(
303 check_health=True, retry_socket_connect=False
304 ),
305 lambda error: self.disconnect(),
306 )
307
308 async def connect_check_health(
309 self, check_health: bool = True, retry_socket_connect: bool = True
310 ):
311 if self.is_connected:
312 return
313 try:
314 if retry_socket_connect:
315 await self.retry.call_with_retry(
316 lambda: self._connect(), lambda error: self.disconnect()
317 )
318 else:
319 await self._connect()
320 except asyncio.CancelledError:
321 raise # in 3.7 and earlier, this is an Exception, not BaseException
322 except (socket.timeout, asyncio.TimeoutError):
323 raise TimeoutError("Timeout connecting to server")
324 except OSError as e:
325 raise ConnectionError(self._error_message(e))
326 except Exception as exc:
327 raise ConnectionError(exc) from exc
328
329 try:
330 if not self.redis_connect_func:
331 # Use the default on_connect function
332 await self.on_connect_check_health(check_health=check_health)
333 else:
334 # Use the passed function redis_connect_func
335 (
336 await self.redis_connect_func(self)
337 if asyncio.iscoroutinefunction(self.redis_connect_func)
338 else self.redis_connect_func(self)
339 )
340 except RedisError:
341 # clean up after any error in on_connect
342 await self.disconnect()
343 raise
344
345 # run any user callbacks. right now the only internal callback
346 # is for pubsub channel/pattern resubscription
347 # first, remove any dead weakrefs
348 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
349 for ref in self._connect_callbacks:
350 callback = ref()
351 task = callback(self)
352 if task and inspect.isawaitable(task):
353 await task
354
355 def mark_for_reconnect(self):
356 self._should_reconnect = True
357
358 def should_reconnect(self):
359 return self._should_reconnect
360
361 @abstractmethod
362 async def _connect(self):
363 pass
364
365 @abstractmethod
366 def _host_error(self) -> str:
367 pass
368
369 def _error_message(self, exception: BaseException) -> str:
370 return format_error_message(self._host_error(), exception)
371
372 def get_protocol(self):
373 return self.protocol
374
375 async def on_connect(self) -> None:
376 """Initialize the connection, authenticate and select a database"""
377 await self.on_connect_check_health(check_health=True)
378
379 async def on_connect_check_health(self, check_health: bool = True) -> None:
380 self._parser.on_connect(self)
381 parser = self._parser
382
383 auth_args = None
384 # if credential provider or username and/or password are set, authenticate
385 if self.credential_provider or (self.username or self.password):
386 cred_provider = (
387 self.credential_provider
388 or UsernamePasswordCredentialProvider(self.username, self.password)
389 )
390 auth_args = await cred_provider.get_credentials_async()
391
392 # if resp version is specified and we have auth args,
393 # we need to send them via HELLO
394 if auth_args and self.protocol not in [2, "2"]:
395 if isinstance(self._parser, _AsyncRESP2Parser):
396 self.set_parser(_AsyncRESP3Parser)
397 # update cluster exception classes
398 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
399 self._parser.on_connect(self)
400 if len(auth_args) == 1:
401 auth_args = ["default", auth_args[0]]
402 # avoid checking health here -- PING will fail if we try
403 # to check the health prior to the AUTH
404 await self.send_command(
405 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
406 )
407 response = await self.read_response()
408 if response.get(b"proto") != int(self.protocol) and response.get(
409 "proto"
410 ) != int(self.protocol):
411 raise ConnectionError("Invalid RESP version")
412 # avoid checking health here -- PING will fail if we try
413 # to check the health prior to the AUTH
414 elif auth_args:
415 await self.send_command("AUTH", *auth_args, check_health=False)
416
417 try:
418 auth_response = await self.read_response()
419 except AuthenticationWrongNumberOfArgsError:
420 # a username and password were specified but the Redis
421 # server seems to be < 6.0.0 which expects a single password
422 # arg. retry auth with just the password.
423 # https://github.com/andymccurdy/redis-py/issues/1274
424 await self.send_command("AUTH", auth_args[-1], check_health=False)
425 auth_response = await self.read_response()
426
427 if str_if_bytes(auth_response) != "OK":
428 raise AuthenticationError("Invalid Username or Password")
429
430 # if resp version is specified, switch to it
431 elif self.protocol not in [2, "2"]:
432 if isinstance(self._parser, _AsyncRESP2Parser):
433 self.set_parser(_AsyncRESP3Parser)
434 # update cluster exception classes
435 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
436 self._parser.on_connect(self)
437 await self.send_command("HELLO", self.protocol, check_health=check_health)
438 response = await self.read_response()
439 # if response.get(b"proto") != self.protocol and response.get(
440 # "proto"
441 # ) != self.protocol:
442 # raise ConnectionError("Invalid RESP version")
443
444 # if a client_name is given, set it
445 if self.client_name:
446 await self.send_command(
447 "CLIENT",
448 "SETNAME",
449 self.client_name,
450 check_health=check_health,
451 )
452 if str_if_bytes(await self.read_response()) != "OK":
453 raise ConnectionError("Error setting client name")
454
455 # set the library name and version, pipeline for lower startup latency
456 if self.lib_name:
457 await self.send_command(
458 "CLIENT",
459 "SETINFO",
460 "LIB-NAME",
461 self.lib_name,
462 check_health=check_health,
463 )
464 if self.lib_version:
465 await self.send_command(
466 "CLIENT",
467 "SETINFO",
468 "LIB-VER",
469 self.lib_version,
470 check_health=check_health,
471 )
472 # if a database is specified, switch to it. Also pipeline this
473 if self.db:
474 await self.send_command("SELECT", self.db, check_health=check_health)
475
476 # read responses from pipeline
477 for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
478 try:
479 await self.read_response()
480 except ResponseError:
481 pass
482
483 if self.db:
484 if str_if_bytes(await self.read_response()) != "OK":
485 raise ConnectionError("Invalid Database")
486
487 async def disconnect(self, nowait: bool = False) -> None:
488 """Disconnects from the Redis server"""
489 try:
490 async with async_timeout(self.socket_connect_timeout):
491 self._parser.on_disconnect()
492 if not self.is_connected:
493 return
494 try:
495 self._writer.close() # type: ignore[union-attr]
496 # wait for close to finish, except when handling errors and
497 # forcefully disconnecting.
498 if not nowait:
499 await self._writer.wait_closed() # type: ignore[union-attr]
500 except OSError:
501 pass
502 finally:
503 self._reader = None
504 self._writer = None
505 except asyncio.TimeoutError:
506 raise TimeoutError(
507 f"Timed out closing connection after {self.socket_connect_timeout}"
508 ) from None
509
510 async def _send_ping(self):
511 """Send PING, expect PONG in return"""
512 await self.send_command("PING", check_health=False)
513 if str_if_bytes(await self.read_response()) != "PONG":
514 raise ConnectionError("Bad response from PING health check")
515
516 async def _ping_failed(self, error):
517 """Function to call when PING fails"""
518 await self.disconnect()
519
520 async def check_health(self):
521 """Check the health of the connection with a PING/PONG"""
522 if (
523 self.health_check_interval
524 and asyncio.get_running_loop().time() > self.next_health_check
525 ):
526 await self.retry.call_with_retry(self._send_ping, self._ping_failed)
527
528 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
529 self._writer.writelines(command)
530 await self._writer.drain()
531
532 async def send_packed_command(
533 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
534 ) -> None:
535 if not self.is_connected:
536 await self.connect_check_health(check_health=False)
537 if check_health:
538 await self.check_health()
539
540 try:
541 if isinstance(command, str):
542 command = command.encode()
543 if isinstance(command, bytes):
544 command = [command]
545 if self.socket_timeout:
546 await asyncio.wait_for(
547 self._send_packed_command(command), self.socket_timeout
548 )
549 else:
550 self._writer.writelines(command)
551 await self._writer.drain()
552 except asyncio.TimeoutError:
553 await self.disconnect(nowait=True)
554 raise TimeoutError("Timeout writing to socket") from None
555 except OSError as e:
556 await self.disconnect(nowait=True)
557 if len(e.args) == 1:
558 err_no, errmsg = "UNKNOWN", e.args[0]
559 else:
560 err_no = e.args[0]
561 errmsg = e.args[1]
562 raise ConnectionError(
563 f"Error {err_no} while writing to socket. {errmsg}."
564 ) from e
565 except BaseException:
566 # BaseExceptions can be raised when a socket send operation is not
567 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
568 # to send un-sent data. However, the send_packed_command() API
569 # does not support it so there is no point in keeping the connection open.
570 await self.disconnect(nowait=True)
571 raise
572
573 async def send_command(self, *args: Any, **kwargs: Any) -> None:
574 """Pack and send a command to the Redis server"""
575 await self.send_packed_command(
576 self.pack_command(*args), check_health=kwargs.get("check_health", True)
577 )
578
579 async def can_read_destructive(self):
580 """Poll the socket to see if there's data that can be read."""
581 try:
582 return await self._parser.can_read_destructive()
583 except OSError as e:
584 await self.disconnect(nowait=True)
585 host_error = self._host_error()
586 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
587
588 async def read_response(
589 self,
590 disable_decoding: bool = False,
591 timeout: Optional[float] = None,
592 *,
593 disconnect_on_error: bool = True,
594 push_request: Optional[bool] = False,
595 ):
596 """Read the response from a previously sent command"""
597 read_timeout = timeout if timeout is not None else self.socket_timeout
598 host_error = self._host_error()
599 try:
600 if read_timeout is not None and self.protocol in ["3", 3]:
601 async with async_timeout(read_timeout):
602 response = await self._parser.read_response(
603 disable_decoding=disable_decoding, push_request=push_request
604 )
605 elif read_timeout is not None:
606 async with async_timeout(read_timeout):
607 response = await self._parser.read_response(
608 disable_decoding=disable_decoding
609 )
610 elif self.protocol in ["3", 3]:
611 response = await self._parser.read_response(
612 disable_decoding=disable_decoding, push_request=push_request
613 )
614 else:
615 response = await self._parser.read_response(
616 disable_decoding=disable_decoding
617 )
618 except asyncio.TimeoutError:
619 if timeout is not None:
620 # user requested timeout, return None. Operation can be retried
621 return None
622 # it was a self.socket_timeout error.
623 if disconnect_on_error:
624 await self.disconnect(nowait=True)
625 raise TimeoutError(f"Timeout reading from {host_error}")
626 except OSError as e:
627 if disconnect_on_error:
628 await self.disconnect(nowait=True)
629 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
630 except BaseException:
631 # Also by default close in case of BaseException. A lot of code
632 # relies on this behaviour when doing Command/Response pairs.
633 # See #1128.
634 if disconnect_on_error:
635 await self.disconnect(nowait=True)
636 raise
637
638 if self.health_check_interval:
639 next_time = asyncio.get_running_loop().time() + self.health_check_interval
640 self.next_health_check = next_time
641
642 if isinstance(response, ResponseError):
643 raise response from None
644 return response
645
646 def pack_command(self, *args: EncodableT) -> List[bytes]:
647 """Pack a series of arguments into the Redis protocol"""
648 output = []
649 # the client might have included 1 or more literal arguments in
650 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
651 # arguments to be sent separately, so split the first argument
652 # manually. These arguments should be bytestrings so that they are
653 # not encoded.
654 assert not isinstance(args[0], float)
655 if isinstance(args[0], str):
656 args = tuple(args[0].encode().split()) + args[1:]
657 elif b" " in args[0]:
658 args = tuple(args[0].split()) + args[1:]
659
660 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
661
662 buffer_cutoff = self._buffer_cutoff
663 for arg in map(self.encoder.encode, args):
664 # to avoid large string mallocs, chunk the command into the
665 # output list if we're sending large values or memoryviews
666 arg_length = len(arg)
667 if (
668 len(buff) > buffer_cutoff
669 or arg_length > buffer_cutoff
670 or isinstance(arg, memoryview)
671 ):
672 buff = SYM_EMPTY.join(
673 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
674 )
675 output.append(buff)
676 output.append(arg)
677 buff = SYM_CRLF
678 else:
679 buff = SYM_EMPTY.join(
680 (
681 buff,
682 SYM_DOLLAR,
683 str(arg_length).encode(),
684 SYM_CRLF,
685 arg,
686 SYM_CRLF,
687 )
688 )
689 output.append(buff)
690 return output
691
692 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
693 """Pack multiple commands into the Redis protocol"""
694 output: List[bytes] = []
695 pieces: List[bytes] = []
696 buffer_length = 0
697 buffer_cutoff = self._buffer_cutoff
698
699 for cmd in commands:
700 for chunk in self.pack_command(*cmd):
701 chunklen = len(chunk)
702 if (
703 buffer_length > buffer_cutoff
704 or chunklen > buffer_cutoff
705 or isinstance(chunk, memoryview)
706 ):
707 if pieces:
708 output.append(SYM_EMPTY.join(pieces))
709 buffer_length = 0
710 pieces = []
711
712 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
713 output.append(chunk)
714 else:
715 pieces.append(chunk)
716 buffer_length += chunklen
717
718 if pieces:
719 output.append(SYM_EMPTY.join(pieces))
720 return output
721
722 def _socket_is_empty(self):
723 """Check if the socket is empty"""
724 return len(self._reader._buffer) == 0
725
726 async def process_invalidation_messages(self):
727 while not self._socket_is_empty():
728 await self.read_response(push_request=True)
729
730 def set_re_auth_token(self, token: TokenInterface):
731 self._re_auth_token = token
732
733 async def re_auth(self):
734 if self._re_auth_token is not None:
735 await self.send_command(
736 "AUTH",
737 self._re_auth_token.try_get("oid"),
738 self._re_auth_token.get_value(),
739 )
740 await self.read_response()
741 self._re_auth_token = None
742
743
744class Connection(AbstractConnection):
745 "Manages TCP communication to and from a Redis server"
746
747 def __init__(
748 self,
749 *,
750 host: str = "localhost",
751 port: Union[str, int] = 6379,
752 socket_keepalive: bool = False,
753 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
754 socket_type: int = 0,
755 **kwargs,
756 ):
757 self.host = host
758 self.port = int(port)
759 self.socket_keepalive = socket_keepalive
760 self.socket_keepalive_options = socket_keepalive_options or {}
761 self.socket_type = socket_type
762 super().__init__(**kwargs)
763
764 def repr_pieces(self):
765 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
766 if self.client_name:
767 pieces.append(("client_name", self.client_name))
768 return pieces
769
770 def _connection_arguments(self) -> Mapping:
771 return {"host": self.host, "port": self.port}
772
773 async def _connect(self):
774 """Create a TCP socket connection"""
775 async with async_timeout(self.socket_connect_timeout):
776 reader, writer = await asyncio.open_connection(
777 **self._connection_arguments()
778 )
779 self._reader = reader
780 self._writer = writer
781 sock = writer.transport.get_extra_info("socket")
782 if sock:
783 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
784 try:
785 # TCP_KEEPALIVE
786 if self.socket_keepalive:
787 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
788 for k, v in self.socket_keepalive_options.items():
789 sock.setsockopt(socket.SOL_TCP, k, v)
790
791 except (OSError, TypeError):
792 # `socket_keepalive_options` might contain invalid options
793 # causing an error. Do not leave the connection open.
794 writer.close()
795 raise
796
797 def _host_error(self) -> str:
798 return f"{self.host}:{self.port}"
799
800
801class SSLConnection(Connection):
802 """Manages SSL connections to and from the Redis server(s).
803 This class extends the Connection class, adding SSL functionality, and making
804 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
805 """
806
807 def __init__(
808 self,
809 ssl_keyfile: Optional[str] = None,
810 ssl_certfile: Optional[str] = None,
811 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
812 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
813 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
814 ssl_ca_certs: Optional[str] = None,
815 ssl_ca_data: Optional[str] = None,
816 ssl_check_hostname: bool = True,
817 ssl_min_version: Optional[TLSVersion] = None,
818 ssl_ciphers: Optional[str] = None,
819 **kwargs,
820 ):
821 if not SSL_AVAILABLE:
822 raise RedisError("Python wasn't built with SSL support")
823
824 self.ssl_context: RedisSSLContext = RedisSSLContext(
825 keyfile=ssl_keyfile,
826 certfile=ssl_certfile,
827 cert_reqs=ssl_cert_reqs,
828 include_verify_flags=ssl_include_verify_flags,
829 exclude_verify_flags=ssl_exclude_verify_flags,
830 ca_certs=ssl_ca_certs,
831 ca_data=ssl_ca_data,
832 check_hostname=ssl_check_hostname,
833 min_version=ssl_min_version,
834 ciphers=ssl_ciphers,
835 )
836 super().__init__(**kwargs)
837
838 def _connection_arguments(self) -> Mapping:
839 kwargs = super()._connection_arguments()
840 kwargs["ssl"] = self.ssl_context.get()
841 return kwargs
842
843 @property
844 def keyfile(self):
845 return self.ssl_context.keyfile
846
847 @property
848 def certfile(self):
849 return self.ssl_context.certfile
850
851 @property
852 def cert_reqs(self):
853 return self.ssl_context.cert_reqs
854
855 @property
856 def include_verify_flags(self):
857 return self.ssl_context.include_verify_flags
858
859 @property
860 def exclude_verify_flags(self):
861 return self.ssl_context.exclude_verify_flags
862
863 @property
864 def ca_certs(self):
865 return self.ssl_context.ca_certs
866
867 @property
868 def ca_data(self):
869 return self.ssl_context.ca_data
870
871 @property
872 def check_hostname(self):
873 return self.ssl_context.check_hostname
874
875 @property
876 def min_version(self):
877 return self.ssl_context.min_version
878
879
880class RedisSSLContext:
881 __slots__ = (
882 "keyfile",
883 "certfile",
884 "cert_reqs",
885 "include_verify_flags",
886 "exclude_verify_flags",
887 "ca_certs",
888 "ca_data",
889 "context",
890 "check_hostname",
891 "min_version",
892 "ciphers",
893 )
894
895 def __init__(
896 self,
897 keyfile: Optional[str] = None,
898 certfile: Optional[str] = None,
899 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
900 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
901 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
902 ca_certs: Optional[str] = None,
903 ca_data: Optional[str] = None,
904 check_hostname: bool = False,
905 min_version: Optional[TLSVersion] = None,
906 ciphers: Optional[str] = None,
907 ):
908 if not SSL_AVAILABLE:
909 raise RedisError("Python wasn't built with SSL support")
910
911 self.keyfile = keyfile
912 self.certfile = certfile
913 if cert_reqs is None:
914 cert_reqs = ssl.CERT_NONE
915 elif isinstance(cert_reqs, str):
916 CERT_REQS = { # noqa: N806
917 "none": ssl.CERT_NONE,
918 "optional": ssl.CERT_OPTIONAL,
919 "required": ssl.CERT_REQUIRED,
920 }
921 if cert_reqs not in CERT_REQS:
922 raise RedisError(
923 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
924 )
925 cert_reqs = CERT_REQS[cert_reqs]
926 self.cert_reqs = cert_reqs
927 self.include_verify_flags = include_verify_flags
928 self.exclude_verify_flags = exclude_verify_flags
929 self.ca_certs = ca_certs
930 self.ca_data = ca_data
931 self.check_hostname = (
932 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
933 )
934 self.min_version = min_version
935 self.ciphers = ciphers
936 self.context: Optional[SSLContext] = None
937
938 def get(self) -> SSLContext:
939 if not self.context:
940 context = ssl.create_default_context()
941 context.check_hostname = self.check_hostname
942 context.verify_mode = self.cert_reqs
943 if self.include_verify_flags:
944 for flag in self.include_verify_flags:
945 context.verify_flags |= flag
946 if self.exclude_verify_flags:
947 for flag in self.exclude_verify_flags:
948 context.verify_flags &= ~flag
949 if self.certfile and self.keyfile:
950 context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
951 if self.ca_certs or self.ca_data:
952 context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
953 if self.min_version is not None:
954 context.minimum_version = self.min_version
955 if self.ciphers is not None:
956 context.set_ciphers(self.ciphers)
957 self.context = context
958 return self.context
959
960
961class UnixDomainSocketConnection(AbstractConnection):
962 "Manages UDS communication to and from a Redis server"
963
964 def __init__(self, *, path: str = "", **kwargs):
965 self.path = path
966 super().__init__(**kwargs)
967
968 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
969 pieces = [("path", self.path), ("db", self.db)]
970 if self.client_name:
971 pieces.append(("client_name", self.client_name))
972 return pieces
973
974 async def _connect(self):
975 async with async_timeout(self.socket_connect_timeout):
976 reader, writer = await asyncio.open_unix_connection(path=self.path)
977 self._reader = reader
978 self._writer = writer
979 await self.on_connect()
980
981 def _host_error(self) -> str:
982 return self.path
983
984
985FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
986
987
988def to_bool(value) -> Optional[bool]:
989 if value is None or value == "":
990 return None
991 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
992 return False
993 return bool(value)
994
995
996def parse_ssl_verify_flags(value):
997 # flags are passed in as a string representation of a list,
998 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
999 verify_flags_str = value.replace("[", "").replace("]", "")
1000
1001 verify_flags = []
1002 for flag in verify_flags_str.split(","):
1003 flag = flag.strip()
1004 if not hasattr(VerifyFlags, flag):
1005 raise ValueError(f"Invalid ssl verify flag: {flag}")
1006 verify_flags.append(getattr(VerifyFlags, flag))
1007 return verify_flags
1008
1009
1010URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1011 {
1012 "db": int,
1013 "socket_timeout": float,
1014 "socket_connect_timeout": float,
1015 "socket_keepalive": to_bool,
1016 "retry_on_timeout": to_bool,
1017 "max_connections": int,
1018 "health_check_interval": int,
1019 "ssl_check_hostname": to_bool,
1020 "ssl_include_verify_flags": parse_ssl_verify_flags,
1021 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1022 "timeout": float,
1023 }
1024)
1025
1026
1027class ConnectKwargs(TypedDict, total=False):
1028 username: str
1029 password: str
1030 connection_class: Type[AbstractConnection]
1031 host: str
1032 port: int
1033 db: int
1034 path: str
1035
1036
1037def parse_url(url: str) -> ConnectKwargs:
1038 parsed: ParseResult = urlparse(url)
1039 kwargs: ConnectKwargs = {}
1040
1041 for name, value_list in parse_qs(parsed.query).items():
1042 if value_list and len(value_list) > 0:
1043 value = unquote(value_list[0])
1044 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1045 if parser:
1046 try:
1047 kwargs[name] = parser(value)
1048 except (TypeError, ValueError):
1049 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1050 else:
1051 kwargs[name] = value
1052
1053 if parsed.username:
1054 kwargs["username"] = unquote(parsed.username)
1055 if parsed.password:
1056 kwargs["password"] = unquote(parsed.password)
1057
1058 # We only support redis://, rediss:// and unix:// schemes.
1059 if parsed.scheme == "unix":
1060 if parsed.path:
1061 kwargs["path"] = unquote(parsed.path)
1062 kwargs["connection_class"] = UnixDomainSocketConnection
1063
1064 elif parsed.scheme in ("redis", "rediss"):
1065 if parsed.hostname:
1066 kwargs["host"] = unquote(parsed.hostname)
1067 if parsed.port:
1068 kwargs["port"] = int(parsed.port)
1069
1070 # If there's a path argument, use it as the db argument if a
1071 # querystring value wasn't specified
1072 if parsed.path and "db" not in kwargs:
1073 try:
1074 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1075 except (AttributeError, ValueError):
1076 pass
1077
1078 if parsed.scheme == "rediss":
1079 kwargs["connection_class"] = SSLConnection
1080
1081 else:
1082 valid_schemes = "redis://, rediss://, unix://"
1083 raise ValueError(
1084 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1085 )
1086
1087 return kwargs
1088
1089
1090_CP = TypeVar("_CP", bound="ConnectionPool")
1091
1092
1093class ConnectionPool:
1094 """
1095 Create a connection pool. ``If max_connections`` is set, then this
1096 object raises :py:class:`~redis.ConnectionError` when the pool's
1097 limit is reached.
1098
1099 By default, TCP connections are created unless ``connection_class``
1100 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1101 unix sockets.
1102 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1103
1104 Any additional keyword arguments are passed to the constructor of
1105 ``connection_class``.
1106 """
1107
1108 @classmethod
1109 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1110 """
1111 Return a connection pool configured from the given URL.
1112
1113 For example::
1114
1115 redis://[[username]:[password]]@localhost:6379/0
1116 rediss://[[username]:[password]]@localhost:6379/0
1117 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1118
1119 Three URL schemes are supported:
1120
1121 - `redis://` creates a TCP socket connection. See more at:
1122 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1123 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1124 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1125 - ``unix://``: creates a Unix Domain Socket connection.
1126
1127 The username, password, hostname, path and all querystring values
1128 are passed through urllib.parse.unquote in order to replace any
1129 percent-encoded values with their corresponding characters.
1130
1131 There are several ways to specify a database number. The first value
1132 found will be used:
1133
1134 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1135
1136 2. If using the redis:// or rediss:// schemes, the path argument
1137 of the url, e.g. redis://localhost/0
1138
1139 3. A ``db`` keyword argument to this function.
1140
1141 If none of these options are specified, the default db=0 is used.
1142
1143 All querystring options are cast to their appropriate Python types.
1144 Boolean arguments can be specified with string values "True"/"False"
1145 or "Yes"/"No". Values that cannot be properly cast cause a
1146 ``ValueError`` to be raised. Once parsed, the querystring arguments
1147 and keyword arguments are passed to the ``ConnectionPool``'s
1148 class initializer. In the case of conflicting arguments, querystring
1149 arguments always win.
1150 """
1151 url_options = parse_url(url)
1152 kwargs.update(url_options)
1153 return cls(**kwargs)
1154
1155 def __init__(
1156 self,
1157 connection_class: Type[AbstractConnection] = Connection,
1158 max_connections: Optional[int] = None,
1159 **connection_kwargs,
1160 ):
1161 max_connections = max_connections or 2**31
1162 if not isinstance(max_connections, int) or max_connections < 0:
1163 raise ValueError('"max_connections" must be a positive integer')
1164
1165 self.connection_class = connection_class
1166 self.connection_kwargs = connection_kwargs
1167 self.max_connections = max_connections
1168
1169 self._available_connections: List[AbstractConnection] = []
1170 self._in_use_connections: Set[AbstractConnection] = set()
1171 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1172 self._lock = asyncio.Lock()
1173 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1174 if self._event_dispatcher is None:
1175 self._event_dispatcher = EventDispatcher()
1176
1177 def __repr__(self):
1178 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1179 return (
1180 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1181 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1182 f"({conn_kwargs})>)>"
1183 )
1184
1185 def reset(self):
1186 self._available_connections = []
1187 self._in_use_connections = weakref.WeakSet()
1188
1189 def can_get_connection(self) -> bool:
1190 """Return True if a connection can be retrieved from the pool."""
1191 return (
1192 self._available_connections
1193 or len(self._in_use_connections) < self.max_connections
1194 )
1195
1196 @deprecated_args(
1197 args_to_warn=["*"],
1198 reason="Use get_connection() without args instead",
1199 version="5.3.0",
1200 )
1201 async def get_connection(self, command_name=None, *keys, **options):
1202 async with self._lock:
1203 """Get a connected connection from the pool"""
1204 connection = self.get_available_connection()
1205 try:
1206 await self.ensure_connection(connection)
1207 except BaseException:
1208 await self.release(connection)
1209 raise
1210
1211 return connection
1212
1213 def get_available_connection(self):
1214 """Get a connection from the pool, without making sure it is connected"""
1215 try:
1216 connection = self._available_connections.pop()
1217 except IndexError:
1218 if len(self._in_use_connections) >= self.max_connections:
1219 raise MaxConnectionsError("Too many connections") from None
1220 connection = self.make_connection()
1221 self._in_use_connections.add(connection)
1222 return connection
1223
1224 def get_encoder(self):
1225 """Return an encoder based on encoding settings"""
1226 kwargs = self.connection_kwargs
1227 return self.encoder_class(
1228 encoding=kwargs.get("encoding", "utf-8"),
1229 encoding_errors=kwargs.get("encoding_errors", "strict"),
1230 decode_responses=kwargs.get("decode_responses", False),
1231 )
1232
1233 def make_connection(self):
1234 """Create a new connection. Can be overridden by child classes."""
1235 return self.connection_class(**self.connection_kwargs)
1236
1237 async def ensure_connection(self, connection: AbstractConnection):
1238 """Ensure that the connection object is connected and valid"""
1239 await connection.connect()
1240 # connections that the pool provides should be ready to send
1241 # a command. if not, the connection was either returned to the
1242 # pool before all data has been read or the socket has been
1243 # closed. either way, reconnect and verify everything is good.
1244 try:
1245 if await connection.can_read_destructive():
1246 raise ConnectionError("Connection has data") from None
1247 except (ConnectionError, TimeoutError, OSError):
1248 await connection.disconnect()
1249 await connection.connect()
1250 if await connection.can_read_destructive():
1251 raise ConnectionError("Connection not ready") from None
1252
1253 async def release(self, connection: AbstractConnection):
1254 """Releases the connection back to the pool"""
1255 # Connections should always be returned to the correct pool,
1256 # not doing so is an error that will cause an exception here.
1257 self._in_use_connections.remove(connection)
1258 if connection.should_reconnect():
1259 await connection.disconnect()
1260
1261 self._available_connections.append(connection)
1262 await self._event_dispatcher.dispatch_async(
1263 AsyncAfterConnectionReleasedEvent(connection)
1264 )
1265
1266 async def disconnect(self, inuse_connections: bool = True):
1267 """
1268 Disconnects connections in the pool
1269
1270 If ``inuse_connections`` is True, disconnect connections that are
1271 current in use, potentially by other tasks. Otherwise only disconnect
1272 connections that are idle in the pool.
1273 """
1274 if inuse_connections:
1275 connections: Iterable[AbstractConnection] = chain(
1276 self._available_connections, self._in_use_connections
1277 )
1278 else:
1279 connections = self._available_connections
1280 resp = await asyncio.gather(
1281 *(connection.disconnect() for connection in connections),
1282 return_exceptions=True,
1283 )
1284 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1285 if exc:
1286 raise exc
1287
1288 async def update_active_connections_for_reconnect(self):
1289 """
1290 Mark all active connections for reconnect.
1291 """
1292 async with self._lock:
1293 for conn in self._in_use_connections:
1294 conn.mark_for_reconnect()
1295
1296 async def aclose(self) -> None:
1297 """Close the pool, disconnecting all connections"""
1298 await self.disconnect()
1299
1300 def set_retry(self, retry: "Retry") -> None:
1301 for conn in self._available_connections:
1302 conn.retry = retry
1303 for conn in self._in_use_connections:
1304 conn.retry = retry
1305
1306 async def re_auth_callback(self, token: TokenInterface):
1307 async with self._lock:
1308 for conn in self._available_connections:
1309 await conn.retry.call_with_retry(
1310 lambda: conn.send_command(
1311 "AUTH", token.try_get("oid"), token.get_value()
1312 ),
1313 lambda error: self._mock(error),
1314 )
1315 await conn.retry.call_with_retry(
1316 lambda: conn.read_response(), lambda error: self._mock(error)
1317 )
1318 for conn in self._in_use_connections:
1319 conn.set_re_auth_token(token)
1320
1321 async def _mock(self, error: RedisError):
1322 """
1323 Dummy functions, needs to be passed as error callback to retry object.
1324 :param error:
1325 :return:
1326 """
1327 pass
1328
1329
1330class BlockingConnectionPool(ConnectionPool):
1331 """
1332 A blocking connection pool::
1333
1334 >>> from redis.asyncio import Redis, BlockingConnectionPool
1335 >>> client = Redis.from_pool(BlockingConnectionPool())
1336
1337 It performs the same function as the default
1338 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1339 it maintains a pool of reusable connections that can be shared by
1340 multiple async redis clients.
1341
1342 The difference is that, in the event that a client tries to get a
1343 connection from the pool when all of connections are in use, rather than
1344 raising a :py:class:`~redis.ConnectionError` (as the default
1345 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1346 blocks the current `Task` for a specified number of seconds until
1347 a connection becomes available.
1348
1349 Use ``max_connections`` to increase / decrease the pool size::
1350
1351 >>> pool = BlockingConnectionPool(max_connections=10)
1352
1353 Use ``timeout`` to tell it either how many seconds to wait for a connection
1354 to become available, or to block forever:
1355
1356 >>> # Block forever.
1357 >>> pool = BlockingConnectionPool(timeout=None)
1358
1359 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1360 >>> # not available.
1361 >>> pool = BlockingConnectionPool(timeout=5)
1362 """
1363
1364 def __init__(
1365 self,
1366 max_connections: int = 50,
1367 timeout: Optional[float] = 20,
1368 connection_class: Type[AbstractConnection] = Connection,
1369 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1370 **connection_kwargs,
1371 ):
1372 super().__init__(
1373 connection_class=connection_class,
1374 max_connections=max_connections,
1375 **connection_kwargs,
1376 )
1377 self._condition = asyncio.Condition()
1378 self.timeout = timeout
1379
1380 @deprecated_args(
1381 args_to_warn=["*"],
1382 reason="Use get_connection() without args instead",
1383 version="5.3.0",
1384 )
1385 async def get_connection(self, command_name=None, *keys, **options):
1386 """Gets a connection from the pool, blocking until one is available"""
1387 try:
1388 async with self._condition:
1389 async with async_timeout(self.timeout):
1390 await self._condition.wait_for(self.can_get_connection)
1391 connection = super().get_available_connection()
1392 except asyncio.TimeoutError as err:
1393 raise ConnectionError("No connection available.") from err
1394
1395 # We now perform the connection check outside of the lock.
1396 try:
1397 await self.ensure_connection(connection)
1398 return connection
1399 except BaseException:
1400 await self.release(connection)
1401 raise
1402
1403 async def release(self, connection: AbstractConnection):
1404 """Releases the connection back to the pool."""
1405 async with self._condition:
1406 await super().release(connection)
1407 self._condition.notify()