1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import time
8import warnings
9import weakref
10from abc import abstractmethod
11from itertools import chain
12from types import MappingProxyType
13from typing import (
14 Any,
15 Callable,
16 Iterable,
17 List,
18 Mapping,
19 Optional,
20 Protocol,
21 Set,
22 Tuple,
23 Type,
24 TypedDict,
25 TypeVar,
26 Union,
27)
28from urllib.parse import ParseResult, parse_qs, unquote, urlparse
29
30from ..observability.attributes import (
31 DB_CLIENT_CONNECTION_POOL_NAME,
32 DB_CLIENT_CONNECTION_STATE,
33 AttributeBuilder,
34 ConnectionState,
35 get_pool_name,
36)
37from ..utils import SSL_AVAILABLE
38
39if SSL_AVAILABLE:
40 import ssl
41 from ssl import SSLContext, TLSVersion, VerifyFlags
42else:
43 ssl = None
44 TLSVersion = None
45 SSLContext = None
46 VerifyFlags = None
47
48from ..auth.token import TokenInterface
49from ..driver_info import DriverInfo, resolve_driver_info
50from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
51from ..utils import deprecated_args, format_error_message
52
53# the functionality is available in 3.11.x but has a major issue before
54# 3.11.3. See https://github.com/redis/redis-py/issues/2633
55if sys.version_info >= (3, 11, 3):
56 from asyncio import timeout as async_timeout
57else:
58 from async_timeout import timeout as async_timeout
59
60from redis.asyncio.observability.recorder import (
61 record_connection_closed,
62 record_connection_count,
63 record_connection_create_time,
64 record_connection_wait_time,
65 record_error_count,
66)
67from redis.asyncio.retry import Retry
68from redis.backoff import NoBackoff
69from redis.connection import DEFAULT_RESP_VERSION
70from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
71from redis.exceptions import (
72 AuthenticationError,
73 AuthenticationWrongNumberOfArgsError,
74 ConnectionError,
75 DataError,
76 MaxConnectionsError,
77 RedisError,
78 ResponseError,
79 TimeoutError,
80)
81from redis.observability.metrics import CloseReason
82from redis.typing import EncodableT
83from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
84
85from .._parsers import (
86 BaseParser,
87 Encoder,
88 _AsyncHiredisParser,
89 _AsyncRESP2Parser,
90 _AsyncRESP3Parser,
91)
92
93SYM_STAR = b"*"
94SYM_DOLLAR = b"$"
95SYM_CRLF = b"\r\n"
96SYM_LF = b"\n"
97SYM_EMPTY = b""
98
99
100class _Sentinel(enum.Enum):
101 sentinel = object()
102
103
104SENTINEL = _Sentinel.sentinel
105
106
107DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
108if HIREDIS_AVAILABLE:
109 DefaultParser = _AsyncHiredisParser
110else:
111 DefaultParser = _AsyncRESP2Parser
112
113
114class ConnectCallbackProtocol(Protocol):
115 def __call__(self, connection: "AbstractConnection"): ...
116
117
118class AsyncConnectCallbackProtocol(Protocol):
119 async def __call__(self, connection: "AbstractConnection"): ...
120
121
122ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
123
124
125class AbstractConnection:
126 """Manages communication to and from a Redis server"""
127
128 __slots__ = (
129 "db",
130 "username",
131 "client_name",
132 "lib_name",
133 "lib_version",
134 "credential_provider",
135 "password",
136 "socket_timeout",
137 "socket_connect_timeout",
138 "redis_connect_func",
139 "retry_on_timeout",
140 "retry_on_error",
141 "health_check_interval",
142 "next_health_check",
143 "last_active_at",
144 "encoder",
145 "ssl_context",
146 "protocol",
147 "_reader",
148 "_writer",
149 "_parser",
150 "_connect_callbacks",
151 "_buffer_cutoff",
152 "_lock",
153 "_socket_read_size",
154 "__dict__",
155 )
156
157 @deprecated_args(
158 args_to_warn=["lib_name", "lib_version"],
159 reason="Use 'driver_info' parameter instead. "
160 "lib_name and lib_version will be removed in a future version.",
161 )
162 def __init__(
163 self,
164 *,
165 db: Union[str, int] = 0,
166 password: Optional[str] = None,
167 socket_timeout: Optional[float] = None,
168 socket_connect_timeout: Optional[float] = None,
169 retry_on_timeout: bool = False,
170 retry_on_error: Union[list, _Sentinel] = SENTINEL,
171 encoding: str = "utf-8",
172 encoding_errors: str = "strict",
173 decode_responses: bool = False,
174 parser_class: Type[BaseParser] = DefaultParser,
175 socket_read_size: int = 65536,
176 health_check_interval: float = 0,
177 client_name: Optional[str] = None,
178 lib_name: Optional[str] = None,
179 lib_version: Optional[str] = None,
180 driver_info: Optional[DriverInfo] = None,
181 username: Optional[str] = None,
182 retry: Optional[Retry] = None,
183 redis_connect_func: Optional[ConnectCallbackT] = None,
184 encoder_class: Type[Encoder] = Encoder,
185 credential_provider: Optional[CredentialProvider] = None,
186 protocol: Optional[int] = 2,
187 event_dispatcher: Optional[EventDispatcher] = None,
188 ):
189 """
190 Initialize a new async Connection.
191
192 Parameters
193 ----------
194 driver_info : DriverInfo, optional
195 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
196 are ignored. If not provided, a DriverInfo will be created from lib_name
197 and lib_version (or defaults if those are also None).
198 lib_name : str, optional
199 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
200 lib_version : str, optional
201 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
202 """
203 if (username or password) and credential_provider is not None:
204 raise DataError(
205 "'username' and 'password' cannot be passed along with 'credential_"
206 "provider'. Please provide only one of the following arguments: \n"
207 "1. 'password' and (optional) 'username'\n"
208 "2. 'credential_provider'"
209 )
210 if event_dispatcher is None:
211 self._event_dispatcher = EventDispatcher()
212 else:
213 self._event_dispatcher = event_dispatcher
214 self.db = db
215 self.client_name = client_name
216
217 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
218 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
219
220 self.credential_provider = credential_provider
221 self.password = password
222 self.username = username
223 self.socket_timeout = socket_timeout
224 if socket_connect_timeout is None:
225 socket_connect_timeout = socket_timeout
226 self.socket_connect_timeout = socket_connect_timeout
227 self.retry_on_timeout = retry_on_timeout
228 if retry_on_error is SENTINEL:
229 retry_on_error = []
230 if retry_on_timeout:
231 retry_on_error.append(TimeoutError)
232 retry_on_error.append(socket.timeout)
233 retry_on_error.append(asyncio.TimeoutError)
234 self.retry_on_error = retry_on_error
235 if retry or retry_on_error:
236 if not retry:
237 self.retry = Retry(NoBackoff(), 1)
238 else:
239 # deep-copy the Retry object as it is mutable
240 self.retry = copy.deepcopy(retry)
241 # Update the retry's supported errors with the specified errors
242 self.retry.update_supported_errors(retry_on_error)
243 else:
244 self.retry = Retry(NoBackoff(), 0)
245 self.health_check_interval = health_check_interval
246 self.next_health_check: float = -1
247 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
248 self.redis_connect_func = redis_connect_func
249 self._reader: Optional[asyncio.StreamReader] = None
250 self._writer: Optional[asyncio.StreamWriter] = None
251 self._socket_read_size = socket_read_size
252 self.set_parser(parser_class)
253 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
254 self._buffer_cutoff = 6000
255 self._re_auth_token: Optional[TokenInterface] = None
256 self._should_reconnect = False
257
258 try:
259 p = int(protocol)
260 except TypeError:
261 p = DEFAULT_RESP_VERSION
262 except ValueError:
263 raise ConnectionError("protocol must be an integer")
264 else:
265 if p < 2 or p > 3:
266 raise ConnectionError("protocol must be either 2 or 3")
267 self.protocol = p
268
269 def __del__(self, _warnings: Any = warnings):
270 # For some reason, the individual streams don't get properly garbage
271 # collected and therefore produce no resource warnings. We add one
272 # here, in the same style as those from the stdlib.
273 if getattr(self, "_writer", None):
274 _warnings.warn(
275 f"unclosed Connection {self!r}", ResourceWarning, source=self
276 )
277
278 try:
279 asyncio.get_running_loop()
280 self._close()
281 except RuntimeError:
282 # No actions been taken if pool already closed.
283 pass
284
285 def _close(self):
286 """
287 Internal method to silently close the connection without waiting
288 """
289 if self._writer:
290 self._writer.close()
291 self._writer = self._reader = None
292
293 def __repr__(self):
294 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
295 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
296
297 @abstractmethod
298 def repr_pieces(self):
299 pass
300
301 @property
302 def is_connected(self):
303 return self._reader is not None and self._writer is not None
304
305 def register_connect_callback(self, callback):
306 """
307 Register a callback to be called when the connection is established either
308 initially or reconnected. This allows listeners to issue commands that
309 are ephemeral to the connection, for example pub/sub subscription or
310 key tracking. The callback must be a _method_ and will be kept as
311 a weak reference.
312 """
313 wm = weakref.WeakMethod(callback)
314 if wm not in self._connect_callbacks:
315 self._connect_callbacks.append(wm)
316
317 def deregister_connect_callback(self, callback):
318 """
319 De-register a previously registered callback. It will no-longer receive
320 notifications on connection events. Calling this is not required when the
321 listener goes away, since the callbacks are kept as weak methods.
322 """
323 try:
324 self._connect_callbacks.remove(weakref.WeakMethod(callback))
325 except ValueError:
326 pass
327
328 def set_parser(self, parser_class: Type[BaseParser]) -> None:
329 """
330 Creates a new instance of parser_class with socket size:
331 _socket_read_size and assigns it to the parser for the connection
332 :param parser_class: The required parser class
333 """
334 self._parser = parser_class(socket_read_size=self._socket_read_size)
335
336 async def connect(self):
337 """Connects to the Redis server if not already connected"""
338 # try once the socket connect with the handshake, retry the whole
339 # connect/handshake flow based on retry policy
340 await self.retry.call_with_retry(
341 lambda: self.connect_check_health(
342 check_health=True, retry_socket_connect=False
343 ),
344 lambda error, failure_count: self.disconnect(
345 error=error, failure_count=failure_count
346 ),
347 with_failure_count=True,
348 )
349
350 async def connect_check_health(
351 self, check_health: bool = True, retry_socket_connect: bool = True
352 ):
353 if self.is_connected:
354 return
355 # Track actual retry attempts for error reporting
356 actual_retry_attempts = 0
357
358 def failure_callback(error, failure_count):
359 nonlocal actual_retry_attempts
360 actual_retry_attempts = failure_count
361 return self.disconnect(error=error, failure_count=failure_count)
362
363 try:
364 if retry_socket_connect:
365 await self.retry.call_with_retry(
366 lambda: self._connect(),
367 failure_callback,
368 with_failure_count=True,
369 )
370 else:
371 await self._connect()
372 except asyncio.CancelledError:
373 raise # in 3.7 and earlier, this is an Exception, not BaseException
374 except (socket.timeout, asyncio.TimeoutError):
375 e = TimeoutError("Timeout connecting to server")
376 await record_error_count(
377 server_address=getattr(self, "host", None),
378 server_port=getattr(self, "port", None),
379 network_peer_address=getattr(self, "host", None),
380 network_peer_port=getattr(self, "port", None),
381 error_type=e,
382 retry_attempts=actual_retry_attempts,
383 is_internal=False,
384 )
385 raise e
386 except OSError as e:
387 e = ConnectionError(self._error_message(e))
388 await record_error_count(
389 server_address=getattr(self, "host", None),
390 server_port=getattr(self, "port", None),
391 network_peer_address=getattr(self, "host", None),
392 network_peer_port=getattr(self, "port", None),
393 error_type=e,
394 retry_attempts=actual_retry_attempts,
395 is_internal=False,
396 )
397 raise e
398 except Exception as exc:
399 raise ConnectionError(exc) from exc
400
401 try:
402 if not self.redis_connect_func:
403 # Use the default on_connect function
404 await self.on_connect_check_health(check_health=check_health)
405 else:
406 # Use the passed function redis_connect_func
407 (
408 await self.redis_connect_func(self)
409 if asyncio.iscoroutinefunction(self.redis_connect_func)
410 else self.redis_connect_func(self)
411 )
412 except RedisError:
413 # clean up after any error in on_connect
414 await self.disconnect()
415 raise
416
417 # run any user callbacks. right now the only internal callback
418 # is for pubsub channel/pattern resubscription
419 # first, remove any dead weakrefs
420 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
421 for ref in self._connect_callbacks:
422 callback = ref()
423 task = callback(self)
424 if task and inspect.isawaitable(task):
425 await task
426
427 def mark_for_reconnect(self):
428 self._should_reconnect = True
429
430 def should_reconnect(self):
431 return self._should_reconnect
432
433 def reset_should_reconnect(self):
434 self._should_reconnect = False
435
436 @abstractmethod
437 async def _connect(self):
438 pass
439
440 @abstractmethod
441 def _host_error(self) -> str:
442 pass
443
444 def _error_message(self, exception: BaseException) -> str:
445 return format_error_message(self._host_error(), exception)
446
447 def get_protocol(self):
448 return self.protocol
449
450 async def on_connect(self) -> None:
451 """Initialize the connection, authenticate and select a database"""
452 await self.on_connect_check_health(check_health=True)
453
454 async def on_connect_check_health(self, check_health: bool = True) -> None:
455 self._parser.on_connect(self)
456 parser = self._parser
457
458 auth_args = None
459 # if credential provider or username and/or password are set, authenticate
460 if self.credential_provider or (self.username or self.password):
461 cred_provider = (
462 self.credential_provider
463 or UsernamePasswordCredentialProvider(self.username, self.password)
464 )
465 auth_args = await cred_provider.get_credentials_async()
466
467 # if resp version is specified and we have auth args,
468 # we need to send them via HELLO
469 if auth_args and self.protocol not in [2, "2"]:
470 if isinstance(self._parser, _AsyncRESP2Parser):
471 self.set_parser(_AsyncRESP3Parser)
472 # update cluster exception classes
473 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
474 self._parser.on_connect(self)
475 if len(auth_args) == 1:
476 auth_args = ["default", auth_args[0]]
477 # avoid checking health here -- PING will fail if we try
478 # to check the health prior to the AUTH
479 await self.send_command(
480 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
481 )
482 response = await self.read_response()
483 if response.get(b"proto") != int(self.protocol) and response.get(
484 "proto"
485 ) != int(self.protocol):
486 raise ConnectionError("Invalid RESP version")
487 # avoid checking health here -- PING will fail if we try
488 # to check the health prior to the AUTH
489 elif auth_args:
490 await self.send_command("AUTH", *auth_args, check_health=False)
491
492 try:
493 auth_response = await self.read_response()
494 except AuthenticationWrongNumberOfArgsError:
495 # a username and password were specified but the Redis
496 # server seems to be < 6.0.0 which expects a single password
497 # arg. retry auth with just the password.
498 # https://github.com/andymccurdy/redis-py/issues/1274
499 await self.send_command("AUTH", auth_args[-1], check_health=False)
500 auth_response = await self.read_response()
501
502 if str_if_bytes(auth_response) != "OK":
503 raise AuthenticationError("Invalid Username or Password")
504
505 # if resp version is specified, switch to it
506 elif self.protocol not in [2, "2"]:
507 if isinstance(self._parser, _AsyncRESP2Parser):
508 self.set_parser(_AsyncRESP3Parser)
509 # update cluster exception classes
510 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
511 self._parser.on_connect(self)
512 await self.send_command("HELLO", self.protocol, check_health=check_health)
513 response = await self.read_response()
514 # if response.get(b"proto") != self.protocol and response.get(
515 # "proto"
516 # ) != self.protocol:
517 # raise ConnectionError("Invalid RESP version")
518
519 # if a client_name is given, set it
520 if self.client_name:
521 await self.send_command(
522 "CLIENT",
523 "SETNAME",
524 self.client_name,
525 check_health=check_health,
526 )
527 if str_if_bytes(await self.read_response()) != "OK":
528 raise ConnectionError("Error setting client name")
529
530 # Set the library name and version from driver_info, pipeline for lower startup latency
531 lib_name_sent = False
532 lib_version_sent = False
533
534 if self.driver_info and self.driver_info.formatted_name:
535 await self.send_command(
536 "CLIENT",
537 "SETINFO",
538 "LIB-NAME",
539 self.driver_info.formatted_name,
540 check_health=check_health,
541 )
542 lib_name_sent = True
543
544 if self.driver_info and self.driver_info.lib_version:
545 await self.send_command(
546 "CLIENT",
547 "SETINFO",
548 "LIB-VER",
549 self.driver_info.lib_version,
550 check_health=check_health,
551 )
552 lib_version_sent = True
553
554 # if a database is specified, switch to it. Also pipeline this
555 if self.db:
556 await self.send_command("SELECT", self.db, check_health=check_health)
557
558 # read responses from pipeline
559 for _ in range(sum([lib_name_sent, lib_version_sent])):
560 try:
561 await self.read_response()
562 except ResponseError:
563 pass
564
565 if self.db:
566 if str_if_bytes(await self.read_response()) != "OK":
567 raise ConnectionError("Invalid Database")
568
569 async def disconnect(
570 self,
571 nowait: bool = False,
572 error: Optional[Exception] = None,
573 failure_count: Optional[int] = None,
574 health_check_failed: bool = False,
575 ) -> None:
576 """Disconnects from the Redis server"""
577 # On Python 3.13+, asyncio.timeout() raises RuntimeError when called
578 # outside a running Task (e.g. during GC finalization or event-loop
579 # callbacks). In that context we fall back to a synchronous close.
580 # See https://github.com/redis/redis-py/issues/3856
581 if asyncio.current_task() is None:
582 self._parser.on_disconnect()
583 self.reset_should_reconnect()
584 self._close()
585 return
586
587 try:
588 async with async_timeout(self.socket_connect_timeout):
589 self._parser.on_disconnect()
590 # Reset the reconnect flag
591 self.reset_should_reconnect()
592 if not self.is_connected:
593 return
594 try:
595 self._writer.close() # type: ignore[union-attr]
596 # wait for close to finish, except when handling errors and
597 # forcefully disconnecting.
598 if not nowait:
599 await self._writer.wait_closed() # type: ignore[union-attr]
600 except OSError:
601 pass
602 finally:
603 self._reader = None
604 self._writer = None
605 except asyncio.TimeoutError:
606 raise TimeoutError(
607 f"Timed out closing connection after {self.socket_connect_timeout}"
608 ) from None
609
610 if error:
611 if health_check_failed:
612 close_reason = CloseReason.HEALTHCHECK_FAILED
613 else:
614 close_reason = CloseReason.ERROR
615
616 if failure_count is not None and failure_count > self.retry.get_retries():
617 await record_error_count(
618 server_address=getattr(self, "host", None),
619 server_port=getattr(self, "port", None),
620 network_peer_address=getattr(self, "host", None),
621 network_peer_port=getattr(self, "port", None),
622 error_type=error,
623 retry_attempts=failure_count,
624 )
625
626 await record_connection_closed(
627 close_reason=close_reason,
628 error_type=error,
629 )
630 else:
631 await record_connection_closed(
632 close_reason=CloseReason.APPLICATION_CLOSE,
633 )
634
635 async def _send_ping(self):
636 """Send PING, expect PONG in return"""
637 await self.send_command("PING", check_health=False)
638 if str_if_bytes(await self.read_response()) != "PONG":
639 raise ConnectionError("Bad response from PING health check")
640
641 async def _ping_failed(self, error, failure_count):
642 """Function to call when PING fails"""
643 await self.disconnect(
644 error=error, failure_count=failure_count, health_check_failed=True
645 )
646
647 async def check_health(self):
648 """Check the health of the connection with a PING/PONG"""
649 if (
650 self.health_check_interval
651 and asyncio.get_running_loop().time() > self.next_health_check
652 ):
653 await self.retry.call_with_retry(
654 self._send_ping, self._ping_failed, with_failure_count=True
655 )
656
657 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
658 self._writer.writelines(command)
659 await self._writer.drain()
660
661 async def send_packed_command(
662 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
663 ) -> None:
664 if not self.is_connected:
665 await self.connect_check_health(check_health=False)
666 if check_health:
667 await self.check_health()
668
669 try:
670 if isinstance(command, str):
671 command = command.encode()
672 if isinstance(command, bytes):
673 command = [command]
674 if self.socket_timeout:
675 await asyncio.wait_for(
676 self._send_packed_command(command), self.socket_timeout
677 )
678 else:
679 self._writer.writelines(command)
680 await self._writer.drain()
681 except asyncio.TimeoutError:
682 await self.disconnect(nowait=True)
683 raise TimeoutError("Timeout writing to socket") from None
684 except OSError as e:
685 await self.disconnect(nowait=True)
686 if len(e.args) == 1:
687 err_no, errmsg = "UNKNOWN", e.args[0]
688 else:
689 err_no = e.args[0]
690 errmsg = e.args[1]
691 raise ConnectionError(
692 f"Error {err_no} while writing to socket. {errmsg}."
693 ) from e
694 except BaseException:
695 # BaseExceptions can be raised when a socket send operation is not
696 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
697 # to send un-sent data. However, the send_packed_command() API
698 # does not support it so there is no point in keeping the connection open.
699 await self.disconnect(nowait=True)
700 raise
701
702 async def send_command(self, *args: Any, **kwargs: Any) -> None:
703 """Pack and send a command to the Redis server"""
704 await self.send_packed_command(
705 self.pack_command(*args), check_health=kwargs.get("check_health", True)
706 )
707
708 async def can_read_destructive(self):
709 """Poll the socket to see if there's data that can be read."""
710 try:
711 return await self._parser.can_read_destructive()
712 except OSError as e:
713 await self.disconnect(nowait=True)
714 host_error = self._host_error()
715 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
716
717 async def read_response(
718 self,
719 disable_decoding: bool = False,
720 timeout: Optional[float] = None,
721 *,
722 disconnect_on_error: bool = True,
723 push_request: Optional[bool] = False,
724 ):
725 """Read the response from a previously sent command"""
726 read_timeout = timeout if timeout is not None else self.socket_timeout
727 host_error = self._host_error()
728 try:
729 if read_timeout is not None and self.protocol in ["3", 3]:
730 async with async_timeout(read_timeout):
731 response = await self._parser.read_response(
732 disable_decoding=disable_decoding, push_request=push_request
733 )
734 elif read_timeout is not None:
735 async with async_timeout(read_timeout):
736 response = await self._parser.read_response(
737 disable_decoding=disable_decoding
738 )
739 elif self.protocol in ["3", 3]:
740 response = await self._parser.read_response(
741 disable_decoding=disable_decoding, push_request=push_request
742 )
743 else:
744 response = await self._parser.read_response(
745 disable_decoding=disable_decoding
746 )
747 except asyncio.TimeoutError:
748 if timeout is not None:
749 # user requested timeout, return None. Operation can be retried
750 return None
751 # it was a self.socket_timeout error.
752 if disconnect_on_error:
753 await self.disconnect(nowait=True)
754 raise TimeoutError(f"Timeout reading from {host_error}")
755 except OSError as e:
756 if disconnect_on_error:
757 await self.disconnect(nowait=True)
758 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
759 except BaseException:
760 # Also by default close in case of BaseException. A lot of code
761 # relies on this behaviour when doing Command/Response pairs.
762 # See #1128.
763 if disconnect_on_error:
764 await self.disconnect(nowait=True)
765 raise
766
767 if self.health_check_interval:
768 next_time = asyncio.get_running_loop().time() + self.health_check_interval
769 self.next_health_check = next_time
770
771 if isinstance(response, ResponseError):
772 raise response from None
773 return response
774
775 def pack_command(self, *args: EncodableT) -> List[bytes]:
776 """Pack a series of arguments into the Redis protocol"""
777 output = []
778 # the client might have included 1 or more literal arguments in
779 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
780 # arguments to be sent separately, so split the first argument
781 # manually. These arguments should be bytestrings so that they are
782 # not encoded.
783 assert not isinstance(args[0], float)
784 if isinstance(args[0], str):
785 args = tuple(args[0].encode().split()) + args[1:]
786 elif b" " in args[0]:
787 args = tuple(args[0].split()) + args[1:]
788
789 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
790
791 buffer_cutoff = self._buffer_cutoff
792 for arg in map(self.encoder.encode, args):
793 # to avoid large string mallocs, chunk the command into the
794 # output list if we're sending large values or memoryviews
795 arg_length = len(arg)
796 if (
797 len(buff) > buffer_cutoff
798 or arg_length > buffer_cutoff
799 or isinstance(arg, memoryview)
800 ):
801 buff = SYM_EMPTY.join(
802 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
803 )
804 output.append(buff)
805 output.append(arg)
806 buff = SYM_CRLF
807 else:
808 buff = SYM_EMPTY.join(
809 (
810 buff,
811 SYM_DOLLAR,
812 str(arg_length).encode(),
813 SYM_CRLF,
814 arg,
815 SYM_CRLF,
816 )
817 )
818 output.append(buff)
819 return output
820
821 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
822 """Pack multiple commands into the Redis protocol"""
823 output: List[bytes] = []
824 pieces: List[bytes] = []
825 buffer_length = 0
826 buffer_cutoff = self._buffer_cutoff
827
828 for cmd in commands:
829 for chunk in self.pack_command(*cmd):
830 chunklen = len(chunk)
831 if (
832 buffer_length > buffer_cutoff
833 or chunklen > buffer_cutoff
834 or isinstance(chunk, memoryview)
835 ):
836 if pieces:
837 output.append(SYM_EMPTY.join(pieces))
838 buffer_length = 0
839 pieces = []
840
841 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
842 output.append(chunk)
843 else:
844 pieces.append(chunk)
845 buffer_length += chunklen
846
847 if pieces:
848 output.append(SYM_EMPTY.join(pieces))
849 return output
850
851 def _socket_is_empty(self):
852 """Check if the socket is empty"""
853 return len(self._reader._buffer) == 0
854
855 async def process_invalidation_messages(self):
856 while not self._socket_is_empty():
857 await self.read_response(push_request=True)
858
859 def set_re_auth_token(self, token: TokenInterface):
860 self._re_auth_token = token
861
862 async def re_auth(self):
863 if self._re_auth_token is not None:
864 await self.send_command(
865 "AUTH",
866 self._re_auth_token.try_get("oid"),
867 self._re_auth_token.get_value(),
868 )
869 await self.read_response()
870 self._re_auth_token = None
871
872
873class Connection(AbstractConnection):
874 "Manages TCP communication to and from a Redis server"
875
876 def __init__(
877 self,
878 *,
879 host: str = "localhost",
880 port: Union[str, int] = 6379,
881 socket_keepalive: bool = False,
882 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
883 socket_type: int = 0,
884 **kwargs,
885 ):
886 self.host = host
887 self.port = int(port)
888 self.socket_keepalive = socket_keepalive
889 self.socket_keepalive_options = socket_keepalive_options or {}
890 self.socket_type = socket_type
891 super().__init__(**kwargs)
892
893 def repr_pieces(self):
894 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
895 if self.client_name:
896 pieces.append(("client_name", self.client_name))
897 return pieces
898
899 def _connection_arguments(self) -> Mapping:
900 return {"host": self.host, "port": self.port}
901
902 async def _connect(self):
903 """Create a TCP socket connection"""
904 async with async_timeout(self.socket_connect_timeout):
905 reader, writer = await asyncio.open_connection(
906 **self._connection_arguments()
907 )
908 self._reader = reader
909 self._writer = writer
910 sock = writer.transport.get_extra_info("socket")
911 if sock:
912 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
913 try:
914 # TCP_KEEPALIVE
915 if self.socket_keepalive:
916 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
917 for k, v in self.socket_keepalive_options.items():
918 sock.setsockopt(socket.SOL_TCP, k, v)
919
920 except (OSError, TypeError):
921 # `socket_keepalive_options` might contain invalid options
922 # causing an error. Do not leave the connection open.
923 writer.close()
924 raise
925
926 def _host_error(self) -> str:
927 return f"{self.host}:{self.port}"
928
929
930class SSLConnection(Connection):
931 """Manages SSL connections to and from the Redis server(s).
932 This class extends the Connection class, adding SSL functionality, and making
933 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
934 """
935
936 def __init__(
937 self,
938 ssl_keyfile: Optional[str] = None,
939 ssl_certfile: Optional[str] = None,
940 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
941 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
942 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
943 ssl_ca_certs: Optional[str] = None,
944 ssl_ca_data: Optional[str] = None,
945 ssl_ca_path: Optional[str] = None,
946 ssl_check_hostname: bool = True,
947 ssl_min_version: Optional[TLSVersion] = None,
948 ssl_ciphers: Optional[str] = None,
949 ssl_password: Optional[str] = None,
950 **kwargs,
951 ):
952 if not SSL_AVAILABLE:
953 raise RedisError("Python wasn't built with SSL support")
954
955 self.ssl_context: RedisSSLContext = RedisSSLContext(
956 keyfile=ssl_keyfile,
957 certfile=ssl_certfile,
958 cert_reqs=ssl_cert_reqs,
959 include_verify_flags=ssl_include_verify_flags,
960 exclude_verify_flags=ssl_exclude_verify_flags,
961 ca_certs=ssl_ca_certs,
962 ca_data=ssl_ca_data,
963 ca_path=ssl_ca_path,
964 check_hostname=ssl_check_hostname,
965 min_version=ssl_min_version,
966 ciphers=ssl_ciphers,
967 password=ssl_password,
968 )
969 super().__init__(**kwargs)
970
971 def _connection_arguments(self) -> Mapping:
972 kwargs = super()._connection_arguments()
973 kwargs["ssl"] = self.ssl_context.get()
974 return kwargs
975
976 @property
977 def keyfile(self):
978 return self.ssl_context.keyfile
979
980 @property
981 def certfile(self):
982 return self.ssl_context.certfile
983
984 @property
985 def cert_reqs(self):
986 return self.ssl_context.cert_reqs
987
988 @property
989 def include_verify_flags(self):
990 return self.ssl_context.include_verify_flags
991
992 @property
993 def exclude_verify_flags(self):
994 return self.ssl_context.exclude_verify_flags
995
996 @property
997 def ca_certs(self):
998 return self.ssl_context.ca_certs
999
1000 @property
1001 def ca_data(self):
1002 return self.ssl_context.ca_data
1003
1004 @property
1005 def check_hostname(self):
1006 return self.ssl_context.check_hostname
1007
1008 @property
1009 def min_version(self):
1010 return self.ssl_context.min_version
1011
1012
1013class RedisSSLContext:
1014 __slots__ = (
1015 "keyfile",
1016 "certfile",
1017 "cert_reqs",
1018 "include_verify_flags",
1019 "exclude_verify_flags",
1020 "ca_certs",
1021 "ca_data",
1022 "ca_path",
1023 "context",
1024 "check_hostname",
1025 "min_version",
1026 "ciphers",
1027 "password",
1028 )
1029
1030 def __init__(
1031 self,
1032 keyfile: Optional[str] = None,
1033 certfile: Optional[str] = None,
1034 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
1035 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1036 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1037 ca_certs: Optional[str] = None,
1038 ca_data: Optional[str] = None,
1039 ca_path: Optional[str] = None,
1040 check_hostname: bool = False,
1041 min_version: Optional[TLSVersion] = None,
1042 ciphers: Optional[str] = None,
1043 password: Optional[str] = None,
1044 ):
1045 if not SSL_AVAILABLE:
1046 raise RedisError("Python wasn't built with SSL support")
1047
1048 self.keyfile = keyfile
1049 self.certfile = certfile
1050 if cert_reqs is None:
1051 cert_reqs = ssl.CERT_NONE
1052 elif isinstance(cert_reqs, str):
1053 CERT_REQS = { # noqa: N806
1054 "none": ssl.CERT_NONE,
1055 "optional": ssl.CERT_OPTIONAL,
1056 "required": ssl.CERT_REQUIRED,
1057 }
1058 if cert_reqs not in CERT_REQS:
1059 raise RedisError(
1060 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
1061 )
1062 cert_reqs = CERT_REQS[cert_reqs]
1063 self.cert_reqs = cert_reqs
1064 self.include_verify_flags = include_verify_flags
1065 self.exclude_verify_flags = exclude_verify_flags
1066 self.ca_certs = ca_certs
1067 self.ca_data = ca_data
1068 self.ca_path = ca_path
1069 self.check_hostname = (
1070 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1071 )
1072 self.min_version = min_version
1073 self.ciphers = ciphers
1074 self.password = password
1075 self.context: Optional[SSLContext] = None
1076
1077 def get(self) -> SSLContext:
1078 if not self.context:
1079 context = ssl.create_default_context()
1080 context.check_hostname = self.check_hostname
1081 context.verify_mode = self.cert_reqs
1082 if self.include_verify_flags:
1083 for flag in self.include_verify_flags:
1084 context.verify_flags |= flag
1085 if self.exclude_verify_flags:
1086 for flag in self.exclude_verify_flags:
1087 context.verify_flags &= ~flag
1088 if self.certfile or self.keyfile:
1089 context.load_cert_chain(
1090 certfile=self.certfile,
1091 keyfile=self.keyfile,
1092 password=self.password,
1093 )
1094 if self.ca_certs or self.ca_data or self.ca_path:
1095 context.load_verify_locations(
1096 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1097 )
1098 if self.min_version is not None:
1099 context.minimum_version = self.min_version
1100 if self.ciphers is not None:
1101 context.set_ciphers(self.ciphers)
1102 self.context = context
1103 return self.context
1104
1105
1106class UnixDomainSocketConnection(AbstractConnection):
1107 "Manages UDS communication to and from a Redis server"
1108
1109 def __init__(self, *, path: str = "", **kwargs):
1110 self.path = path
1111 super().__init__(**kwargs)
1112
1113 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
1114 pieces = [("path", self.path), ("db", self.db)]
1115 if self.client_name:
1116 pieces.append(("client_name", self.client_name))
1117 return pieces
1118
1119 async def _connect(self):
1120 async with async_timeout(self.socket_connect_timeout):
1121 reader, writer = await asyncio.open_unix_connection(path=self.path)
1122 self._reader = reader
1123 self._writer = writer
1124 await self.on_connect()
1125
1126 def _host_error(self) -> str:
1127 return self.path
1128
1129
1130FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1131
1132
1133def to_bool(value) -> Optional[bool]:
1134 if value is None or value == "":
1135 return None
1136 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1137 return False
1138 return bool(value)
1139
1140
1141def parse_ssl_verify_flags(value):
1142 # flags are passed in as a string representation of a list,
1143 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1144 verify_flags_str = value.replace("[", "").replace("]", "")
1145
1146 verify_flags = []
1147 for flag in verify_flags_str.split(","):
1148 flag = flag.strip()
1149 if not hasattr(VerifyFlags, flag):
1150 raise ValueError(f"Invalid ssl verify flag: {flag}")
1151 verify_flags.append(getattr(VerifyFlags, flag))
1152 return verify_flags
1153
1154
1155URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1156 {
1157 "db": int,
1158 "socket_timeout": float,
1159 "socket_connect_timeout": float,
1160 "socket_keepalive": to_bool,
1161 "retry_on_timeout": to_bool,
1162 "max_connections": int,
1163 "health_check_interval": int,
1164 "ssl_check_hostname": to_bool,
1165 "ssl_include_verify_flags": parse_ssl_verify_flags,
1166 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1167 "timeout": float,
1168 }
1169)
1170
1171
1172class ConnectKwargs(TypedDict, total=False):
1173 username: str
1174 password: str
1175 connection_class: Type[AbstractConnection]
1176 host: str
1177 port: int
1178 db: int
1179 path: str
1180
1181
1182def parse_url(url: str) -> ConnectKwargs:
1183 parsed: ParseResult = urlparse(url)
1184 kwargs: ConnectKwargs = {}
1185
1186 for name, value_list in parse_qs(parsed.query).items():
1187 if value_list and len(value_list) > 0:
1188 value = unquote(value_list[0])
1189 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1190 if parser:
1191 try:
1192 kwargs[name] = parser(value)
1193 except (TypeError, ValueError):
1194 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1195 else:
1196 kwargs[name] = value
1197
1198 if parsed.username:
1199 kwargs["username"] = unquote(parsed.username)
1200 if parsed.password:
1201 kwargs["password"] = unquote(parsed.password)
1202
1203 # We only support redis://, rediss:// and unix:// schemes.
1204 if parsed.scheme == "unix":
1205 if parsed.path:
1206 kwargs["path"] = unquote(parsed.path)
1207 kwargs["connection_class"] = UnixDomainSocketConnection
1208
1209 elif parsed.scheme in ("redis", "rediss"):
1210 if parsed.hostname:
1211 kwargs["host"] = unquote(parsed.hostname)
1212 if parsed.port:
1213 kwargs["port"] = int(parsed.port)
1214
1215 # If there's a path argument, use it as the db argument if a
1216 # querystring value wasn't specified
1217 if parsed.path and "db" not in kwargs:
1218 try:
1219 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1220 except (AttributeError, ValueError):
1221 pass
1222
1223 if parsed.scheme == "rediss":
1224 kwargs["connection_class"] = SSLConnection
1225
1226 else:
1227 valid_schemes = "redis://, rediss://, unix://"
1228 raise ValueError(
1229 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1230 )
1231
1232 return kwargs
1233
1234
1235_CP = TypeVar("_CP", bound="ConnectionPool")
1236
1237
1238class ConnectionPool:
1239 """
1240 Create a connection pool. ``If max_connections`` is set, then this
1241 object raises :py:class:`~redis.ConnectionError` when the pool's
1242 limit is reached.
1243
1244 By default, TCP connections are created unless ``connection_class``
1245 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1246 unix sockets.
1247 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1248
1249 Any additional keyword arguments are passed to the constructor of
1250 ``connection_class``.
1251 """
1252
1253 @classmethod
1254 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1255 """
1256 Return a connection pool configured from the given URL.
1257
1258 For example::
1259
1260 redis://[[username]:[password]]@localhost:6379/0
1261 rediss://[[username]:[password]]@localhost:6379/0
1262 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1263
1264 Three URL schemes are supported:
1265
1266 - `redis://` creates a TCP socket connection. See more at:
1267 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1268 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1269 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1270 - ``unix://``: creates a Unix Domain Socket connection.
1271
1272 The username, password, hostname, path and all querystring values
1273 are passed through urllib.parse.unquote in order to replace any
1274 percent-encoded values with their corresponding characters.
1275
1276 There are several ways to specify a database number. The first value
1277 found will be used:
1278
1279 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1280
1281 2. If using the redis:// or rediss:// schemes, the path argument
1282 of the url, e.g. redis://localhost/0
1283
1284 3. A ``db`` keyword argument to this function.
1285
1286 If none of these options are specified, the default db=0 is used.
1287
1288 All querystring options are cast to their appropriate Python types.
1289 Boolean arguments can be specified with string values "True"/"False"
1290 or "Yes"/"No". Values that cannot be properly cast cause a
1291 ``ValueError`` to be raised. Once parsed, the querystring arguments
1292 and keyword arguments are passed to the ``ConnectionPool``'s
1293 class initializer. In the case of conflicting arguments, querystring
1294 arguments always win.
1295 """
1296 url_options = parse_url(url)
1297 kwargs.update(url_options)
1298 return cls(**kwargs)
1299
1300 def __init__(
1301 self,
1302 connection_class: Type[AbstractConnection] = Connection,
1303 max_connections: Optional[int] = None,
1304 **connection_kwargs,
1305 ):
1306 max_connections = max_connections or 2**31
1307 if not isinstance(max_connections, int) or max_connections < 0:
1308 raise ValueError('"max_connections" must be a positive integer')
1309
1310 self.connection_class = connection_class
1311 self.connection_kwargs = connection_kwargs
1312 self.max_connections = max_connections
1313
1314 self._available_connections: List[AbstractConnection] = []
1315 self._in_use_connections: Set[AbstractConnection] = set()
1316 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1317 self._lock = asyncio.Lock()
1318 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1319 if self._event_dispatcher is None:
1320 self._event_dispatcher = EventDispatcher()
1321
1322 # Keys that should be redacted in __repr__ to avoid exposing sensitive information
1323 SENSITIVE_REPR_KEYS = frozenset(
1324 {
1325 "password",
1326 "username",
1327 "ssl_password",
1328 "credential_provider",
1329 }
1330 )
1331
1332 def __repr__(self):
1333 conn_kwargs = ",".join(
1334 [
1335 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}"
1336 for k, v in self.connection_kwargs.items()
1337 ]
1338 )
1339 return (
1340 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1341 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1342 f"({conn_kwargs})>)>"
1343 )
1344
1345 def reset(self):
1346 # Record metrics for connections being removed before clearing
1347 # (only if attributes exist - they won't during __init__)
1348 if hasattr(self, "_available_connections") and hasattr(
1349 self, "_in_use_connections"
1350 ):
1351 idle_count = len(self._available_connections)
1352 in_use_count = len(self._in_use_connections)
1353 if idle_count > 0 or in_use_count > 0:
1354 pool_name = get_pool_name(self)
1355 # Note: Using sync version since reset() is sync
1356 from redis.observability.recorder import (
1357 record_connection_count as sync_record_connection_count,
1358 )
1359
1360 if idle_count > 0:
1361 sync_record_connection_count(
1362 pool_name=pool_name,
1363 connection_state=ConnectionState.IDLE,
1364 counter=-idle_count,
1365 )
1366 if in_use_count > 0:
1367 sync_record_connection_count(
1368 pool_name=pool_name,
1369 connection_state=ConnectionState.USED,
1370 counter=-in_use_count,
1371 )
1372
1373 self._available_connections = []
1374 self._in_use_connections = weakref.WeakSet()
1375
1376 def __del__(self) -> None:
1377 """Clean up connection pool and record metrics when garbage collected."""
1378 try:
1379 if not hasattr(self, "_available_connections") or not hasattr(
1380 self, "_in_use_connections"
1381 ):
1382 return
1383 idle_count = len(self._available_connections)
1384 in_use_count = len(self._in_use_connections)
1385 if idle_count > 0 or in_use_count > 0:
1386 pool_name = get_pool_name(self)
1387 # Note: Using sync version since __del__ is sync
1388 from redis.observability.recorder import (
1389 record_connection_count as sync_record_connection_count,
1390 )
1391
1392 if idle_count > 0:
1393 sync_record_connection_count(
1394 pool_name=pool_name,
1395 connection_state=ConnectionState.IDLE,
1396 counter=-idle_count,
1397 )
1398 if in_use_count > 0:
1399 sync_record_connection_count(
1400 pool_name=pool_name,
1401 connection_state=ConnectionState.USED,
1402 counter=-in_use_count,
1403 )
1404 except Exception:
1405 pass
1406
1407 def can_get_connection(self) -> bool:
1408 """Return True if a connection can be retrieved from the pool."""
1409 return (
1410 self._available_connections
1411 or len(self._in_use_connections) < self.max_connections
1412 )
1413
1414 @deprecated_args(
1415 args_to_warn=["*"],
1416 reason="Use get_connection() without args instead",
1417 version="5.3.0",
1418 )
1419 async def get_connection(self, command_name=None, *keys, **options):
1420 """Get a connected connection from the pool"""
1421 # Track connection count before to detect if a new connection is created
1422 async with self._lock:
1423 connections_before = len(self._available_connections) + len(
1424 self._in_use_connections
1425 )
1426 start_time_created = time.monotonic()
1427 connection = self.get_available_connection()
1428 connections_after = len(self._available_connections) + len(
1429 self._in_use_connections
1430 )
1431 is_created = connections_after > connections_before
1432
1433 # Record state transition for observability
1434 # This ensures counters stay balanced if ensure_connection() fails and release() is called
1435 pool_name = get_pool_name(self)
1436 if is_created:
1437 # New connection created and acquired: just USED +1
1438 await record_connection_count(
1439 pool_name=pool_name,
1440 connection_state=ConnectionState.USED,
1441 counter=1,
1442 )
1443 else:
1444 # Existing connection acquired from pool: IDLE -> USED
1445 await record_connection_count(
1446 pool_name=pool_name,
1447 connection_state=ConnectionState.IDLE,
1448 counter=-1,
1449 )
1450 await record_connection_count(
1451 pool_name=pool_name,
1452 connection_state=ConnectionState.USED,
1453 counter=1,
1454 )
1455
1456 # We now perform the connection check outside of the lock.
1457 try:
1458 await self.ensure_connection(connection)
1459
1460 if is_created:
1461 await record_connection_create_time(
1462 connection_pool=self,
1463 duration_seconds=time.monotonic() - start_time_created,
1464 )
1465
1466 return connection
1467 except BaseException:
1468 await self.release(connection)
1469 raise
1470
1471 def get_available_connection(self):
1472 """Get a connection from the pool, without making sure it is connected"""
1473 try:
1474 connection = self._available_connections.pop()
1475 except IndexError:
1476 if len(self._in_use_connections) >= self.max_connections:
1477 raise MaxConnectionsError("Too many connections") from None
1478 connection = self.make_connection()
1479 self._in_use_connections.add(connection)
1480 return connection
1481
1482 def get_encoder(self):
1483 """Return an encoder based on encoding settings"""
1484 kwargs = self.connection_kwargs
1485 return self.encoder_class(
1486 encoding=kwargs.get("encoding", "utf-8"),
1487 encoding_errors=kwargs.get("encoding_errors", "strict"),
1488 decode_responses=kwargs.get("decode_responses", False),
1489 )
1490
1491 def make_connection(self):
1492 """Create a new connection. Can be overridden by child classes."""
1493 # Note: We don't record IDLE here because async uses a sync make_connection
1494 # but async record_connection_count. The recording is handled in get_connection.
1495 return self.connection_class(**self.connection_kwargs)
1496
1497 async def ensure_connection(self, connection: AbstractConnection):
1498 """Ensure that the connection object is connected and valid"""
1499 await connection.connect()
1500 # connections that the pool provides should be ready to send
1501 # a command. if not, the connection was either returned to the
1502 # pool before all data has been read or the socket has been
1503 # closed. either way, reconnect and verify everything is good.
1504 try:
1505 if await connection.can_read_destructive():
1506 raise ConnectionError("Connection has data") from None
1507 except (ConnectionError, TimeoutError, OSError):
1508 await connection.disconnect()
1509 await connection.connect()
1510 if await connection.can_read_destructive():
1511 raise ConnectionError("Connection not ready") from None
1512
1513 async def release(self, connection: AbstractConnection):
1514 """Releases the connection back to the pool"""
1515 # Connections should always be returned to the correct pool,
1516 # not doing so is an error that will cause an exception here.
1517 self._in_use_connections.remove(connection)
1518
1519 if connection.should_reconnect():
1520 await connection.disconnect()
1521
1522 self._available_connections.append(connection)
1523 await self._event_dispatcher.dispatch_async(
1524 AsyncAfterConnectionReleasedEvent(connection)
1525 )
1526
1527 # Record state transition: USED -> IDLE
1528 pool_name = get_pool_name(self)
1529 await record_connection_count(
1530 pool_name=pool_name,
1531 connection_state=ConnectionState.USED,
1532 counter=-1,
1533 )
1534 await record_connection_count(
1535 pool_name=pool_name,
1536 connection_state=ConnectionState.IDLE,
1537 counter=1,
1538 )
1539
1540 async def disconnect(self, inuse_connections: bool = True):
1541 """
1542 Disconnects connections in the pool
1543
1544 If ``inuse_connections`` is True, disconnect connections that are
1545 current in use, potentially by other tasks. Otherwise only disconnect
1546 connections that are idle in the pool.
1547 """
1548 if inuse_connections:
1549 connections: Iterable[AbstractConnection] = chain(
1550 self._available_connections, self._in_use_connections
1551 )
1552 else:
1553 connections = self._available_connections
1554 resp = await asyncio.gather(
1555 *(connection.disconnect() for connection in connections),
1556 return_exceptions=True,
1557 )
1558
1559 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1560 if exc:
1561 raise exc
1562
1563 async def update_active_connections_for_reconnect(self):
1564 """
1565 Mark all active connections for reconnect.
1566 """
1567 async with self._lock:
1568 for conn in self._in_use_connections:
1569 conn.mark_for_reconnect()
1570
1571 async def aclose(self) -> None:
1572 """Close the pool, disconnecting all connections"""
1573 await self.disconnect()
1574
1575 def set_retry(self, retry: "Retry") -> None:
1576 for conn in self._available_connections:
1577 conn.retry = retry
1578 for conn in self._in_use_connections:
1579 conn.retry = retry
1580
1581 async def re_auth_callback(self, token: TokenInterface):
1582 async with self._lock:
1583 for conn in self._available_connections:
1584 await conn.retry.call_with_retry(
1585 lambda: conn.send_command(
1586 "AUTH", token.try_get("oid"), token.get_value()
1587 ),
1588 lambda error: self._mock(error),
1589 )
1590 await conn.retry.call_with_retry(
1591 lambda: conn.read_response(), lambda error: self._mock(error)
1592 )
1593 for conn in self._in_use_connections:
1594 conn.set_re_auth_token(token)
1595
1596 async def _mock(self, error: RedisError):
1597 """
1598 Dummy functions, needs to be passed as error callback to retry object.
1599 :param error:
1600 :return:
1601 """
1602 pass
1603
1604 def get_connection_count(self) -> List[tuple[int, dict]]:
1605 """
1606 Returns a connection count (both idle and in use).
1607 """
1608 attributes = AttributeBuilder.build_base_attributes()
1609 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
1610 free_connections_attributes = attributes.copy()
1611 in_use_connections_attributes = attributes.copy()
1612
1613 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1614 ConnectionState.IDLE.value
1615 )
1616 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1617 ConnectionState.USED.value
1618 )
1619
1620 return [
1621 (len(self._available_connections), free_connections_attributes),
1622 (len(self._in_use_connections), in_use_connections_attributes),
1623 ]
1624
1625
1626class BlockingConnectionPool(ConnectionPool):
1627 """
1628 A blocking connection pool::
1629
1630 >>> from redis.asyncio import Redis, BlockingConnectionPool
1631 >>> client = Redis.from_pool(BlockingConnectionPool())
1632
1633 It performs the same function as the default
1634 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1635 it maintains a pool of reusable connections that can be shared by
1636 multiple async redis clients.
1637
1638 The difference is that, in the event that a client tries to get a
1639 connection from the pool when all of connections are in use, rather than
1640 raising a :py:class:`~redis.ConnectionError` (as the default
1641 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1642 blocks the current `Task` for a specified number of seconds until
1643 a connection becomes available.
1644
1645 Use ``max_connections`` to increase / decrease the pool size::
1646
1647 >>> pool = BlockingConnectionPool(max_connections=10)
1648
1649 Use ``timeout`` to tell it either how many seconds to wait for a connection
1650 to become available, or to block forever:
1651
1652 >>> # Block forever.
1653 >>> pool = BlockingConnectionPool(timeout=None)
1654
1655 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1656 >>> # not available.
1657 >>> pool = BlockingConnectionPool(timeout=5)
1658 """
1659
1660 def __init__(
1661 self,
1662 max_connections: int = 50,
1663 timeout: Optional[float] = 20,
1664 connection_class: Type[AbstractConnection] = Connection,
1665 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1666 **connection_kwargs,
1667 ):
1668 super().__init__(
1669 connection_class=connection_class,
1670 max_connections=max_connections,
1671 **connection_kwargs,
1672 )
1673 self._condition = asyncio.Condition()
1674 self.timeout = timeout
1675
1676 @deprecated_args(
1677 args_to_warn=["*"],
1678 reason="Use get_connection() without args instead",
1679 version="5.3.0",
1680 )
1681 async def get_connection(self, command_name=None, *keys, **options):
1682 """Gets a connection from the pool, blocking until one is available"""
1683 # Start timing for wait time observability
1684 start_time_acquired = time.monotonic()
1685
1686 try:
1687 async with self._condition:
1688 async with async_timeout(self.timeout):
1689 await self._condition.wait_for(self.can_get_connection)
1690 # Track connection count before to detect if a new connection is created
1691 connections_before = len(self._available_connections) + len(
1692 self._in_use_connections
1693 )
1694 start_time_created = time.monotonic()
1695 connection = super().get_available_connection()
1696 connections_after = len(self._available_connections) + len(
1697 self._in_use_connections
1698 )
1699 is_created = connections_after > connections_before
1700 except asyncio.TimeoutError as err:
1701 raise ConnectionError("No connection available.") from err
1702
1703 # We now perform the connection check outside of the lock.
1704 try:
1705 await self.ensure_connection(connection)
1706
1707 if is_created:
1708 await record_connection_create_time(
1709 connection_pool=self,
1710 duration_seconds=time.monotonic() - start_time_created,
1711 )
1712
1713 await record_connection_wait_time(
1714 pool_name=get_pool_name(self),
1715 duration_seconds=time.monotonic() - start_time_acquired,
1716 )
1717
1718 return connection
1719 except BaseException:
1720 await self.release(connection)
1721 raise
1722
1723 async def release(self, connection: AbstractConnection):
1724 """Releases the connection back to the pool."""
1725 async with self._condition:
1726 await super().release(connection)
1727 self._condition.notify()