1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import time
8import warnings
9import weakref
10from abc import ABC, abstractmethod
11from itertools import chain
12from types import MappingProxyType
13from typing import (
14 Any,
15 Callable,
16 Iterable,
17 List,
18 Mapping,
19 Optional,
20 Protocol,
21 Set,
22 Tuple,
23 Type,
24 TypedDict,
25 TypeVar,
26 Union,
27)
28from urllib.parse import ParseResult, parse_qs, unquote, urlparse
29
30from ..observability.attributes import (
31 DB_CLIENT_CONNECTION_POOL_NAME,
32 DB_CLIENT_CONNECTION_STATE,
33 AttributeBuilder,
34 ConnectionState,
35 get_pool_name,
36)
37from ..utils import SSL_AVAILABLE
38
39if SSL_AVAILABLE:
40 import ssl
41 from ssl import SSLContext, TLSVersion, VerifyFlags
42else:
43 ssl = None
44 TLSVersion = None
45 SSLContext = None
46 VerifyFlags = None
47
48from ..auth.token import TokenInterface
49from ..driver_info import DriverInfo, resolve_driver_info
50from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
51from ..utils import deprecated_args, format_error_message
52
53# the functionality is available in 3.11.x but has a major issue before
54# 3.11.3. See https://github.com/redis/redis-py/issues/2633
55if sys.version_info >= (3, 11, 3):
56 from asyncio import timeout as async_timeout
57else:
58 from async_timeout import timeout as async_timeout
59
60from redis.asyncio.observability.recorder import (
61 record_connection_closed,
62 record_connection_count,
63 record_connection_create_time,
64 record_connection_wait_time,
65 record_error_count,
66)
67from redis.asyncio.retry import Retry
68from redis.backoff import NoBackoff
69from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
70from redis.exceptions import (
71 AuthenticationError,
72 AuthenticationWrongNumberOfArgsError,
73 ConnectionError,
74 DataError,
75 MaxConnectionsError,
76 RedisError,
77 ResponseError,
78 TimeoutError,
79)
80from redis.observability.metrics import CloseReason
81from redis.typing import EncodableT
82from redis.utils import DEFAULT_RESP_VERSION, HIREDIS_AVAILABLE, str_if_bytes
83
84from .._parsers import (
85 BaseParser,
86 Encoder,
87 _AsyncHiredisParser,
88 _AsyncRESP2Parser,
89 _AsyncRESP3Parser,
90)
91
92SYM_STAR = b"*"
93SYM_DOLLAR = b"$"
94SYM_CRLF = b"\r\n"
95SYM_LF = b"\n"
96SYM_EMPTY = b""
97
98
99class _Sentinel(enum.Enum):
100 sentinel = object()
101
102
103SENTINEL = _Sentinel.sentinel
104
105
106DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
107if HIREDIS_AVAILABLE:
108 DefaultParser = _AsyncHiredisParser
109else:
110 DefaultParser = _AsyncRESP3Parser
111
112
113class ConnectCallbackProtocol(Protocol):
114 def __call__(self, connection: "AbstractConnection"): ...
115
116
117class AsyncConnectCallbackProtocol(Protocol):
118 async def __call__(self, connection: "AbstractConnection"): ...
119
120
121ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
122
123
124class AbstractConnection:
125 """Manages communication to and from a Redis server"""
126
127 __slots__ = (
128 "db",
129 "username",
130 "client_name",
131 "lib_name",
132 "lib_version",
133 "credential_provider",
134 "password",
135 "socket_timeout",
136 "socket_connect_timeout",
137 "redis_connect_func",
138 "retry_on_timeout",
139 "retry_on_error",
140 "health_check_interval",
141 "next_health_check",
142 "last_active_at",
143 "encoder",
144 "ssl_context",
145 "protocol",
146 "_reader",
147 "_writer",
148 "_parser",
149 "_connect_callbacks",
150 "_buffer_cutoff",
151 "_lock",
152 "_socket_read_size",
153 "__dict__",
154 )
155
156 @deprecated_args(
157 args_to_warn=["lib_name", "lib_version"],
158 reason="Use 'driver_info' parameter instead. "
159 "lib_name and lib_version will be removed in a future version.",
160 )
161 def __init__(
162 self,
163 *,
164 db: Union[str, int] = 0,
165 password: Optional[str] = None,
166 socket_timeout: Optional[float] = None,
167 socket_connect_timeout: Optional[float] = None,
168 retry_on_timeout: bool = False,
169 retry_on_error: Union[list, _Sentinel] = SENTINEL,
170 encoding: str = "utf-8",
171 encoding_errors: str = "strict",
172 decode_responses: bool = False,
173 parser_class: Type[BaseParser] = DefaultParser,
174 socket_read_size: int = 65536,
175 health_check_interval: float = 0,
176 client_name: Optional[str] = None,
177 lib_name: Optional[str] = None,
178 lib_version: Optional[str] = None,
179 driver_info: Optional[DriverInfo] = None,
180 username: Optional[str] = None,
181 retry: Optional[Retry] = None,
182 redis_connect_func: Optional[ConnectCallbackT] = None,
183 encoder_class: Type[Encoder] = Encoder,
184 credential_provider: Optional[CredentialProvider] = None,
185 protocol: Optional[int] = 3,
186 event_dispatcher: Optional[EventDispatcher] = None,
187 ):
188 """
189 Initialize a new async Connection.
190
191 Parameters
192 ----------
193 driver_info : DriverInfo, optional
194 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
195 are ignored. If not provided, a DriverInfo will be created from lib_name
196 and lib_version (or defaults if those are also None).
197 lib_name : str, optional
198 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
199 lib_version : str, optional
200 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
201 """
202 if (username or password) and credential_provider is not None:
203 raise DataError(
204 "'username' and 'password' cannot be passed along with 'credential_"
205 "provider'. Please provide only one of the following arguments: \n"
206 "1. 'password' and (optional) 'username'\n"
207 "2. 'credential_provider'"
208 )
209 if event_dispatcher is None:
210 self._event_dispatcher = EventDispatcher()
211 else:
212 self._event_dispatcher = event_dispatcher
213 self.db = db
214 self.client_name = client_name
215
216 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
217 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
218
219 self.credential_provider = credential_provider
220 self.password = password
221 self.username = username
222 self.socket_timeout = socket_timeout
223 if socket_connect_timeout is None:
224 socket_connect_timeout = socket_timeout
225 self.socket_connect_timeout = socket_connect_timeout
226 self.retry_on_timeout = retry_on_timeout
227 if retry_on_error is SENTINEL:
228 retry_on_error = []
229 if retry_on_timeout:
230 retry_on_error.append(TimeoutError)
231 retry_on_error.append(socket.timeout)
232 retry_on_error.append(asyncio.TimeoutError)
233 self.retry_on_error = retry_on_error
234 if retry or retry_on_error:
235 if not retry:
236 self.retry = Retry(NoBackoff(), 1)
237 else:
238 # deep-copy the Retry object as it is mutable
239 self.retry = copy.deepcopy(retry)
240 # Update the retry's supported errors with the specified errors
241 self.retry.update_supported_errors(retry_on_error)
242 else:
243 self.retry = Retry(NoBackoff(), 0)
244 self.health_check_interval = health_check_interval
245 self.next_health_check: float = -1
246 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
247 self.redis_connect_func = redis_connect_func
248 self._reader: Optional[asyncio.StreamReader] = None
249 self._writer: Optional[asyncio.StreamWriter] = None
250 self._socket_read_size = socket_read_size
251 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
252 self._buffer_cutoff = 6000
253 self._re_auth_token: Optional[TokenInterface] = None
254 self._should_reconnect = False
255
256 try:
257 p = int(protocol)
258 except TypeError:
259 p = DEFAULT_RESP_VERSION
260 except ValueError:
261 raise ConnectionError("protocol must be an integer")
262 else:
263 if p < 2 or p > 3:
264 raise ConnectionError("protocol must be either 2 or 3")
265 self.protocol = p
266 # Reconcile parser ↔ protocol mismatches.
267 # Hiredis handles both RESP2 and RESP3 natively, so only
268 # pure-Python parsers need to be swapped.
269 if self.protocol == 3 and parser_class == _AsyncRESP2Parser:
270 parser_class = _AsyncRESP3Parser
271 elif self.protocol == 2 and parser_class == _AsyncRESP3Parser:
272 parser_class = _AsyncRESP2Parser
273 self.set_parser(parser_class)
274
275 def __del__(self, _warnings: Any = warnings):
276 # For some reason, the individual streams don't get properly garbage
277 # collected and therefore produce no resource warnings. We add one
278 # here, in the same style as those from the stdlib.
279 if getattr(self, "_writer", None):
280 _warnings.warn(
281 f"unclosed Connection {self!r}", ResourceWarning, source=self
282 )
283
284 try:
285 asyncio.get_running_loop()
286 self._close()
287 except RuntimeError:
288 # No actions been taken if pool already closed.
289 pass
290
291 def _close(self):
292 """
293 Internal method to silently close the connection without waiting
294 """
295 if self._writer:
296 self._writer.close()
297 self._writer = self._reader = None
298
299 def __repr__(self):
300 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
301 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
302
303 @abstractmethod
304 def repr_pieces(self):
305 pass
306
307 @property
308 def is_connected(self):
309 return self._reader is not None and self._writer is not None
310
311 def register_connect_callback(self, callback):
312 """
313 Register a callback to be called when the connection is established either
314 initially or reconnected. This allows listeners to issue commands that
315 are ephemeral to the connection, for example pub/sub subscription or
316 key tracking. The callback must be a _method_ and will be kept as
317 a weak reference.
318 """
319 wm = weakref.WeakMethod(callback)
320 if wm not in self._connect_callbacks:
321 self._connect_callbacks.append(wm)
322
323 def deregister_connect_callback(self, callback):
324 """
325 De-register a previously registered callback. It will no-longer receive
326 notifications on connection events. Calling this is not required when the
327 listener goes away, since the callbacks are kept as weak methods.
328 """
329 try:
330 self._connect_callbacks.remove(weakref.WeakMethod(callback))
331 except ValueError:
332 pass
333
334 def set_parser(self, parser_class: Type[BaseParser]) -> None:
335 """
336 Creates a new instance of parser_class with socket size:
337 _socket_read_size and assigns it to the parser for the connection
338 :param parser_class: The required parser class
339 """
340 self._parser = parser_class(socket_read_size=self._socket_read_size)
341
342 async def connect(self):
343 """Connects to the Redis server if not already connected"""
344 # try once the socket connect with the handshake, retry the whole
345 # connect/handshake flow based on retry policy
346 await self.retry.call_with_retry(
347 lambda: self.connect_check_health(
348 check_health=True, retry_socket_connect=False
349 ),
350 lambda error, failure_count: self.disconnect(
351 error=error, failure_count=failure_count
352 ),
353 with_failure_count=True,
354 )
355
356 async def connect_check_health(
357 self, check_health: bool = True, retry_socket_connect: bool = True
358 ):
359 if self.is_connected:
360 return
361 # Track actual retry attempts for error reporting
362 actual_retry_attempts = 0
363
364 def failure_callback(error, failure_count):
365 nonlocal actual_retry_attempts
366 actual_retry_attempts = failure_count
367 return self.disconnect(error=error, failure_count=failure_count)
368
369 try:
370 if retry_socket_connect:
371 await self.retry.call_with_retry(
372 lambda: self._connect(),
373 failure_callback,
374 with_failure_count=True,
375 )
376 else:
377 await self._connect()
378 except asyncio.CancelledError:
379 raise # in 3.7 and earlier, this is an Exception, not BaseException
380 except (socket.timeout, asyncio.TimeoutError):
381 e = TimeoutError("Timeout connecting to server")
382 await record_error_count(
383 server_address=getattr(self, "host", None),
384 server_port=getattr(self, "port", None),
385 network_peer_address=getattr(self, "host", None),
386 network_peer_port=getattr(self, "port", None),
387 error_type=e,
388 retry_attempts=actual_retry_attempts,
389 is_internal=False,
390 )
391 raise e
392 except OSError as e:
393 e = ConnectionError(self._error_message(e))
394 await record_error_count(
395 server_address=getattr(self, "host", None),
396 server_port=getattr(self, "port", None),
397 network_peer_address=getattr(self, "host", None),
398 network_peer_port=getattr(self, "port", None),
399 error_type=e,
400 retry_attempts=actual_retry_attempts,
401 is_internal=False,
402 )
403 raise e
404 except Exception as exc:
405 raise ConnectionError(exc) from exc
406
407 try:
408 if not self.redis_connect_func:
409 # Use the default on_connect function
410 await self.on_connect_check_health(check_health=check_health)
411 else:
412 # Use the passed function redis_connect_func
413 (
414 await self.redis_connect_func(self)
415 if asyncio.iscoroutinefunction(self.redis_connect_func)
416 else self.redis_connect_func(self)
417 )
418 except RedisError:
419 # clean up after any error in on_connect
420 await self.disconnect()
421 raise
422
423 # run any user callbacks. right now the only internal callback
424 # is for pubsub channel/pattern resubscription
425 # first, remove any dead weakrefs
426 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
427 for ref in self._connect_callbacks:
428 callback = ref()
429 task = callback(self)
430 if task and inspect.isawaitable(task):
431 await task
432
433 def mark_for_reconnect(self):
434 self._should_reconnect = True
435
436 def should_reconnect(self):
437 return self._should_reconnect
438
439 def reset_should_reconnect(self):
440 self._should_reconnect = False
441
442 @abstractmethod
443 async def _connect(self):
444 pass
445
446 @abstractmethod
447 def _host_error(self) -> str:
448 pass
449
450 def _error_message(self, exception: BaseException) -> str:
451 return format_error_message(self._host_error(), exception)
452
453 def get_protocol(self):
454 return self.protocol
455
456 async def on_connect(self) -> None:
457 """Initialize the connection, authenticate and select a database"""
458 await self.on_connect_check_health(check_health=True)
459
460 async def on_connect_check_health(self, check_health: bool = True) -> None:
461 self._parser.on_connect(self)
462 parser = self._parser
463
464 auth_args = None
465 # if credential provider or username and/or password are set, authenticate
466 if self.credential_provider or (self.username or self.password):
467 cred_provider = (
468 self.credential_provider
469 or UsernamePasswordCredentialProvider(self.username, self.password)
470 )
471 auth_args = await cred_provider.get_credentials_async()
472
473 # if resp version is specified and we have auth args,
474 # we need to send them via HELLO
475 if auth_args and self.protocol not in [2, "2"]:
476 if isinstance(self._parser, _AsyncRESP2Parser):
477 self.set_parser(_AsyncRESP3Parser)
478 # update cluster exception classes
479 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
480 self._parser.on_connect(self)
481 if len(auth_args) == 1:
482 auth_args = ["default", auth_args[0]]
483 # avoid checking health here -- PING will fail if we try
484 # to check the health prior to the AUTH
485 await self.send_command(
486 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
487 )
488 response = await self.read_response()
489 if response.get(b"proto") != int(self.protocol) and response.get(
490 "proto"
491 ) != int(self.protocol):
492 raise ConnectionError("Invalid RESP version")
493 # avoid checking health here -- PING will fail if we try
494 # to check the health prior to the AUTH
495 elif auth_args:
496 await self.send_command("AUTH", *auth_args, check_health=False)
497
498 try:
499 auth_response = await self.read_response()
500 except AuthenticationWrongNumberOfArgsError:
501 # a username and password were specified but the Redis
502 # server seems to be < 6.0.0 which expects a single password
503 # arg. retry auth with just the password.
504 # https://github.com/andymccurdy/redis-py/issues/1274
505 await self.send_command("AUTH", auth_args[-1], check_health=False)
506 auth_response = await self.read_response()
507
508 if str_if_bytes(auth_response) != "OK":
509 raise AuthenticationError("Invalid Username or Password")
510
511 # if resp version is specified, switch to it
512 elif self.protocol not in [2, "2"]:
513 if isinstance(self._parser, _AsyncRESP2Parser):
514 self.set_parser(_AsyncRESP3Parser)
515 # update cluster exception classes
516 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
517 self._parser.on_connect(self)
518 await self.send_command("HELLO", self.protocol, check_health=check_health)
519 response = await self.read_response()
520 # if response.get(b"proto") != self.protocol and response.get(
521 # "proto"
522 # ) != self.protocol:
523 # raise ConnectionError("Invalid RESP version")
524
525 # if a client_name is given, set it
526 if self.client_name:
527 await self.send_command(
528 "CLIENT",
529 "SETNAME",
530 self.client_name,
531 check_health=check_health,
532 )
533 if str_if_bytes(await self.read_response()) != "OK":
534 raise ConnectionError("Error setting client name")
535
536 # Set the library name and version from driver_info, pipeline for lower startup latency
537 lib_name_sent = False
538 lib_version_sent = False
539
540 if self.driver_info and self.driver_info.formatted_name:
541 await self.send_command(
542 "CLIENT",
543 "SETINFO",
544 "LIB-NAME",
545 self.driver_info.formatted_name,
546 check_health=check_health,
547 )
548 lib_name_sent = True
549
550 if self.driver_info and self.driver_info.lib_version:
551 await self.send_command(
552 "CLIENT",
553 "SETINFO",
554 "LIB-VER",
555 self.driver_info.lib_version,
556 check_health=check_health,
557 )
558 lib_version_sent = True
559
560 # if a database is specified, switch to it. Also pipeline this
561 if self.db:
562 await self.send_command("SELECT", self.db, check_health=check_health)
563
564 # read responses from pipeline
565 for _ in range(sum([lib_name_sent, lib_version_sent])):
566 try:
567 await self.read_response()
568 except ResponseError:
569 pass
570
571 if self.db:
572 if str_if_bytes(await self.read_response()) != "OK":
573 raise ConnectionError("Invalid Database")
574
575 async def disconnect(
576 self,
577 nowait: bool = False,
578 error: Optional[Exception] = None,
579 failure_count: Optional[int] = None,
580 health_check_failed: bool = False,
581 ) -> None:
582 """Disconnects from the Redis server"""
583 # On Python 3.13+, asyncio.timeout() raises RuntimeError when called
584 # outside a running Task (e.g. during GC finalization or event-loop
585 # callbacks). In that context we fall back to a synchronous close.
586 # See https://github.com/redis/redis-py/issues/3856
587 if asyncio.current_task() is None:
588 self._parser.on_disconnect()
589 self.reset_should_reconnect()
590 self._close()
591 return
592
593 try:
594 async with async_timeout(self.socket_connect_timeout):
595 self._parser.on_disconnect()
596 # Reset the reconnect flag
597 self.reset_should_reconnect()
598 if not self.is_connected:
599 return
600 try:
601 self._writer.close() # type: ignore[union-attr]
602 # wait for close to finish, except when handling errors and
603 # forcefully disconnecting.
604 if not nowait:
605 await self._writer.wait_closed() # type: ignore[union-attr]
606 except OSError:
607 pass
608 finally:
609 self._reader = None
610 self._writer = None
611 except asyncio.TimeoutError:
612 raise TimeoutError(
613 f"Timed out closing connection after {self.socket_connect_timeout}"
614 ) from None
615
616 if error:
617 if health_check_failed:
618 close_reason = CloseReason.HEALTHCHECK_FAILED
619 else:
620 close_reason = CloseReason.ERROR
621
622 if failure_count is not None and failure_count > self.retry.get_retries():
623 await record_error_count(
624 server_address=getattr(self, "host", None),
625 server_port=getattr(self, "port", None),
626 network_peer_address=getattr(self, "host", None),
627 network_peer_port=getattr(self, "port", None),
628 error_type=error,
629 retry_attempts=failure_count,
630 )
631
632 await record_connection_closed(
633 close_reason=close_reason,
634 error_type=error,
635 )
636 else:
637 await record_connection_closed(
638 close_reason=CloseReason.APPLICATION_CLOSE,
639 )
640
641 async def _send_ping(self):
642 """Send PING, expect PONG in return"""
643 await self.send_command("PING", check_health=False)
644 if str_if_bytes(await self.read_response()) != "PONG":
645 raise ConnectionError("Bad response from PING health check")
646
647 async def _ping_failed(self, error, failure_count):
648 """Function to call when PING fails"""
649 await self.disconnect(
650 error=error, failure_count=failure_count, health_check_failed=True
651 )
652
653 async def check_health(self):
654 """Check the health of the connection with a PING/PONG"""
655 if (
656 self.health_check_interval
657 and asyncio.get_running_loop().time() > self.next_health_check
658 ):
659 await self.retry.call_with_retry(
660 self._send_ping, self._ping_failed, with_failure_count=True
661 )
662
663 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
664 self._writer.writelines(command)
665 await self._writer.drain()
666
667 async def send_packed_command(
668 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
669 ) -> None:
670 if not self.is_connected:
671 await self.connect_check_health(check_health=False)
672 if check_health:
673 await self.check_health()
674
675 try:
676 if isinstance(command, str):
677 command = command.encode()
678 if isinstance(command, bytes):
679 command = [command]
680 if self.socket_timeout:
681 await asyncio.wait_for(
682 self._send_packed_command(command), self.socket_timeout
683 )
684 else:
685 self._writer.writelines(command)
686 await self._writer.drain()
687 except asyncio.TimeoutError:
688 await self.disconnect(nowait=True)
689 raise TimeoutError("Timeout writing to socket") from None
690 except OSError as e:
691 await self.disconnect(nowait=True)
692 if len(e.args) == 1:
693 err_no, errmsg = "UNKNOWN", e.args[0]
694 else:
695 err_no = e.args[0]
696 errmsg = e.args[1]
697 raise ConnectionError(
698 f"Error {err_no} while writing to socket. {errmsg}."
699 ) from e
700 except BaseException:
701 # BaseExceptions can be raised when a socket send operation is not
702 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
703 # to send un-sent data. However, the send_packed_command() API
704 # does not support it so there is no point in keeping the connection open.
705 await self.disconnect(nowait=True)
706 raise
707
708 async def send_command(self, *args: Any, **kwargs: Any) -> None:
709 """Pack and send a command to the Redis server"""
710 await self.send_packed_command(
711 self.pack_command(*args), check_health=kwargs.get("check_health", True)
712 )
713
714 async def can_read_destructive(self):
715 """Poll the socket to see if there's data that can be read."""
716 try:
717 return await self._parser.can_read_destructive()
718 except OSError as e:
719 await self.disconnect(nowait=True)
720 host_error = self._host_error()
721 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
722
723 async def read_response(
724 self,
725 disable_decoding: bool = False,
726 timeout: Optional[float] = None,
727 *,
728 disconnect_on_error: bool = True,
729 push_request: Optional[bool] = False,
730 ):
731 """Read the response from a previously sent command"""
732 read_timeout = timeout if timeout is not None else self.socket_timeout
733 host_error = self._host_error()
734 try:
735 if read_timeout is not None and self.protocol in ["3", 3]:
736 async with async_timeout(read_timeout):
737 response = await self._parser.read_response(
738 disable_decoding=disable_decoding, push_request=push_request
739 )
740 elif read_timeout is not None:
741 async with async_timeout(read_timeout):
742 response = await self._parser.read_response(
743 disable_decoding=disable_decoding
744 )
745 elif self.protocol in ["3", 3]:
746 response = await self._parser.read_response(
747 disable_decoding=disable_decoding, push_request=push_request
748 )
749 else:
750 response = await self._parser.read_response(
751 disable_decoding=disable_decoding
752 )
753 except asyncio.TimeoutError:
754 if timeout is not None:
755 # user requested timeout, return None. Operation can be retried
756 return None
757 # it was a self.socket_timeout error.
758 if disconnect_on_error:
759 await self.disconnect(nowait=True)
760 raise TimeoutError(f"Timeout reading from {host_error}")
761 except OSError as e:
762 if disconnect_on_error:
763 await self.disconnect(nowait=True)
764 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
765 except BaseException:
766 # Also by default close in case of BaseException. A lot of code
767 # relies on this behaviour when doing Command/Response pairs.
768 # See #1128.
769 if disconnect_on_error:
770 await self.disconnect(nowait=True)
771 raise
772
773 if self.health_check_interval:
774 next_time = asyncio.get_running_loop().time() + self.health_check_interval
775 self.next_health_check = next_time
776
777 if isinstance(response, ResponseError):
778 raise response from None
779 return response
780
781 def pack_command(self, *args: EncodableT) -> List[bytes]:
782 """Pack a series of arguments into the Redis protocol"""
783 output = []
784 # the client might have included 1 or more literal arguments in
785 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
786 # arguments to be sent separately, so split the first argument
787 # manually. These arguments should be bytestrings so that they are
788 # not encoded.
789 assert not isinstance(args[0], float)
790 if isinstance(args[0], str):
791 args = tuple(args[0].encode().split()) + args[1:]
792 elif b" " in args[0]:
793 args = tuple(args[0].split()) + args[1:]
794
795 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
796
797 buffer_cutoff = self._buffer_cutoff
798 for arg in map(self.encoder.encode, args):
799 # to avoid large string mallocs, chunk the command into the
800 # output list if we're sending large values or memoryviews
801 arg_length = len(arg)
802 if (
803 len(buff) > buffer_cutoff
804 or arg_length > buffer_cutoff
805 or isinstance(arg, memoryview)
806 ):
807 buff = SYM_EMPTY.join(
808 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
809 )
810 output.append(buff)
811 output.append(arg)
812 buff = SYM_CRLF
813 else:
814 buff = SYM_EMPTY.join(
815 (
816 buff,
817 SYM_DOLLAR,
818 str(arg_length).encode(),
819 SYM_CRLF,
820 arg,
821 SYM_CRLF,
822 )
823 )
824 output.append(buff)
825 return output
826
827 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
828 """Pack multiple commands into the Redis protocol"""
829 output: List[bytes] = []
830 pieces: List[bytes] = []
831 buffer_length = 0
832 buffer_cutoff = self._buffer_cutoff
833
834 for cmd in commands:
835 for chunk in self.pack_command(*cmd):
836 chunklen = len(chunk)
837 if (
838 buffer_length > buffer_cutoff
839 or chunklen > buffer_cutoff
840 or isinstance(chunk, memoryview)
841 ):
842 if pieces:
843 output.append(SYM_EMPTY.join(pieces))
844 buffer_length = 0
845 pieces = []
846
847 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
848 output.append(chunk)
849 else:
850 pieces.append(chunk)
851 buffer_length += chunklen
852
853 if pieces:
854 output.append(SYM_EMPTY.join(pieces))
855 return output
856
857 def _socket_is_empty(self):
858 """Check if the socket is empty"""
859 return len(self._reader._buffer) == 0
860
861 async def process_invalidation_messages(self):
862 while not self._socket_is_empty():
863 await self.read_response(push_request=True)
864
865 def set_re_auth_token(self, token: TokenInterface):
866 self._re_auth_token = token
867
868 async def re_auth(self):
869 if self._re_auth_token is not None:
870 await self.send_command(
871 "AUTH",
872 self._re_auth_token.try_get("oid"),
873 self._re_auth_token.get_value(),
874 )
875 await self.read_response()
876 self._re_auth_token = None
877
878
879class Connection(AbstractConnection):
880 "Manages TCP communication to and from a Redis server"
881
882 def __init__(
883 self,
884 *,
885 host: str = "localhost",
886 port: Union[str, int] = 6379,
887 socket_keepalive: bool = False,
888 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
889 socket_type: int = 0,
890 **kwargs,
891 ):
892 self.host = host
893 self.port = int(port)
894 self.socket_keepalive = socket_keepalive
895 self.socket_keepalive_options = socket_keepalive_options or {}
896 self.socket_type = socket_type
897 super().__init__(**kwargs)
898
899 def repr_pieces(self):
900 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
901 if self.client_name:
902 pieces.append(("client_name", self.client_name))
903 return pieces
904
905 def _connection_arguments(self) -> Mapping:
906 return {"host": self.host, "port": self.port}
907
908 async def _connect(self):
909 """Create a TCP socket connection"""
910 async with async_timeout(self.socket_connect_timeout):
911 reader, writer = await asyncio.open_connection(
912 **self._connection_arguments()
913 )
914 self._reader = reader
915 self._writer = writer
916 sock = writer.transport.get_extra_info("socket")
917 if sock:
918 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
919 try:
920 # TCP_KEEPALIVE
921 if self.socket_keepalive:
922 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
923 for k, v in self.socket_keepalive_options.items():
924 sock.setsockopt(socket.SOL_TCP, k, v)
925
926 except (OSError, TypeError):
927 # `socket_keepalive_options` might contain invalid options
928 # causing an error. Do not leave the connection open.
929 writer.close()
930 raise
931
932 def _host_error(self) -> str:
933 return f"{self.host}:{self.port}"
934
935
936class SSLConnection(Connection):
937 """Manages SSL connections to and from the Redis server(s).
938 This class extends the Connection class, adding SSL functionality, and making
939 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
940 """
941
942 def __init__(
943 self,
944 ssl_keyfile: Optional[str] = None,
945 ssl_certfile: Optional[str] = None,
946 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
947 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
948 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
949 ssl_ca_certs: Optional[str] = None,
950 ssl_ca_data: Optional[str] = None,
951 ssl_ca_path: Optional[str] = None,
952 ssl_check_hostname: bool = True,
953 ssl_min_version: Optional[TLSVersion] = None,
954 ssl_ciphers: Optional[str] = None,
955 ssl_password: Optional[str] = None,
956 **kwargs,
957 ):
958 if not SSL_AVAILABLE:
959 raise RedisError("Python wasn't built with SSL support")
960
961 self.ssl_context: RedisSSLContext = RedisSSLContext(
962 keyfile=ssl_keyfile,
963 certfile=ssl_certfile,
964 cert_reqs=ssl_cert_reqs,
965 include_verify_flags=ssl_include_verify_flags,
966 exclude_verify_flags=ssl_exclude_verify_flags,
967 ca_certs=ssl_ca_certs,
968 ca_data=ssl_ca_data,
969 ca_path=ssl_ca_path,
970 check_hostname=ssl_check_hostname,
971 min_version=ssl_min_version,
972 ciphers=ssl_ciphers,
973 password=ssl_password,
974 )
975 super().__init__(**kwargs)
976
977 def _connection_arguments(self) -> Mapping:
978 kwargs = super()._connection_arguments()
979 kwargs["ssl"] = self.ssl_context.get()
980 return kwargs
981
982 @property
983 def keyfile(self):
984 return self.ssl_context.keyfile
985
986 @property
987 def certfile(self):
988 return self.ssl_context.certfile
989
990 @property
991 def cert_reqs(self):
992 return self.ssl_context.cert_reqs
993
994 @property
995 def include_verify_flags(self):
996 return self.ssl_context.include_verify_flags
997
998 @property
999 def exclude_verify_flags(self):
1000 return self.ssl_context.exclude_verify_flags
1001
1002 @property
1003 def ca_certs(self):
1004 return self.ssl_context.ca_certs
1005
1006 @property
1007 def ca_data(self):
1008 return self.ssl_context.ca_data
1009
1010 @property
1011 def check_hostname(self):
1012 return self.ssl_context.check_hostname
1013
1014 @property
1015 def min_version(self):
1016 return self.ssl_context.min_version
1017
1018
1019class RedisSSLContext:
1020 __slots__ = (
1021 "keyfile",
1022 "certfile",
1023 "cert_reqs",
1024 "include_verify_flags",
1025 "exclude_verify_flags",
1026 "ca_certs",
1027 "ca_data",
1028 "ca_path",
1029 "context",
1030 "check_hostname",
1031 "min_version",
1032 "ciphers",
1033 "password",
1034 )
1035
1036 def __init__(
1037 self,
1038 keyfile: Optional[str] = None,
1039 certfile: Optional[str] = None,
1040 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
1041 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1042 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1043 ca_certs: Optional[str] = None,
1044 ca_data: Optional[str] = None,
1045 ca_path: Optional[str] = None,
1046 check_hostname: bool = False,
1047 min_version: Optional[TLSVersion] = None,
1048 ciphers: Optional[str] = None,
1049 password: Optional[str] = None,
1050 ):
1051 if not SSL_AVAILABLE:
1052 raise RedisError("Python wasn't built with SSL support")
1053
1054 self.keyfile = keyfile
1055 self.certfile = certfile
1056 if cert_reqs is None:
1057 cert_reqs = ssl.CERT_NONE
1058 elif isinstance(cert_reqs, str):
1059 CERT_REQS = { # noqa: N806
1060 "none": ssl.CERT_NONE,
1061 "optional": ssl.CERT_OPTIONAL,
1062 "required": ssl.CERT_REQUIRED,
1063 }
1064 if cert_reqs not in CERT_REQS:
1065 raise RedisError(
1066 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
1067 )
1068 cert_reqs = CERT_REQS[cert_reqs]
1069 self.cert_reqs = cert_reqs
1070 self.include_verify_flags = include_verify_flags
1071 self.exclude_verify_flags = exclude_verify_flags
1072 self.ca_certs = ca_certs
1073 self.ca_data = ca_data
1074 self.ca_path = ca_path
1075 self.check_hostname = (
1076 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1077 )
1078 self.min_version = min_version
1079 self.ciphers = ciphers
1080 self.password = password
1081 self.context: Optional[SSLContext] = None
1082
1083 def get(self) -> SSLContext:
1084 if not self.context:
1085 context = ssl.create_default_context()
1086 context.check_hostname = self.check_hostname
1087 context.verify_mode = self.cert_reqs
1088 if self.include_verify_flags:
1089 for flag in self.include_verify_flags:
1090 context.verify_flags |= flag
1091 if self.exclude_verify_flags:
1092 for flag in self.exclude_verify_flags:
1093 context.verify_flags &= ~flag
1094 if self.certfile or self.keyfile:
1095 context.load_cert_chain(
1096 certfile=self.certfile,
1097 keyfile=self.keyfile,
1098 password=self.password,
1099 )
1100 if self.ca_certs or self.ca_data or self.ca_path:
1101 context.load_verify_locations(
1102 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1103 )
1104 if self.min_version is not None:
1105 context.minimum_version = self.min_version
1106 if self.ciphers is not None:
1107 context.set_ciphers(self.ciphers)
1108 self.context = context
1109 return self.context
1110
1111
1112class UnixDomainSocketConnection(AbstractConnection):
1113 "Manages UDS communication to and from a Redis server"
1114
1115 def __init__(self, *, path: str = "", **kwargs):
1116 self.path = path
1117 super().__init__(**kwargs)
1118
1119 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
1120 pieces = [("path", self.path), ("db", self.db)]
1121 if self.client_name:
1122 pieces.append(("client_name", self.client_name))
1123 return pieces
1124
1125 async def _connect(self):
1126 async with async_timeout(self.socket_connect_timeout):
1127 reader, writer = await asyncio.open_unix_connection(path=self.path)
1128 self._reader = reader
1129 self._writer = writer
1130 await self.on_connect()
1131
1132 def _host_error(self) -> str:
1133 return self.path
1134
1135
1136FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1137
1138
1139def to_bool(value) -> Optional[bool]:
1140 if value is None or value == "":
1141 return None
1142 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1143 return False
1144 return bool(value)
1145
1146
1147def parse_ssl_verify_flags(value):
1148 # flags are passed in as a string representation of a list,
1149 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1150 verify_flags_str = value.replace("[", "").replace("]", "")
1151
1152 verify_flags = []
1153 for flag in verify_flags_str.split(","):
1154 flag = flag.strip()
1155 if not hasattr(VerifyFlags, flag):
1156 raise ValueError(f"Invalid ssl verify flag: {flag}")
1157 verify_flags.append(getattr(VerifyFlags, flag))
1158 return verify_flags
1159
1160
1161URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1162 {
1163 "db": int,
1164 "socket_timeout": float,
1165 "socket_connect_timeout": float,
1166 "socket_keepalive": to_bool,
1167 "retry_on_timeout": to_bool,
1168 "max_connections": int,
1169 "health_check_interval": int,
1170 "ssl_check_hostname": to_bool,
1171 "ssl_include_verify_flags": parse_ssl_verify_flags,
1172 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1173 "timeout": float,
1174 }
1175)
1176
1177
1178class ConnectKwargs(TypedDict, total=False):
1179 username: str
1180 password: str
1181 connection_class: Type[AbstractConnection]
1182 host: str
1183 port: int
1184 db: int
1185 path: str
1186
1187
1188def parse_url(url: str) -> ConnectKwargs:
1189 parsed: ParseResult = urlparse(url)
1190 kwargs: ConnectKwargs = {}
1191
1192 for name, value_list in parse_qs(parsed.query).items():
1193 if value_list and len(value_list) > 0:
1194 value = unquote(value_list[0])
1195 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1196 if parser:
1197 try:
1198 kwargs[name] = parser(value)
1199 except (TypeError, ValueError):
1200 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1201 else:
1202 kwargs[name] = value
1203
1204 if parsed.username:
1205 kwargs["username"] = unquote(parsed.username)
1206 if parsed.password:
1207 kwargs["password"] = unquote(parsed.password)
1208
1209 # We only support redis://, rediss:// and unix:// schemes.
1210 if parsed.scheme == "unix":
1211 if parsed.path:
1212 kwargs["path"] = unquote(parsed.path)
1213 kwargs["connection_class"] = UnixDomainSocketConnection
1214
1215 elif parsed.scheme in ("redis", "rediss"):
1216 if parsed.hostname:
1217 kwargs["host"] = unquote(parsed.hostname)
1218 if parsed.port:
1219 kwargs["port"] = int(parsed.port)
1220
1221 # If there's a path argument, use it as the db argument if a
1222 # querystring value wasn't specified
1223 if parsed.path and "db" not in kwargs:
1224 try:
1225 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1226 except (AttributeError, ValueError):
1227 pass
1228
1229 if parsed.scheme == "rediss":
1230 kwargs["connection_class"] = SSLConnection
1231
1232 else:
1233 valid_schemes = "redis://, rediss://, unix://"
1234 raise ValueError(
1235 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1236 )
1237
1238 return kwargs
1239
1240
1241_CP = TypeVar("_CP", bound="ConnectionPool")
1242
1243
1244class ConnectionPoolInterface(ABC):
1245 @abstractmethod
1246 def get_protocol(self):
1247 pass
1248
1249 @abstractmethod
1250 def reset(self) -> None:
1251 pass
1252
1253 @abstractmethod
1254 @deprecated_args(
1255 args_to_warn=["*"],
1256 reason="Use get_connection() without args instead",
1257 version="5.3.0",
1258 )
1259 async def get_connection(
1260 self, command_name: Optional[str] = None, *keys: Any, **options: Any
1261 ) -> "AbstractConnection":
1262 pass
1263
1264 @abstractmethod
1265 def get_encoder(self) -> "Encoder":
1266 pass
1267
1268 @abstractmethod
1269 async def release(self, connection: "AbstractConnection") -> None:
1270 pass
1271
1272 @abstractmethod
1273 async def disconnect(self, inuse_connections: bool = True) -> None:
1274 pass
1275
1276 @abstractmethod
1277 async def aclose(self) -> None:
1278 pass
1279
1280 @abstractmethod
1281 def set_retry(self, retry: "Retry") -> None:
1282 pass
1283
1284 @abstractmethod
1285 async def re_auth_callback(self, token: TokenInterface) -> None:
1286 pass
1287
1288 @abstractmethod
1289 def get_connection_count(self) -> List[Tuple[int, dict]]:
1290 """
1291 Returns a connection count (both idle and in use).
1292 """
1293 pass
1294
1295
1296class ConnectionPool(ConnectionPoolInterface):
1297 """
1298 Create a connection pool. ``If max_connections`` is set, then this
1299 object raises :py:class:`~redis.ConnectionError` when the pool's
1300 limit is reached.
1301
1302 By default, TCP connections are created unless ``connection_class``
1303 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1304 unix sockets.
1305 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1306
1307 Any additional keyword arguments are passed to the constructor of
1308 ``connection_class``.
1309 """
1310
1311 @classmethod
1312 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1313 """
1314 Return a connection pool configured from the given URL.
1315
1316 For example::
1317
1318 redis://[[username]:[password]]@localhost:6379/0
1319 rediss://[[username]:[password]]@localhost:6379/0
1320 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1321
1322 Three URL schemes are supported:
1323
1324 - `redis://` creates a TCP socket connection. See more at:
1325 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1326 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1327 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1328 - ``unix://``: creates a Unix Domain Socket connection.
1329
1330 The username, password, hostname, path and all querystring values
1331 are passed through urllib.parse.unquote in order to replace any
1332 percent-encoded values with their corresponding characters.
1333
1334 There are several ways to specify a database number. The first value
1335 found will be used:
1336
1337 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1338
1339 2. If using the redis:// or rediss:// schemes, the path argument
1340 of the url, e.g. redis://localhost/0
1341
1342 3. A ``db`` keyword argument to this function.
1343
1344 If none of these options are specified, the default db=0 is used.
1345
1346 All querystring options are cast to their appropriate Python types.
1347 Boolean arguments can be specified with string values "True"/"False"
1348 or "Yes"/"No". Values that cannot be properly cast cause a
1349 ``ValueError`` to be raised. Once parsed, the querystring arguments
1350 and keyword arguments are passed to the ``ConnectionPool``'s
1351 class initializer. In the case of conflicting arguments, querystring
1352 arguments always win.
1353 """
1354 url_options = parse_url(url)
1355 kwargs.update(url_options)
1356 return cls(**kwargs)
1357
1358 def __init__(
1359 self,
1360 connection_class: Type[AbstractConnection] = Connection,
1361 max_connections: Optional[int] = None,
1362 **connection_kwargs,
1363 ):
1364 max_connections = max_connections or 2**31
1365 if not isinstance(max_connections, int) or max_connections < 0:
1366 raise ValueError('"max_connections" must be a positive integer')
1367
1368 self.connection_class = connection_class
1369 self.connection_kwargs = connection_kwargs
1370 self.max_connections = max_connections
1371
1372 self._available_connections: List[AbstractConnection] = []
1373 self._in_use_connections: Set[AbstractConnection] = set()
1374 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1375 self._lock = asyncio.Lock()
1376 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1377 if self._event_dispatcher is None:
1378 self._event_dispatcher = EventDispatcher()
1379
1380 # Keys that should be redacted in __repr__ to avoid exposing sensitive information
1381 SENSITIVE_REPR_KEYS = frozenset(
1382 {
1383 "password",
1384 "username",
1385 "ssl_password",
1386 "credential_provider",
1387 }
1388 )
1389
1390 def __repr__(self):
1391 conn_kwargs = ",".join(
1392 [
1393 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}"
1394 for k, v in self.connection_kwargs.items()
1395 ]
1396 )
1397 return (
1398 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1399 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1400 f"({conn_kwargs})>)>"
1401 )
1402
1403 def get_protocol(self):
1404 """
1405 Returns:
1406 The RESP protocol version, or ``None`` if the protocol is not specified,
1407 in which case the server default will be used.
1408 """
1409 return self.connection_kwargs.get("protocol", None)
1410
1411 def reset(self):
1412 # Record metrics for connections being removed before clearing
1413 # (only if attributes exist - they won't during __init__)
1414 if hasattr(self, "_available_connections") and hasattr(
1415 self, "_in_use_connections"
1416 ):
1417 idle_count = len(self._available_connections)
1418 in_use_count = len(self._in_use_connections)
1419 if idle_count > 0 or in_use_count > 0:
1420 pool_name = get_pool_name(self)
1421 # Note: Using sync version since reset() is sync
1422 from redis.observability.recorder import (
1423 record_connection_count as sync_record_connection_count,
1424 )
1425
1426 if idle_count > 0:
1427 sync_record_connection_count(
1428 pool_name=pool_name,
1429 connection_state=ConnectionState.IDLE,
1430 counter=-idle_count,
1431 )
1432 if in_use_count > 0:
1433 sync_record_connection_count(
1434 pool_name=pool_name,
1435 connection_state=ConnectionState.USED,
1436 counter=-in_use_count,
1437 )
1438
1439 self._available_connections = []
1440 self._in_use_connections = weakref.WeakSet()
1441
1442 def __del__(self) -> None:
1443 """Clean up connection pool and record metrics when garbage collected."""
1444 try:
1445 if not hasattr(self, "_available_connections") or not hasattr(
1446 self, "_in_use_connections"
1447 ):
1448 return
1449 idle_count = len(self._available_connections)
1450 in_use_count = len(self._in_use_connections)
1451 if idle_count > 0 or in_use_count > 0:
1452 pool_name = get_pool_name(self)
1453 # Note: Using sync version since __del__ is sync
1454 from redis.observability.recorder import (
1455 record_connection_count as sync_record_connection_count,
1456 )
1457
1458 if idle_count > 0:
1459 sync_record_connection_count(
1460 pool_name=pool_name,
1461 connection_state=ConnectionState.IDLE,
1462 counter=-idle_count,
1463 )
1464 if in_use_count > 0:
1465 sync_record_connection_count(
1466 pool_name=pool_name,
1467 connection_state=ConnectionState.USED,
1468 counter=-in_use_count,
1469 )
1470 except Exception:
1471 pass
1472
1473 def can_get_connection(self) -> bool:
1474 """Return True if a connection can be retrieved from the pool."""
1475 return (
1476 self._available_connections
1477 or len(self._in_use_connections) < self.max_connections
1478 )
1479
1480 @deprecated_args(
1481 args_to_warn=["*"],
1482 reason="Use get_connection() without args instead",
1483 version="5.3.0",
1484 )
1485 async def get_connection(self, command_name=None, *keys, **options):
1486 """Get a connected connection from the pool"""
1487 # Track connection count before to detect if a new connection is created
1488 async with self._lock:
1489 connections_before = len(self._available_connections) + len(
1490 self._in_use_connections
1491 )
1492 start_time_created = time.monotonic()
1493 connection = self.get_available_connection()
1494 connections_after = len(self._available_connections) + len(
1495 self._in_use_connections
1496 )
1497 is_created = connections_after > connections_before
1498
1499 # Record state transition for observability
1500 # This ensures counters stay balanced if ensure_connection() fails and release() is called
1501 pool_name = get_pool_name(self)
1502 if is_created:
1503 # New connection created and acquired: just USED +1
1504 await record_connection_count(
1505 pool_name=pool_name,
1506 connection_state=ConnectionState.USED,
1507 counter=1,
1508 )
1509 else:
1510 # Existing connection acquired from pool: IDLE -> USED
1511 await record_connection_count(
1512 pool_name=pool_name,
1513 connection_state=ConnectionState.IDLE,
1514 counter=-1,
1515 )
1516 await record_connection_count(
1517 pool_name=pool_name,
1518 connection_state=ConnectionState.USED,
1519 counter=1,
1520 )
1521
1522 # We now perform the connection check outside of the lock.
1523 try:
1524 await self.ensure_connection(connection)
1525
1526 if is_created:
1527 await record_connection_create_time(
1528 connection_pool=self,
1529 duration_seconds=time.monotonic() - start_time_created,
1530 )
1531
1532 return connection
1533 except BaseException:
1534 await self.release(connection)
1535 raise
1536
1537 def get_available_connection(self):
1538 """Get a connection from the pool, without making sure it is connected"""
1539 try:
1540 connection = self._available_connections.pop()
1541 except IndexError:
1542 if len(self._in_use_connections) >= self.max_connections:
1543 raise MaxConnectionsError("Too many connections") from None
1544 connection = self.make_connection()
1545 self._in_use_connections.add(connection)
1546 return connection
1547
1548 def get_encoder(self):
1549 """Return an encoder based on encoding settings"""
1550 kwargs = self.connection_kwargs
1551 return self.encoder_class(
1552 encoding=kwargs.get("encoding", "utf-8"),
1553 encoding_errors=kwargs.get("encoding_errors", "strict"),
1554 decode_responses=kwargs.get("decode_responses", False),
1555 )
1556
1557 def make_connection(self):
1558 """Create a new connection. Can be overridden by child classes."""
1559 # Note: We don't record IDLE here because async uses a sync make_connection
1560 # but async record_connection_count. The recording is handled in get_connection.
1561 return self.connection_class(**self.connection_kwargs)
1562
1563 async def ensure_connection(self, connection: AbstractConnection):
1564 """Ensure that the connection object is connected and valid"""
1565 await connection.connect()
1566 # connections that the pool provides should be ready to send
1567 # a command. if not, the connection was either returned to the
1568 # pool before all data has been read or the socket has been
1569 # closed. either way, reconnect and verify everything is good.
1570 try:
1571 if await connection.can_read_destructive():
1572 raise ConnectionError("Connection has data") from None
1573 except (ConnectionError, TimeoutError, OSError):
1574 await connection.disconnect()
1575 await connection.connect()
1576 if await connection.can_read_destructive():
1577 raise ConnectionError("Connection not ready") from None
1578
1579 async def release(self, connection: AbstractConnection):
1580 """Releases the connection back to the pool"""
1581 # Connections should always be returned to the correct pool,
1582 # not doing so is an error that will cause an exception here.
1583 self._in_use_connections.remove(connection)
1584
1585 if connection.should_reconnect():
1586 await connection.disconnect()
1587
1588 self._available_connections.append(connection)
1589 await self._event_dispatcher.dispatch_async(
1590 AsyncAfterConnectionReleasedEvent(connection)
1591 )
1592
1593 # Record state transition: USED -> IDLE
1594 pool_name = get_pool_name(self)
1595 await record_connection_count(
1596 pool_name=pool_name,
1597 connection_state=ConnectionState.USED,
1598 counter=-1,
1599 )
1600 await record_connection_count(
1601 pool_name=pool_name,
1602 connection_state=ConnectionState.IDLE,
1603 counter=1,
1604 )
1605
1606 async def disconnect(self, inuse_connections: bool = True):
1607 """
1608 Disconnects connections in the pool
1609
1610 If ``inuse_connections`` is True, disconnect connections that are
1611 current in use, potentially by other tasks. Otherwise only disconnect
1612 connections that are idle in the pool.
1613 """
1614 if inuse_connections:
1615 connections: Iterable[AbstractConnection] = chain(
1616 self._available_connections, self._in_use_connections
1617 )
1618 else:
1619 connections = self._available_connections
1620 resp = await asyncio.gather(
1621 *(connection.disconnect() for connection in connections),
1622 return_exceptions=True,
1623 )
1624
1625 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1626 if exc:
1627 raise exc
1628
1629 async def update_active_connections_for_reconnect(self):
1630 """
1631 Mark all active connections for reconnect.
1632 """
1633 async with self._lock:
1634 for conn in self._in_use_connections:
1635 conn.mark_for_reconnect()
1636
1637 async def aclose(self) -> None:
1638 """Close the pool, disconnecting all connections"""
1639 await self.disconnect()
1640
1641 def set_retry(self, retry: "Retry") -> None:
1642 for conn in self._available_connections:
1643 conn.retry = retry
1644 for conn in self._in_use_connections:
1645 conn.retry = retry
1646
1647 async def re_auth_callback(self, token: TokenInterface):
1648 async with self._lock:
1649 for conn in self._available_connections:
1650 await conn.retry.call_with_retry(
1651 lambda: conn.send_command(
1652 "AUTH", token.try_get("oid"), token.get_value()
1653 ),
1654 lambda error: self._mock(error),
1655 )
1656 await conn.retry.call_with_retry(
1657 lambda: conn.read_response(), lambda error: self._mock(error)
1658 )
1659 for conn in self._in_use_connections:
1660 conn.set_re_auth_token(token)
1661
1662 async def _mock(self, error: RedisError):
1663 """
1664 Dummy functions, needs to be passed as error callback to retry object.
1665 :param error:
1666 :return:
1667 """
1668 pass
1669
1670 def get_connection_count(self) -> List[tuple[int, dict]]:
1671 """
1672 Returns a connection count (both idle and in use).
1673 """
1674 attributes = AttributeBuilder.build_base_attributes()
1675 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
1676 free_connections_attributes = attributes.copy()
1677 in_use_connections_attributes = attributes.copy()
1678
1679 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1680 ConnectionState.IDLE.value
1681 )
1682 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1683 ConnectionState.USED.value
1684 )
1685
1686 return [
1687 (len(self._available_connections), free_connections_attributes),
1688 (len(self._in_use_connections), in_use_connections_attributes),
1689 ]
1690
1691
1692class BlockingConnectionPool(ConnectionPool):
1693 """
1694 A blocking connection pool::
1695
1696 >>> from redis.asyncio import Redis, BlockingConnectionPool
1697 >>> client = Redis.from_pool(BlockingConnectionPool())
1698
1699 It performs the same function as the default
1700 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1701 it maintains a pool of reusable connections that can be shared by
1702 multiple async redis clients.
1703
1704 The difference is that, in the event that a client tries to get a
1705 connection from the pool when all of connections are in use, rather than
1706 raising a :py:class:`~redis.ConnectionError` (as the default
1707 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1708 blocks the current `Task` for a specified number of seconds until
1709 a connection becomes available.
1710
1711 Use ``max_connections`` to increase / decrease the pool size::
1712
1713 >>> pool = BlockingConnectionPool(max_connections=10)
1714
1715 Use ``timeout`` to tell it either how many seconds to wait for a connection
1716 to become available, or to block forever:
1717
1718 >>> # Block forever.
1719 >>> pool = BlockingConnectionPool(timeout=None)
1720
1721 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1722 >>> # not available.
1723 >>> pool = BlockingConnectionPool(timeout=5)
1724 """
1725
1726 def __init__(
1727 self,
1728 max_connections: int = 50,
1729 timeout: Optional[float] = 20,
1730 connection_class: Type[AbstractConnection] = Connection,
1731 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1732 **connection_kwargs,
1733 ):
1734 super().__init__(
1735 connection_class=connection_class,
1736 max_connections=max_connections,
1737 **connection_kwargs,
1738 )
1739 self._condition = asyncio.Condition()
1740 self.timeout = timeout
1741
1742 @deprecated_args(
1743 args_to_warn=["*"],
1744 reason="Use get_connection() without args instead",
1745 version="5.3.0",
1746 )
1747 async def get_connection(self, command_name=None, *keys, **options):
1748 """Gets a connection from the pool, blocking until one is available"""
1749 # Start timing for wait time observability
1750 start_time_acquired = time.monotonic()
1751
1752 try:
1753 async with self._condition:
1754 async with async_timeout(self.timeout):
1755 await self._condition.wait_for(self.can_get_connection)
1756 # Track connection count before to detect if a new connection is created
1757 connections_before = len(self._available_connections) + len(
1758 self._in_use_connections
1759 )
1760 start_time_created = time.monotonic()
1761 connection = super().get_available_connection()
1762 connections_after = len(self._available_connections) + len(
1763 self._in_use_connections
1764 )
1765 is_created = connections_after > connections_before
1766 except asyncio.TimeoutError as err:
1767 raise ConnectionError("No connection available.") from err
1768
1769 # We now perform the connection check outside of the lock.
1770 try:
1771 await self.ensure_connection(connection)
1772
1773 if is_created:
1774 await record_connection_create_time(
1775 connection_pool=self,
1776 duration_seconds=time.monotonic() - start_time_created,
1777 )
1778
1779 await record_connection_wait_time(
1780 pool_name=get_pool_name(self),
1781 duration_seconds=time.monotonic() - start_time_acquired,
1782 )
1783
1784 return connection
1785 except BaseException:
1786 await self.release(connection)
1787 raise
1788
1789 async def release(self, connection: AbstractConnection):
1790 """Releases the connection back to the pool."""
1791 async with self._condition:
1792 await super().release(connection)
1793 self._condition.notify()