1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import warnings
8import weakref
9from abc import abstractmethod
10from itertools import chain
11from types import MappingProxyType
12from typing import (
13 Any,
14 Callable,
15 Iterable,
16 List,
17 Mapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26)
27from urllib.parse import ParseResult, parse_qs, unquote, urlparse
28
29from ..utils import SSL_AVAILABLE
30
31if SSL_AVAILABLE:
32 import ssl
33 from ssl import SSLContext, TLSVersion, VerifyFlags
34else:
35 ssl = None
36 TLSVersion = None
37 SSLContext = None
38 VerifyFlags = None
39
40from ..auth.token import TokenInterface
41from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
42from ..utils import deprecated_args, format_error_message
43
44# the functionality is available in 3.11.x but has a major issue before
45# 3.11.3. See https://github.com/redis/redis-py/issues/2633
46if sys.version_info >= (3, 11, 3):
47 from asyncio import timeout as async_timeout
48else:
49 from async_timeout import timeout as async_timeout
50
51from redis.asyncio.retry import Retry
52from redis.backoff import NoBackoff
53from redis.connection import DEFAULT_RESP_VERSION
54from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
55from redis.exceptions import (
56 AuthenticationError,
57 AuthenticationWrongNumberOfArgsError,
58 ConnectionError,
59 DataError,
60 RedisError,
61 ResponseError,
62 TimeoutError,
63)
64from redis.typing import EncodableT
65from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
66
67from .._parsers import (
68 BaseParser,
69 Encoder,
70 _AsyncHiredisParser,
71 _AsyncRESP2Parser,
72 _AsyncRESP3Parser,
73)
74
75SYM_STAR = b"*"
76SYM_DOLLAR = b"$"
77SYM_CRLF = b"\r\n"
78SYM_LF = b"\n"
79SYM_EMPTY = b""
80
81
82class _Sentinel(enum.Enum):
83 sentinel = object()
84
85
86SENTINEL = _Sentinel.sentinel
87
88
89DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
90if HIREDIS_AVAILABLE:
91 DefaultParser = _AsyncHiredisParser
92else:
93 DefaultParser = _AsyncRESP2Parser
94
95
96class ConnectCallbackProtocol(Protocol):
97 def __call__(self, connection: "AbstractConnection"): ...
98
99
100class AsyncConnectCallbackProtocol(Protocol):
101 async def __call__(self, connection: "AbstractConnection"): ...
102
103
104ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
105
106
107class AbstractConnection:
108 """Manages communication to and from a Redis server"""
109
110 __slots__ = (
111 "db",
112 "username",
113 "client_name",
114 "lib_name",
115 "lib_version",
116 "credential_provider",
117 "password",
118 "socket_timeout",
119 "socket_connect_timeout",
120 "redis_connect_func",
121 "retry_on_timeout",
122 "retry_on_error",
123 "health_check_interval",
124 "next_health_check",
125 "last_active_at",
126 "encoder",
127 "ssl_context",
128 "protocol",
129 "_reader",
130 "_writer",
131 "_parser",
132 "_connect_callbacks",
133 "_buffer_cutoff",
134 "_lock",
135 "_socket_read_size",
136 "__dict__",
137 )
138
139 def __init__(
140 self,
141 *,
142 db: Union[str, int] = 0,
143 password: Optional[str] = None,
144 socket_timeout: Optional[float] = None,
145 socket_connect_timeout: Optional[float] = None,
146 retry_on_timeout: bool = False,
147 retry_on_error: Union[list, _Sentinel] = SENTINEL,
148 encoding: str = "utf-8",
149 encoding_errors: str = "strict",
150 decode_responses: bool = False,
151 parser_class: Type[BaseParser] = DefaultParser,
152 socket_read_size: int = 65536,
153 health_check_interval: float = 0,
154 client_name: Optional[str] = None,
155 lib_name: Optional[str] = "redis-py",
156 lib_version: Optional[str] = get_lib_version(),
157 username: Optional[str] = None,
158 retry: Optional[Retry] = None,
159 redis_connect_func: Optional[ConnectCallbackT] = None,
160 encoder_class: Type[Encoder] = Encoder,
161 credential_provider: Optional[CredentialProvider] = None,
162 protocol: Optional[int] = 2,
163 event_dispatcher: Optional[EventDispatcher] = None,
164 ):
165 if (username or password) and credential_provider is not None:
166 raise DataError(
167 "'username' and 'password' cannot be passed along with 'credential_"
168 "provider'. Please provide only one of the following arguments: \n"
169 "1. 'password' and (optional) 'username'\n"
170 "2. 'credential_provider'"
171 )
172 if event_dispatcher is None:
173 self._event_dispatcher = EventDispatcher()
174 else:
175 self._event_dispatcher = event_dispatcher
176 self.db = db
177 self.client_name = client_name
178 self.lib_name = lib_name
179 self.lib_version = lib_version
180 self.credential_provider = credential_provider
181 self.password = password
182 self.username = username
183 self.socket_timeout = socket_timeout
184 if socket_connect_timeout is None:
185 socket_connect_timeout = socket_timeout
186 self.socket_connect_timeout = socket_connect_timeout
187 self.retry_on_timeout = retry_on_timeout
188 if retry_on_error is SENTINEL:
189 retry_on_error = []
190 if retry_on_timeout:
191 retry_on_error.append(TimeoutError)
192 retry_on_error.append(socket.timeout)
193 retry_on_error.append(asyncio.TimeoutError)
194 self.retry_on_error = retry_on_error
195 if retry or retry_on_error:
196 if not retry:
197 self.retry = Retry(NoBackoff(), 1)
198 else:
199 # deep-copy the Retry object as it is mutable
200 self.retry = copy.deepcopy(retry)
201 # Update the retry's supported errors with the specified errors
202 self.retry.update_supported_errors(retry_on_error)
203 else:
204 self.retry = Retry(NoBackoff(), 0)
205 self.health_check_interval = health_check_interval
206 self.next_health_check: float = -1
207 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
208 self.redis_connect_func = redis_connect_func
209 self._reader: Optional[asyncio.StreamReader] = None
210 self._writer: Optional[asyncio.StreamWriter] = None
211 self._socket_read_size = socket_read_size
212 self.set_parser(parser_class)
213 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
214 self._buffer_cutoff = 6000
215 self._re_auth_token: Optional[TokenInterface] = None
216 self._should_reconnect = False
217
218 try:
219 p = int(protocol)
220 except TypeError:
221 p = DEFAULT_RESP_VERSION
222 except ValueError:
223 raise ConnectionError("protocol must be an integer")
224 finally:
225 if p < 2 or p > 3:
226 raise ConnectionError("protocol must be either 2 or 3")
227 self.protocol = protocol
228
229 def __del__(self, _warnings: Any = warnings):
230 # For some reason, the individual streams don't get properly garbage
231 # collected and therefore produce no resource warnings. We add one
232 # here, in the same style as those from the stdlib.
233 if getattr(self, "_writer", None):
234 _warnings.warn(
235 f"unclosed Connection {self!r}", ResourceWarning, source=self
236 )
237
238 try:
239 asyncio.get_running_loop()
240 self._close()
241 except RuntimeError:
242 # No actions been taken if pool already closed.
243 pass
244
245 def _close(self):
246 """
247 Internal method to silently close the connection without waiting
248 """
249 if self._writer:
250 self._writer.close()
251 self._writer = self._reader = None
252
253 def __repr__(self):
254 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
255 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
256
257 @abstractmethod
258 def repr_pieces(self):
259 pass
260
261 @property
262 def is_connected(self):
263 return self._reader is not None and self._writer is not None
264
265 def register_connect_callback(self, callback):
266 """
267 Register a callback to be called when the connection is established either
268 initially or reconnected. This allows listeners to issue commands that
269 are ephemeral to the connection, for example pub/sub subscription or
270 key tracking. The callback must be a _method_ and will be kept as
271 a weak reference.
272 """
273 wm = weakref.WeakMethod(callback)
274 if wm not in self._connect_callbacks:
275 self._connect_callbacks.append(wm)
276
277 def deregister_connect_callback(self, callback):
278 """
279 De-register a previously registered callback. It will no-longer receive
280 notifications on connection events. Calling this is not required when the
281 listener goes away, since the callbacks are kept as weak methods.
282 """
283 try:
284 self._connect_callbacks.remove(weakref.WeakMethod(callback))
285 except ValueError:
286 pass
287
288 def set_parser(self, parser_class: Type[BaseParser]) -> None:
289 """
290 Creates a new instance of parser_class with socket size:
291 _socket_read_size and assigns it to the parser for the connection
292 :param parser_class: The required parser class
293 """
294 self._parser = parser_class(socket_read_size=self._socket_read_size)
295
296 async def connect(self):
297 """Connects to the Redis server if not already connected"""
298 await self.connect_check_health(check_health=True)
299
300 async def connect_check_health(
301 self, check_health: bool = True, retry_socket_connect: bool = True
302 ):
303 if self.is_connected:
304 return
305 try:
306 if retry_socket_connect:
307 await self.retry.call_with_retry(
308 lambda: self._connect(), lambda error: self.disconnect()
309 )
310 else:
311 await self._connect()
312 except asyncio.CancelledError:
313 raise # in 3.7 and earlier, this is an Exception, not BaseException
314 except (socket.timeout, asyncio.TimeoutError):
315 raise TimeoutError("Timeout connecting to server")
316 except OSError as e:
317 raise ConnectionError(self._error_message(e))
318 except Exception as exc:
319 raise ConnectionError(exc) from exc
320
321 try:
322 if not self.redis_connect_func:
323 # Use the default on_connect function
324 await self.on_connect_check_health(check_health=check_health)
325 else:
326 # Use the passed function redis_connect_func
327 (
328 await self.redis_connect_func(self)
329 if asyncio.iscoroutinefunction(self.redis_connect_func)
330 else self.redis_connect_func(self)
331 )
332 except RedisError:
333 # clean up after any error in on_connect
334 await self.disconnect()
335 raise
336
337 # run any user callbacks. right now the only internal callback
338 # is for pubsub channel/pattern resubscription
339 # first, remove any dead weakrefs
340 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
341 for ref in self._connect_callbacks:
342 callback = ref()
343 task = callback(self)
344 if task and inspect.isawaitable(task):
345 await task
346
347 def mark_for_reconnect(self):
348 self._should_reconnect = True
349
350 def should_reconnect(self):
351 return self._should_reconnect
352
353 @abstractmethod
354 async def _connect(self):
355 pass
356
357 @abstractmethod
358 def _host_error(self) -> str:
359 pass
360
361 def _error_message(self, exception: BaseException) -> str:
362 return format_error_message(self._host_error(), exception)
363
364 def get_protocol(self):
365 return self.protocol
366
367 async def on_connect(self) -> None:
368 """Initialize the connection, authenticate and select a database"""
369 await self.on_connect_check_health(check_health=True)
370
371 async def on_connect_check_health(self, check_health: bool = True) -> None:
372 self._parser.on_connect(self)
373 parser = self._parser
374
375 auth_args = None
376 # if credential provider or username and/or password are set, authenticate
377 if self.credential_provider or (self.username or self.password):
378 cred_provider = (
379 self.credential_provider
380 or UsernamePasswordCredentialProvider(self.username, self.password)
381 )
382 auth_args = await cred_provider.get_credentials_async()
383
384 # if resp version is specified and we have auth args,
385 # we need to send them via HELLO
386 if auth_args and self.protocol not in [2, "2"]:
387 if isinstance(self._parser, _AsyncRESP2Parser):
388 self.set_parser(_AsyncRESP3Parser)
389 # update cluster exception classes
390 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
391 self._parser.on_connect(self)
392 if len(auth_args) == 1:
393 auth_args = ["default", auth_args[0]]
394 # avoid checking health here -- PING will fail if we try
395 # to check the health prior to the AUTH
396 await self.send_command(
397 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
398 )
399 response = await self.read_response()
400 if response.get(b"proto") != int(self.protocol) and response.get(
401 "proto"
402 ) != int(self.protocol):
403 raise ConnectionError("Invalid RESP version")
404 # avoid checking health here -- PING will fail if we try
405 # to check the health prior to the AUTH
406 elif auth_args:
407 await self.send_command("AUTH", *auth_args, check_health=False)
408
409 try:
410 auth_response = await self.read_response()
411 except AuthenticationWrongNumberOfArgsError:
412 # a username and password were specified but the Redis
413 # server seems to be < 6.0.0 which expects a single password
414 # arg. retry auth with just the password.
415 # https://github.com/andymccurdy/redis-py/issues/1274
416 await self.send_command("AUTH", auth_args[-1], check_health=False)
417 auth_response = await self.read_response()
418
419 if str_if_bytes(auth_response) != "OK":
420 raise AuthenticationError("Invalid Username or Password")
421
422 # if resp version is specified, switch to it
423 elif self.protocol not in [2, "2"]:
424 if isinstance(self._parser, _AsyncRESP2Parser):
425 self.set_parser(_AsyncRESP3Parser)
426 # update cluster exception classes
427 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
428 self._parser.on_connect(self)
429 await self.send_command("HELLO", self.protocol, check_health=check_health)
430 response = await self.read_response()
431 # if response.get(b"proto") != self.protocol and response.get(
432 # "proto"
433 # ) != self.protocol:
434 # raise ConnectionError("Invalid RESP version")
435
436 # if a client_name is given, set it
437 if self.client_name:
438 await self.send_command(
439 "CLIENT",
440 "SETNAME",
441 self.client_name,
442 check_health=check_health,
443 )
444 if str_if_bytes(await self.read_response()) != "OK":
445 raise ConnectionError("Error setting client name")
446
447 # set the library name and version, pipeline for lower startup latency
448 if self.lib_name:
449 await self.send_command(
450 "CLIENT",
451 "SETINFO",
452 "LIB-NAME",
453 self.lib_name,
454 check_health=check_health,
455 )
456 if self.lib_version:
457 await self.send_command(
458 "CLIENT",
459 "SETINFO",
460 "LIB-VER",
461 self.lib_version,
462 check_health=check_health,
463 )
464 # if a database is specified, switch to it. Also pipeline this
465 if self.db:
466 await self.send_command("SELECT", self.db, check_health=check_health)
467
468 # read responses from pipeline
469 for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
470 try:
471 await self.read_response()
472 except ResponseError:
473 pass
474
475 if self.db:
476 if str_if_bytes(await self.read_response()) != "OK":
477 raise ConnectionError("Invalid Database")
478
479 async def disconnect(self, nowait: bool = False) -> None:
480 """Disconnects from the Redis server"""
481 try:
482 async with async_timeout(self.socket_connect_timeout):
483 self._parser.on_disconnect()
484 if not self.is_connected:
485 return
486 try:
487 self._writer.close() # type: ignore[union-attr]
488 # wait for close to finish, except when handling errors and
489 # forcefully disconnecting.
490 if not nowait:
491 await self._writer.wait_closed() # type: ignore[union-attr]
492 except OSError:
493 pass
494 finally:
495 self._reader = None
496 self._writer = None
497 except asyncio.TimeoutError:
498 raise TimeoutError(
499 f"Timed out closing connection after {self.socket_connect_timeout}"
500 ) from None
501
502 async def _send_ping(self):
503 """Send PING, expect PONG in return"""
504 await self.send_command("PING", check_health=False)
505 if str_if_bytes(await self.read_response()) != "PONG":
506 raise ConnectionError("Bad response from PING health check")
507
508 async def _ping_failed(self, error):
509 """Function to call when PING fails"""
510 await self.disconnect()
511
512 async def check_health(self):
513 """Check the health of the connection with a PING/PONG"""
514 if (
515 self.health_check_interval
516 and asyncio.get_running_loop().time() > self.next_health_check
517 ):
518 await self.retry.call_with_retry(self._send_ping, self._ping_failed)
519
520 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
521 self._writer.writelines(command)
522 await self._writer.drain()
523
524 async def send_packed_command(
525 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
526 ) -> None:
527 if not self.is_connected:
528 await self.connect_check_health(check_health=False)
529 if check_health:
530 await self.check_health()
531
532 try:
533 if isinstance(command, str):
534 command = command.encode()
535 if isinstance(command, bytes):
536 command = [command]
537 if self.socket_timeout:
538 await asyncio.wait_for(
539 self._send_packed_command(command), self.socket_timeout
540 )
541 else:
542 self._writer.writelines(command)
543 await self._writer.drain()
544 except asyncio.TimeoutError:
545 await self.disconnect(nowait=True)
546 raise TimeoutError("Timeout writing to socket") from None
547 except OSError as e:
548 await self.disconnect(nowait=True)
549 if len(e.args) == 1:
550 err_no, errmsg = "UNKNOWN", e.args[0]
551 else:
552 err_no = e.args[0]
553 errmsg = e.args[1]
554 raise ConnectionError(
555 f"Error {err_no} while writing to socket. {errmsg}."
556 ) from e
557 except BaseException:
558 # BaseExceptions can be raised when a socket send operation is not
559 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
560 # to send un-sent data. However, the send_packed_command() API
561 # does not support it so there is no point in keeping the connection open.
562 await self.disconnect(nowait=True)
563 raise
564
565 async def send_command(self, *args: Any, **kwargs: Any) -> None:
566 """Pack and send a command to the Redis server"""
567 await self.send_packed_command(
568 self.pack_command(*args), check_health=kwargs.get("check_health", True)
569 )
570
571 async def can_read_destructive(self):
572 """Poll the socket to see if there's data that can be read."""
573 try:
574 return await self._parser.can_read_destructive()
575 except OSError as e:
576 await self.disconnect(nowait=True)
577 host_error = self._host_error()
578 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
579
580 async def read_response(
581 self,
582 disable_decoding: bool = False,
583 timeout: Optional[float] = None,
584 *,
585 disconnect_on_error: bool = True,
586 push_request: Optional[bool] = False,
587 ):
588 """Read the response from a previously sent command"""
589 read_timeout = timeout if timeout is not None else self.socket_timeout
590 host_error = self._host_error()
591 try:
592 if read_timeout is not None and self.protocol in ["3", 3]:
593 async with async_timeout(read_timeout):
594 response = await self._parser.read_response(
595 disable_decoding=disable_decoding, push_request=push_request
596 )
597 elif read_timeout is not None:
598 async with async_timeout(read_timeout):
599 response = await self._parser.read_response(
600 disable_decoding=disable_decoding
601 )
602 elif self.protocol in ["3", 3]:
603 response = await self._parser.read_response(
604 disable_decoding=disable_decoding, push_request=push_request
605 )
606 else:
607 response = await self._parser.read_response(
608 disable_decoding=disable_decoding
609 )
610 except asyncio.TimeoutError:
611 if timeout is not None:
612 # user requested timeout, return None. Operation can be retried
613 return None
614 # it was a self.socket_timeout error.
615 if disconnect_on_error:
616 await self.disconnect(nowait=True)
617 raise TimeoutError(f"Timeout reading from {host_error}")
618 except OSError as e:
619 if disconnect_on_error:
620 await self.disconnect(nowait=True)
621 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
622 except BaseException:
623 # Also by default close in case of BaseException. A lot of code
624 # relies on this behaviour when doing Command/Response pairs.
625 # See #1128.
626 if disconnect_on_error:
627 await self.disconnect(nowait=True)
628 raise
629
630 if self.health_check_interval:
631 next_time = asyncio.get_running_loop().time() + self.health_check_interval
632 self.next_health_check = next_time
633
634 if isinstance(response, ResponseError):
635 raise response from None
636 return response
637
638 def pack_command(self, *args: EncodableT) -> List[bytes]:
639 """Pack a series of arguments into the Redis protocol"""
640 output = []
641 # the client might have included 1 or more literal arguments in
642 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
643 # arguments to be sent separately, so split the first argument
644 # manually. These arguments should be bytestrings so that they are
645 # not encoded.
646 assert not isinstance(args[0], float)
647 if isinstance(args[0], str):
648 args = tuple(args[0].encode().split()) + args[1:]
649 elif b" " in args[0]:
650 args = tuple(args[0].split()) + args[1:]
651
652 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
653
654 buffer_cutoff = self._buffer_cutoff
655 for arg in map(self.encoder.encode, args):
656 # to avoid large string mallocs, chunk the command into the
657 # output list if we're sending large values or memoryviews
658 arg_length = len(arg)
659 if (
660 len(buff) > buffer_cutoff
661 or arg_length > buffer_cutoff
662 or isinstance(arg, memoryview)
663 ):
664 buff = SYM_EMPTY.join(
665 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
666 )
667 output.append(buff)
668 output.append(arg)
669 buff = SYM_CRLF
670 else:
671 buff = SYM_EMPTY.join(
672 (
673 buff,
674 SYM_DOLLAR,
675 str(arg_length).encode(),
676 SYM_CRLF,
677 arg,
678 SYM_CRLF,
679 )
680 )
681 output.append(buff)
682 return output
683
684 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
685 """Pack multiple commands into the Redis protocol"""
686 output: List[bytes] = []
687 pieces: List[bytes] = []
688 buffer_length = 0
689 buffer_cutoff = self._buffer_cutoff
690
691 for cmd in commands:
692 for chunk in self.pack_command(*cmd):
693 chunklen = len(chunk)
694 if (
695 buffer_length > buffer_cutoff
696 or chunklen > buffer_cutoff
697 or isinstance(chunk, memoryview)
698 ):
699 if pieces:
700 output.append(SYM_EMPTY.join(pieces))
701 buffer_length = 0
702 pieces = []
703
704 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
705 output.append(chunk)
706 else:
707 pieces.append(chunk)
708 buffer_length += chunklen
709
710 if pieces:
711 output.append(SYM_EMPTY.join(pieces))
712 return output
713
714 def _socket_is_empty(self):
715 """Check if the socket is empty"""
716 return len(self._reader._buffer) == 0
717
718 async def process_invalidation_messages(self):
719 while not self._socket_is_empty():
720 await self.read_response(push_request=True)
721
722 def set_re_auth_token(self, token: TokenInterface):
723 self._re_auth_token = token
724
725 async def re_auth(self):
726 if self._re_auth_token is not None:
727 await self.send_command(
728 "AUTH",
729 self._re_auth_token.try_get("oid"),
730 self._re_auth_token.get_value(),
731 )
732 await self.read_response()
733 self._re_auth_token = None
734
735
736class Connection(AbstractConnection):
737 "Manages TCP communication to and from a Redis server"
738
739 def __init__(
740 self,
741 *,
742 host: str = "localhost",
743 port: Union[str, int] = 6379,
744 socket_keepalive: bool = False,
745 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
746 socket_type: int = 0,
747 **kwargs,
748 ):
749 self.host = host
750 self.port = int(port)
751 self.socket_keepalive = socket_keepalive
752 self.socket_keepalive_options = socket_keepalive_options or {}
753 self.socket_type = socket_type
754 super().__init__(**kwargs)
755
756 def repr_pieces(self):
757 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
758 if self.client_name:
759 pieces.append(("client_name", self.client_name))
760 return pieces
761
762 def _connection_arguments(self) -> Mapping:
763 return {"host": self.host, "port": self.port}
764
765 async def _connect(self):
766 """Create a TCP socket connection"""
767 async with async_timeout(self.socket_connect_timeout):
768 reader, writer = await asyncio.open_connection(
769 **self._connection_arguments()
770 )
771 self._reader = reader
772 self._writer = writer
773 sock = writer.transport.get_extra_info("socket")
774 if sock:
775 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
776 try:
777 # TCP_KEEPALIVE
778 if self.socket_keepalive:
779 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
780 for k, v in self.socket_keepalive_options.items():
781 sock.setsockopt(socket.SOL_TCP, k, v)
782
783 except (OSError, TypeError):
784 # `socket_keepalive_options` might contain invalid options
785 # causing an error. Do not leave the connection open.
786 writer.close()
787 raise
788
789 def _host_error(self) -> str:
790 return f"{self.host}:{self.port}"
791
792
793class SSLConnection(Connection):
794 """Manages SSL connections to and from the Redis server(s).
795 This class extends the Connection class, adding SSL functionality, and making
796 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
797 """
798
799 def __init__(
800 self,
801 ssl_keyfile: Optional[str] = None,
802 ssl_certfile: Optional[str] = None,
803 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
804 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
805 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
806 ssl_ca_certs: Optional[str] = None,
807 ssl_ca_data: Optional[str] = None,
808 ssl_check_hostname: bool = True,
809 ssl_min_version: Optional[TLSVersion] = None,
810 ssl_ciphers: Optional[str] = None,
811 **kwargs,
812 ):
813 if not SSL_AVAILABLE:
814 raise RedisError("Python wasn't built with SSL support")
815
816 self.ssl_context: RedisSSLContext = RedisSSLContext(
817 keyfile=ssl_keyfile,
818 certfile=ssl_certfile,
819 cert_reqs=ssl_cert_reqs,
820 include_verify_flags=ssl_include_verify_flags,
821 exclude_verify_flags=ssl_exclude_verify_flags,
822 ca_certs=ssl_ca_certs,
823 ca_data=ssl_ca_data,
824 check_hostname=ssl_check_hostname,
825 min_version=ssl_min_version,
826 ciphers=ssl_ciphers,
827 )
828 super().__init__(**kwargs)
829
830 def _connection_arguments(self) -> Mapping:
831 kwargs = super()._connection_arguments()
832 kwargs["ssl"] = self.ssl_context.get()
833 return kwargs
834
835 @property
836 def keyfile(self):
837 return self.ssl_context.keyfile
838
839 @property
840 def certfile(self):
841 return self.ssl_context.certfile
842
843 @property
844 def cert_reqs(self):
845 return self.ssl_context.cert_reqs
846
847 @property
848 def include_verify_flags(self):
849 return self.ssl_context.include_verify_flags
850
851 @property
852 def exclude_verify_flags(self):
853 return self.ssl_context.exclude_verify_flags
854
855 @property
856 def ca_certs(self):
857 return self.ssl_context.ca_certs
858
859 @property
860 def ca_data(self):
861 return self.ssl_context.ca_data
862
863 @property
864 def check_hostname(self):
865 return self.ssl_context.check_hostname
866
867 @property
868 def min_version(self):
869 return self.ssl_context.min_version
870
871
872class RedisSSLContext:
873 __slots__ = (
874 "keyfile",
875 "certfile",
876 "cert_reqs",
877 "include_verify_flags",
878 "exclude_verify_flags",
879 "ca_certs",
880 "ca_data",
881 "context",
882 "check_hostname",
883 "min_version",
884 "ciphers",
885 )
886
887 def __init__(
888 self,
889 keyfile: Optional[str] = None,
890 certfile: Optional[str] = None,
891 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
892 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
893 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
894 ca_certs: Optional[str] = None,
895 ca_data: Optional[str] = None,
896 check_hostname: bool = False,
897 min_version: Optional[TLSVersion] = None,
898 ciphers: Optional[str] = None,
899 ):
900 if not SSL_AVAILABLE:
901 raise RedisError("Python wasn't built with SSL support")
902
903 self.keyfile = keyfile
904 self.certfile = certfile
905 if cert_reqs is None:
906 cert_reqs = ssl.CERT_NONE
907 elif isinstance(cert_reqs, str):
908 CERT_REQS = { # noqa: N806
909 "none": ssl.CERT_NONE,
910 "optional": ssl.CERT_OPTIONAL,
911 "required": ssl.CERT_REQUIRED,
912 }
913 if cert_reqs not in CERT_REQS:
914 raise RedisError(
915 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
916 )
917 cert_reqs = CERT_REQS[cert_reqs]
918 self.cert_reqs = cert_reqs
919 self.include_verify_flags = include_verify_flags
920 self.exclude_verify_flags = exclude_verify_flags
921 self.ca_certs = ca_certs
922 self.ca_data = ca_data
923 self.check_hostname = (
924 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
925 )
926 self.min_version = min_version
927 self.ciphers = ciphers
928 self.context: Optional[SSLContext] = None
929
930 def get(self) -> SSLContext:
931 if not self.context:
932 context = ssl.create_default_context()
933 context.check_hostname = self.check_hostname
934 context.verify_mode = self.cert_reqs
935 if self.include_verify_flags:
936 for flag in self.include_verify_flags:
937 context.verify_flags |= flag
938 if self.exclude_verify_flags:
939 for flag in self.exclude_verify_flags:
940 context.verify_flags &= ~flag
941 if self.certfile and self.keyfile:
942 context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
943 if self.ca_certs or self.ca_data:
944 context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
945 if self.min_version is not None:
946 context.minimum_version = self.min_version
947 if self.ciphers is not None:
948 context.set_ciphers(self.ciphers)
949 self.context = context
950 return self.context
951
952
953class UnixDomainSocketConnection(AbstractConnection):
954 "Manages UDS communication to and from a Redis server"
955
956 def __init__(self, *, path: str = "", **kwargs):
957 self.path = path
958 super().__init__(**kwargs)
959
960 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
961 pieces = [("path", self.path), ("db", self.db)]
962 if self.client_name:
963 pieces.append(("client_name", self.client_name))
964 return pieces
965
966 async def _connect(self):
967 async with async_timeout(self.socket_connect_timeout):
968 reader, writer = await asyncio.open_unix_connection(path=self.path)
969 self._reader = reader
970 self._writer = writer
971 await self.on_connect()
972
973 def _host_error(self) -> str:
974 return self.path
975
976
977FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
978
979
980def to_bool(value) -> Optional[bool]:
981 if value is None or value == "":
982 return None
983 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
984 return False
985 return bool(value)
986
987
988def parse_ssl_verify_flags(value):
989 # flags are passed in as a string representation of a list,
990 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
991 verify_flags_str = value.replace("[", "").replace("]", "")
992
993 verify_flags = []
994 for flag in verify_flags_str.split(","):
995 flag = flag.strip()
996 if not hasattr(VerifyFlags, flag):
997 raise ValueError(f"Invalid ssl verify flag: {flag}")
998 verify_flags.append(getattr(VerifyFlags, flag))
999 return verify_flags
1000
1001
1002URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1003 {
1004 "db": int,
1005 "socket_timeout": float,
1006 "socket_connect_timeout": float,
1007 "socket_keepalive": to_bool,
1008 "retry_on_timeout": to_bool,
1009 "max_connections": int,
1010 "health_check_interval": int,
1011 "ssl_check_hostname": to_bool,
1012 "ssl_include_verify_flags": parse_ssl_verify_flags,
1013 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1014 "timeout": float,
1015 }
1016)
1017
1018
1019class ConnectKwargs(TypedDict, total=False):
1020 username: str
1021 password: str
1022 connection_class: Type[AbstractConnection]
1023 host: str
1024 port: int
1025 db: int
1026 path: str
1027
1028
1029def parse_url(url: str) -> ConnectKwargs:
1030 parsed: ParseResult = urlparse(url)
1031 kwargs: ConnectKwargs = {}
1032
1033 for name, value_list in parse_qs(parsed.query).items():
1034 if value_list and len(value_list) > 0:
1035 value = unquote(value_list[0])
1036 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1037 if parser:
1038 try:
1039 kwargs[name] = parser(value)
1040 except (TypeError, ValueError):
1041 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1042 else:
1043 kwargs[name] = value
1044
1045 if parsed.username:
1046 kwargs["username"] = unquote(parsed.username)
1047 if parsed.password:
1048 kwargs["password"] = unquote(parsed.password)
1049
1050 # We only support redis://, rediss:// and unix:// schemes.
1051 if parsed.scheme == "unix":
1052 if parsed.path:
1053 kwargs["path"] = unquote(parsed.path)
1054 kwargs["connection_class"] = UnixDomainSocketConnection
1055
1056 elif parsed.scheme in ("redis", "rediss"):
1057 if parsed.hostname:
1058 kwargs["host"] = unquote(parsed.hostname)
1059 if parsed.port:
1060 kwargs["port"] = int(parsed.port)
1061
1062 # If there's a path argument, use it as the db argument if a
1063 # querystring value wasn't specified
1064 if parsed.path and "db" not in kwargs:
1065 try:
1066 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1067 except (AttributeError, ValueError):
1068 pass
1069
1070 if parsed.scheme == "rediss":
1071 kwargs["connection_class"] = SSLConnection
1072
1073 else:
1074 valid_schemes = "redis://, rediss://, unix://"
1075 raise ValueError(
1076 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1077 )
1078
1079 return kwargs
1080
1081
1082_CP = TypeVar("_CP", bound="ConnectionPool")
1083
1084
1085class ConnectionPool:
1086 """
1087 Create a connection pool. ``If max_connections`` is set, then this
1088 object raises :py:class:`~redis.ConnectionError` when the pool's
1089 limit is reached.
1090
1091 By default, TCP connections are created unless ``connection_class``
1092 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1093 unix sockets.
1094 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1095
1096 Any additional keyword arguments are passed to the constructor of
1097 ``connection_class``.
1098 """
1099
1100 @classmethod
1101 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1102 """
1103 Return a connection pool configured from the given URL.
1104
1105 For example::
1106
1107 redis://[[username]:[password]]@localhost:6379/0
1108 rediss://[[username]:[password]]@localhost:6379/0
1109 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1110
1111 Three URL schemes are supported:
1112
1113 - `redis://` creates a TCP socket connection. See more at:
1114 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1115 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1116 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1117 - ``unix://``: creates a Unix Domain Socket connection.
1118
1119 The username, password, hostname, path and all querystring values
1120 are passed through urllib.parse.unquote in order to replace any
1121 percent-encoded values with their corresponding characters.
1122
1123 There are several ways to specify a database number. The first value
1124 found will be used:
1125
1126 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1127
1128 2. If using the redis:// or rediss:// schemes, the path argument
1129 of the url, e.g. redis://localhost/0
1130
1131 3. A ``db`` keyword argument to this function.
1132
1133 If none of these options are specified, the default db=0 is used.
1134
1135 All querystring options are cast to their appropriate Python types.
1136 Boolean arguments can be specified with string values "True"/"False"
1137 or "Yes"/"No". Values that cannot be properly cast cause a
1138 ``ValueError`` to be raised. Once parsed, the querystring arguments
1139 and keyword arguments are passed to the ``ConnectionPool``'s
1140 class initializer. In the case of conflicting arguments, querystring
1141 arguments always win.
1142 """
1143 url_options = parse_url(url)
1144 kwargs.update(url_options)
1145 return cls(**kwargs)
1146
1147 def __init__(
1148 self,
1149 connection_class: Type[AbstractConnection] = Connection,
1150 max_connections: Optional[int] = None,
1151 **connection_kwargs,
1152 ):
1153 max_connections = max_connections or 2**31
1154 if not isinstance(max_connections, int) or max_connections < 0:
1155 raise ValueError('"max_connections" must be a positive integer')
1156
1157 self.connection_class = connection_class
1158 self.connection_kwargs = connection_kwargs
1159 self.max_connections = max_connections
1160
1161 self._available_connections: List[AbstractConnection] = []
1162 self._in_use_connections: Set[AbstractConnection] = set()
1163 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1164 self._lock = asyncio.Lock()
1165 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1166 if self._event_dispatcher is None:
1167 self._event_dispatcher = EventDispatcher()
1168
1169 def __repr__(self):
1170 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1171 return (
1172 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1173 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1174 f"({conn_kwargs})>)>"
1175 )
1176
1177 def reset(self):
1178 self._available_connections = []
1179 self._in_use_connections = weakref.WeakSet()
1180
1181 def can_get_connection(self) -> bool:
1182 """Return True if a connection can be retrieved from the pool."""
1183 return (
1184 self._available_connections
1185 or len(self._in_use_connections) < self.max_connections
1186 )
1187
1188 @deprecated_args(
1189 args_to_warn=["*"],
1190 reason="Use get_connection() without args instead",
1191 version="5.3.0",
1192 )
1193 async def get_connection(self, command_name=None, *keys, **options):
1194 async with self._lock:
1195 """Get a connected connection from the pool"""
1196 connection = self.get_available_connection()
1197 try:
1198 await self.ensure_connection(connection)
1199 except BaseException:
1200 await self.release(connection)
1201 raise
1202
1203 return connection
1204
1205 def get_available_connection(self):
1206 """Get a connection from the pool, without making sure it is connected"""
1207 try:
1208 connection = self._available_connections.pop()
1209 except IndexError:
1210 if len(self._in_use_connections) >= self.max_connections:
1211 raise ConnectionError("Too many connections") from None
1212 connection = self.make_connection()
1213 self._in_use_connections.add(connection)
1214 return connection
1215
1216 def get_encoder(self):
1217 """Return an encoder based on encoding settings"""
1218 kwargs = self.connection_kwargs
1219 return self.encoder_class(
1220 encoding=kwargs.get("encoding", "utf-8"),
1221 encoding_errors=kwargs.get("encoding_errors", "strict"),
1222 decode_responses=kwargs.get("decode_responses", False),
1223 )
1224
1225 def make_connection(self):
1226 """Create a new connection. Can be overridden by child classes."""
1227 return self.connection_class(**self.connection_kwargs)
1228
1229 async def ensure_connection(self, connection: AbstractConnection):
1230 """Ensure that the connection object is connected and valid"""
1231 await connection.connect()
1232 # connections that the pool provides should be ready to send
1233 # a command. if not, the connection was either returned to the
1234 # pool before all data has been read or the socket has been
1235 # closed. either way, reconnect and verify everything is good.
1236 try:
1237 if await connection.can_read_destructive():
1238 raise ConnectionError("Connection has data") from None
1239 except (ConnectionError, TimeoutError, OSError):
1240 await connection.disconnect()
1241 await connection.connect()
1242 if await connection.can_read_destructive():
1243 raise ConnectionError("Connection not ready") from None
1244
1245 async def release(self, connection: AbstractConnection):
1246 """Releases the connection back to the pool"""
1247 # Connections should always be returned to the correct pool,
1248 # not doing so is an error that will cause an exception here.
1249 self._in_use_connections.remove(connection)
1250 if connection.should_reconnect():
1251 await connection.disconnect()
1252
1253 self._available_connections.append(connection)
1254 await self._event_dispatcher.dispatch_async(
1255 AsyncAfterConnectionReleasedEvent(connection)
1256 )
1257
1258 async def disconnect(self, inuse_connections: bool = True):
1259 """
1260 Disconnects connections in the pool
1261
1262 If ``inuse_connections`` is True, disconnect connections that are
1263 current in use, potentially by other tasks. Otherwise only disconnect
1264 connections that are idle in the pool.
1265 """
1266 if inuse_connections:
1267 connections: Iterable[AbstractConnection] = chain(
1268 self._available_connections, self._in_use_connections
1269 )
1270 else:
1271 connections = self._available_connections
1272 resp = await asyncio.gather(
1273 *(connection.disconnect() for connection in connections),
1274 return_exceptions=True,
1275 )
1276 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1277 if exc:
1278 raise exc
1279
1280 async def update_active_connections_for_reconnect(self):
1281 """
1282 Mark all active connections for reconnect.
1283 """
1284 async with self._lock:
1285 for conn in self._in_use_connections:
1286 conn.mark_for_reconnect()
1287
1288 async def aclose(self) -> None:
1289 """Close the pool, disconnecting all connections"""
1290 await self.disconnect()
1291
1292 def set_retry(self, retry: "Retry") -> None:
1293 for conn in self._available_connections:
1294 conn.retry = retry
1295 for conn in self._in_use_connections:
1296 conn.retry = retry
1297
1298 async def re_auth_callback(self, token: TokenInterface):
1299 async with self._lock:
1300 for conn in self._available_connections:
1301 await conn.retry.call_with_retry(
1302 lambda: conn.send_command(
1303 "AUTH", token.try_get("oid"), token.get_value()
1304 ),
1305 lambda error: self._mock(error),
1306 )
1307 await conn.retry.call_with_retry(
1308 lambda: conn.read_response(), lambda error: self._mock(error)
1309 )
1310 for conn in self._in_use_connections:
1311 conn.set_re_auth_token(token)
1312
1313 async def _mock(self, error: RedisError):
1314 """
1315 Dummy functions, needs to be passed as error callback to retry object.
1316 :param error:
1317 :return:
1318 """
1319 pass
1320
1321
1322class BlockingConnectionPool(ConnectionPool):
1323 """
1324 A blocking connection pool::
1325
1326 >>> from redis.asyncio import Redis, BlockingConnectionPool
1327 >>> client = Redis.from_pool(BlockingConnectionPool())
1328
1329 It performs the same function as the default
1330 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1331 it maintains a pool of reusable connections that can be shared by
1332 multiple async redis clients.
1333
1334 The difference is that, in the event that a client tries to get a
1335 connection from the pool when all of connections are in use, rather than
1336 raising a :py:class:`~redis.ConnectionError` (as the default
1337 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1338 blocks the current `Task` for a specified number of seconds until
1339 a connection becomes available.
1340
1341 Use ``max_connections`` to increase / decrease the pool size::
1342
1343 >>> pool = BlockingConnectionPool(max_connections=10)
1344
1345 Use ``timeout`` to tell it either how many seconds to wait for a connection
1346 to become available, or to block forever:
1347
1348 >>> # Block forever.
1349 >>> pool = BlockingConnectionPool(timeout=None)
1350
1351 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1352 >>> # not available.
1353 >>> pool = BlockingConnectionPool(timeout=5)
1354 """
1355
1356 def __init__(
1357 self,
1358 max_connections: int = 50,
1359 timeout: Optional[int] = 20,
1360 connection_class: Type[AbstractConnection] = Connection,
1361 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1362 **connection_kwargs,
1363 ):
1364 super().__init__(
1365 connection_class=connection_class,
1366 max_connections=max_connections,
1367 **connection_kwargs,
1368 )
1369 self._condition = asyncio.Condition()
1370 self.timeout = timeout
1371
1372 @deprecated_args(
1373 args_to_warn=["*"],
1374 reason="Use get_connection() without args instead",
1375 version="5.3.0",
1376 )
1377 async def get_connection(self, command_name=None, *keys, **options):
1378 """Gets a connection from the pool, blocking until one is available"""
1379 try:
1380 async with self._condition:
1381 async with async_timeout(self.timeout):
1382 await self._condition.wait_for(self.can_get_connection)
1383 connection = super().get_available_connection()
1384 except asyncio.TimeoutError as err:
1385 raise ConnectionError("No connection available.") from err
1386
1387 # We now perform the connection check outside of the lock.
1388 try:
1389 await self.ensure_connection(connection)
1390 return connection
1391 except BaseException:
1392 await self.release(connection)
1393 raise
1394
1395 async def release(self, connection: AbstractConnection):
1396 """Releases the connection back to the pool."""
1397 async with self._condition:
1398 await super().release(connection)
1399 self._condition.notify()