1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import warnings
8import weakref
9from abc import abstractmethod
10from itertools import chain
11from types import MappingProxyType
12from typing import (
13 Any,
14 Callable,
15 Iterable,
16 List,
17 Mapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26)
27from urllib.parse import ParseResult, parse_qs, unquote, urlparse
28
29from ..utils import SSL_AVAILABLE
30
31if SSL_AVAILABLE:
32 import ssl
33 from ssl import SSLContext, TLSVersion
34else:
35 ssl = None
36 TLSVersion = None
37 SSLContext = None
38
39from ..auth.token import TokenInterface
40from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
41from ..utils import deprecated_args, format_error_message
42
43# the functionality is available in 3.11.x but has a major issue before
44# 3.11.3. See https://github.com/redis/redis-py/issues/2633
45if sys.version_info >= (3, 11, 3):
46 from asyncio import timeout as async_timeout
47else:
48 from async_timeout import timeout as async_timeout
49
50from redis.asyncio.retry import Retry
51from redis.backoff import NoBackoff
52from redis.connection import DEFAULT_RESP_VERSION
53from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
54from redis.exceptions import (
55 AuthenticationError,
56 AuthenticationWrongNumberOfArgsError,
57 ConnectionError,
58 DataError,
59 RedisError,
60 ResponseError,
61 TimeoutError,
62)
63from redis.typing import EncodableT
64from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
65
66from .._parsers import (
67 BaseParser,
68 Encoder,
69 _AsyncHiredisParser,
70 _AsyncRESP2Parser,
71 _AsyncRESP3Parser,
72)
73
74SYM_STAR = b"*"
75SYM_DOLLAR = b"$"
76SYM_CRLF = b"\r\n"
77SYM_LF = b"\n"
78SYM_EMPTY = b""
79
80
81class _Sentinel(enum.Enum):
82 sentinel = object()
83
84
85SENTINEL = _Sentinel.sentinel
86
87
88DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
89if HIREDIS_AVAILABLE:
90 DefaultParser = _AsyncHiredisParser
91else:
92 DefaultParser = _AsyncRESP2Parser
93
94
95class ConnectCallbackProtocol(Protocol):
96 def __call__(self, connection: "AbstractConnection"): ...
97
98
99class AsyncConnectCallbackProtocol(Protocol):
100 async def __call__(self, connection: "AbstractConnection"): ...
101
102
103ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
104
105
106class AbstractConnection:
107 """Manages communication to and from a Redis server"""
108
109 __slots__ = (
110 "db",
111 "username",
112 "client_name",
113 "lib_name",
114 "lib_version",
115 "credential_provider",
116 "password",
117 "socket_timeout",
118 "socket_connect_timeout",
119 "redis_connect_func",
120 "retry_on_timeout",
121 "retry_on_error",
122 "health_check_interval",
123 "next_health_check",
124 "last_active_at",
125 "encoder",
126 "ssl_context",
127 "protocol",
128 "_reader",
129 "_writer",
130 "_parser",
131 "_connect_callbacks",
132 "_buffer_cutoff",
133 "_lock",
134 "_socket_read_size",
135 "__dict__",
136 )
137
138 def __init__(
139 self,
140 *,
141 db: Union[str, int] = 0,
142 password: Optional[str] = None,
143 socket_timeout: Optional[float] = None,
144 socket_connect_timeout: Optional[float] = None,
145 retry_on_timeout: bool = False,
146 retry_on_error: Union[list, _Sentinel] = SENTINEL,
147 encoding: str = "utf-8",
148 encoding_errors: str = "strict",
149 decode_responses: bool = False,
150 parser_class: Type[BaseParser] = DefaultParser,
151 socket_read_size: int = 65536,
152 health_check_interval: float = 0,
153 client_name: Optional[str] = None,
154 lib_name: Optional[str] = "redis-py",
155 lib_version: Optional[str] = get_lib_version(),
156 username: Optional[str] = None,
157 retry: Optional[Retry] = None,
158 redis_connect_func: Optional[ConnectCallbackT] = None,
159 encoder_class: Type[Encoder] = Encoder,
160 credential_provider: Optional[CredentialProvider] = None,
161 protocol: Optional[int] = 2,
162 event_dispatcher: Optional[EventDispatcher] = None,
163 ):
164 if (username or password) and credential_provider is not None:
165 raise DataError(
166 "'username' and 'password' cannot be passed along with 'credential_"
167 "provider'. Please provide only one of the following arguments: \n"
168 "1. 'password' and (optional) 'username'\n"
169 "2. 'credential_provider'"
170 )
171 if event_dispatcher is None:
172 self._event_dispatcher = EventDispatcher()
173 else:
174 self._event_dispatcher = event_dispatcher
175 self.db = db
176 self.client_name = client_name
177 self.lib_name = lib_name
178 self.lib_version = lib_version
179 self.credential_provider = credential_provider
180 self.password = password
181 self.username = username
182 self.socket_timeout = socket_timeout
183 if socket_connect_timeout is None:
184 socket_connect_timeout = socket_timeout
185 self.socket_connect_timeout = socket_connect_timeout
186 self.retry_on_timeout = retry_on_timeout
187 if retry_on_error is SENTINEL:
188 retry_on_error = []
189 if retry_on_timeout:
190 retry_on_error.append(TimeoutError)
191 retry_on_error.append(socket.timeout)
192 retry_on_error.append(asyncio.TimeoutError)
193 self.retry_on_error = retry_on_error
194 if retry or retry_on_error:
195 if not retry:
196 self.retry = Retry(NoBackoff(), 1)
197 else:
198 # deep-copy the Retry object as it is mutable
199 self.retry = copy.deepcopy(retry)
200 # Update the retry's supported errors with the specified errors
201 self.retry.update_supported_errors(retry_on_error)
202 else:
203 self.retry = Retry(NoBackoff(), 0)
204 self.health_check_interval = health_check_interval
205 self.next_health_check: float = -1
206 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
207 self.redis_connect_func = redis_connect_func
208 self._reader: Optional[asyncio.StreamReader] = None
209 self._writer: Optional[asyncio.StreamWriter] = None
210 self._socket_read_size = socket_read_size
211 self.set_parser(parser_class)
212 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
213 self._buffer_cutoff = 6000
214 self._re_auth_token: Optional[TokenInterface] = None
215
216 try:
217 p = int(protocol)
218 except TypeError:
219 p = DEFAULT_RESP_VERSION
220 except ValueError:
221 raise ConnectionError("protocol must be an integer")
222 finally:
223 if p < 2 or p > 3:
224 raise ConnectionError("protocol must be either 2 or 3")
225 self.protocol = protocol
226
227 def __del__(self, _warnings: Any = warnings):
228 # For some reason, the individual streams don't get properly garbage
229 # collected and therefore produce no resource warnings. We add one
230 # here, in the same style as those from the stdlib.
231 if getattr(self, "_writer", None):
232 _warnings.warn(
233 f"unclosed Connection {self!r}", ResourceWarning, source=self
234 )
235
236 try:
237 asyncio.get_running_loop()
238 self._close()
239 except RuntimeError:
240 # No actions been taken if pool already closed.
241 pass
242
243 def _close(self):
244 """
245 Internal method to silently close the connection without waiting
246 """
247 if self._writer:
248 self._writer.close()
249 self._writer = self._reader = None
250
251 def __repr__(self):
252 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
253 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
254
255 @abstractmethod
256 def repr_pieces(self):
257 pass
258
259 @property
260 def is_connected(self):
261 return self._reader is not None and self._writer is not None
262
263 def register_connect_callback(self, callback):
264 """
265 Register a callback to be called when the connection is established either
266 initially or reconnected. This allows listeners to issue commands that
267 are ephemeral to the connection, for example pub/sub subscription or
268 key tracking. The callback must be a _method_ and will be kept as
269 a weak reference.
270 """
271 wm = weakref.WeakMethod(callback)
272 if wm not in self._connect_callbacks:
273 self._connect_callbacks.append(wm)
274
275 def deregister_connect_callback(self, callback):
276 """
277 De-register a previously registered callback. It will no-longer receive
278 notifications on connection events. Calling this is not required when the
279 listener goes away, since the callbacks are kept as weak methods.
280 """
281 try:
282 self._connect_callbacks.remove(weakref.WeakMethod(callback))
283 except ValueError:
284 pass
285
286 def set_parser(self, parser_class: Type[BaseParser]) -> None:
287 """
288 Creates a new instance of parser_class with socket size:
289 _socket_read_size and assigns it to the parser for the connection
290 :param parser_class: The required parser class
291 """
292 self._parser = parser_class(socket_read_size=self._socket_read_size)
293
294 async def connect(self):
295 """Connects to the Redis server if not already connected"""
296 await self.connect_check_health(check_health=True)
297
298 async def connect_check_health(
299 self, check_health: bool = True, retry_socket_connect: bool = True
300 ):
301 if self.is_connected:
302 return
303 try:
304 if retry_socket_connect:
305 await self.retry.call_with_retry(
306 lambda: self._connect(), lambda error: self.disconnect()
307 )
308 else:
309 await self._connect()
310 except asyncio.CancelledError:
311 raise # in 3.7 and earlier, this is an Exception, not BaseException
312 except (socket.timeout, asyncio.TimeoutError):
313 raise TimeoutError("Timeout connecting to server")
314 except OSError as e:
315 raise ConnectionError(self._error_message(e))
316 except Exception as exc:
317 raise ConnectionError(exc) from exc
318
319 try:
320 if not self.redis_connect_func:
321 # Use the default on_connect function
322 await self.on_connect_check_health(check_health=check_health)
323 else:
324 # Use the passed function redis_connect_func
325 (
326 await self.redis_connect_func(self)
327 if asyncio.iscoroutinefunction(self.redis_connect_func)
328 else self.redis_connect_func(self)
329 )
330 except RedisError:
331 # clean up after any error in on_connect
332 await self.disconnect()
333 raise
334
335 # run any user callbacks. right now the only internal callback
336 # is for pubsub channel/pattern resubscription
337 # first, remove any dead weakrefs
338 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
339 for ref in self._connect_callbacks:
340 callback = ref()
341 task = callback(self)
342 if task and inspect.isawaitable(task):
343 await task
344
345 @abstractmethod
346 async def _connect(self):
347 pass
348
349 @abstractmethod
350 def _host_error(self) -> str:
351 pass
352
353 def _error_message(self, exception: BaseException) -> str:
354 return format_error_message(self._host_error(), exception)
355
356 def get_protocol(self):
357 return self.protocol
358
359 async def on_connect(self) -> None:
360 """Initialize the connection, authenticate and select a database"""
361 await self.on_connect_check_health(check_health=True)
362
363 async def on_connect_check_health(self, check_health: bool = True) -> None:
364 self._parser.on_connect(self)
365 parser = self._parser
366
367 auth_args = None
368 # if credential provider or username and/or password are set, authenticate
369 if self.credential_provider or (self.username or self.password):
370 cred_provider = (
371 self.credential_provider
372 or UsernamePasswordCredentialProvider(self.username, self.password)
373 )
374 auth_args = await cred_provider.get_credentials_async()
375
376 # if resp version is specified and we have auth args,
377 # we need to send them via HELLO
378 if auth_args and self.protocol not in [2, "2"]:
379 if isinstance(self._parser, _AsyncRESP2Parser):
380 self.set_parser(_AsyncRESP3Parser)
381 # update cluster exception classes
382 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
383 self._parser.on_connect(self)
384 if len(auth_args) == 1:
385 auth_args = ["default", auth_args[0]]
386 # avoid checking health here -- PING will fail if we try
387 # to check the health prior to the AUTH
388 await self.send_command(
389 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
390 )
391 response = await self.read_response()
392 if response.get(b"proto") != int(self.protocol) and response.get(
393 "proto"
394 ) != int(self.protocol):
395 raise ConnectionError("Invalid RESP version")
396 # avoid checking health here -- PING will fail if we try
397 # to check the health prior to the AUTH
398 elif auth_args:
399 await self.send_command("AUTH", *auth_args, check_health=False)
400
401 try:
402 auth_response = await self.read_response()
403 except AuthenticationWrongNumberOfArgsError:
404 # a username and password were specified but the Redis
405 # server seems to be < 6.0.0 which expects a single password
406 # arg. retry auth with just the password.
407 # https://github.com/andymccurdy/redis-py/issues/1274
408 await self.send_command("AUTH", auth_args[-1], check_health=False)
409 auth_response = await self.read_response()
410
411 if str_if_bytes(auth_response) != "OK":
412 raise AuthenticationError("Invalid Username or Password")
413
414 # if resp version is specified, switch to it
415 elif self.protocol not in [2, "2"]:
416 if isinstance(self._parser, _AsyncRESP2Parser):
417 self.set_parser(_AsyncRESP3Parser)
418 # update cluster exception classes
419 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
420 self._parser.on_connect(self)
421 await self.send_command("HELLO", self.protocol, check_health=check_health)
422 response = await self.read_response()
423 # if response.get(b"proto") != self.protocol and response.get(
424 # "proto"
425 # ) != self.protocol:
426 # raise ConnectionError("Invalid RESP version")
427
428 # if a client_name is given, set it
429 if self.client_name:
430 await self.send_command(
431 "CLIENT",
432 "SETNAME",
433 self.client_name,
434 check_health=check_health,
435 )
436 if str_if_bytes(await self.read_response()) != "OK":
437 raise ConnectionError("Error setting client name")
438
439 # set the library name and version, pipeline for lower startup latency
440 if self.lib_name:
441 await self.send_command(
442 "CLIENT",
443 "SETINFO",
444 "LIB-NAME",
445 self.lib_name,
446 check_health=check_health,
447 )
448 if self.lib_version:
449 await self.send_command(
450 "CLIENT",
451 "SETINFO",
452 "LIB-VER",
453 self.lib_version,
454 check_health=check_health,
455 )
456 # if a database is specified, switch to it. Also pipeline this
457 if self.db:
458 await self.send_command("SELECT", self.db, check_health=check_health)
459
460 # read responses from pipeline
461 for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
462 try:
463 await self.read_response()
464 except ResponseError:
465 pass
466
467 if self.db:
468 if str_if_bytes(await self.read_response()) != "OK":
469 raise ConnectionError("Invalid Database")
470
471 async def disconnect(self, nowait: bool = False) -> None:
472 """Disconnects from the Redis server"""
473 try:
474 async with async_timeout(self.socket_connect_timeout):
475 self._parser.on_disconnect()
476 if not self.is_connected:
477 return
478 try:
479 self._writer.close() # type: ignore[union-attr]
480 # wait for close to finish, except when handling errors and
481 # forcefully disconnecting.
482 if not nowait:
483 await self._writer.wait_closed() # type: ignore[union-attr]
484 except OSError:
485 pass
486 finally:
487 self._reader = None
488 self._writer = None
489 except asyncio.TimeoutError:
490 raise TimeoutError(
491 f"Timed out closing connection after {self.socket_connect_timeout}"
492 ) from None
493
494 async def _send_ping(self):
495 """Send PING, expect PONG in return"""
496 await self.send_command("PING", check_health=False)
497 if str_if_bytes(await self.read_response()) != "PONG":
498 raise ConnectionError("Bad response from PING health check")
499
500 async def _ping_failed(self, error):
501 """Function to call when PING fails"""
502 await self.disconnect()
503
504 async def check_health(self):
505 """Check the health of the connection with a PING/PONG"""
506 if (
507 self.health_check_interval
508 and asyncio.get_running_loop().time() > self.next_health_check
509 ):
510 await self.retry.call_with_retry(self._send_ping, self._ping_failed)
511
512 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
513 self._writer.writelines(command)
514 await self._writer.drain()
515
516 async def send_packed_command(
517 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
518 ) -> None:
519 if not self.is_connected:
520 await self.connect_check_health(check_health=False)
521 if check_health:
522 await self.check_health()
523
524 try:
525 if isinstance(command, str):
526 command = command.encode()
527 if isinstance(command, bytes):
528 command = [command]
529 if self.socket_timeout:
530 await asyncio.wait_for(
531 self._send_packed_command(command), self.socket_timeout
532 )
533 else:
534 self._writer.writelines(command)
535 await self._writer.drain()
536 except asyncio.TimeoutError:
537 await self.disconnect(nowait=True)
538 raise TimeoutError("Timeout writing to socket") from None
539 except OSError as e:
540 await self.disconnect(nowait=True)
541 if len(e.args) == 1:
542 err_no, errmsg = "UNKNOWN", e.args[0]
543 else:
544 err_no = e.args[0]
545 errmsg = e.args[1]
546 raise ConnectionError(
547 f"Error {err_no} while writing to socket. {errmsg}."
548 ) from e
549 except BaseException:
550 # BaseExceptions can be raised when a socket send operation is not
551 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
552 # to send un-sent data. However, the send_packed_command() API
553 # does not support it so there is no point in keeping the connection open.
554 await self.disconnect(nowait=True)
555 raise
556
557 async def send_command(self, *args: Any, **kwargs: Any) -> None:
558 """Pack and send a command to the Redis server"""
559 await self.send_packed_command(
560 self.pack_command(*args), check_health=kwargs.get("check_health", True)
561 )
562
563 async def can_read_destructive(self):
564 """Poll the socket to see if there's data that can be read."""
565 try:
566 return await self._parser.can_read_destructive()
567 except OSError as e:
568 await self.disconnect(nowait=True)
569 host_error = self._host_error()
570 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
571
572 async def read_response(
573 self,
574 disable_decoding: bool = False,
575 timeout: Optional[float] = None,
576 *,
577 disconnect_on_error: bool = True,
578 push_request: Optional[bool] = False,
579 ):
580 """Read the response from a previously sent command"""
581 read_timeout = timeout if timeout is not None else self.socket_timeout
582 host_error = self._host_error()
583 try:
584 if read_timeout is not None and self.protocol in ["3", 3]:
585 async with async_timeout(read_timeout):
586 response = await self._parser.read_response(
587 disable_decoding=disable_decoding, push_request=push_request
588 )
589 elif read_timeout is not None:
590 async with async_timeout(read_timeout):
591 response = await self._parser.read_response(
592 disable_decoding=disable_decoding
593 )
594 elif self.protocol in ["3", 3]:
595 response = await self._parser.read_response(
596 disable_decoding=disable_decoding, push_request=push_request
597 )
598 else:
599 response = await self._parser.read_response(
600 disable_decoding=disable_decoding
601 )
602 except asyncio.TimeoutError:
603 if timeout is not None:
604 # user requested timeout, return None. Operation can be retried
605 return None
606 # it was a self.socket_timeout error.
607 if disconnect_on_error:
608 await self.disconnect(nowait=True)
609 raise TimeoutError(f"Timeout reading from {host_error}")
610 except OSError as e:
611 if disconnect_on_error:
612 await self.disconnect(nowait=True)
613 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
614 except BaseException:
615 # Also by default close in case of BaseException. A lot of code
616 # relies on this behaviour when doing Command/Response pairs.
617 # See #1128.
618 if disconnect_on_error:
619 await self.disconnect(nowait=True)
620 raise
621
622 if self.health_check_interval:
623 next_time = asyncio.get_running_loop().time() + self.health_check_interval
624 self.next_health_check = next_time
625
626 if isinstance(response, ResponseError):
627 raise response from None
628 return response
629
630 def pack_command(self, *args: EncodableT) -> List[bytes]:
631 """Pack a series of arguments into the Redis protocol"""
632 output = []
633 # the client might have included 1 or more literal arguments in
634 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
635 # arguments to be sent separately, so split the first argument
636 # manually. These arguments should be bytestrings so that they are
637 # not encoded.
638 assert not isinstance(args[0], float)
639 if isinstance(args[0], str):
640 args = tuple(args[0].encode().split()) + args[1:]
641 elif b" " in args[0]:
642 args = tuple(args[0].split()) + args[1:]
643
644 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
645
646 buffer_cutoff = self._buffer_cutoff
647 for arg in map(self.encoder.encode, args):
648 # to avoid large string mallocs, chunk the command into the
649 # output list if we're sending large values or memoryviews
650 arg_length = len(arg)
651 if (
652 len(buff) > buffer_cutoff
653 or arg_length > buffer_cutoff
654 or isinstance(arg, memoryview)
655 ):
656 buff = SYM_EMPTY.join(
657 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
658 )
659 output.append(buff)
660 output.append(arg)
661 buff = SYM_CRLF
662 else:
663 buff = SYM_EMPTY.join(
664 (
665 buff,
666 SYM_DOLLAR,
667 str(arg_length).encode(),
668 SYM_CRLF,
669 arg,
670 SYM_CRLF,
671 )
672 )
673 output.append(buff)
674 return output
675
676 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
677 """Pack multiple commands into the Redis protocol"""
678 output: List[bytes] = []
679 pieces: List[bytes] = []
680 buffer_length = 0
681 buffer_cutoff = self._buffer_cutoff
682
683 for cmd in commands:
684 for chunk in self.pack_command(*cmd):
685 chunklen = len(chunk)
686 if (
687 buffer_length > buffer_cutoff
688 or chunklen > buffer_cutoff
689 or isinstance(chunk, memoryview)
690 ):
691 if pieces:
692 output.append(SYM_EMPTY.join(pieces))
693 buffer_length = 0
694 pieces = []
695
696 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
697 output.append(chunk)
698 else:
699 pieces.append(chunk)
700 buffer_length += chunklen
701
702 if pieces:
703 output.append(SYM_EMPTY.join(pieces))
704 return output
705
706 def _socket_is_empty(self):
707 """Check if the socket is empty"""
708 return len(self._reader._buffer) == 0
709
710 async def process_invalidation_messages(self):
711 while not self._socket_is_empty():
712 await self.read_response(push_request=True)
713
714 def set_re_auth_token(self, token: TokenInterface):
715 self._re_auth_token = token
716
717 async def re_auth(self):
718 if self._re_auth_token is not None:
719 await self.send_command(
720 "AUTH",
721 self._re_auth_token.try_get("oid"),
722 self._re_auth_token.get_value(),
723 )
724 await self.read_response()
725 self._re_auth_token = None
726
727
728class Connection(AbstractConnection):
729 "Manages TCP communication to and from a Redis server"
730
731 def __init__(
732 self,
733 *,
734 host: str = "localhost",
735 port: Union[str, int] = 6379,
736 socket_keepalive: bool = False,
737 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
738 socket_type: int = 0,
739 **kwargs,
740 ):
741 self.host = host
742 self.port = int(port)
743 self.socket_keepalive = socket_keepalive
744 self.socket_keepalive_options = socket_keepalive_options or {}
745 self.socket_type = socket_type
746 super().__init__(**kwargs)
747
748 def repr_pieces(self):
749 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
750 if self.client_name:
751 pieces.append(("client_name", self.client_name))
752 return pieces
753
754 def _connection_arguments(self) -> Mapping:
755 return {"host": self.host, "port": self.port}
756
757 async def _connect(self):
758 """Create a TCP socket connection"""
759 async with async_timeout(self.socket_connect_timeout):
760 reader, writer = await asyncio.open_connection(
761 **self._connection_arguments()
762 )
763 self._reader = reader
764 self._writer = writer
765 sock = writer.transport.get_extra_info("socket")
766 if sock:
767 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
768 try:
769 # TCP_KEEPALIVE
770 if self.socket_keepalive:
771 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
772 for k, v in self.socket_keepalive_options.items():
773 sock.setsockopt(socket.SOL_TCP, k, v)
774
775 except (OSError, TypeError):
776 # `socket_keepalive_options` might contain invalid options
777 # causing an error. Do not leave the connection open.
778 writer.close()
779 raise
780
781 def _host_error(self) -> str:
782 return f"{self.host}:{self.port}"
783
784
785class SSLConnection(Connection):
786 """Manages SSL connections to and from the Redis server(s).
787 This class extends the Connection class, adding SSL functionality, and making
788 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
789 """
790
791 def __init__(
792 self,
793 ssl_keyfile: Optional[str] = None,
794 ssl_certfile: Optional[str] = None,
795 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
796 ssl_ca_certs: Optional[str] = None,
797 ssl_ca_data: Optional[str] = None,
798 ssl_check_hostname: bool = True,
799 ssl_min_version: Optional[TLSVersion] = None,
800 ssl_ciphers: Optional[str] = None,
801 **kwargs,
802 ):
803 if not SSL_AVAILABLE:
804 raise RedisError("Python wasn't built with SSL support")
805
806 self.ssl_context: RedisSSLContext = RedisSSLContext(
807 keyfile=ssl_keyfile,
808 certfile=ssl_certfile,
809 cert_reqs=ssl_cert_reqs,
810 ca_certs=ssl_ca_certs,
811 ca_data=ssl_ca_data,
812 check_hostname=ssl_check_hostname,
813 min_version=ssl_min_version,
814 ciphers=ssl_ciphers,
815 )
816 super().__init__(**kwargs)
817
818 def _connection_arguments(self) -> Mapping:
819 kwargs = super()._connection_arguments()
820 kwargs["ssl"] = self.ssl_context.get()
821 return kwargs
822
823 @property
824 def keyfile(self):
825 return self.ssl_context.keyfile
826
827 @property
828 def certfile(self):
829 return self.ssl_context.certfile
830
831 @property
832 def cert_reqs(self):
833 return self.ssl_context.cert_reqs
834
835 @property
836 def ca_certs(self):
837 return self.ssl_context.ca_certs
838
839 @property
840 def ca_data(self):
841 return self.ssl_context.ca_data
842
843 @property
844 def check_hostname(self):
845 return self.ssl_context.check_hostname
846
847 @property
848 def min_version(self):
849 return self.ssl_context.min_version
850
851
852class RedisSSLContext:
853 __slots__ = (
854 "keyfile",
855 "certfile",
856 "cert_reqs",
857 "ca_certs",
858 "ca_data",
859 "context",
860 "check_hostname",
861 "min_version",
862 "ciphers",
863 )
864
865 def __init__(
866 self,
867 keyfile: Optional[str] = None,
868 certfile: Optional[str] = None,
869 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
870 ca_certs: Optional[str] = None,
871 ca_data: Optional[str] = None,
872 check_hostname: bool = False,
873 min_version: Optional[TLSVersion] = None,
874 ciphers: Optional[str] = None,
875 ):
876 if not SSL_AVAILABLE:
877 raise RedisError("Python wasn't built with SSL support")
878
879 self.keyfile = keyfile
880 self.certfile = certfile
881 if cert_reqs is None:
882 cert_reqs = ssl.CERT_NONE
883 elif isinstance(cert_reqs, str):
884 CERT_REQS = { # noqa: N806
885 "none": ssl.CERT_NONE,
886 "optional": ssl.CERT_OPTIONAL,
887 "required": ssl.CERT_REQUIRED,
888 }
889 if cert_reqs not in CERT_REQS:
890 raise RedisError(
891 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
892 )
893 cert_reqs = CERT_REQS[cert_reqs]
894 self.cert_reqs = cert_reqs
895 self.ca_certs = ca_certs
896 self.ca_data = ca_data
897 self.check_hostname = (
898 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
899 )
900 self.min_version = min_version
901 self.ciphers = ciphers
902 self.context: Optional[SSLContext] = None
903
904 def get(self) -> SSLContext:
905 if not self.context:
906 context = ssl.create_default_context()
907 context.check_hostname = self.check_hostname
908 context.verify_mode = self.cert_reqs
909 if self.certfile and self.keyfile:
910 context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
911 if self.ca_certs or self.ca_data:
912 context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
913 if self.min_version is not None:
914 context.minimum_version = self.min_version
915 if self.ciphers is not None:
916 context.set_ciphers(self.ciphers)
917 self.context = context
918 return self.context
919
920
921class UnixDomainSocketConnection(AbstractConnection):
922 "Manages UDS communication to and from a Redis server"
923
924 def __init__(self, *, path: str = "", **kwargs):
925 self.path = path
926 super().__init__(**kwargs)
927
928 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
929 pieces = [("path", self.path), ("db", self.db)]
930 if self.client_name:
931 pieces.append(("client_name", self.client_name))
932 return pieces
933
934 async def _connect(self):
935 async with async_timeout(self.socket_connect_timeout):
936 reader, writer = await asyncio.open_unix_connection(path=self.path)
937 self._reader = reader
938 self._writer = writer
939 await self.on_connect()
940
941 def _host_error(self) -> str:
942 return self.path
943
944
945FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
946
947
948def to_bool(value) -> Optional[bool]:
949 if value is None or value == "":
950 return None
951 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
952 return False
953 return bool(value)
954
955
956URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
957 {
958 "db": int,
959 "socket_timeout": float,
960 "socket_connect_timeout": float,
961 "socket_keepalive": to_bool,
962 "retry_on_timeout": to_bool,
963 "max_connections": int,
964 "health_check_interval": int,
965 "ssl_check_hostname": to_bool,
966 "timeout": float,
967 }
968)
969
970
971class ConnectKwargs(TypedDict, total=False):
972 username: str
973 password: str
974 connection_class: Type[AbstractConnection]
975 host: str
976 port: int
977 db: int
978 path: str
979
980
981def parse_url(url: str) -> ConnectKwargs:
982 parsed: ParseResult = urlparse(url)
983 kwargs: ConnectKwargs = {}
984
985 for name, value_list in parse_qs(parsed.query).items():
986 if value_list and len(value_list) > 0:
987 value = unquote(value_list[0])
988 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
989 if parser:
990 try:
991 kwargs[name] = parser(value)
992 except (TypeError, ValueError):
993 raise ValueError(f"Invalid value for '{name}' in connection URL.")
994 else:
995 kwargs[name] = value
996
997 if parsed.username:
998 kwargs["username"] = unquote(parsed.username)
999 if parsed.password:
1000 kwargs["password"] = unquote(parsed.password)
1001
1002 # We only support redis://, rediss:// and unix:// schemes.
1003 if parsed.scheme == "unix":
1004 if parsed.path:
1005 kwargs["path"] = unquote(parsed.path)
1006 kwargs["connection_class"] = UnixDomainSocketConnection
1007
1008 elif parsed.scheme in ("redis", "rediss"):
1009 if parsed.hostname:
1010 kwargs["host"] = unquote(parsed.hostname)
1011 if parsed.port:
1012 kwargs["port"] = int(parsed.port)
1013
1014 # If there's a path argument, use it as the db argument if a
1015 # querystring value wasn't specified
1016 if parsed.path and "db" not in kwargs:
1017 try:
1018 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1019 except (AttributeError, ValueError):
1020 pass
1021
1022 if parsed.scheme == "rediss":
1023 kwargs["connection_class"] = SSLConnection
1024 else:
1025 valid_schemes = "redis://, rediss://, unix://"
1026 raise ValueError(
1027 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1028 )
1029
1030 return kwargs
1031
1032
1033_CP = TypeVar("_CP", bound="ConnectionPool")
1034
1035
1036class ConnectionPool:
1037 """
1038 Create a connection pool. ``If max_connections`` is set, then this
1039 object raises :py:class:`~redis.ConnectionError` when the pool's
1040 limit is reached.
1041
1042 By default, TCP connections are created unless ``connection_class``
1043 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1044 unix sockets.
1045 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1046
1047 Any additional keyword arguments are passed to the constructor of
1048 ``connection_class``.
1049 """
1050
1051 @classmethod
1052 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1053 """
1054 Return a connection pool configured from the given URL.
1055
1056 For example::
1057
1058 redis://[[username]:[password]]@localhost:6379/0
1059 rediss://[[username]:[password]]@localhost:6379/0
1060 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1061
1062 Three URL schemes are supported:
1063
1064 - `redis://` creates a TCP socket connection. See more at:
1065 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1066 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1067 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1068 - ``unix://``: creates a Unix Domain Socket connection.
1069
1070 The username, password, hostname, path and all querystring values
1071 are passed through urllib.parse.unquote in order to replace any
1072 percent-encoded values with their corresponding characters.
1073
1074 There are several ways to specify a database number. The first value
1075 found will be used:
1076
1077 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1078
1079 2. If using the redis:// or rediss:// schemes, the path argument
1080 of the url, e.g. redis://localhost/0
1081
1082 3. A ``db`` keyword argument to this function.
1083
1084 If none of these options are specified, the default db=0 is used.
1085
1086 All querystring options are cast to their appropriate Python types.
1087 Boolean arguments can be specified with string values "True"/"False"
1088 or "Yes"/"No". Values that cannot be properly cast cause a
1089 ``ValueError`` to be raised. Once parsed, the querystring arguments
1090 and keyword arguments are passed to the ``ConnectionPool``'s
1091 class initializer. In the case of conflicting arguments, querystring
1092 arguments always win.
1093 """
1094 url_options = parse_url(url)
1095 kwargs.update(url_options)
1096 return cls(**kwargs)
1097
1098 def __init__(
1099 self,
1100 connection_class: Type[AbstractConnection] = Connection,
1101 max_connections: Optional[int] = None,
1102 **connection_kwargs,
1103 ):
1104 max_connections = max_connections or 2**31
1105 if not isinstance(max_connections, int) or max_connections < 0:
1106 raise ValueError('"max_connections" must be a positive integer')
1107
1108 self.connection_class = connection_class
1109 self.connection_kwargs = connection_kwargs
1110 self.max_connections = max_connections
1111
1112 self._available_connections: List[AbstractConnection] = []
1113 self._in_use_connections: Set[AbstractConnection] = set()
1114 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1115 self._lock = asyncio.Lock()
1116 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1117 if self._event_dispatcher is None:
1118 self._event_dispatcher = EventDispatcher()
1119
1120 def __repr__(self):
1121 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1122 return (
1123 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1124 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1125 f"({conn_kwargs})>)>"
1126 )
1127
1128 def reset(self):
1129 self._available_connections = []
1130 self._in_use_connections = weakref.WeakSet()
1131
1132 def can_get_connection(self) -> bool:
1133 """Return True if a connection can be retrieved from the pool."""
1134 return (
1135 self._available_connections
1136 or len(self._in_use_connections) < self.max_connections
1137 )
1138
1139 @deprecated_args(
1140 args_to_warn=["*"],
1141 reason="Use get_connection() without args instead",
1142 version="5.3.0",
1143 )
1144 async def get_connection(self, command_name=None, *keys, **options):
1145 async with self._lock:
1146 """Get a connected connection from the pool"""
1147 connection = self.get_available_connection()
1148 try:
1149 await self.ensure_connection(connection)
1150 except BaseException:
1151 await self.release(connection)
1152 raise
1153
1154 return connection
1155
1156 def get_available_connection(self):
1157 """Get a connection from the pool, without making sure it is connected"""
1158 try:
1159 connection = self._available_connections.pop()
1160 except IndexError:
1161 if len(self._in_use_connections) >= self.max_connections:
1162 raise ConnectionError("Too many connections") from None
1163 connection = self.make_connection()
1164 self._in_use_connections.add(connection)
1165 return connection
1166
1167 def get_encoder(self):
1168 """Return an encoder based on encoding settings"""
1169 kwargs = self.connection_kwargs
1170 return self.encoder_class(
1171 encoding=kwargs.get("encoding", "utf-8"),
1172 encoding_errors=kwargs.get("encoding_errors", "strict"),
1173 decode_responses=kwargs.get("decode_responses", False),
1174 )
1175
1176 def make_connection(self):
1177 """Create a new connection. Can be overridden by child classes."""
1178 return self.connection_class(**self.connection_kwargs)
1179
1180 async def ensure_connection(self, connection: AbstractConnection):
1181 """Ensure that the connection object is connected and valid"""
1182 await connection.connect()
1183 # connections that the pool provides should be ready to send
1184 # a command. if not, the connection was either returned to the
1185 # pool before all data has been read or the socket has been
1186 # closed. either way, reconnect and verify everything is good.
1187 try:
1188 if await connection.can_read_destructive():
1189 raise ConnectionError("Connection has data") from None
1190 except (ConnectionError, TimeoutError, OSError):
1191 await connection.disconnect()
1192 await connection.connect()
1193 if await connection.can_read_destructive():
1194 raise ConnectionError("Connection not ready") from None
1195
1196 async def release(self, connection: AbstractConnection):
1197 """Releases the connection back to the pool"""
1198 # Connections should always be returned to the correct pool,
1199 # not doing so is an error that will cause an exception here.
1200 self._in_use_connections.remove(connection)
1201 self._available_connections.append(connection)
1202 await self._event_dispatcher.dispatch_async(
1203 AsyncAfterConnectionReleasedEvent(connection)
1204 )
1205
1206 async def disconnect(self, inuse_connections: bool = True):
1207 """
1208 Disconnects connections in the pool
1209
1210 If ``inuse_connections`` is True, disconnect connections that are
1211 current in use, potentially by other tasks. Otherwise only disconnect
1212 connections that are idle in the pool.
1213 """
1214 if inuse_connections:
1215 connections: Iterable[AbstractConnection] = chain(
1216 self._available_connections, self._in_use_connections
1217 )
1218 else:
1219 connections = self._available_connections
1220 resp = await asyncio.gather(
1221 *(connection.disconnect() for connection in connections),
1222 return_exceptions=True,
1223 )
1224 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1225 if exc:
1226 raise exc
1227
1228 async def aclose(self) -> None:
1229 """Close the pool, disconnecting all connections"""
1230 await self.disconnect()
1231
1232 def set_retry(self, retry: "Retry") -> None:
1233 for conn in self._available_connections:
1234 conn.retry = retry
1235 for conn in self._in_use_connections:
1236 conn.retry = retry
1237
1238 async def re_auth_callback(self, token: TokenInterface):
1239 async with self._lock:
1240 for conn in self._available_connections:
1241 await conn.retry.call_with_retry(
1242 lambda: conn.send_command(
1243 "AUTH", token.try_get("oid"), token.get_value()
1244 ),
1245 lambda error: self._mock(error),
1246 )
1247 await conn.retry.call_with_retry(
1248 lambda: conn.read_response(), lambda error: self._mock(error)
1249 )
1250 for conn in self._in_use_connections:
1251 conn.set_re_auth_token(token)
1252
1253 async def _mock(self, error: RedisError):
1254 """
1255 Dummy functions, needs to be passed as error callback to retry object.
1256 :param error:
1257 :return:
1258 """
1259 pass
1260
1261
1262class BlockingConnectionPool(ConnectionPool):
1263 """
1264 A blocking connection pool::
1265
1266 >>> from redis.asyncio import Redis, BlockingConnectionPool
1267 >>> client = Redis.from_pool(BlockingConnectionPool())
1268
1269 It performs the same function as the default
1270 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1271 it maintains a pool of reusable connections that can be shared by
1272 multiple async redis clients.
1273
1274 The difference is that, in the event that a client tries to get a
1275 connection from the pool when all of connections are in use, rather than
1276 raising a :py:class:`~redis.ConnectionError` (as the default
1277 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1278 blocks the current `Task` for a specified number of seconds until
1279 a connection becomes available.
1280
1281 Use ``max_connections`` to increase / decrease the pool size::
1282
1283 >>> pool = BlockingConnectionPool(max_connections=10)
1284
1285 Use ``timeout`` to tell it either how many seconds to wait for a connection
1286 to become available, or to block forever:
1287
1288 >>> # Block forever.
1289 >>> pool = BlockingConnectionPool(timeout=None)
1290
1291 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1292 >>> # not available.
1293 >>> pool = BlockingConnectionPool(timeout=5)
1294 """
1295
1296 def __init__(
1297 self,
1298 max_connections: int = 50,
1299 timeout: Optional[int] = 20,
1300 connection_class: Type[AbstractConnection] = Connection,
1301 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1302 **connection_kwargs,
1303 ):
1304 super().__init__(
1305 connection_class=connection_class,
1306 max_connections=max_connections,
1307 **connection_kwargs,
1308 )
1309 self._condition = asyncio.Condition()
1310 self.timeout = timeout
1311
1312 @deprecated_args(
1313 args_to_warn=["*"],
1314 reason="Use get_connection() without args instead",
1315 version="5.3.0",
1316 )
1317 async def get_connection(self, command_name=None, *keys, **options):
1318 """Gets a connection from the pool, blocking until one is available"""
1319 try:
1320 async with self._condition:
1321 async with async_timeout(self.timeout):
1322 await self._condition.wait_for(self.can_get_connection)
1323 connection = super().get_available_connection()
1324 except asyncio.TimeoutError as err:
1325 raise ConnectionError("No connection available.") from err
1326
1327 # We now perform the connection check outside of the lock.
1328 try:
1329 await self.ensure_connection(connection)
1330 return connection
1331 except BaseException:
1332 await self.release(connection)
1333 raise
1334
1335 async def release(self, connection: AbstractConnection):
1336 """Releases the connection back to the pool."""
1337 async with self._condition:
1338 await super().release(connection)
1339 self._condition.notify()