1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import warnings
8import weakref
9from abc import abstractmethod
10from itertools import chain
11from types import MappingProxyType
12from typing import (
13 Any,
14 Callable,
15 Iterable,
16 List,
17 Mapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26)
27from urllib.parse import ParseResult, parse_qs, unquote, urlparse
28
29from ..utils import SSL_AVAILABLE
30
31if SSL_AVAILABLE:
32 import ssl
33 from ssl import SSLContext, TLSVersion, VerifyFlags
34else:
35 ssl = None
36 TLSVersion = None
37 SSLContext = None
38 VerifyFlags = None
39
40from ..auth.token import TokenInterface
41from ..driver_info import DriverInfo, resolve_driver_info
42from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
43from ..utils import deprecated_args, format_error_message
44
45# the functionality is available in 3.11.x but has a major issue before
46# 3.11.3. See https://github.com/redis/redis-py/issues/2633
47if sys.version_info >= (3, 11, 3):
48 from asyncio import timeout as async_timeout
49else:
50 from async_timeout import timeout as async_timeout
51
52from redis.asyncio.retry import Retry
53from redis.backoff import NoBackoff
54from redis.connection import DEFAULT_RESP_VERSION
55from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
56from redis.exceptions import (
57 AuthenticationError,
58 AuthenticationWrongNumberOfArgsError,
59 ConnectionError,
60 DataError,
61 MaxConnectionsError,
62 RedisError,
63 ResponseError,
64 TimeoutError,
65)
66from redis.typing import EncodableT
67from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
68
69from .._parsers import (
70 BaseParser,
71 Encoder,
72 _AsyncHiredisParser,
73 _AsyncRESP2Parser,
74 _AsyncRESP3Parser,
75)
76
77SYM_STAR = b"*"
78SYM_DOLLAR = b"$"
79SYM_CRLF = b"\r\n"
80SYM_LF = b"\n"
81SYM_EMPTY = b""
82
83
84class _Sentinel(enum.Enum):
85 sentinel = object()
86
87
88SENTINEL = _Sentinel.sentinel
89
90
91DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
92if HIREDIS_AVAILABLE:
93 DefaultParser = _AsyncHiredisParser
94else:
95 DefaultParser = _AsyncRESP2Parser
96
97
98class ConnectCallbackProtocol(Protocol):
99 def __call__(self, connection: "AbstractConnection"): ...
100
101
102class AsyncConnectCallbackProtocol(Protocol):
103 async def __call__(self, connection: "AbstractConnection"): ...
104
105
106ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
107
108
109class AbstractConnection:
110 """Manages communication to and from a Redis server"""
111
112 __slots__ = (
113 "db",
114 "username",
115 "client_name",
116 "lib_name",
117 "lib_version",
118 "credential_provider",
119 "password",
120 "socket_timeout",
121 "socket_connect_timeout",
122 "redis_connect_func",
123 "retry_on_timeout",
124 "retry_on_error",
125 "health_check_interval",
126 "next_health_check",
127 "last_active_at",
128 "encoder",
129 "ssl_context",
130 "protocol",
131 "_reader",
132 "_writer",
133 "_parser",
134 "_connect_callbacks",
135 "_buffer_cutoff",
136 "_lock",
137 "_socket_read_size",
138 "__dict__",
139 )
140
141 @deprecated_args(
142 args_to_warn=["lib_name", "lib_version"],
143 reason="Use 'driver_info' parameter instead. "
144 "lib_name and lib_version will be removed in a future version.",
145 )
146 def __init__(
147 self,
148 *,
149 db: Union[str, int] = 0,
150 password: Optional[str] = None,
151 socket_timeout: Optional[float] = None,
152 socket_connect_timeout: Optional[float] = None,
153 retry_on_timeout: bool = False,
154 retry_on_error: Union[list, _Sentinel] = SENTINEL,
155 encoding: str = "utf-8",
156 encoding_errors: str = "strict",
157 decode_responses: bool = False,
158 parser_class: Type[BaseParser] = DefaultParser,
159 socket_read_size: int = 65536,
160 health_check_interval: float = 0,
161 client_name: Optional[str] = None,
162 lib_name: Optional[str] = None,
163 lib_version: Optional[str] = None,
164 driver_info: Optional[DriverInfo] = None,
165 username: Optional[str] = None,
166 retry: Optional[Retry] = None,
167 redis_connect_func: Optional[ConnectCallbackT] = None,
168 encoder_class: Type[Encoder] = Encoder,
169 credential_provider: Optional[CredentialProvider] = None,
170 protocol: Optional[int] = 2,
171 event_dispatcher: Optional[EventDispatcher] = None,
172 ):
173 """
174 Initialize a new async Connection.
175
176 Parameters
177 ----------
178 driver_info : DriverInfo, optional
179 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
180 are ignored. If not provided, a DriverInfo will be created from lib_name
181 and lib_version (or defaults if those are also None).
182 lib_name : str, optional
183 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
184 lib_version : str, optional
185 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
186 """
187 if (username or password) and credential_provider is not None:
188 raise DataError(
189 "'username' and 'password' cannot be passed along with 'credential_"
190 "provider'. Please provide only one of the following arguments: \n"
191 "1. 'password' and (optional) 'username'\n"
192 "2. 'credential_provider'"
193 )
194 if event_dispatcher is None:
195 self._event_dispatcher = EventDispatcher()
196 else:
197 self._event_dispatcher = event_dispatcher
198 self.db = db
199 self.client_name = client_name
200
201 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
202 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
203
204 self.credential_provider = credential_provider
205 self.password = password
206 self.username = username
207 self.socket_timeout = socket_timeout
208 if socket_connect_timeout is None:
209 socket_connect_timeout = socket_timeout
210 self.socket_connect_timeout = socket_connect_timeout
211 self.retry_on_timeout = retry_on_timeout
212 if retry_on_error is SENTINEL:
213 retry_on_error = []
214 if retry_on_timeout:
215 retry_on_error.append(TimeoutError)
216 retry_on_error.append(socket.timeout)
217 retry_on_error.append(asyncio.TimeoutError)
218 self.retry_on_error = retry_on_error
219 if retry or retry_on_error:
220 if not retry:
221 self.retry = Retry(NoBackoff(), 1)
222 else:
223 # deep-copy the Retry object as it is mutable
224 self.retry = copy.deepcopy(retry)
225 # Update the retry's supported errors with the specified errors
226 self.retry.update_supported_errors(retry_on_error)
227 else:
228 self.retry = Retry(NoBackoff(), 0)
229 self.health_check_interval = health_check_interval
230 self.next_health_check: float = -1
231 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
232 self.redis_connect_func = redis_connect_func
233 self._reader: Optional[asyncio.StreamReader] = None
234 self._writer: Optional[asyncio.StreamWriter] = None
235 self._socket_read_size = socket_read_size
236 self.set_parser(parser_class)
237 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
238 self._buffer_cutoff = 6000
239 self._re_auth_token: Optional[TokenInterface] = None
240 self._should_reconnect = False
241
242 try:
243 p = int(protocol)
244 except TypeError:
245 p = DEFAULT_RESP_VERSION
246 except ValueError:
247 raise ConnectionError("protocol must be an integer")
248 finally:
249 if p < 2 or p > 3:
250 raise ConnectionError("protocol must be either 2 or 3")
251 self.protocol = protocol
252
253 def __del__(self, _warnings: Any = warnings):
254 # For some reason, the individual streams don't get properly garbage
255 # collected and therefore produce no resource warnings. We add one
256 # here, in the same style as those from the stdlib.
257 if getattr(self, "_writer", None):
258 _warnings.warn(
259 f"unclosed Connection {self!r}", ResourceWarning, source=self
260 )
261
262 try:
263 asyncio.get_running_loop()
264 self._close()
265 except RuntimeError:
266 # No actions been taken if pool already closed.
267 pass
268
269 def _close(self):
270 """
271 Internal method to silently close the connection without waiting
272 """
273 if self._writer:
274 self._writer.close()
275 self._writer = self._reader = None
276
277 def __repr__(self):
278 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
279 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
280
281 @abstractmethod
282 def repr_pieces(self):
283 pass
284
285 @property
286 def is_connected(self):
287 return self._reader is not None and self._writer is not None
288
289 def register_connect_callback(self, callback):
290 """
291 Register a callback to be called when the connection is established either
292 initially or reconnected. This allows listeners to issue commands that
293 are ephemeral to the connection, for example pub/sub subscription or
294 key tracking. The callback must be a _method_ and will be kept as
295 a weak reference.
296 """
297 wm = weakref.WeakMethod(callback)
298 if wm not in self._connect_callbacks:
299 self._connect_callbacks.append(wm)
300
301 def deregister_connect_callback(self, callback):
302 """
303 De-register a previously registered callback. It will no-longer receive
304 notifications on connection events. Calling this is not required when the
305 listener goes away, since the callbacks are kept as weak methods.
306 """
307 try:
308 self._connect_callbacks.remove(weakref.WeakMethod(callback))
309 except ValueError:
310 pass
311
312 def set_parser(self, parser_class: Type[BaseParser]) -> None:
313 """
314 Creates a new instance of parser_class with socket size:
315 _socket_read_size and assigns it to the parser for the connection
316 :param parser_class: The required parser class
317 """
318 self._parser = parser_class(socket_read_size=self._socket_read_size)
319
320 async def connect(self):
321 """Connects to the Redis server if not already connected"""
322 # try once the socket connect with the handshake, retry the whole
323 # connect/handshake flow based on retry policy
324 await self.retry.call_with_retry(
325 lambda: self.connect_check_health(
326 check_health=True, retry_socket_connect=False
327 ),
328 lambda error: self.disconnect(),
329 )
330
331 async def connect_check_health(
332 self, check_health: bool = True, retry_socket_connect: bool = True
333 ):
334 if self.is_connected:
335 return
336 try:
337 if retry_socket_connect:
338 await self.retry.call_with_retry(
339 lambda: self._connect(), lambda error: self.disconnect()
340 )
341 else:
342 await self._connect()
343 except asyncio.CancelledError:
344 raise # in 3.7 and earlier, this is an Exception, not BaseException
345 except (socket.timeout, asyncio.TimeoutError):
346 raise TimeoutError("Timeout connecting to server")
347 except OSError as e:
348 raise ConnectionError(self._error_message(e))
349 except Exception as exc:
350 raise ConnectionError(exc) from exc
351
352 try:
353 if not self.redis_connect_func:
354 # Use the default on_connect function
355 await self.on_connect_check_health(check_health=check_health)
356 else:
357 # Use the passed function redis_connect_func
358 (
359 await self.redis_connect_func(self)
360 if asyncio.iscoroutinefunction(self.redis_connect_func)
361 else self.redis_connect_func(self)
362 )
363 except RedisError:
364 # clean up after any error in on_connect
365 await self.disconnect()
366 raise
367
368 # run any user callbacks. right now the only internal callback
369 # is for pubsub channel/pattern resubscription
370 # first, remove any dead weakrefs
371 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
372 for ref in self._connect_callbacks:
373 callback = ref()
374 task = callback(self)
375 if task and inspect.isawaitable(task):
376 await task
377
378 def mark_for_reconnect(self):
379 self._should_reconnect = True
380
381 def should_reconnect(self):
382 return self._should_reconnect
383
384 def reset_should_reconnect(self):
385 self._should_reconnect = False
386
387 @abstractmethod
388 async def _connect(self):
389 pass
390
391 @abstractmethod
392 def _host_error(self) -> str:
393 pass
394
395 def _error_message(self, exception: BaseException) -> str:
396 return format_error_message(self._host_error(), exception)
397
398 def get_protocol(self):
399 return self.protocol
400
401 async def on_connect(self) -> None:
402 """Initialize the connection, authenticate and select a database"""
403 await self.on_connect_check_health(check_health=True)
404
405 async def on_connect_check_health(self, check_health: bool = True) -> None:
406 self._parser.on_connect(self)
407 parser = self._parser
408
409 auth_args = None
410 # if credential provider or username and/or password are set, authenticate
411 if self.credential_provider or (self.username or self.password):
412 cred_provider = (
413 self.credential_provider
414 or UsernamePasswordCredentialProvider(self.username, self.password)
415 )
416 auth_args = await cred_provider.get_credentials_async()
417
418 # if resp version is specified and we have auth args,
419 # we need to send them via HELLO
420 if auth_args and self.protocol not in [2, "2"]:
421 if isinstance(self._parser, _AsyncRESP2Parser):
422 self.set_parser(_AsyncRESP3Parser)
423 # update cluster exception classes
424 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
425 self._parser.on_connect(self)
426 if len(auth_args) == 1:
427 auth_args = ["default", auth_args[0]]
428 # avoid checking health here -- PING will fail if we try
429 # to check the health prior to the AUTH
430 await self.send_command(
431 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
432 )
433 response = await self.read_response()
434 if response.get(b"proto") != int(self.protocol) and response.get(
435 "proto"
436 ) != int(self.protocol):
437 raise ConnectionError("Invalid RESP version")
438 # avoid checking health here -- PING will fail if we try
439 # to check the health prior to the AUTH
440 elif auth_args:
441 await self.send_command("AUTH", *auth_args, check_health=False)
442
443 try:
444 auth_response = await self.read_response()
445 except AuthenticationWrongNumberOfArgsError:
446 # a username and password were specified but the Redis
447 # server seems to be < 6.0.0 which expects a single password
448 # arg. retry auth with just the password.
449 # https://github.com/andymccurdy/redis-py/issues/1274
450 await self.send_command("AUTH", auth_args[-1], check_health=False)
451 auth_response = await self.read_response()
452
453 if str_if_bytes(auth_response) != "OK":
454 raise AuthenticationError("Invalid Username or Password")
455
456 # if resp version is specified, switch to it
457 elif self.protocol not in [2, "2"]:
458 if isinstance(self._parser, _AsyncRESP2Parser):
459 self.set_parser(_AsyncRESP3Parser)
460 # update cluster exception classes
461 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
462 self._parser.on_connect(self)
463 await self.send_command("HELLO", self.protocol, check_health=check_health)
464 response = await self.read_response()
465 # if response.get(b"proto") != self.protocol and response.get(
466 # "proto"
467 # ) != self.protocol:
468 # raise ConnectionError("Invalid RESP version")
469
470 # if a client_name is given, set it
471 if self.client_name:
472 await self.send_command(
473 "CLIENT",
474 "SETNAME",
475 self.client_name,
476 check_health=check_health,
477 )
478 if str_if_bytes(await self.read_response()) != "OK":
479 raise ConnectionError("Error setting client name")
480
481 # Set the library name and version from driver_info, pipeline for lower startup latency
482 lib_name_sent = False
483 lib_version_sent = False
484
485 if self.driver_info and self.driver_info.formatted_name:
486 await self.send_command(
487 "CLIENT",
488 "SETINFO",
489 "LIB-NAME",
490 self.driver_info.formatted_name,
491 check_health=check_health,
492 )
493 lib_name_sent = True
494
495 if self.driver_info and self.driver_info.lib_version:
496 await self.send_command(
497 "CLIENT",
498 "SETINFO",
499 "LIB-VER",
500 self.driver_info.lib_version,
501 check_health=check_health,
502 )
503 lib_version_sent = True
504
505 # if a database is specified, switch to it. Also pipeline this
506 if self.db:
507 await self.send_command("SELECT", self.db, check_health=check_health)
508
509 # read responses from pipeline
510 for _ in range(sum([lib_name_sent, lib_version_sent])):
511 try:
512 await self.read_response()
513 except ResponseError:
514 pass
515
516 if self.db:
517 if str_if_bytes(await self.read_response()) != "OK":
518 raise ConnectionError("Invalid Database")
519
520 async def disconnect(self, nowait: bool = False) -> None:
521 """Disconnects from the Redis server"""
522 try:
523 async with async_timeout(self.socket_connect_timeout):
524 self._parser.on_disconnect()
525 # Reset the reconnect flag
526 self.reset_should_reconnect()
527 if not self.is_connected:
528 return
529 try:
530 self._writer.close() # type: ignore[union-attr]
531 # wait for close to finish, except when handling errors and
532 # forcefully disconnecting.
533 if not nowait:
534 await self._writer.wait_closed() # type: ignore[union-attr]
535 except OSError:
536 pass
537 finally:
538 self._reader = None
539 self._writer = None
540 except asyncio.TimeoutError:
541 raise TimeoutError(
542 f"Timed out closing connection after {self.socket_connect_timeout}"
543 ) from None
544
545 async def _send_ping(self):
546 """Send PING, expect PONG in return"""
547 await self.send_command("PING", check_health=False)
548 if str_if_bytes(await self.read_response()) != "PONG":
549 raise ConnectionError("Bad response from PING health check")
550
551 async def _ping_failed(self, error):
552 """Function to call when PING fails"""
553 await self.disconnect()
554
555 async def check_health(self):
556 """Check the health of the connection with a PING/PONG"""
557 if (
558 self.health_check_interval
559 and asyncio.get_running_loop().time() > self.next_health_check
560 ):
561 await self.retry.call_with_retry(self._send_ping, self._ping_failed)
562
563 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
564 self._writer.writelines(command)
565 await self._writer.drain()
566
567 async def send_packed_command(
568 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
569 ) -> None:
570 if not self.is_connected:
571 await self.connect_check_health(check_health=False)
572 if check_health:
573 await self.check_health()
574
575 try:
576 if isinstance(command, str):
577 command = command.encode()
578 if isinstance(command, bytes):
579 command = [command]
580 if self.socket_timeout:
581 await asyncio.wait_for(
582 self._send_packed_command(command), self.socket_timeout
583 )
584 else:
585 self._writer.writelines(command)
586 await self._writer.drain()
587 except asyncio.TimeoutError:
588 await self.disconnect(nowait=True)
589 raise TimeoutError("Timeout writing to socket") from None
590 except OSError as e:
591 await self.disconnect(nowait=True)
592 if len(e.args) == 1:
593 err_no, errmsg = "UNKNOWN", e.args[0]
594 else:
595 err_no = e.args[0]
596 errmsg = e.args[1]
597 raise ConnectionError(
598 f"Error {err_no} while writing to socket. {errmsg}."
599 ) from e
600 except BaseException:
601 # BaseExceptions can be raised when a socket send operation is not
602 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
603 # to send un-sent data. However, the send_packed_command() API
604 # does not support it so there is no point in keeping the connection open.
605 await self.disconnect(nowait=True)
606 raise
607
608 async def send_command(self, *args: Any, **kwargs: Any) -> None:
609 """Pack and send a command to the Redis server"""
610 await self.send_packed_command(
611 self.pack_command(*args), check_health=kwargs.get("check_health", True)
612 )
613
614 async def can_read_destructive(self):
615 """Poll the socket to see if there's data that can be read."""
616 try:
617 return await self._parser.can_read_destructive()
618 except OSError as e:
619 await self.disconnect(nowait=True)
620 host_error = self._host_error()
621 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
622
623 async def read_response(
624 self,
625 disable_decoding: bool = False,
626 timeout: Optional[float] = None,
627 *,
628 disconnect_on_error: bool = True,
629 push_request: Optional[bool] = False,
630 ):
631 """Read the response from a previously sent command"""
632 read_timeout = timeout if timeout is not None else self.socket_timeout
633 host_error = self._host_error()
634 try:
635 if read_timeout is not None and self.protocol in ["3", 3]:
636 async with async_timeout(read_timeout):
637 response = await self._parser.read_response(
638 disable_decoding=disable_decoding, push_request=push_request
639 )
640 elif read_timeout is not None:
641 async with async_timeout(read_timeout):
642 response = await self._parser.read_response(
643 disable_decoding=disable_decoding
644 )
645 elif self.protocol in ["3", 3]:
646 response = await self._parser.read_response(
647 disable_decoding=disable_decoding, push_request=push_request
648 )
649 else:
650 response = await self._parser.read_response(
651 disable_decoding=disable_decoding
652 )
653 except asyncio.TimeoutError:
654 if timeout is not None:
655 # user requested timeout, return None. Operation can be retried
656 return None
657 # it was a self.socket_timeout error.
658 if disconnect_on_error:
659 await self.disconnect(nowait=True)
660 raise TimeoutError(f"Timeout reading from {host_error}")
661 except OSError as e:
662 if disconnect_on_error:
663 await self.disconnect(nowait=True)
664 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
665 except BaseException:
666 # Also by default close in case of BaseException. A lot of code
667 # relies on this behaviour when doing Command/Response pairs.
668 # See #1128.
669 if disconnect_on_error:
670 await self.disconnect(nowait=True)
671 raise
672
673 if self.health_check_interval:
674 next_time = asyncio.get_running_loop().time() + self.health_check_interval
675 self.next_health_check = next_time
676
677 if isinstance(response, ResponseError):
678 raise response from None
679 return response
680
681 def pack_command(self, *args: EncodableT) -> List[bytes]:
682 """Pack a series of arguments into the Redis protocol"""
683 output = []
684 # the client might have included 1 or more literal arguments in
685 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
686 # arguments to be sent separately, so split the first argument
687 # manually. These arguments should be bytestrings so that they are
688 # not encoded.
689 assert not isinstance(args[0], float)
690 if isinstance(args[0], str):
691 args = tuple(args[0].encode().split()) + args[1:]
692 elif b" " in args[0]:
693 args = tuple(args[0].split()) + args[1:]
694
695 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
696
697 buffer_cutoff = self._buffer_cutoff
698 for arg in map(self.encoder.encode, args):
699 # to avoid large string mallocs, chunk the command into the
700 # output list if we're sending large values or memoryviews
701 arg_length = len(arg)
702 if (
703 len(buff) > buffer_cutoff
704 or arg_length > buffer_cutoff
705 or isinstance(arg, memoryview)
706 ):
707 buff = SYM_EMPTY.join(
708 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
709 )
710 output.append(buff)
711 output.append(arg)
712 buff = SYM_CRLF
713 else:
714 buff = SYM_EMPTY.join(
715 (
716 buff,
717 SYM_DOLLAR,
718 str(arg_length).encode(),
719 SYM_CRLF,
720 arg,
721 SYM_CRLF,
722 )
723 )
724 output.append(buff)
725 return output
726
727 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
728 """Pack multiple commands into the Redis protocol"""
729 output: List[bytes] = []
730 pieces: List[bytes] = []
731 buffer_length = 0
732 buffer_cutoff = self._buffer_cutoff
733
734 for cmd in commands:
735 for chunk in self.pack_command(*cmd):
736 chunklen = len(chunk)
737 if (
738 buffer_length > buffer_cutoff
739 or chunklen > buffer_cutoff
740 or isinstance(chunk, memoryview)
741 ):
742 if pieces:
743 output.append(SYM_EMPTY.join(pieces))
744 buffer_length = 0
745 pieces = []
746
747 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
748 output.append(chunk)
749 else:
750 pieces.append(chunk)
751 buffer_length += chunklen
752
753 if pieces:
754 output.append(SYM_EMPTY.join(pieces))
755 return output
756
757 def _socket_is_empty(self):
758 """Check if the socket is empty"""
759 return len(self._reader._buffer) == 0
760
761 async def process_invalidation_messages(self):
762 while not self._socket_is_empty():
763 await self.read_response(push_request=True)
764
765 def set_re_auth_token(self, token: TokenInterface):
766 self._re_auth_token = token
767
768 async def re_auth(self):
769 if self._re_auth_token is not None:
770 await self.send_command(
771 "AUTH",
772 self._re_auth_token.try_get("oid"),
773 self._re_auth_token.get_value(),
774 )
775 await self.read_response()
776 self._re_auth_token = None
777
778
779class Connection(AbstractConnection):
780 "Manages TCP communication to and from a Redis server"
781
782 def __init__(
783 self,
784 *,
785 host: str = "localhost",
786 port: Union[str, int] = 6379,
787 socket_keepalive: bool = False,
788 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
789 socket_type: int = 0,
790 **kwargs,
791 ):
792 self.host = host
793 self.port = int(port)
794 self.socket_keepalive = socket_keepalive
795 self.socket_keepalive_options = socket_keepalive_options or {}
796 self.socket_type = socket_type
797 super().__init__(**kwargs)
798
799 def repr_pieces(self):
800 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
801 if self.client_name:
802 pieces.append(("client_name", self.client_name))
803 return pieces
804
805 def _connection_arguments(self) -> Mapping:
806 return {"host": self.host, "port": self.port}
807
808 async def _connect(self):
809 """Create a TCP socket connection"""
810 async with async_timeout(self.socket_connect_timeout):
811 reader, writer = await asyncio.open_connection(
812 **self._connection_arguments()
813 )
814 self._reader = reader
815 self._writer = writer
816 sock = writer.transport.get_extra_info("socket")
817 if sock:
818 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
819 try:
820 # TCP_KEEPALIVE
821 if self.socket_keepalive:
822 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
823 for k, v in self.socket_keepalive_options.items():
824 sock.setsockopt(socket.SOL_TCP, k, v)
825
826 except (OSError, TypeError):
827 # `socket_keepalive_options` might contain invalid options
828 # causing an error. Do not leave the connection open.
829 writer.close()
830 raise
831
832 def _host_error(self) -> str:
833 return f"{self.host}:{self.port}"
834
835
836class SSLConnection(Connection):
837 """Manages SSL connections to and from the Redis server(s).
838 This class extends the Connection class, adding SSL functionality, and making
839 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
840 """
841
842 def __init__(
843 self,
844 ssl_keyfile: Optional[str] = None,
845 ssl_certfile: Optional[str] = None,
846 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
847 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
848 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
849 ssl_ca_certs: Optional[str] = None,
850 ssl_ca_data: Optional[str] = None,
851 ssl_ca_path: Optional[str] = None,
852 ssl_check_hostname: bool = True,
853 ssl_min_version: Optional[TLSVersion] = None,
854 ssl_ciphers: Optional[str] = None,
855 ssl_password: Optional[str] = None,
856 **kwargs,
857 ):
858 if not SSL_AVAILABLE:
859 raise RedisError("Python wasn't built with SSL support")
860
861 self.ssl_context: RedisSSLContext = RedisSSLContext(
862 keyfile=ssl_keyfile,
863 certfile=ssl_certfile,
864 cert_reqs=ssl_cert_reqs,
865 include_verify_flags=ssl_include_verify_flags,
866 exclude_verify_flags=ssl_exclude_verify_flags,
867 ca_certs=ssl_ca_certs,
868 ca_data=ssl_ca_data,
869 ca_path=ssl_ca_path,
870 check_hostname=ssl_check_hostname,
871 min_version=ssl_min_version,
872 ciphers=ssl_ciphers,
873 password=ssl_password,
874 )
875 super().__init__(**kwargs)
876
877 def _connection_arguments(self) -> Mapping:
878 kwargs = super()._connection_arguments()
879 kwargs["ssl"] = self.ssl_context.get()
880 return kwargs
881
882 @property
883 def keyfile(self):
884 return self.ssl_context.keyfile
885
886 @property
887 def certfile(self):
888 return self.ssl_context.certfile
889
890 @property
891 def cert_reqs(self):
892 return self.ssl_context.cert_reqs
893
894 @property
895 def include_verify_flags(self):
896 return self.ssl_context.include_verify_flags
897
898 @property
899 def exclude_verify_flags(self):
900 return self.ssl_context.exclude_verify_flags
901
902 @property
903 def ca_certs(self):
904 return self.ssl_context.ca_certs
905
906 @property
907 def ca_data(self):
908 return self.ssl_context.ca_data
909
910 @property
911 def check_hostname(self):
912 return self.ssl_context.check_hostname
913
914 @property
915 def min_version(self):
916 return self.ssl_context.min_version
917
918
919class RedisSSLContext:
920 __slots__ = (
921 "keyfile",
922 "certfile",
923 "cert_reqs",
924 "include_verify_flags",
925 "exclude_verify_flags",
926 "ca_certs",
927 "ca_data",
928 "ca_path",
929 "context",
930 "check_hostname",
931 "min_version",
932 "ciphers",
933 "password",
934 )
935
936 def __init__(
937 self,
938 keyfile: Optional[str] = None,
939 certfile: Optional[str] = None,
940 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
941 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
942 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
943 ca_certs: Optional[str] = None,
944 ca_data: Optional[str] = None,
945 ca_path: Optional[str] = None,
946 check_hostname: bool = False,
947 min_version: Optional[TLSVersion] = None,
948 ciphers: Optional[str] = None,
949 password: Optional[str] = None,
950 ):
951 if not SSL_AVAILABLE:
952 raise RedisError("Python wasn't built with SSL support")
953
954 self.keyfile = keyfile
955 self.certfile = certfile
956 if cert_reqs is None:
957 cert_reqs = ssl.CERT_NONE
958 elif isinstance(cert_reqs, str):
959 CERT_REQS = { # noqa: N806
960 "none": ssl.CERT_NONE,
961 "optional": ssl.CERT_OPTIONAL,
962 "required": ssl.CERT_REQUIRED,
963 }
964 if cert_reqs not in CERT_REQS:
965 raise RedisError(
966 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
967 )
968 cert_reqs = CERT_REQS[cert_reqs]
969 self.cert_reqs = cert_reqs
970 self.include_verify_flags = include_verify_flags
971 self.exclude_verify_flags = exclude_verify_flags
972 self.ca_certs = ca_certs
973 self.ca_data = ca_data
974 self.ca_path = ca_path
975 self.check_hostname = (
976 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
977 )
978 self.min_version = min_version
979 self.ciphers = ciphers
980 self.password = password
981 self.context: Optional[SSLContext] = None
982
983 def get(self) -> SSLContext:
984 if not self.context:
985 context = ssl.create_default_context()
986 context.check_hostname = self.check_hostname
987 context.verify_mode = self.cert_reqs
988 if self.include_verify_flags:
989 for flag in self.include_verify_flags:
990 context.verify_flags |= flag
991 if self.exclude_verify_flags:
992 for flag in self.exclude_verify_flags:
993 context.verify_flags &= ~flag
994 if self.certfile or self.keyfile:
995 context.load_cert_chain(
996 certfile=self.certfile,
997 keyfile=self.keyfile,
998 password=self.password,
999 )
1000 if self.ca_certs or self.ca_data or self.ca_path:
1001 context.load_verify_locations(
1002 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1003 )
1004 if self.min_version is not None:
1005 context.minimum_version = self.min_version
1006 if self.ciphers is not None:
1007 context.set_ciphers(self.ciphers)
1008 self.context = context
1009 return self.context
1010
1011
1012class UnixDomainSocketConnection(AbstractConnection):
1013 "Manages UDS communication to and from a Redis server"
1014
1015 def __init__(self, *, path: str = "", **kwargs):
1016 self.path = path
1017 super().__init__(**kwargs)
1018
1019 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
1020 pieces = [("path", self.path), ("db", self.db)]
1021 if self.client_name:
1022 pieces.append(("client_name", self.client_name))
1023 return pieces
1024
1025 async def _connect(self):
1026 async with async_timeout(self.socket_connect_timeout):
1027 reader, writer = await asyncio.open_unix_connection(path=self.path)
1028 self._reader = reader
1029 self._writer = writer
1030 await self.on_connect()
1031
1032 def _host_error(self) -> str:
1033 return self.path
1034
1035
1036FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1037
1038
1039def to_bool(value) -> Optional[bool]:
1040 if value is None or value == "":
1041 return None
1042 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1043 return False
1044 return bool(value)
1045
1046
1047def parse_ssl_verify_flags(value):
1048 # flags are passed in as a string representation of a list,
1049 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1050 verify_flags_str = value.replace("[", "").replace("]", "")
1051
1052 verify_flags = []
1053 for flag in verify_flags_str.split(","):
1054 flag = flag.strip()
1055 if not hasattr(VerifyFlags, flag):
1056 raise ValueError(f"Invalid ssl verify flag: {flag}")
1057 verify_flags.append(getattr(VerifyFlags, flag))
1058 return verify_flags
1059
1060
1061URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1062 {
1063 "db": int,
1064 "socket_timeout": float,
1065 "socket_connect_timeout": float,
1066 "socket_keepalive": to_bool,
1067 "retry_on_timeout": to_bool,
1068 "max_connections": int,
1069 "health_check_interval": int,
1070 "ssl_check_hostname": to_bool,
1071 "ssl_include_verify_flags": parse_ssl_verify_flags,
1072 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1073 "timeout": float,
1074 }
1075)
1076
1077
1078class ConnectKwargs(TypedDict, total=False):
1079 username: str
1080 password: str
1081 connection_class: Type[AbstractConnection]
1082 host: str
1083 port: int
1084 db: int
1085 path: str
1086
1087
1088def parse_url(url: str) -> ConnectKwargs:
1089 parsed: ParseResult = urlparse(url)
1090 kwargs: ConnectKwargs = {}
1091
1092 for name, value_list in parse_qs(parsed.query).items():
1093 if value_list and len(value_list) > 0:
1094 value = unquote(value_list[0])
1095 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1096 if parser:
1097 try:
1098 kwargs[name] = parser(value)
1099 except (TypeError, ValueError):
1100 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1101 else:
1102 kwargs[name] = value
1103
1104 if parsed.username:
1105 kwargs["username"] = unquote(parsed.username)
1106 if parsed.password:
1107 kwargs["password"] = unquote(parsed.password)
1108
1109 # We only support redis://, rediss:// and unix:// schemes.
1110 if parsed.scheme == "unix":
1111 if parsed.path:
1112 kwargs["path"] = unquote(parsed.path)
1113 kwargs["connection_class"] = UnixDomainSocketConnection
1114
1115 elif parsed.scheme in ("redis", "rediss"):
1116 if parsed.hostname:
1117 kwargs["host"] = unquote(parsed.hostname)
1118 if parsed.port:
1119 kwargs["port"] = int(parsed.port)
1120
1121 # If there's a path argument, use it as the db argument if a
1122 # querystring value wasn't specified
1123 if parsed.path and "db" not in kwargs:
1124 try:
1125 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1126 except (AttributeError, ValueError):
1127 pass
1128
1129 if parsed.scheme == "rediss":
1130 kwargs["connection_class"] = SSLConnection
1131
1132 else:
1133 valid_schemes = "redis://, rediss://, unix://"
1134 raise ValueError(
1135 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1136 )
1137
1138 return kwargs
1139
1140
1141_CP = TypeVar("_CP", bound="ConnectionPool")
1142
1143
1144class ConnectionPool:
1145 """
1146 Create a connection pool. ``If max_connections`` is set, then this
1147 object raises :py:class:`~redis.ConnectionError` when the pool's
1148 limit is reached.
1149
1150 By default, TCP connections are created unless ``connection_class``
1151 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1152 unix sockets.
1153 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1154
1155 Any additional keyword arguments are passed to the constructor of
1156 ``connection_class``.
1157 """
1158
1159 @classmethod
1160 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1161 """
1162 Return a connection pool configured from the given URL.
1163
1164 For example::
1165
1166 redis://[[username]:[password]]@localhost:6379/0
1167 rediss://[[username]:[password]]@localhost:6379/0
1168 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1169
1170 Three URL schemes are supported:
1171
1172 - `redis://` creates a TCP socket connection. See more at:
1173 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1174 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1175 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1176 - ``unix://``: creates a Unix Domain Socket connection.
1177
1178 The username, password, hostname, path and all querystring values
1179 are passed through urllib.parse.unquote in order to replace any
1180 percent-encoded values with their corresponding characters.
1181
1182 There are several ways to specify a database number. The first value
1183 found will be used:
1184
1185 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1186
1187 2. If using the redis:// or rediss:// schemes, the path argument
1188 of the url, e.g. redis://localhost/0
1189
1190 3. A ``db`` keyword argument to this function.
1191
1192 If none of these options are specified, the default db=0 is used.
1193
1194 All querystring options are cast to their appropriate Python types.
1195 Boolean arguments can be specified with string values "True"/"False"
1196 or "Yes"/"No". Values that cannot be properly cast cause a
1197 ``ValueError`` to be raised. Once parsed, the querystring arguments
1198 and keyword arguments are passed to the ``ConnectionPool``'s
1199 class initializer. In the case of conflicting arguments, querystring
1200 arguments always win.
1201 """
1202 url_options = parse_url(url)
1203 kwargs.update(url_options)
1204 return cls(**kwargs)
1205
1206 def __init__(
1207 self,
1208 connection_class: Type[AbstractConnection] = Connection,
1209 max_connections: Optional[int] = None,
1210 **connection_kwargs,
1211 ):
1212 max_connections = max_connections or 2**31
1213 if not isinstance(max_connections, int) or max_connections < 0:
1214 raise ValueError('"max_connections" must be a positive integer')
1215
1216 self.connection_class = connection_class
1217 self.connection_kwargs = connection_kwargs
1218 self.max_connections = max_connections
1219
1220 self._available_connections: List[AbstractConnection] = []
1221 self._in_use_connections: Set[AbstractConnection] = set()
1222 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1223 self._lock = asyncio.Lock()
1224 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1225 if self._event_dispatcher is None:
1226 self._event_dispatcher = EventDispatcher()
1227
1228 def __repr__(self):
1229 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1230 return (
1231 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1232 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1233 f"({conn_kwargs})>)>"
1234 )
1235
1236 def reset(self):
1237 self._available_connections = []
1238 self._in_use_connections = weakref.WeakSet()
1239
1240 def can_get_connection(self) -> bool:
1241 """Return True if a connection can be retrieved from the pool."""
1242 return (
1243 self._available_connections
1244 or len(self._in_use_connections) < self.max_connections
1245 )
1246
1247 @deprecated_args(
1248 args_to_warn=["*"],
1249 reason="Use get_connection() without args instead",
1250 version="5.3.0",
1251 )
1252 async def get_connection(self, command_name=None, *keys, **options):
1253 """Get a connected connection from the pool"""
1254 async with self._lock:
1255 connection = self.get_available_connection()
1256
1257 # We now perform the connection check outside of the lock.
1258 try:
1259 await self.ensure_connection(connection)
1260 return connection
1261 except BaseException:
1262 await self.release(connection)
1263 raise
1264
1265 def get_available_connection(self):
1266 """Get a connection from the pool, without making sure it is connected"""
1267 try:
1268 connection = self._available_connections.pop()
1269 except IndexError:
1270 if len(self._in_use_connections) >= self.max_connections:
1271 raise MaxConnectionsError("Too many connections") from None
1272 connection = self.make_connection()
1273 self._in_use_connections.add(connection)
1274 return connection
1275
1276 def get_encoder(self):
1277 """Return an encoder based on encoding settings"""
1278 kwargs = self.connection_kwargs
1279 return self.encoder_class(
1280 encoding=kwargs.get("encoding", "utf-8"),
1281 encoding_errors=kwargs.get("encoding_errors", "strict"),
1282 decode_responses=kwargs.get("decode_responses", False),
1283 )
1284
1285 def make_connection(self):
1286 """Create a new connection. Can be overridden by child classes."""
1287 return self.connection_class(**self.connection_kwargs)
1288
1289 async def ensure_connection(self, connection: AbstractConnection):
1290 """Ensure that the connection object is connected and valid"""
1291 await connection.connect()
1292 # connections that the pool provides should be ready to send
1293 # a command. if not, the connection was either returned to the
1294 # pool before all data has been read or the socket has been
1295 # closed. either way, reconnect and verify everything is good.
1296 try:
1297 if await connection.can_read_destructive():
1298 raise ConnectionError("Connection has data") from None
1299 except (ConnectionError, TimeoutError, OSError):
1300 await connection.disconnect()
1301 await connection.connect()
1302 if await connection.can_read_destructive():
1303 raise ConnectionError("Connection not ready") from None
1304
1305 async def release(self, connection: AbstractConnection):
1306 """Releases the connection back to the pool"""
1307 # Connections should always be returned to the correct pool,
1308 # not doing so is an error that will cause an exception here.
1309 self._in_use_connections.remove(connection)
1310 if connection.should_reconnect():
1311 await connection.disconnect()
1312
1313 self._available_connections.append(connection)
1314 await self._event_dispatcher.dispatch_async(
1315 AsyncAfterConnectionReleasedEvent(connection)
1316 )
1317
1318 async def disconnect(self, inuse_connections: bool = True):
1319 """
1320 Disconnects connections in the pool
1321
1322 If ``inuse_connections`` is True, disconnect connections that are
1323 current in use, potentially by other tasks. Otherwise only disconnect
1324 connections that are idle in the pool.
1325 """
1326 if inuse_connections:
1327 connections: Iterable[AbstractConnection] = chain(
1328 self._available_connections, self._in_use_connections
1329 )
1330 else:
1331 connections = self._available_connections
1332 resp = await asyncio.gather(
1333 *(connection.disconnect() for connection in connections),
1334 return_exceptions=True,
1335 )
1336 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1337 if exc:
1338 raise exc
1339
1340 async def update_active_connections_for_reconnect(self):
1341 """
1342 Mark all active connections for reconnect.
1343 """
1344 async with self._lock:
1345 for conn in self._in_use_connections:
1346 conn.mark_for_reconnect()
1347
1348 async def aclose(self) -> None:
1349 """Close the pool, disconnecting all connections"""
1350 await self.disconnect()
1351
1352 def set_retry(self, retry: "Retry") -> None:
1353 for conn in self._available_connections:
1354 conn.retry = retry
1355 for conn in self._in_use_connections:
1356 conn.retry = retry
1357
1358 async def re_auth_callback(self, token: TokenInterface):
1359 async with self._lock:
1360 for conn in self._available_connections:
1361 await conn.retry.call_with_retry(
1362 lambda: conn.send_command(
1363 "AUTH", token.try_get("oid"), token.get_value()
1364 ),
1365 lambda error: self._mock(error),
1366 )
1367 await conn.retry.call_with_retry(
1368 lambda: conn.read_response(), lambda error: self._mock(error)
1369 )
1370 for conn in self._in_use_connections:
1371 conn.set_re_auth_token(token)
1372
1373 async def _mock(self, error: RedisError):
1374 """
1375 Dummy functions, needs to be passed as error callback to retry object.
1376 :param error:
1377 :return:
1378 """
1379 pass
1380
1381
1382class BlockingConnectionPool(ConnectionPool):
1383 """
1384 A blocking connection pool::
1385
1386 >>> from redis.asyncio import Redis, BlockingConnectionPool
1387 >>> client = Redis.from_pool(BlockingConnectionPool())
1388
1389 It performs the same function as the default
1390 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1391 it maintains a pool of reusable connections that can be shared by
1392 multiple async redis clients.
1393
1394 The difference is that, in the event that a client tries to get a
1395 connection from the pool when all of connections are in use, rather than
1396 raising a :py:class:`~redis.ConnectionError` (as the default
1397 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1398 blocks the current `Task` for a specified number of seconds until
1399 a connection becomes available.
1400
1401 Use ``max_connections`` to increase / decrease the pool size::
1402
1403 >>> pool = BlockingConnectionPool(max_connections=10)
1404
1405 Use ``timeout`` to tell it either how many seconds to wait for a connection
1406 to become available, or to block forever:
1407
1408 >>> # Block forever.
1409 >>> pool = BlockingConnectionPool(timeout=None)
1410
1411 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1412 >>> # not available.
1413 >>> pool = BlockingConnectionPool(timeout=5)
1414 """
1415
1416 def __init__(
1417 self,
1418 max_connections: int = 50,
1419 timeout: Optional[float] = 20,
1420 connection_class: Type[AbstractConnection] = Connection,
1421 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1422 **connection_kwargs,
1423 ):
1424 super().__init__(
1425 connection_class=connection_class,
1426 max_connections=max_connections,
1427 **connection_kwargs,
1428 )
1429 self._condition = asyncio.Condition()
1430 self.timeout = timeout
1431
1432 @deprecated_args(
1433 args_to_warn=["*"],
1434 reason="Use get_connection() without args instead",
1435 version="5.3.0",
1436 )
1437 async def get_connection(self, command_name=None, *keys, **options):
1438 """Gets a connection from the pool, blocking until one is available"""
1439 try:
1440 async with self._condition:
1441 async with async_timeout(self.timeout):
1442 await self._condition.wait_for(self.can_get_connection)
1443 connection = super().get_available_connection()
1444 except asyncio.TimeoutError as err:
1445 raise ConnectionError("No connection available.") from err
1446
1447 # We now perform the connection check outside of the lock.
1448 try:
1449 await self.ensure_connection(connection)
1450 return connection
1451 except BaseException:
1452 await self.release(connection)
1453 raise
1454
1455 async def release(self, connection: AbstractConnection):
1456 """Releases the connection back to the pool."""
1457 async with self._condition:
1458 await super().release(connection)
1459 self._condition.notify()