1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import warnings
8import weakref
9from abc import abstractmethod
10from itertools import chain
11from types import MappingProxyType
12from typing import (
13 Any,
14 Callable,
15 Iterable,
16 List,
17 Mapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26)
27from urllib.parse import ParseResult, parse_qs, unquote, urlparse
28
29from ..utils import SSL_AVAILABLE
30
31if SSL_AVAILABLE:
32 import ssl
33 from ssl import SSLContext, TLSVersion, VerifyFlags
34else:
35 ssl = None
36 TLSVersion = None
37 SSLContext = None
38 VerifyFlags = None
39
40from ..auth.token import TokenInterface
41from ..driver_info import DriverInfo, resolve_driver_info
42from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
43from ..utils import deprecated_args, format_error_message
44
45# the functionality is available in 3.11.x but has a major issue before
46# 3.11.3. See https://github.com/redis/redis-py/issues/2633
47if sys.version_info >= (3, 11, 3):
48 from asyncio import timeout as async_timeout
49else:
50 from async_timeout import timeout as async_timeout
51
52from redis.asyncio.retry import Retry
53from redis.backoff import NoBackoff
54from redis.connection import DEFAULT_RESP_VERSION
55from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
56from redis.exceptions import (
57 AuthenticationError,
58 AuthenticationWrongNumberOfArgsError,
59 ConnectionError,
60 DataError,
61 MaxConnectionsError,
62 RedisError,
63 ResponseError,
64 TimeoutError,
65)
66from redis.typing import EncodableT
67from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
68
69from .._parsers import (
70 BaseParser,
71 Encoder,
72 _AsyncHiredisParser,
73 _AsyncRESP2Parser,
74 _AsyncRESP3Parser,
75)
76
77SYM_STAR = b"*"
78SYM_DOLLAR = b"$"
79SYM_CRLF = b"\r\n"
80SYM_LF = b"\n"
81SYM_EMPTY = b""
82
83
84class _Sentinel(enum.Enum):
85 sentinel = object()
86
87
88SENTINEL = _Sentinel.sentinel
89
90
91DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
92if HIREDIS_AVAILABLE:
93 DefaultParser = _AsyncHiredisParser
94else:
95 DefaultParser = _AsyncRESP2Parser
96
97
98class ConnectCallbackProtocol(Protocol):
99 def __call__(self, connection: "AbstractConnection"): ...
100
101
102class AsyncConnectCallbackProtocol(Protocol):
103 async def __call__(self, connection: "AbstractConnection"): ...
104
105
106ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
107
108
109class AbstractConnection:
110 """Manages communication to and from a Redis server"""
111
112 __slots__ = (
113 "db",
114 "username",
115 "client_name",
116 "lib_name",
117 "lib_version",
118 "credential_provider",
119 "password",
120 "socket_timeout",
121 "socket_connect_timeout",
122 "redis_connect_func",
123 "retry_on_timeout",
124 "retry_on_error",
125 "health_check_interval",
126 "next_health_check",
127 "last_active_at",
128 "encoder",
129 "ssl_context",
130 "protocol",
131 "_reader",
132 "_writer",
133 "_parser",
134 "_connect_callbacks",
135 "_buffer_cutoff",
136 "_lock",
137 "_socket_read_size",
138 "__dict__",
139 )
140
141 @deprecated_args(
142 args_to_warn=["lib_name", "lib_version"],
143 reason="Use 'driver_info' parameter instead. "
144 "lib_name and lib_version will be removed in a future version.",
145 )
146 def __init__(
147 self,
148 *,
149 db: Union[str, int] = 0,
150 password: Optional[str] = None,
151 socket_timeout: Optional[float] = None,
152 socket_connect_timeout: Optional[float] = None,
153 retry_on_timeout: bool = False,
154 retry_on_error: Union[list, _Sentinel] = SENTINEL,
155 encoding: str = "utf-8",
156 encoding_errors: str = "strict",
157 decode_responses: bool = False,
158 parser_class: Type[BaseParser] = DefaultParser,
159 socket_read_size: int = 65536,
160 health_check_interval: float = 0,
161 client_name: Optional[str] = None,
162 lib_name: Optional[str] = None,
163 lib_version: Optional[str] = None,
164 driver_info: Optional[DriverInfo] = None,
165 username: Optional[str] = None,
166 retry: Optional[Retry] = None,
167 redis_connect_func: Optional[ConnectCallbackT] = None,
168 encoder_class: Type[Encoder] = Encoder,
169 credential_provider: Optional[CredentialProvider] = None,
170 protocol: Optional[int] = 2,
171 event_dispatcher: Optional[EventDispatcher] = None,
172 ):
173 """
174 Initialize a new async Connection.
175
176 Parameters
177 ----------
178 driver_info : DriverInfo, optional
179 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
180 are ignored. If not provided, a DriverInfo will be created from lib_name
181 and lib_version (or defaults if those are also None).
182 lib_name : str, optional
183 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
184 lib_version : str, optional
185 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
186 """
187 if (username or password) and credential_provider is not None:
188 raise DataError(
189 "'username' and 'password' cannot be passed along with 'credential_"
190 "provider'. Please provide only one of the following arguments: \n"
191 "1. 'password' and (optional) 'username'\n"
192 "2. 'credential_provider'"
193 )
194 if event_dispatcher is None:
195 self._event_dispatcher = EventDispatcher()
196 else:
197 self._event_dispatcher = event_dispatcher
198 self.db = db
199 self.client_name = client_name
200
201 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
202 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
203
204 self.credential_provider = credential_provider
205 self.password = password
206 self.username = username
207 self.socket_timeout = socket_timeout
208 if socket_connect_timeout is None:
209 socket_connect_timeout = socket_timeout
210 self.socket_connect_timeout = socket_connect_timeout
211 self.retry_on_timeout = retry_on_timeout
212 if retry_on_error is SENTINEL:
213 retry_on_error = []
214 if retry_on_timeout:
215 retry_on_error.append(TimeoutError)
216 retry_on_error.append(socket.timeout)
217 retry_on_error.append(asyncio.TimeoutError)
218 self.retry_on_error = retry_on_error
219 if retry or retry_on_error:
220 if not retry:
221 self.retry = Retry(NoBackoff(), 1)
222 else:
223 # deep-copy the Retry object as it is mutable
224 self.retry = copy.deepcopy(retry)
225 # Update the retry's supported errors with the specified errors
226 self.retry.update_supported_errors(retry_on_error)
227 else:
228 self.retry = Retry(NoBackoff(), 0)
229 self.health_check_interval = health_check_interval
230 self.next_health_check: float = -1
231 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
232 self.redis_connect_func = redis_connect_func
233 self._reader: Optional[asyncio.StreamReader] = None
234 self._writer: Optional[asyncio.StreamWriter] = None
235 self._socket_read_size = socket_read_size
236 self.set_parser(parser_class)
237 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
238 self._buffer_cutoff = 6000
239 self._re_auth_token: Optional[TokenInterface] = None
240 self._should_reconnect = False
241
242 try:
243 p = int(protocol)
244 except TypeError:
245 p = DEFAULT_RESP_VERSION
246 except ValueError:
247 raise ConnectionError("protocol must be an integer")
248 finally:
249 if p < 2 or p > 3:
250 raise ConnectionError("protocol must be either 2 or 3")
251 self.protocol = protocol
252
253 def __del__(self, _warnings: Any = warnings):
254 # For some reason, the individual streams don't get properly garbage
255 # collected and therefore produce no resource warnings. We add one
256 # here, in the same style as those from the stdlib.
257 if getattr(self, "_writer", None):
258 _warnings.warn(
259 f"unclosed Connection {self!r}", ResourceWarning, source=self
260 )
261
262 try:
263 asyncio.get_running_loop()
264 self._close()
265 except RuntimeError:
266 # No actions been taken if pool already closed.
267 pass
268
269 def _close(self):
270 """
271 Internal method to silently close the connection without waiting
272 """
273 if self._writer:
274 self._writer.close()
275 self._writer = self._reader = None
276
277 def __repr__(self):
278 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
279 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
280
281 @abstractmethod
282 def repr_pieces(self):
283 pass
284
285 @property
286 def is_connected(self):
287 return self._reader is not None and self._writer is not None
288
289 def register_connect_callback(self, callback):
290 """
291 Register a callback to be called when the connection is established either
292 initially or reconnected. This allows listeners to issue commands that
293 are ephemeral to the connection, for example pub/sub subscription or
294 key tracking. The callback must be a _method_ and will be kept as
295 a weak reference.
296 """
297 wm = weakref.WeakMethod(callback)
298 if wm not in self._connect_callbacks:
299 self._connect_callbacks.append(wm)
300
301 def deregister_connect_callback(self, callback):
302 """
303 De-register a previously registered callback. It will no-longer receive
304 notifications on connection events. Calling this is not required when the
305 listener goes away, since the callbacks are kept as weak methods.
306 """
307 try:
308 self._connect_callbacks.remove(weakref.WeakMethod(callback))
309 except ValueError:
310 pass
311
312 def set_parser(self, parser_class: Type[BaseParser]) -> None:
313 """
314 Creates a new instance of parser_class with socket size:
315 _socket_read_size and assigns it to the parser for the connection
316 :param parser_class: The required parser class
317 """
318 self._parser = parser_class(socket_read_size=self._socket_read_size)
319
320 async def connect(self):
321 """Connects to the Redis server if not already connected"""
322 # try once the socket connect with the handshake, retry the whole
323 # connect/handshake flow based on retry policy
324 await self.retry.call_with_retry(
325 lambda: self.connect_check_health(
326 check_health=True, retry_socket_connect=False
327 ),
328 lambda error: self.disconnect(),
329 )
330
331 async def connect_check_health(
332 self, check_health: bool = True, retry_socket_connect: bool = True
333 ):
334 if self.is_connected:
335 return
336 try:
337 if retry_socket_connect:
338 await self.retry.call_with_retry(
339 lambda: self._connect(), lambda error: self.disconnect()
340 )
341 else:
342 await self._connect()
343 except asyncio.CancelledError:
344 raise # in 3.7 and earlier, this is an Exception, not BaseException
345 except (socket.timeout, asyncio.TimeoutError):
346 raise TimeoutError("Timeout connecting to server")
347 except OSError as e:
348 raise ConnectionError(self._error_message(e))
349 except Exception as exc:
350 raise ConnectionError(exc) from exc
351
352 try:
353 if not self.redis_connect_func:
354 # Use the default on_connect function
355 await self.on_connect_check_health(check_health=check_health)
356 else:
357 # Use the passed function redis_connect_func
358 (
359 await self.redis_connect_func(self)
360 if asyncio.iscoroutinefunction(self.redis_connect_func)
361 else self.redis_connect_func(self)
362 )
363 except RedisError:
364 # clean up after any error in on_connect
365 await self.disconnect()
366 raise
367
368 # run any user callbacks. right now the only internal callback
369 # is for pubsub channel/pattern resubscription
370 # first, remove any dead weakrefs
371 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
372 for ref in self._connect_callbacks:
373 callback = ref()
374 task = callback(self)
375 if task and inspect.isawaitable(task):
376 await task
377
378 def mark_for_reconnect(self):
379 self._should_reconnect = True
380
381 def should_reconnect(self):
382 return self._should_reconnect
383
384 @abstractmethod
385 async def _connect(self):
386 pass
387
388 @abstractmethod
389 def _host_error(self) -> str:
390 pass
391
392 def _error_message(self, exception: BaseException) -> str:
393 return format_error_message(self._host_error(), exception)
394
395 def get_protocol(self):
396 return self.protocol
397
398 async def on_connect(self) -> None:
399 """Initialize the connection, authenticate and select a database"""
400 await self.on_connect_check_health(check_health=True)
401
402 async def on_connect_check_health(self, check_health: bool = True) -> None:
403 self._parser.on_connect(self)
404 parser = self._parser
405
406 auth_args = None
407 # if credential provider or username and/or password are set, authenticate
408 if self.credential_provider or (self.username or self.password):
409 cred_provider = (
410 self.credential_provider
411 or UsernamePasswordCredentialProvider(self.username, self.password)
412 )
413 auth_args = await cred_provider.get_credentials_async()
414
415 # if resp version is specified and we have auth args,
416 # we need to send them via HELLO
417 if auth_args and self.protocol not in [2, "2"]:
418 if isinstance(self._parser, _AsyncRESP2Parser):
419 self.set_parser(_AsyncRESP3Parser)
420 # update cluster exception classes
421 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
422 self._parser.on_connect(self)
423 if len(auth_args) == 1:
424 auth_args = ["default", auth_args[0]]
425 # avoid checking health here -- PING will fail if we try
426 # to check the health prior to the AUTH
427 await self.send_command(
428 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
429 )
430 response = await self.read_response()
431 if response.get(b"proto") != int(self.protocol) and response.get(
432 "proto"
433 ) != int(self.protocol):
434 raise ConnectionError("Invalid RESP version")
435 # avoid checking health here -- PING will fail if we try
436 # to check the health prior to the AUTH
437 elif auth_args:
438 await self.send_command("AUTH", *auth_args, check_health=False)
439
440 try:
441 auth_response = await self.read_response()
442 except AuthenticationWrongNumberOfArgsError:
443 # a username and password were specified but the Redis
444 # server seems to be < 6.0.0 which expects a single password
445 # arg. retry auth with just the password.
446 # https://github.com/andymccurdy/redis-py/issues/1274
447 await self.send_command("AUTH", auth_args[-1], check_health=False)
448 auth_response = await self.read_response()
449
450 if str_if_bytes(auth_response) != "OK":
451 raise AuthenticationError("Invalid Username or Password")
452
453 # if resp version is specified, switch to it
454 elif self.protocol not in [2, "2"]:
455 if isinstance(self._parser, _AsyncRESP2Parser):
456 self.set_parser(_AsyncRESP3Parser)
457 # update cluster exception classes
458 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
459 self._parser.on_connect(self)
460 await self.send_command("HELLO", self.protocol, check_health=check_health)
461 response = await self.read_response()
462 # if response.get(b"proto") != self.protocol and response.get(
463 # "proto"
464 # ) != self.protocol:
465 # raise ConnectionError("Invalid RESP version")
466
467 # if a client_name is given, set it
468 if self.client_name:
469 await self.send_command(
470 "CLIENT",
471 "SETNAME",
472 self.client_name,
473 check_health=check_health,
474 )
475 if str_if_bytes(await self.read_response()) != "OK":
476 raise ConnectionError("Error setting client name")
477
478 # Set the library name and version from driver_info, pipeline for lower startup latency
479 lib_name_sent = False
480 lib_version_sent = False
481
482 if self.driver_info and self.driver_info.formatted_name:
483 await self.send_command(
484 "CLIENT",
485 "SETINFO",
486 "LIB-NAME",
487 self.driver_info.formatted_name,
488 check_health=check_health,
489 )
490 lib_name_sent = True
491
492 if self.driver_info and self.driver_info.lib_version:
493 await self.send_command(
494 "CLIENT",
495 "SETINFO",
496 "LIB-VER",
497 self.driver_info.lib_version,
498 check_health=check_health,
499 )
500 lib_version_sent = True
501
502 # if a database is specified, switch to it. Also pipeline this
503 if self.db:
504 await self.send_command("SELECT", self.db, check_health=check_health)
505
506 # read responses from pipeline
507 for _ in range(sum([lib_name_sent, lib_version_sent])):
508 try:
509 await self.read_response()
510 except ResponseError:
511 pass
512
513 if self.db:
514 if str_if_bytes(await self.read_response()) != "OK":
515 raise ConnectionError("Invalid Database")
516
517 async def disconnect(self, nowait: bool = False) -> None:
518 """Disconnects from the Redis server"""
519 try:
520 async with async_timeout(self.socket_connect_timeout):
521 self._parser.on_disconnect()
522 if not self.is_connected:
523 return
524 try:
525 self._writer.close() # type: ignore[union-attr]
526 # wait for close to finish, except when handling errors and
527 # forcefully disconnecting.
528 if not nowait:
529 await self._writer.wait_closed() # type: ignore[union-attr]
530 except OSError:
531 pass
532 finally:
533 self._reader = None
534 self._writer = None
535 except asyncio.TimeoutError:
536 raise TimeoutError(
537 f"Timed out closing connection after {self.socket_connect_timeout}"
538 ) from None
539
540 async def _send_ping(self):
541 """Send PING, expect PONG in return"""
542 await self.send_command("PING", check_health=False)
543 if str_if_bytes(await self.read_response()) != "PONG":
544 raise ConnectionError("Bad response from PING health check")
545
546 async def _ping_failed(self, error):
547 """Function to call when PING fails"""
548 await self.disconnect()
549
550 async def check_health(self):
551 """Check the health of the connection with a PING/PONG"""
552 if (
553 self.health_check_interval
554 and asyncio.get_running_loop().time() > self.next_health_check
555 ):
556 await self.retry.call_with_retry(self._send_ping, self._ping_failed)
557
558 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
559 self._writer.writelines(command)
560 await self._writer.drain()
561
562 async def send_packed_command(
563 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
564 ) -> None:
565 if not self.is_connected:
566 await self.connect_check_health(check_health=False)
567 if check_health:
568 await self.check_health()
569
570 try:
571 if isinstance(command, str):
572 command = command.encode()
573 if isinstance(command, bytes):
574 command = [command]
575 if self.socket_timeout:
576 await asyncio.wait_for(
577 self._send_packed_command(command), self.socket_timeout
578 )
579 else:
580 self._writer.writelines(command)
581 await self._writer.drain()
582 except asyncio.TimeoutError:
583 await self.disconnect(nowait=True)
584 raise TimeoutError("Timeout writing to socket") from None
585 except OSError as e:
586 await self.disconnect(nowait=True)
587 if len(e.args) == 1:
588 err_no, errmsg = "UNKNOWN", e.args[0]
589 else:
590 err_no = e.args[0]
591 errmsg = e.args[1]
592 raise ConnectionError(
593 f"Error {err_no} while writing to socket. {errmsg}."
594 ) from e
595 except BaseException:
596 # BaseExceptions can be raised when a socket send operation is not
597 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
598 # to send un-sent data. However, the send_packed_command() API
599 # does not support it so there is no point in keeping the connection open.
600 await self.disconnect(nowait=True)
601 raise
602
603 async def send_command(self, *args: Any, **kwargs: Any) -> None:
604 """Pack and send a command to the Redis server"""
605 await self.send_packed_command(
606 self.pack_command(*args), check_health=kwargs.get("check_health", True)
607 )
608
609 async def can_read_destructive(self):
610 """Poll the socket to see if there's data that can be read."""
611 try:
612 return await self._parser.can_read_destructive()
613 except OSError as e:
614 await self.disconnect(nowait=True)
615 host_error = self._host_error()
616 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
617
618 async def read_response(
619 self,
620 disable_decoding: bool = False,
621 timeout: Optional[float] = None,
622 *,
623 disconnect_on_error: bool = True,
624 push_request: Optional[bool] = False,
625 ):
626 """Read the response from a previously sent command"""
627 read_timeout = timeout if timeout is not None else self.socket_timeout
628 host_error = self._host_error()
629 try:
630 if read_timeout is not None and self.protocol in ["3", 3]:
631 async with async_timeout(read_timeout):
632 response = await self._parser.read_response(
633 disable_decoding=disable_decoding, push_request=push_request
634 )
635 elif read_timeout is not None:
636 async with async_timeout(read_timeout):
637 response = await self._parser.read_response(
638 disable_decoding=disable_decoding
639 )
640 elif self.protocol in ["3", 3]:
641 response = await self._parser.read_response(
642 disable_decoding=disable_decoding, push_request=push_request
643 )
644 else:
645 response = await self._parser.read_response(
646 disable_decoding=disable_decoding
647 )
648 except asyncio.TimeoutError:
649 if timeout is not None:
650 # user requested timeout, return None. Operation can be retried
651 return None
652 # it was a self.socket_timeout error.
653 if disconnect_on_error:
654 await self.disconnect(nowait=True)
655 raise TimeoutError(f"Timeout reading from {host_error}")
656 except OSError as e:
657 if disconnect_on_error:
658 await self.disconnect(nowait=True)
659 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
660 except BaseException:
661 # Also by default close in case of BaseException. A lot of code
662 # relies on this behaviour when doing Command/Response pairs.
663 # See #1128.
664 if disconnect_on_error:
665 await self.disconnect(nowait=True)
666 raise
667
668 if self.health_check_interval:
669 next_time = asyncio.get_running_loop().time() + self.health_check_interval
670 self.next_health_check = next_time
671
672 if isinstance(response, ResponseError):
673 raise response from None
674 return response
675
676 def pack_command(self, *args: EncodableT) -> List[bytes]:
677 """Pack a series of arguments into the Redis protocol"""
678 output = []
679 # the client might have included 1 or more literal arguments in
680 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
681 # arguments to be sent separately, so split the first argument
682 # manually. These arguments should be bytestrings so that they are
683 # not encoded.
684 assert not isinstance(args[0], float)
685 if isinstance(args[0], str):
686 args = tuple(args[0].encode().split()) + args[1:]
687 elif b" " in args[0]:
688 args = tuple(args[0].split()) + args[1:]
689
690 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
691
692 buffer_cutoff = self._buffer_cutoff
693 for arg in map(self.encoder.encode, args):
694 # to avoid large string mallocs, chunk the command into the
695 # output list if we're sending large values or memoryviews
696 arg_length = len(arg)
697 if (
698 len(buff) > buffer_cutoff
699 or arg_length > buffer_cutoff
700 or isinstance(arg, memoryview)
701 ):
702 buff = SYM_EMPTY.join(
703 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
704 )
705 output.append(buff)
706 output.append(arg)
707 buff = SYM_CRLF
708 else:
709 buff = SYM_EMPTY.join(
710 (
711 buff,
712 SYM_DOLLAR,
713 str(arg_length).encode(),
714 SYM_CRLF,
715 arg,
716 SYM_CRLF,
717 )
718 )
719 output.append(buff)
720 return output
721
722 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
723 """Pack multiple commands into the Redis protocol"""
724 output: List[bytes] = []
725 pieces: List[bytes] = []
726 buffer_length = 0
727 buffer_cutoff = self._buffer_cutoff
728
729 for cmd in commands:
730 for chunk in self.pack_command(*cmd):
731 chunklen = len(chunk)
732 if (
733 buffer_length > buffer_cutoff
734 or chunklen > buffer_cutoff
735 or isinstance(chunk, memoryview)
736 ):
737 if pieces:
738 output.append(SYM_EMPTY.join(pieces))
739 buffer_length = 0
740 pieces = []
741
742 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
743 output.append(chunk)
744 else:
745 pieces.append(chunk)
746 buffer_length += chunklen
747
748 if pieces:
749 output.append(SYM_EMPTY.join(pieces))
750 return output
751
752 def _socket_is_empty(self):
753 """Check if the socket is empty"""
754 return len(self._reader._buffer) == 0
755
756 async def process_invalidation_messages(self):
757 while not self._socket_is_empty():
758 await self.read_response(push_request=True)
759
760 def set_re_auth_token(self, token: TokenInterface):
761 self._re_auth_token = token
762
763 async def re_auth(self):
764 if self._re_auth_token is not None:
765 await self.send_command(
766 "AUTH",
767 self._re_auth_token.try_get("oid"),
768 self._re_auth_token.get_value(),
769 )
770 await self.read_response()
771 self._re_auth_token = None
772
773
774class Connection(AbstractConnection):
775 "Manages TCP communication to and from a Redis server"
776
777 def __init__(
778 self,
779 *,
780 host: str = "localhost",
781 port: Union[str, int] = 6379,
782 socket_keepalive: bool = False,
783 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
784 socket_type: int = 0,
785 **kwargs,
786 ):
787 self.host = host
788 self.port = int(port)
789 self.socket_keepalive = socket_keepalive
790 self.socket_keepalive_options = socket_keepalive_options or {}
791 self.socket_type = socket_type
792 super().__init__(**kwargs)
793
794 def repr_pieces(self):
795 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
796 if self.client_name:
797 pieces.append(("client_name", self.client_name))
798 return pieces
799
800 def _connection_arguments(self) -> Mapping:
801 return {"host": self.host, "port": self.port}
802
803 async def _connect(self):
804 """Create a TCP socket connection"""
805 async with async_timeout(self.socket_connect_timeout):
806 reader, writer = await asyncio.open_connection(
807 **self._connection_arguments()
808 )
809 self._reader = reader
810 self._writer = writer
811 sock = writer.transport.get_extra_info("socket")
812 if sock:
813 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
814 try:
815 # TCP_KEEPALIVE
816 if self.socket_keepalive:
817 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
818 for k, v in self.socket_keepalive_options.items():
819 sock.setsockopt(socket.SOL_TCP, k, v)
820
821 except (OSError, TypeError):
822 # `socket_keepalive_options` might contain invalid options
823 # causing an error. Do not leave the connection open.
824 writer.close()
825 raise
826
827 def _host_error(self) -> str:
828 return f"{self.host}:{self.port}"
829
830
831class SSLConnection(Connection):
832 """Manages SSL connections to and from the Redis server(s).
833 This class extends the Connection class, adding SSL functionality, and making
834 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
835 """
836
837 def __init__(
838 self,
839 ssl_keyfile: Optional[str] = None,
840 ssl_certfile: Optional[str] = None,
841 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
842 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
843 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
844 ssl_ca_certs: Optional[str] = None,
845 ssl_ca_data: Optional[str] = None,
846 ssl_ca_path: Optional[str] = None,
847 ssl_check_hostname: bool = True,
848 ssl_min_version: Optional[TLSVersion] = None,
849 ssl_ciphers: Optional[str] = None,
850 ssl_password: Optional[str] = None,
851 **kwargs,
852 ):
853 if not SSL_AVAILABLE:
854 raise RedisError("Python wasn't built with SSL support")
855
856 self.ssl_context: RedisSSLContext = RedisSSLContext(
857 keyfile=ssl_keyfile,
858 certfile=ssl_certfile,
859 cert_reqs=ssl_cert_reqs,
860 include_verify_flags=ssl_include_verify_flags,
861 exclude_verify_flags=ssl_exclude_verify_flags,
862 ca_certs=ssl_ca_certs,
863 ca_data=ssl_ca_data,
864 ca_path=ssl_ca_path,
865 check_hostname=ssl_check_hostname,
866 min_version=ssl_min_version,
867 ciphers=ssl_ciphers,
868 password=ssl_password,
869 )
870 super().__init__(**kwargs)
871
872 def _connection_arguments(self) -> Mapping:
873 kwargs = super()._connection_arguments()
874 kwargs["ssl"] = self.ssl_context.get()
875 return kwargs
876
877 @property
878 def keyfile(self):
879 return self.ssl_context.keyfile
880
881 @property
882 def certfile(self):
883 return self.ssl_context.certfile
884
885 @property
886 def cert_reqs(self):
887 return self.ssl_context.cert_reqs
888
889 @property
890 def include_verify_flags(self):
891 return self.ssl_context.include_verify_flags
892
893 @property
894 def exclude_verify_flags(self):
895 return self.ssl_context.exclude_verify_flags
896
897 @property
898 def ca_certs(self):
899 return self.ssl_context.ca_certs
900
901 @property
902 def ca_data(self):
903 return self.ssl_context.ca_data
904
905 @property
906 def check_hostname(self):
907 return self.ssl_context.check_hostname
908
909 @property
910 def min_version(self):
911 return self.ssl_context.min_version
912
913
914class RedisSSLContext:
915 __slots__ = (
916 "keyfile",
917 "certfile",
918 "cert_reqs",
919 "include_verify_flags",
920 "exclude_verify_flags",
921 "ca_certs",
922 "ca_data",
923 "ca_path",
924 "context",
925 "check_hostname",
926 "min_version",
927 "ciphers",
928 "password",
929 )
930
931 def __init__(
932 self,
933 keyfile: Optional[str] = None,
934 certfile: Optional[str] = None,
935 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
936 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
937 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
938 ca_certs: Optional[str] = None,
939 ca_data: Optional[str] = None,
940 ca_path: Optional[str] = None,
941 check_hostname: bool = False,
942 min_version: Optional[TLSVersion] = None,
943 ciphers: Optional[str] = None,
944 password: Optional[str] = None,
945 ):
946 if not SSL_AVAILABLE:
947 raise RedisError("Python wasn't built with SSL support")
948
949 self.keyfile = keyfile
950 self.certfile = certfile
951 if cert_reqs is None:
952 cert_reqs = ssl.CERT_NONE
953 elif isinstance(cert_reqs, str):
954 CERT_REQS = { # noqa: N806
955 "none": ssl.CERT_NONE,
956 "optional": ssl.CERT_OPTIONAL,
957 "required": ssl.CERT_REQUIRED,
958 }
959 if cert_reqs not in CERT_REQS:
960 raise RedisError(
961 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
962 )
963 cert_reqs = CERT_REQS[cert_reqs]
964 self.cert_reqs = cert_reqs
965 self.include_verify_flags = include_verify_flags
966 self.exclude_verify_flags = exclude_verify_flags
967 self.ca_certs = ca_certs
968 self.ca_data = ca_data
969 self.ca_path = ca_path
970 self.check_hostname = (
971 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
972 )
973 self.min_version = min_version
974 self.ciphers = ciphers
975 self.password = password
976 self.context: Optional[SSLContext] = None
977
978 def get(self) -> SSLContext:
979 if not self.context:
980 context = ssl.create_default_context()
981 context.check_hostname = self.check_hostname
982 context.verify_mode = self.cert_reqs
983 if self.include_verify_flags:
984 for flag in self.include_verify_flags:
985 context.verify_flags |= flag
986 if self.exclude_verify_flags:
987 for flag in self.exclude_verify_flags:
988 context.verify_flags &= ~flag
989 if self.certfile or self.keyfile:
990 context.load_cert_chain(
991 certfile=self.certfile,
992 keyfile=self.keyfile,
993 password=self.password,
994 )
995 if self.ca_certs or self.ca_data or self.ca_path:
996 context.load_verify_locations(
997 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
998 )
999 if self.min_version is not None:
1000 context.minimum_version = self.min_version
1001 if self.ciphers is not None:
1002 context.set_ciphers(self.ciphers)
1003 self.context = context
1004 return self.context
1005
1006
1007class UnixDomainSocketConnection(AbstractConnection):
1008 "Manages UDS communication to and from a Redis server"
1009
1010 def __init__(self, *, path: str = "", **kwargs):
1011 self.path = path
1012 super().__init__(**kwargs)
1013
1014 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
1015 pieces = [("path", self.path), ("db", self.db)]
1016 if self.client_name:
1017 pieces.append(("client_name", self.client_name))
1018 return pieces
1019
1020 async def _connect(self):
1021 async with async_timeout(self.socket_connect_timeout):
1022 reader, writer = await asyncio.open_unix_connection(path=self.path)
1023 self._reader = reader
1024 self._writer = writer
1025 await self.on_connect()
1026
1027 def _host_error(self) -> str:
1028 return self.path
1029
1030
1031FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1032
1033
1034def to_bool(value) -> Optional[bool]:
1035 if value is None or value == "":
1036 return None
1037 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1038 return False
1039 return bool(value)
1040
1041
1042def parse_ssl_verify_flags(value):
1043 # flags are passed in as a string representation of a list,
1044 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1045 verify_flags_str = value.replace("[", "").replace("]", "")
1046
1047 verify_flags = []
1048 for flag in verify_flags_str.split(","):
1049 flag = flag.strip()
1050 if not hasattr(VerifyFlags, flag):
1051 raise ValueError(f"Invalid ssl verify flag: {flag}")
1052 verify_flags.append(getattr(VerifyFlags, flag))
1053 return verify_flags
1054
1055
1056URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1057 {
1058 "db": int,
1059 "socket_timeout": float,
1060 "socket_connect_timeout": float,
1061 "socket_keepalive": to_bool,
1062 "retry_on_timeout": to_bool,
1063 "max_connections": int,
1064 "health_check_interval": int,
1065 "ssl_check_hostname": to_bool,
1066 "ssl_include_verify_flags": parse_ssl_verify_flags,
1067 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1068 "timeout": float,
1069 }
1070)
1071
1072
1073class ConnectKwargs(TypedDict, total=False):
1074 username: str
1075 password: str
1076 connection_class: Type[AbstractConnection]
1077 host: str
1078 port: int
1079 db: int
1080 path: str
1081
1082
1083def parse_url(url: str) -> ConnectKwargs:
1084 parsed: ParseResult = urlparse(url)
1085 kwargs: ConnectKwargs = {}
1086
1087 for name, value_list in parse_qs(parsed.query).items():
1088 if value_list and len(value_list) > 0:
1089 value = unquote(value_list[0])
1090 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1091 if parser:
1092 try:
1093 kwargs[name] = parser(value)
1094 except (TypeError, ValueError):
1095 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1096 else:
1097 kwargs[name] = value
1098
1099 if parsed.username:
1100 kwargs["username"] = unquote(parsed.username)
1101 if parsed.password:
1102 kwargs["password"] = unquote(parsed.password)
1103
1104 # We only support redis://, rediss:// and unix:// schemes.
1105 if parsed.scheme == "unix":
1106 if parsed.path:
1107 kwargs["path"] = unquote(parsed.path)
1108 kwargs["connection_class"] = UnixDomainSocketConnection
1109
1110 elif parsed.scheme in ("redis", "rediss"):
1111 if parsed.hostname:
1112 kwargs["host"] = unquote(parsed.hostname)
1113 if parsed.port:
1114 kwargs["port"] = int(parsed.port)
1115
1116 # If there's a path argument, use it as the db argument if a
1117 # querystring value wasn't specified
1118 if parsed.path and "db" not in kwargs:
1119 try:
1120 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1121 except (AttributeError, ValueError):
1122 pass
1123
1124 if parsed.scheme == "rediss":
1125 kwargs["connection_class"] = SSLConnection
1126
1127 else:
1128 valid_schemes = "redis://, rediss://, unix://"
1129 raise ValueError(
1130 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1131 )
1132
1133 return kwargs
1134
1135
1136_CP = TypeVar("_CP", bound="ConnectionPool")
1137
1138
1139class ConnectionPool:
1140 """
1141 Create a connection pool. ``If max_connections`` is set, then this
1142 object raises :py:class:`~redis.ConnectionError` when the pool's
1143 limit is reached.
1144
1145 By default, TCP connections are created unless ``connection_class``
1146 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1147 unix sockets.
1148 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1149
1150 Any additional keyword arguments are passed to the constructor of
1151 ``connection_class``.
1152 """
1153
1154 @classmethod
1155 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1156 """
1157 Return a connection pool configured from the given URL.
1158
1159 For example::
1160
1161 redis://[[username]:[password]]@localhost:6379/0
1162 rediss://[[username]:[password]]@localhost:6379/0
1163 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1164
1165 Three URL schemes are supported:
1166
1167 - `redis://` creates a TCP socket connection. See more at:
1168 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1169 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1170 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1171 - ``unix://``: creates a Unix Domain Socket connection.
1172
1173 The username, password, hostname, path and all querystring values
1174 are passed through urllib.parse.unquote in order to replace any
1175 percent-encoded values with their corresponding characters.
1176
1177 There are several ways to specify a database number. The first value
1178 found will be used:
1179
1180 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1181
1182 2. If using the redis:// or rediss:// schemes, the path argument
1183 of the url, e.g. redis://localhost/0
1184
1185 3. A ``db`` keyword argument to this function.
1186
1187 If none of these options are specified, the default db=0 is used.
1188
1189 All querystring options are cast to their appropriate Python types.
1190 Boolean arguments can be specified with string values "True"/"False"
1191 or "Yes"/"No". Values that cannot be properly cast cause a
1192 ``ValueError`` to be raised. Once parsed, the querystring arguments
1193 and keyword arguments are passed to the ``ConnectionPool``'s
1194 class initializer. In the case of conflicting arguments, querystring
1195 arguments always win.
1196 """
1197 url_options = parse_url(url)
1198 kwargs.update(url_options)
1199 return cls(**kwargs)
1200
1201 def __init__(
1202 self,
1203 connection_class: Type[AbstractConnection] = Connection,
1204 max_connections: Optional[int] = None,
1205 **connection_kwargs,
1206 ):
1207 max_connections = max_connections or 2**31
1208 if not isinstance(max_connections, int) or max_connections < 0:
1209 raise ValueError('"max_connections" must be a positive integer')
1210
1211 self.connection_class = connection_class
1212 self.connection_kwargs = connection_kwargs
1213 self.max_connections = max_connections
1214
1215 self._available_connections: List[AbstractConnection] = []
1216 self._in_use_connections: Set[AbstractConnection] = set()
1217 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1218 self._lock = asyncio.Lock()
1219 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1220 if self._event_dispatcher is None:
1221 self._event_dispatcher = EventDispatcher()
1222
1223 def __repr__(self):
1224 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1225 return (
1226 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1227 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1228 f"({conn_kwargs})>)>"
1229 )
1230
1231 def reset(self):
1232 self._available_connections = []
1233 self._in_use_connections = weakref.WeakSet()
1234
1235 def can_get_connection(self) -> bool:
1236 """Return True if a connection can be retrieved from the pool."""
1237 return (
1238 self._available_connections
1239 or len(self._in_use_connections) < self.max_connections
1240 )
1241
1242 @deprecated_args(
1243 args_to_warn=["*"],
1244 reason="Use get_connection() without args instead",
1245 version="5.3.0",
1246 )
1247 async def get_connection(self, command_name=None, *keys, **options):
1248 """Get a connected connection from the pool"""
1249 async with self._lock:
1250 connection = self.get_available_connection()
1251
1252 # We now perform the connection check outside of the lock.
1253 try:
1254 await self.ensure_connection(connection)
1255 return connection
1256 except BaseException:
1257 await self.release(connection)
1258 raise
1259
1260 def get_available_connection(self):
1261 """Get a connection from the pool, without making sure it is connected"""
1262 try:
1263 connection = self._available_connections.pop()
1264 except IndexError:
1265 if len(self._in_use_connections) >= self.max_connections:
1266 raise MaxConnectionsError("Too many connections") from None
1267 connection = self.make_connection()
1268 self._in_use_connections.add(connection)
1269 return connection
1270
1271 def get_encoder(self):
1272 """Return an encoder based on encoding settings"""
1273 kwargs = self.connection_kwargs
1274 return self.encoder_class(
1275 encoding=kwargs.get("encoding", "utf-8"),
1276 encoding_errors=kwargs.get("encoding_errors", "strict"),
1277 decode_responses=kwargs.get("decode_responses", False),
1278 )
1279
1280 def make_connection(self):
1281 """Create a new connection. Can be overridden by child classes."""
1282 return self.connection_class(**self.connection_kwargs)
1283
1284 async def ensure_connection(self, connection: AbstractConnection):
1285 """Ensure that the connection object is connected and valid"""
1286 await connection.connect()
1287 # connections that the pool provides should be ready to send
1288 # a command. if not, the connection was either returned to the
1289 # pool before all data has been read or the socket has been
1290 # closed. either way, reconnect and verify everything is good.
1291 try:
1292 if await connection.can_read_destructive():
1293 raise ConnectionError("Connection has data") from None
1294 except (ConnectionError, TimeoutError, OSError):
1295 await connection.disconnect()
1296 await connection.connect()
1297 if await connection.can_read_destructive():
1298 raise ConnectionError("Connection not ready") from None
1299
1300 async def release(self, connection: AbstractConnection):
1301 """Releases the connection back to the pool"""
1302 # Connections should always be returned to the correct pool,
1303 # not doing so is an error that will cause an exception here.
1304 self._in_use_connections.remove(connection)
1305 if connection.should_reconnect():
1306 await connection.disconnect()
1307
1308 self._available_connections.append(connection)
1309 await self._event_dispatcher.dispatch_async(
1310 AsyncAfterConnectionReleasedEvent(connection)
1311 )
1312
1313 async def disconnect(self, inuse_connections: bool = True):
1314 """
1315 Disconnects connections in the pool
1316
1317 If ``inuse_connections`` is True, disconnect connections that are
1318 current in use, potentially by other tasks. Otherwise only disconnect
1319 connections that are idle in the pool.
1320 """
1321 if inuse_connections:
1322 connections: Iterable[AbstractConnection] = chain(
1323 self._available_connections, self._in_use_connections
1324 )
1325 else:
1326 connections = self._available_connections
1327 resp = await asyncio.gather(
1328 *(connection.disconnect() for connection in connections),
1329 return_exceptions=True,
1330 )
1331 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1332 if exc:
1333 raise exc
1334
1335 async def update_active_connections_for_reconnect(self):
1336 """
1337 Mark all active connections for reconnect.
1338 """
1339 async with self._lock:
1340 for conn in self._in_use_connections:
1341 conn.mark_for_reconnect()
1342
1343 async def aclose(self) -> None:
1344 """Close the pool, disconnecting all connections"""
1345 await self.disconnect()
1346
1347 def set_retry(self, retry: "Retry") -> None:
1348 for conn in self._available_connections:
1349 conn.retry = retry
1350 for conn in self._in_use_connections:
1351 conn.retry = retry
1352
1353 async def re_auth_callback(self, token: TokenInterface):
1354 async with self._lock:
1355 for conn in self._available_connections:
1356 await conn.retry.call_with_retry(
1357 lambda: conn.send_command(
1358 "AUTH", token.try_get("oid"), token.get_value()
1359 ),
1360 lambda error: self._mock(error),
1361 )
1362 await conn.retry.call_with_retry(
1363 lambda: conn.read_response(), lambda error: self._mock(error)
1364 )
1365 for conn in self._in_use_connections:
1366 conn.set_re_auth_token(token)
1367
1368 async def _mock(self, error: RedisError):
1369 """
1370 Dummy functions, needs to be passed as error callback to retry object.
1371 :param error:
1372 :return:
1373 """
1374 pass
1375
1376
1377class BlockingConnectionPool(ConnectionPool):
1378 """
1379 A blocking connection pool::
1380
1381 >>> from redis.asyncio import Redis, BlockingConnectionPool
1382 >>> client = Redis.from_pool(BlockingConnectionPool())
1383
1384 It performs the same function as the default
1385 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1386 it maintains a pool of reusable connections that can be shared by
1387 multiple async redis clients.
1388
1389 The difference is that, in the event that a client tries to get a
1390 connection from the pool when all of connections are in use, rather than
1391 raising a :py:class:`~redis.ConnectionError` (as the default
1392 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1393 blocks the current `Task` for a specified number of seconds until
1394 a connection becomes available.
1395
1396 Use ``max_connections`` to increase / decrease the pool size::
1397
1398 >>> pool = BlockingConnectionPool(max_connections=10)
1399
1400 Use ``timeout`` to tell it either how many seconds to wait for a connection
1401 to become available, or to block forever:
1402
1403 >>> # Block forever.
1404 >>> pool = BlockingConnectionPool(timeout=None)
1405
1406 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1407 >>> # not available.
1408 >>> pool = BlockingConnectionPool(timeout=5)
1409 """
1410
1411 def __init__(
1412 self,
1413 max_connections: int = 50,
1414 timeout: Optional[float] = 20,
1415 connection_class: Type[AbstractConnection] = Connection,
1416 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1417 **connection_kwargs,
1418 ):
1419 super().__init__(
1420 connection_class=connection_class,
1421 max_connections=max_connections,
1422 **connection_kwargs,
1423 )
1424 self._condition = asyncio.Condition()
1425 self.timeout = timeout
1426
1427 @deprecated_args(
1428 args_to_warn=["*"],
1429 reason="Use get_connection() without args instead",
1430 version="5.3.0",
1431 )
1432 async def get_connection(self, command_name=None, *keys, **options):
1433 """Gets a connection from the pool, blocking until one is available"""
1434 try:
1435 async with self._condition:
1436 async with async_timeout(self.timeout):
1437 await self._condition.wait_for(self.can_get_connection)
1438 connection = super().get_available_connection()
1439 except asyncio.TimeoutError as err:
1440 raise ConnectionError("No connection available.") from err
1441
1442 # We now perform the connection check outside of the lock.
1443 try:
1444 await self.ensure_connection(connection)
1445 return connection
1446 except BaseException:
1447 await self.release(connection)
1448 raise
1449
1450 async def release(self, connection: AbstractConnection):
1451 """Releases the connection back to the pool."""
1452 async with self._condition:
1453 await super().release(connection)
1454 self._condition.notify()