1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import warnings
8import weakref
9from abc import abstractmethod
10from itertools import chain
11from types import MappingProxyType
12from typing import (
13 Any,
14 Callable,
15 Iterable,
16 List,
17 Mapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26)
27from urllib.parse import ParseResult, parse_qs, unquote, urlparse
28
29from ..utils import SSL_AVAILABLE
30
31if SSL_AVAILABLE:
32 import ssl
33 from ssl import SSLContext, TLSVersion
34else:
35 ssl = None
36 TLSVersion = None
37 SSLContext = None
38
39from ..auth.token import TokenInterface
40from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
41from ..utils import deprecated_args, format_error_message
42
43# the functionality is available in 3.11.x but has a major issue before
44# 3.11.3. See https://github.com/redis/redis-py/issues/2633
45if sys.version_info >= (3, 11, 3):
46 from asyncio import timeout as async_timeout
47else:
48 from async_timeout import timeout as async_timeout
49
50from redis.asyncio.retry import Retry
51from redis.backoff import NoBackoff
52from redis.connection import DEFAULT_RESP_VERSION
53from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
54from redis.exceptions import (
55 AuthenticationError,
56 AuthenticationWrongNumberOfArgsError,
57 ConnectionError,
58 DataError,
59 RedisError,
60 ResponseError,
61 TimeoutError,
62)
63from redis.typing import EncodableT
64from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
65
66from .._parsers import (
67 BaseParser,
68 Encoder,
69 _AsyncHiredisParser,
70 _AsyncRESP2Parser,
71 _AsyncRESP3Parser,
72)
73
74SYM_STAR = b"*"
75SYM_DOLLAR = b"$"
76SYM_CRLF = b"\r\n"
77SYM_LF = b"\n"
78SYM_EMPTY = b""
79
80
81class _Sentinel(enum.Enum):
82 sentinel = object()
83
84
85SENTINEL = _Sentinel.sentinel
86
87
88DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
89if HIREDIS_AVAILABLE:
90 DefaultParser = _AsyncHiredisParser
91else:
92 DefaultParser = _AsyncRESP2Parser
93
94
95class ConnectCallbackProtocol(Protocol):
96 def __call__(self, connection: "AbstractConnection"): ...
97
98
99class AsyncConnectCallbackProtocol(Protocol):
100 async def __call__(self, connection: "AbstractConnection"): ...
101
102
103ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
104
105
106class AbstractConnection:
107 """Manages communication to and from a Redis server"""
108
109 __slots__ = (
110 "db",
111 "username",
112 "client_name",
113 "lib_name",
114 "lib_version",
115 "credential_provider",
116 "password",
117 "socket_timeout",
118 "socket_connect_timeout",
119 "redis_connect_func",
120 "retry_on_timeout",
121 "retry_on_error",
122 "health_check_interval",
123 "next_health_check",
124 "last_active_at",
125 "encoder",
126 "ssl_context",
127 "protocol",
128 "_reader",
129 "_writer",
130 "_parser",
131 "_connect_callbacks",
132 "_buffer_cutoff",
133 "_lock",
134 "_socket_read_size",
135 "__dict__",
136 )
137
138 def __init__(
139 self,
140 *,
141 db: Union[str, int] = 0,
142 password: Optional[str] = None,
143 socket_timeout: Optional[float] = None,
144 socket_connect_timeout: Optional[float] = None,
145 retry_on_timeout: bool = False,
146 retry_on_error: Union[list, _Sentinel] = SENTINEL,
147 encoding: str = "utf-8",
148 encoding_errors: str = "strict",
149 decode_responses: bool = False,
150 parser_class: Type[BaseParser] = DefaultParser,
151 socket_read_size: int = 65536,
152 health_check_interval: float = 0,
153 client_name: Optional[str] = None,
154 lib_name: Optional[str] = "redis-py",
155 lib_version: Optional[str] = get_lib_version(),
156 username: Optional[str] = None,
157 retry: Optional[Retry] = None,
158 redis_connect_func: Optional[ConnectCallbackT] = None,
159 encoder_class: Type[Encoder] = Encoder,
160 credential_provider: Optional[CredentialProvider] = None,
161 protocol: Optional[int] = 2,
162 event_dispatcher: Optional[EventDispatcher] = None,
163 ):
164 if (username or password) and credential_provider is not None:
165 raise DataError(
166 "'username' and 'password' cannot be passed along with 'credential_"
167 "provider'. Please provide only one of the following arguments: \n"
168 "1. 'password' and (optional) 'username'\n"
169 "2. 'credential_provider'"
170 )
171 if event_dispatcher is None:
172 self._event_dispatcher = EventDispatcher()
173 else:
174 self._event_dispatcher = event_dispatcher
175 self.db = db
176 self.client_name = client_name
177 self.lib_name = lib_name
178 self.lib_version = lib_version
179 self.credential_provider = credential_provider
180 self.password = password
181 self.username = username
182 self.socket_timeout = socket_timeout
183 if socket_connect_timeout is None:
184 socket_connect_timeout = socket_timeout
185 self.socket_connect_timeout = socket_connect_timeout
186 self.retry_on_timeout = retry_on_timeout
187 if retry_on_error is SENTINEL:
188 retry_on_error = []
189 if retry_on_timeout:
190 retry_on_error.append(TimeoutError)
191 retry_on_error.append(socket.timeout)
192 retry_on_error.append(asyncio.TimeoutError)
193 self.retry_on_error = retry_on_error
194 if retry or retry_on_error:
195 if not retry:
196 self.retry = Retry(NoBackoff(), 1)
197 else:
198 # deep-copy the Retry object as it is mutable
199 self.retry = copy.deepcopy(retry)
200 # Update the retry's supported errors with the specified errors
201 self.retry.update_supported_errors(retry_on_error)
202 else:
203 self.retry = Retry(NoBackoff(), 0)
204 self.health_check_interval = health_check_interval
205 self.next_health_check: float = -1
206 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
207 self.redis_connect_func = redis_connect_func
208 self._reader: Optional[asyncio.StreamReader] = None
209 self._writer: Optional[asyncio.StreamWriter] = None
210 self._socket_read_size = socket_read_size
211 self.set_parser(parser_class)
212 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
213 self._buffer_cutoff = 6000
214 self._re_auth_token: Optional[TokenInterface] = None
215
216 try:
217 p = int(protocol)
218 except TypeError:
219 p = DEFAULT_RESP_VERSION
220 except ValueError:
221 raise ConnectionError("protocol must be an integer")
222 finally:
223 if p < 2 or p > 3:
224 raise ConnectionError("protocol must be either 2 or 3")
225 self.protocol = protocol
226
227 def __del__(self, _warnings: Any = warnings):
228 # For some reason, the individual streams don't get properly garbage
229 # collected and therefore produce no resource warnings. We add one
230 # here, in the same style as those from the stdlib.
231 if getattr(self, "_writer", None):
232 _warnings.warn(
233 f"unclosed Connection {self!r}", ResourceWarning, source=self
234 )
235
236 try:
237 asyncio.get_running_loop()
238 self._close()
239 except RuntimeError:
240 # No actions been taken if pool already closed.
241 pass
242
243 def _close(self):
244 """
245 Internal method to silently close the connection without waiting
246 """
247 if self._writer:
248 self._writer.close()
249 self._writer = self._reader = None
250
251 def __repr__(self):
252 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
253 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
254
255 @abstractmethod
256 def repr_pieces(self):
257 pass
258
259 @property
260 def is_connected(self):
261 return self._reader is not None and self._writer is not None
262
263 def register_connect_callback(self, callback):
264 """
265 Register a callback to be called when the connection is established either
266 initially or reconnected. This allows listeners to issue commands that
267 are ephemeral to the connection, for example pub/sub subscription or
268 key tracking. The callback must be a _method_ and will be kept as
269 a weak reference.
270 """
271 wm = weakref.WeakMethod(callback)
272 if wm not in self._connect_callbacks:
273 self._connect_callbacks.append(wm)
274
275 def deregister_connect_callback(self, callback):
276 """
277 De-register a previously registered callback. It will no-longer receive
278 notifications on connection events. Calling this is not required when the
279 listener goes away, since the callbacks are kept as weak methods.
280 """
281 try:
282 self._connect_callbacks.remove(weakref.WeakMethod(callback))
283 except ValueError:
284 pass
285
286 def set_parser(self, parser_class: Type[BaseParser]) -> None:
287 """
288 Creates a new instance of parser_class with socket size:
289 _socket_read_size and assigns it to the parser for the connection
290 :param parser_class: The required parser class
291 """
292 self._parser = parser_class(socket_read_size=self._socket_read_size)
293
294 async def connect(self):
295 """Connects to the Redis server if not already connected"""
296 await self.connect_check_health(check_health=True)
297
298 async def connect_check_health(self, check_health: bool = True):
299 if self.is_connected:
300 return
301 try:
302 await self.retry.call_with_retry(
303 lambda: self._connect(), lambda error: self.disconnect()
304 )
305 except asyncio.CancelledError:
306 raise # in 3.7 and earlier, this is an Exception, not BaseException
307 except (socket.timeout, asyncio.TimeoutError):
308 raise TimeoutError("Timeout connecting to server")
309 except OSError as e:
310 raise ConnectionError(self._error_message(e))
311 except Exception as exc:
312 raise ConnectionError(exc) from exc
313
314 try:
315 if not self.redis_connect_func:
316 # Use the default on_connect function
317 await self.on_connect_check_health(check_health=check_health)
318 else:
319 # Use the passed function redis_connect_func
320 (
321 await self.redis_connect_func(self)
322 if asyncio.iscoroutinefunction(self.redis_connect_func)
323 else self.redis_connect_func(self)
324 )
325 except RedisError:
326 # clean up after any error in on_connect
327 await self.disconnect()
328 raise
329
330 # run any user callbacks. right now the only internal callback
331 # is for pubsub channel/pattern resubscription
332 # first, remove any dead weakrefs
333 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
334 for ref in self._connect_callbacks:
335 callback = ref()
336 task = callback(self)
337 if task and inspect.isawaitable(task):
338 await task
339
340 @abstractmethod
341 async def _connect(self):
342 pass
343
344 @abstractmethod
345 def _host_error(self) -> str:
346 pass
347
348 def _error_message(self, exception: BaseException) -> str:
349 return format_error_message(self._host_error(), exception)
350
351 def get_protocol(self):
352 return self.protocol
353
354 async def on_connect(self) -> None:
355 """Initialize the connection, authenticate and select a database"""
356 await self.on_connect_check_health(check_health=True)
357
358 async def on_connect_check_health(self, check_health: bool = True) -> None:
359 self._parser.on_connect(self)
360 parser = self._parser
361
362 auth_args = None
363 # if credential provider or username and/or password are set, authenticate
364 if self.credential_provider or (self.username or self.password):
365 cred_provider = (
366 self.credential_provider
367 or UsernamePasswordCredentialProvider(self.username, self.password)
368 )
369 auth_args = await cred_provider.get_credentials_async()
370
371 # if resp version is specified and we have auth args,
372 # we need to send them via HELLO
373 if auth_args and self.protocol not in [2, "2"]:
374 if isinstance(self._parser, _AsyncRESP2Parser):
375 self.set_parser(_AsyncRESP3Parser)
376 # update cluster exception classes
377 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
378 self._parser.on_connect(self)
379 if len(auth_args) == 1:
380 auth_args = ["default", auth_args[0]]
381 # avoid checking health here -- PING will fail if we try
382 # to check the health prior to the AUTH
383 await self.send_command(
384 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
385 )
386 response = await self.read_response()
387 if response.get(b"proto") != int(self.protocol) and response.get(
388 "proto"
389 ) != int(self.protocol):
390 raise ConnectionError("Invalid RESP version")
391 # avoid checking health here -- PING will fail if we try
392 # to check the health prior to the AUTH
393 elif auth_args:
394 await self.send_command("AUTH", *auth_args, check_health=False)
395
396 try:
397 auth_response = await self.read_response()
398 except AuthenticationWrongNumberOfArgsError:
399 # a username and password were specified but the Redis
400 # server seems to be < 6.0.0 which expects a single password
401 # arg. retry auth with just the password.
402 # https://github.com/andymccurdy/redis-py/issues/1274
403 await self.send_command("AUTH", auth_args[-1], check_health=False)
404 auth_response = await self.read_response()
405
406 if str_if_bytes(auth_response) != "OK":
407 raise AuthenticationError("Invalid Username or Password")
408
409 # if resp version is specified, switch to it
410 elif self.protocol not in [2, "2"]:
411 if isinstance(self._parser, _AsyncRESP2Parser):
412 self.set_parser(_AsyncRESP3Parser)
413 # update cluster exception classes
414 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
415 self._parser.on_connect(self)
416 await self.send_command("HELLO", self.protocol, check_health=check_health)
417 response = await self.read_response()
418 # if response.get(b"proto") != self.protocol and response.get(
419 # "proto"
420 # ) != self.protocol:
421 # raise ConnectionError("Invalid RESP version")
422
423 # if a client_name is given, set it
424 if self.client_name:
425 await self.send_command(
426 "CLIENT",
427 "SETNAME",
428 self.client_name,
429 check_health=check_health,
430 )
431 if str_if_bytes(await self.read_response()) != "OK":
432 raise ConnectionError("Error setting client name")
433
434 # set the library name and version, pipeline for lower startup latency
435 if self.lib_name:
436 await self.send_command(
437 "CLIENT",
438 "SETINFO",
439 "LIB-NAME",
440 self.lib_name,
441 check_health=check_health,
442 )
443 if self.lib_version:
444 await self.send_command(
445 "CLIENT",
446 "SETINFO",
447 "LIB-VER",
448 self.lib_version,
449 check_health=check_health,
450 )
451 # if a database is specified, switch to it. Also pipeline this
452 if self.db:
453 await self.send_command("SELECT", self.db, check_health=check_health)
454
455 # read responses from pipeline
456 for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
457 try:
458 await self.read_response()
459 except ResponseError:
460 pass
461
462 if self.db:
463 if str_if_bytes(await self.read_response()) != "OK":
464 raise ConnectionError("Invalid Database")
465
466 async def disconnect(self, nowait: bool = False) -> None:
467 """Disconnects from the Redis server"""
468 try:
469 async with async_timeout(self.socket_connect_timeout):
470 self._parser.on_disconnect()
471 if not self.is_connected:
472 return
473 try:
474 self._writer.close() # type: ignore[union-attr]
475 # wait for close to finish, except when handling errors and
476 # forcefully disconnecting.
477 if not nowait:
478 await self._writer.wait_closed() # type: ignore[union-attr]
479 except OSError:
480 pass
481 finally:
482 self._reader = None
483 self._writer = None
484 except asyncio.TimeoutError:
485 raise TimeoutError(
486 f"Timed out closing connection after {self.socket_connect_timeout}"
487 ) from None
488
489 async def _send_ping(self):
490 """Send PING, expect PONG in return"""
491 await self.send_command("PING", check_health=False)
492 if str_if_bytes(await self.read_response()) != "PONG":
493 raise ConnectionError("Bad response from PING health check")
494
495 async def _ping_failed(self, error):
496 """Function to call when PING fails"""
497 await self.disconnect()
498
499 async def check_health(self):
500 """Check the health of the connection with a PING/PONG"""
501 if (
502 self.health_check_interval
503 and asyncio.get_running_loop().time() > self.next_health_check
504 ):
505 await self.retry.call_with_retry(self._send_ping, self._ping_failed)
506
507 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
508 self._writer.writelines(command)
509 await self._writer.drain()
510
511 async def send_packed_command(
512 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
513 ) -> None:
514 if not self.is_connected:
515 await self.connect_check_health(check_health=False)
516 if check_health:
517 await self.check_health()
518
519 try:
520 if isinstance(command, str):
521 command = command.encode()
522 if isinstance(command, bytes):
523 command = [command]
524 if self.socket_timeout:
525 await asyncio.wait_for(
526 self._send_packed_command(command), self.socket_timeout
527 )
528 else:
529 self._writer.writelines(command)
530 await self._writer.drain()
531 except asyncio.TimeoutError:
532 await self.disconnect(nowait=True)
533 raise TimeoutError("Timeout writing to socket") from None
534 except OSError as e:
535 await self.disconnect(nowait=True)
536 if len(e.args) == 1:
537 err_no, errmsg = "UNKNOWN", e.args[0]
538 else:
539 err_no = e.args[0]
540 errmsg = e.args[1]
541 raise ConnectionError(
542 f"Error {err_no} while writing to socket. {errmsg}."
543 ) from e
544 except BaseException:
545 # BaseExceptions can be raised when a socket send operation is not
546 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
547 # to send un-sent data. However, the send_packed_command() API
548 # does not support it so there is no point in keeping the connection open.
549 await self.disconnect(nowait=True)
550 raise
551
552 async def send_command(self, *args: Any, **kwargs: Any) -> None:
553 """Pack and send a command to the Redis server"""
554 await self.send_packed_command(
555 self.pack_command(*args), check_health=kwargs.get("check_health", True)
556 )
557
558 async def can_read_destructive(self):
559 """Poll the socket to see if there's data that can be read."""
560 try:
561 return await self._parser.can_read_destructive()
562 except OSError as e:
563 await self.disconnect(nowait=True)
564 host_error = self._host_error()
565 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
566
567 async def read_response(
568 self,
569 disable_decoding: bool = False,
570 timeout: Optional[float] = None,
571 *,
572 disconnect_on_error: bool = True,
573 push_request: Optional[bool] = False,
574 ):
575 """Read the response from a previously sent command"""
576 read_timeout = timeout if timeout is not None else self.socket_timeout
577 host_error = self._host_error()
578 try:
579 if read_timeout is not None and self.protocol in ["3", 3]:
580 async with async_timeout(read_timeout):
581 response = await self._parser.read_response(
582 disable_decoding=disable_decoding, push_request=push_request
583 )
584 elif read_timeout is not None:
585 async with async_timeout(read_timeout):
586 response = await self._parser.read_response(
587 disable_decoding=disable_decoding
588 )
589 elif self.protocol in ["3", 3]:
590 response = await self._parser.read_response(
591 disable_decoding=disable_decoding, push_request=push_request
592 )
593 else:
594 response = await self._parser.read_response(
595 disable_decoding=disable_decoding
596 )
597 except asyncio.TimeoutError:
598 if timeout is not None:
599 # user requested timeout, return None. Operation can be retried
600 return None
601 # it was a self.socket_timeout error.
602 if disconnect_on_error:
603 await self.disconnect(nowait=True)
604 raise TimeoutError(f"Timeout reading from {host_error}")
605 except OSError as e:
606 if disconnect_on_error:
607 await self.disconnect(nowait=True)
608 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
609 except BaseException:
610 # Also by default close in case of BaseException. A lot of code
611 # relies on this behaviour when doing Command/Response pairs.
612 # See #1128.
613 if disconnect_on_error:
614 await self.disconnect(nowait=True)
615 raise
616
617 if self.health_check_interval:
618 next_time = asyncio.get_running_loop().time() + self.health_check_interval
619 self.next_health_check = next_time
620
621 if isinstance(response, ResponseError):
622 raise response from None
623 return response
624
625 def pack_command(self, *args: EncodableT) -> List[bytes]:
626 """Pack a series of arguments into the Redis protocol"""
627 output = []
628 # the client might have included 1 or more literal arguments in
629 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
630 # arguments to be sent separately, so split the first argument
631 # manually. These arguments should be bytestrings so that they are
632 # not encoded.
633 assert not isinstance(args[0], float)
634 if isinstance(args[0], str):
635 args = tuple(args[0].encode().split()) + args[1:]
636 elif b" " in args[0]:
637 args = tuple(args[0].split()) + args[1:]
638
639 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
640
641 buffer_cutoff = self._buffer_cutoff
642 for arg in map(self.encoder.encode, args):
643 # to avoid large string mallocs, chunk the command into the
644 # output list if we're sending large values or memoryviews
645 arg_length = len(arg)
646 if (
647 len(buff) > buffer_cutoff
648 or arg_length > buffer_cutoff
649 or isinstance(arg, memoryview)
650 ):
651 buff = SYM_EMPTY.join(
652 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
653 )
654 output.append(buff)
655 output.append(arg)
656 buff = SYM_CRLF
657 else:
658 buff = SYM_EMPTY.join(
659 (
660 buff,
661 SYM_DOLLAR,
662 str(arg_length).encode(),
663 SYM_CRLF,
664 arg,
665 SYM_CRLF,
666 )
667 )
668 output.append(buff)
669 return output
670
671 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
672 """Pack multiple commands into the Redis protocol"""
673 output: List[bytes] = []
674 pieces: List[bytes] = []
675 buffer_length = 0
676 buffer_cutoff = self._buffer_cutoff
677
678 for cmd in commands:
679 for chunk in self.pack_command(*cmd):
680 chunklen = len(chunk)
681 if (
682 buffer_length > buffer_cutoff
683 or chunklen > buffer_cutoff
684 or isinstance(chunk, memoryview)
685 ):
686 if pieces:
687 output.append(SYM_EMPTY.join(pieces))
688 buffer_length = 0
689 pieces = []
690
691 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
692 output.append(chunk)
693 else:
694 pieces.append(chunk)
695 buffer_length += chunklen
696
697 if pieces:
698 output.append(SYM_EMPTY.join(pieces))
699 return output
700
701 def _socket_is_empty(self):
702 """Check if the socket is empty"""
703 return len(self._reader._buffer) == 0
704
705 async def process_invalidation_messages(self):
706 while not self._socket_is_empty():
707 await self.read_response(push_request=True)
708
709 def set_re_auth_token(self, token: TokenInterface):
710 self._re_auth_token = token
711
712 async def re_auth(self):
713 if self._re_auth_token is not None:
714 await self.send_command(
715 "AUTH",
716 self._re_auth_token.try_get("oid"),
717 self._re_auth_token.get_value(),
718 )
719 await self.read_response()
720 self._re_auth_token = None
721
722
723class Connection(AbstractConnection):
724 "Manages TCP communication to and from a Redis server"
725
726 def __init__(
727 self,
728 *,
729 host: str = "localhost",
730 port: Union[str, int] = 6379,
731 socket_keepalive: bool = False,
732 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
733 socket_type: int = 0,
734 **kwargs,
735 ):
736 self.host = host
737 self.port = int(port)
738 self.socket_keepalive = socket_keepalive
739 self.socket_keepalive_options = socket_keepalive_options or {}
740 self.socket_type = socket_type
741 super().__init__(**kwargs)
742
743 def repr_pieces(self):
744 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
745 if self.client_name:
746 pieces.append(("client_name", self.client_name))
747 return pieces
748
749 def _connection_arguments(self) -> Mapping:
750 return {"host": self.host, "port": self.port}
751
752 async def _connect(self):
753 """Create a TCP socket connection"""
754 async with async_timeout(self.socket_connect_timeout):
755 reader, writer = await asyncio.open_connection(
756 **self._connection_arguments()
757 )
758 self._reader = reader
759 self._writer = writer
760 sock = writer.transport.get_extra_info("socket")
761 if sock:
762 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
763 try:
764 # TCP_KEEPALIVE
765 if self.socket_keepalive:
766 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
767 for k, v in self.socket_keepalive_options.items():
768 sock.setsockopt(socket.SOL_TCP, k, v)
769
770 except (OSError, TypeError):
771 # `socket_keepalive_options` might contain invalid options
772 # causing an error. Do not leave the connection open.
773 writer.close()
774 raise
775
776 def _host_error(self) -> str:
777 return f"{self.host}:{self.port}"
778
779
780class SSLConnection(Connection):
781 """Manages SSL connections to and from the Redis server(s).
782 This class extends the Connection class, adding SSL functionality, and making
783 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
784 """
785
786 def __init__(
787 self,
788 ssl_keyfile: Optional[str] = None,
789 ssl_certfile: Optional[str] = None,
790 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
791 ssl_ca_certs: Optional[str] = None,
792 ssl_ca_data: Optional[str] = None,
793 ssl_check_hostname: bool = True,
794 ssl_min_version: Optional[TLSVersion] = None,
795 ssl_ciphers: Optional[str] = None,
796 **kwargs,
797 ):
798 if not SSL_AVAILABLE:
799 raise RedisError("Python wasn't built with SSL support")
800
801 self.ssl_context: RedisSSLContext = RedisSSLContext(
802 keyfile=ssl_keyfile,
803 certfile=ssl_certfile,
804 cert_reqs=ssl_cert_reqs,
805 ca_certs=ssl_ca_certs,
806 ca_data=ssl_ca_data,
807 check_hostname=ssl_check_hostname,
808 min_version=ssl_min_version,
809 ciphers=ssl_ciphers,
810 )
811 super().__init__(**kwargs)
812
813 def _connection_arguments(self) -> Mapping:
814 kwargs = super()._connection_arguments()
815 kwargs["ssl"] = self.ssl_context.get()
816 return kwargs
817
818 @property
819 def keyfile(self):
820 return self.ssl_context.keyfile
821
822 @property
823 def certfile(self):
824 return self.ssl_context.certfile
825
826 @property
827 def cert_reqs(self):
828 return self.ssl_context.cert_reqs
829
830 @property
831 def ca_certs(self):
832 return self.ssl_context.ca_certs
833
834 @property
835 def ca_data(self):
836 return self.ssl_context.ca_data
837
838 @property
839 def check_hostname(self):
840 return self.ssl_context.check_hostname
841
842 @property
843 def min_version(self):
844 return self.ssl_context.min_version
845
846
847class RedisSSLContext:
848 __slots__ = (
849 "keyfile",
850 "certfile",
851 "cert_reqs",
852 "ca_certs",
853 "ca_data",
854 "context",
855 "check_hostname",
856 "min_version",
857 "ciphers",
858 )
859
860 def __init__(
861 self,
862 keyfile: Optional[str] = None,
863 certfile: Optional[str] = None,
864 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
865 ca_certs: Optional[str] = None,
866 ca_data: Optional[str] = None,
867 check_hostname: bool = False,
868 min_version: Optional[TLSVersion] = None,
869 ciphers: Optional[str] = None,
870 ):
871 if not SSL_AVAILABLE:
872 raise RedisError("Python wasn't built with SSL support")
873
874 self.keyfile = keyfile
875 self.certfile = certfile
876 if cert_reqs is None:
877 cert_reqs = ssl.CERT_NONE
878 elif isinstance(cert_reqs, str):
879 CERT_REQS = { # noqa: N806
880 "none": ssl.CERT_NONE,
881 "optional": ssl.CERT_OPTIONAL,
882 "required": ssl.CERT_REQUIRED,
883 }
884 if cert_reqs not in CERT_REQS:
885 raise RedisError(
886 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
887 )
888 cert_reqs = CERT_REQS[cert_reqs]
889 self.cert_reqs = cert_reqs
890 self.ca_certs = ca_certs
891 self.ca_data = ca_data
892 self.check_hostname = (
893 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
894 )
895 self.min_version = min_version
896 self.ciphers = ciphers
897 self.context: Optional[SSLContext] = None
898
899 def get(self) -> SSLContext:
900 if not self.context:
901 context = ssl.create_default_context()
902 context.check_hostname = self.check_hostname
903 context.verify_mode = self.cert_reqs
904 if self.certfile and self.keyfile:
905 context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
906 if self.ca_certs or self.ca_data:
907 context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
908 if self.min_version is not None:
909 context.minimum_version = self.min_version
910 if self.ciphers is not None:
911 context.set_ciphers(self.ciphers)
912 self.context = context
913 return self.context
914
915
916class UnixDomainSocketConnection(AbstractConnection):
917 "Manages UDS communication to and from a Redis server"
918
919 def __init__(self, *, path: str = "", **kwargs):
920 self.path = path
921 super().__init__(**kwargs)
922
923 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
924 pieces = [("path", self.path), ("db", self.db)]
925 if self.client_name:
926 pieces.append(("client_name", self.client_name))
927 return pieces
928
929 async def _connect(self):
930 async with async_timeout(self.socket_connect_timeout):
931 reader, writer = await asyncio.open_unix_connection(path=self.path)
932 self._reader = reader
933 self._writer = writer
934 await self.on_connect()
935
936 def _host_error(self) -> str:
937 return self.path
938
939
940FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
941
942
943def to_bool(value) -> Optional[bool]:
944 if value is None or value == "":
945 return None
946 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
947 return False
948 return bool(value)
949
950
951URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
952 {
953 "db": int,
954 "socket_timeout": float,
955 "socket_connect_timeout": float,
956 "socket_keepalive": to_bool,
957 "retry_on_timeout": to_bool,
958 "max_connections": int,
959 "health_check_interval": int,
960 "ssl_check_hostname": to_bool,
961 "timeout": float,
962 }
963)
964
965
966class ConnectKwargs(TypedDict, total=False):
967 username: str
968 password: str
969 connection_class: Type[AbstractConnection]
970 host: str
971 port: int
972 db: int
973 path: str
974
975
976def parse_url(url: str) -> ConnectKwargs:
977 parsed: ParseResult = urlparse(url)
978 kwargs: ConnectKwargs = {}
979
980 for name, value_list in parse_qs(parsed.query).items():
981 if value_list and len(value_list) > 0:
982 value = unquote(value_list[0])
983 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
984 if parser:
985 try:
986 kwargs[name] = parser(value)
987 except (TypeError, ValueError):
988 raise ValueError(f"Invalid value for '{name}' in connection URL.")
989 else:
990 kwargs[name] = value
991
992 if parsed.username:
993 kwargs["username"] = unquote(parsed.username)
994 if parsed.password:
995 kwargs["password"] = unquote(parsed.password)
996
997 # We only support redis://, rediss:// and unix:// schemes.
998 if parsed.scheme == "unix":
999 if parsed.path:
1000 kwargs["path"] = unquote(parsed.path)
1001 kwargs["connection_class"] = UnixDomainSocketConnection
1002
1003 elif parsed.scheme in ("redis", "rediss"):
1004 if parsed.hostname:
1005 kwargs["host"] = unquote(parsed.hostname)
1006 if parsed.port:
1007 kwargs["port"] = int(parsed.port)
1008
1009 # If there's a path argument, use it as the db argument if a
1010 # querystring value wasn't specified
1011 if parsed.path and "db" not in kwargs:
1012 try:
1013 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1014 except (AttributeError, ValueError):
1015 pass
1016
1017 if parsed.scheme == "rediss":
1018 kwargs["connection_class"] = SSLConnection
1019 else:
1020 valid_schemes = "redis://, rediss://, unix://"
1021 raise ValueError(
1022 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1023 )
1024
1025 return kwargs
1026
1027
1028_CP = TypeVar("_CP", bound="ConnectionPool")
1029
1030
1031class ConnectionPool:
1032 """
1033 Create a connection pool. ``If max_connections`` is set, then this
1034 object raises :py:class:`~redis.ConnectionError` when the pool's
1035 limit is reached.
1036
1037 By default, TCP connections are created unless ``connection_class``
1038 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1039 unix sockets.
1040
1041 Any additional keyword arguments are passed to the constructor of
1042 ``connection_class``.
1043 """
1044
1045 @classmethod
1046 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1047 """
1048 Return a connection pool configured from the given URL.
1049
1050 For example::
1051
1052 redis://[[username]:[password]]@localhost:6379/0
1053 rediss://[[username]:[password]]@localhost:6379/0
1054 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1055
1056 Three URL schemes are supported:
1057
1058 - `redis://` creates a TCP socket connection. See more at:
1059 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1060 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1061 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1062 - ``unix://``: creates a Unix Domain Socket connection.
1063
1064 The username, password, hostname, path and all querystring values
1065 are passed through urllib.parse.unquote in order to replace any
1066 percent-encoded values with their corresponding characters.
1067
1068 There are several ways to specify a database number. The first value
1069 found will be used:
1070
1071 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1072
1073 2. If using the redis:// or rediss:// schemes, the path argument
1074 of the url, e.g. redis://localhost/0
1075
1076 3. A ``db`` keyword argument to this function.
1077
1078 If none of these options are specified, the default db=0 is used.
1079
1080 All querystring options are cast to their appropriate Python types.
1081 Boolean arguments can be specified with string values "True"/"False"
1082 or "Yes"/"No". Values that cannot be properly cast cause a
1083 ``ValueError`` to be raised. Once parsed, the querystring arguments
1084 and keyword arguments are passed to the ``ConnectionPool``'s
1085 class initializer. In the case of conflicting arguments, querystring
1086 arguments always win.
1087 """
1088 url_options = parse_url(url)
1089 kwargs.update(url_options)
1090 return cls(**kwargs)
1091
1092 def __init__(
1093 self,
1094 connection_class: Type[AbstractConnection] = Connection,
1095 max_connections: Optional[int] = None,
1096 **connection_kwargs,
1097 ):
1098 max_connections = max_connections or 2**31
1099 if not isinstance(max_connections, int) or max_connections < 0:
1100 raise ValueError('"max_connections" must be a positive integer')
1101
1102 self.connection_class = connection_class
1103 self.connection_kwargs = connection_kwargs
1104 self.max_connections = max_connections
1105
1106 self._available_connections: List[AbstractConnection] = []
1107 self._in_use_connections: Set[AbstractConnection] = set()
1108 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1109 self._lock = asyncio.Lock()
1110 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1111 if self._event_dispatcher is None:
1112 self._event_dispatcher = EventDispatcher()
1113
1114 def __repr__(self):
1115 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1116 return (
1117 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1118 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1119 f"({conn_kwargs})>)>"
1120 )
1121
1122 def reset(self):
1123 self._available_connections = []
1124 self._in_use_connections = weakref.WeakSet()
1125
1126 def can_get_connection(self) -> bool:
1127 """Return True if a connection can be retrieved from the pool."""
1128 return (
1129 self._available_connections
1130 or len(self._in_use_connections) < self.max_connections
1131 )
1132
1133 @deprecated_args(
1134 args_to_warn=["*"],
1135 reason="Use get_connection() without args instead",
1136 version="5.3.0",
1137 )
1138 async def get_connection(self, command_name=None, *keys, **options):
1139 async with self._lock:
1140 """Get a connected connection from the pool"""
1141 connection = self.get_available_connection()
1142 try:
1143 await self.ensure_connection(connection)
1144 except BaseException:
1145 await self.release(connection)
1146 raise
1147
1148 return connection
1149
1150 def get_available_connection(self):
1151 """Get a connection from the pool, without making sure it is connected"""
1152 try:
1153 connection = self._available_connections.pop()
1154 except IndexError:
1155 if len(self._in_use_connections) >= self.max_connections:
1156 raise ConnectionError("Too many connections") from None
1157 connection = self.make_connection()
1158 self._in_use_connections.add(connection)
1159 return connection
1160
1161 def get_encoder(self):
1162 """Return an encoder based on encoding settings"""
1163 kwargs = self.connection_kwargs
1164 return self.encoder_class(
1165 encoding=kwargs.get("encoding", "utf-8"),
1166 encoding_errors=kwargs.get("encoding_errors", "strict"),
1167 decode_responses=kwargs.get("decode_responses", False),
1168 )
1169
1170 def make_connection(self):
1171 """Create a new connection. Can be overridden by child classes."""
1172 return self.connection_class(**self.connection_kwargs)
1173
1174 async def ensure_connection(self, connection: AbstractConnection):
1175 """Ensure that the connection object is connected and valid"""
1176 await connection.connect()
1177 # connections that the pool provides should be ready to send
1178 # a command. if not, the connection was either returned to the
1179 # pool before all data has been read or the socket has been
1180 # closed. either way, reconnect and verify everything is good.
1181 try:
1182 if await connection.can_read_destructive():
1183 raise ConnectionError("Connection has data") from None
1184 except (ConnectionError, TimeoutError, OSError):
1185 await connection.disconnect()
1186 await connection.connect()
1187 if await connection.can_read_destructive():
1188 raise ConnectionError("Connection not ready") from None
1189
1190 async def release(self, connection: AbstractConnection):
1191 """Releases the connection back to the pool"""
1192 # Connections should always be returned to the correct pool,
1193 # not doing so is an error that will cause an exception here.
1194 self._in_use_connections.remove(connection)
1195 self._available_connections.append(connection)
1196 await self._event_dispatcher.dispatch_async(
1197 AsyncAfterConnectionReleasedEvent(connection)
1198 )
1199
1200 async def disconnect(self, inuse_connections: bool = True):
1201 """
1202 Disconnects connections in the pool
1203
1204 If ``inuse_connections`` is True, disconnect connections that are
1205 current in use, potentially by other tasks. Otherwise only disconnect
1206 connections that are idle in the pool.
1207 """
1208 if inuse_connections:
1209 connections: Iterable[AbstractConnection] = chain(
1210 self._available_connections, self._in_use_connections
1211 )
1212 else:
1213 connections = self._available_connections
1214 resp = await asyncio.gather(
1215 *(connection.disconnect() for connection in connections),
1216 return_exceptions=True,
1217 )
1218 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1219 if exc:
1220 raise exc
1221
1222 async def aclose(self) -> None:
1223 """Close the pool, disconnecting all connections"""
1224 await self.disconnect()
1225
1226 def set_retry(self, retry: "Retry") -> None:
1227 for conn in self._available_connections:
1228 conn.retry = retry
1229 for conn in self._in_use_connections:
1230 conn.retry = retry
1231
1232 async def re_auth_callback(self, token: TokenInterface):
1233 async with self._lock:
1234 for conn in self._available_connections:
1235 await conn.retry.call_with_retry(
1236 lambda: conn.send_command(
1237 "AUTH", token.try_get("oid"), token.get_value()
1238 ),
1239 lambda error: self._mock(error),
1240 )
1241 await conn.retry.call_with_retry(
1242 lambda: conn.read_response(), lambda error: self._mock(error)
1243 )
1244 for conn in self._in_use_connections:
1245 conn.set_re_auth_token(token)
1246
1247 async def _mock(self, error: RedisError):
1248 """
1249 Dummy functions, needs to be passed as error callback to retry object.
1250 :param error:
1251 :return:
1252 """
1253 pass
1254
1255
1256class BlockingConnectionPool(ConnectionPool):
1257 """
1258 A blocking connection pool::
1259
1260 >>> from redis.asyncio import Redis, BlockingConnectionPool
1261 >>> client = Redis.from_pool(BlockingConnectionPool())
1262
1263 It performs the same function as the default
1264 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1265 it maintains a pool of reusable connections that can be shared by
1266 multiple async redis clients.
1267
1268 The difference is that, in the event that a client tries to get a
1269 connection from the pool when all of connections are in use, rather than
1270 raising a :py:class:`~redis.ConnectionError` (as the default
1271 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1272 blocks the current `Task` for a specified number of seconds until
1273 a connection becomes available.
1274
1275 Use ``max_connections`` to increase / decrease the pool size::
1276
1277 >>> pool = BlockingConnectionPool(max_connections=10)
1278
1279 Use ``timeout`` to tell it either how many seconds to wait for a connection
1280 to become available, or to block forever:
1281
1282 >>> # Block forever.
1283 >>> pool = BlockingConnectionPool(timeout=None)
1284
1285 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1286 >>> # not available.
1287 >>> pool = BlockingConnectionPool(timeout=5)
1288 """
1289
1290 def __init__(
1291 self,
1292 max_connections: int = 50,
1293 timeout: Optional[int] = 20,
1294 connection_class: Type[AbstractConnection] = Connection,
1295 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1296 **connection_kwargs,
1297 ):
1298 super().__init__(
1299 connection_class=connection_class,
1300 max_connections=max_connections,
1301 **connection_kwargs,
1302 )
1303 self._condition = asyncio.Condition()
1304 self.timeout = timeout
1305
1306 @deprecated_args(
1307 args_to_warn=["*"],
1308 reason="Use get_connection() without args instead",
1309 version="5.3.0",
1310 )
1311 async def get_connection(self, command_name=None, *keys, **options):
1312 """Gets a connection from the pool, blocking until one is available"""
1313 try:
1314 async with self._condition:
1315 async with async_timeout(self.timeout):
1316 await self._condition.wait_for(self.can_get_connection)
1317 connection = super().get_available_connection()
1318 except asyncio.TimeoutError as err:
1319 raise ConnectionError("No connection available.") from err
1320
1321 # We now perform the connection check outside of the lock.
1322 try:
1323 await self.ensure_connection(connection)
1324 return connection
1325 except BaseException:
1326 await self.release(connection)
1327 raise
1328
1329 async def release(self, connection: AbstractConnection):
1330 """Releases the connection back to the pool."""
1331 async with self._condition:
1332 await super().release(connection)
1333 self._condition.notify()