1import asyncio
2import copy
3import inspect
4import socket
5import sys
6import time
7import warnings
8import weakref
9from abc import ABC, abstractmethod
10from itertools import chain
11from types import MappingProxyType
12from typing import (
13 Any,
14 Callable,
15 Iterable,
16 List,
17 Mapping,
18 Optional,
19 Protocol,
20 Set,
21 Tuple,
22 Type,
23 TypedDict,
24 TypeVar,
25 Union,
26)
27from urllib.parse import ParseResult, parse_qs, unquote, urlparse
28
29from ..observability.attributes import (
30 DB_CLIENT_CONNECTION_POOL_NAME,
31 DB_CLIENT_CONNECTION_STATE,
32 AttributeBuilder,
33 ConnectionState,
34 get_pool_name,
35)
36from ..utils import SSL_AVAILABLE, deprecated_function
37
38if SSL_AVAILABLE:
39 import ssl
40 from ssl import SSLContext, TLSVersion, VerifyFlags
41else:
42 ssl = None
43 TLSVersion = None
44 SSLContext = None
45 VerifyFlags = None
46
47from ..auth.token import TokenInterface
48from ..driver_info import DriverInfo, resolve_driver_info
49from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
50from ..utils import deprecated_args, format_error_message
51
52# the functionality is available in 3.11.x but has a major issue before
53# 3.11.3. See https://github.com/redis/redis-py/issues/2633
54if sys.version_info >= (3, 11, 3):
55 from asyncio import timeout as async_timeout
56else:
57 from async_timeout import timeout as async_timeout
58
59from redis.asyncio.observability.recorder import (
60 record_connection_closed,
61 record_connection_count,
62 record_connection_create_time,
63 record_connection_wait_time,
64 record_error_count,
65)
66from redis.asyncio.retry import Retry
67from redis.backoff import NoBackoff
68from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
69from redis.exceptions import (
70 AuthenticationError,
71 AuthenticationWrongNumberOfArgsError,
72 ConnectionError,
73 DataError,
74 MaxConnectionsError,
75 RedisError,
76 ResponseError,
77 TimeoutError,
78)
79from redis.observability.metrics import CloseReason
80from redis.typing import EncodableT
81from redis.utils import (
82 DEFAULT_RESP_VERSION,
83 HIREDIS_AVAILABLE,
84 SENTINEL,
85 str_if_bytes,
86)
87
88from .._defaults import (
89 DEFAULT_SOCKET_CONNECT_TIMEOUT,
90 DEFAULT_SOCKET_READ_SIZE,
91 DEFAULT_SOCKET_TIMEOUT,
92 get_default_socket_keepalive_options,
93)
94from .._parsers import (
95 BaseParser,
96 Encoder,
97 _AsyncHiredisParser,
98 _AsyncRESP2Parser,
99 _AsyncRESP3Parser,
100)
101
102SYM_STAR = b"*"
103SYM_DOLLAR = b"$"
104SYM_CRLF = b"\r\n"
105SYM_LF = b"\n"
106SYM_EMPTY = b""
107
108
109DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
110if HIREDIS_AVAILABLE:
111 DefaultParser = _AsyncHiredisParser
112else:
113 DefaultParser = _AsyncRESP3Parser
114
115
116class ConnectCallbackProtocol(Protocol):
117 def __call__(self, connection: "AbstractConnection"): ...
118
119
120class AsyncConnectCallbackProtocol(Protocol):
121 async def __call__(self, connection: "AbstractConnection"): ...
122
123
124ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
125
126
127class AbstractConnection:
128 """Manages communication to and from a Redis server"""
129
130 __slots__ = (
131 "db",
132 "username",
133 "client_name",
134 "lib_name",
135 "lib_version",
136 "credential_provider",
137 "password",
138 "socket_timeout",
139 "socket_connect_timeout",
140 "redis_connect_func",
141 "retry_on_timeout",
142 "retry_on_error",
143 "health_check_interval",
144 "next_health_check",
145 "last_active_at",
146 "encoder",
147 "ssl_context",
148 "protocol",
149 "_reader",
150 "_writer",
151 "_parser",
152 "_connect_callbacks",
153 "_buffer_cutoff",
154 "_lock",
155 "_socket_read_size",
156 "__dict__",
157 )
158
159 @deprecated_args(
160 args_to_warn=["lib_name", "lib_version"],
161 reason="Use 'driver_info' parameter instead. "
162 "lib_name and lib_version will be removed in a future version.",
163 )
164 def __init__(
165 self,
166 *,
167 db: str | int = 0,
168 password: str | None = None,
169 socket_timeout: float | None = DEFAULT_SOCKET_TIMEOUT,
170 socket_connect_timeout: float | None = DEFAULT_SOCKET_CONNECT_TIMEOUT,
171 retry_on_timeout: bool = False,
172 retry_on_error: list | object = SENTINEL,
173 encoding: str = "utf-8",
174 encoding_errors: str = "strict",
175 decode_responses: bool = False,
176 parser_class: Type[BaseParser] = DefaultParser,
177 socket_read_size: int = DEFAULT_SOCKET_READ_SIZE,
178 health_check_interval: float = 0,
179 client_name: str | None = None,
180 lib_name: str | object | None = SENTINEL,
181 lib_version: str | object | None = SENTINEL,
182 driver_info: DriverInfo | object | None = SENTINEL,
183 username: str | None = None,
184 retry: Retry | None = None,
185 redis_connect_func: ConnectCallbackT | None = None,
186 encoder_class: Type[Encoder] = Encoder,
187 credential_provider: CredentialProvider | None = None,
188 protocol: int | None = None,
189 legacy_responses: bool = True,
190 event_dispatcher: EventDispatcher | None = None,
191 ):
192 """
193 Initialize a new async Connection.
194
195 Parameters
196 ----------
197 driver_info : DriverInfo, optional
198 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
199 are ignored. If not provided, a DriverInfo will be created from lib_name
200 and lib_version. Explicit None disables CLIENT SETINFO.
201 lib_name : str, optional
202 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
203 lib_version : str, optional
204 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
205 """
206 if (username or password) and credential_provider is not None:
207 raise DataError(
208 "'username' and 'password' cannot be passed along with 'credential_"
209 "provider'. Please provide only one of the following arguments: \n"
210 "1. 'password' and (optional) 'username'\n"
211 "2. 'credential_provider'"
212 )
213 if event_dispatcher is None:
214 self._event_dispatcher = EventDispatcher()
215 else:
216 self._event_dispatcher = event_dispatcher
217 self.db = db
218 self.client_name = client_name
219
220 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version.
221 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
222
223 self.credential_provider = credential_provider
224 self.password = password
225 self.username = username
226 self.socket_timeout = socket_timeout
227 if socket_connect_timeout is None:
228 socket_connect_timeout = socket_timeout
229 self.socket_connect_timeout = socket_connect_timeout
230 self.retry_on_timeout = retry_on_timeout
231 if retry_on_error is SENTINEL:
232 retry_on_error = []
233 if retry_on_timeout:
234 retry_on_error.append(TimeoutError)
235 retry_on_error.append(socket.timeout)
236 retry_on_error.append(asyncio.TimeoutError)
237 self.retry_on_error = retry_on_error
238 if retry or retry_on_error:
239 if not retry:
240 self.retry = Retry(NoBackoff(), 1)
241 else:
242 # deep-copy the Retry object as it is mutable
243 self.retry = copy.deepcopy(retry)
244 # Update the retry's supported errors with the specified errors
245 self.retry.update_supported_errors(retry_on_error)
246 else:
247 self.retry = Retry(NoBackoff(), 0)
248 self.health_check_interval = health_check_interval
249 self.next_health_check: float = -1
250 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
251 self.redis_connect_func = redis_connect_func
252 self._reader: Optional[asyncio.StreamReader] = None
253 self._writer: Optional[asyncio.StreamWriter] = None
254 self._socket_read_size = socket_read_size
255 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
256 self._buffer_cutoff = 6000
257 self._re_auth_token: Optional[TokenInterface] = None
258 self._should_reconnect = False
259
260 try:
261 p = int(protocol)
262 except TypeError:
263 p = DEFAULT_RESP_VERSION
264 except ValueError:
265 raise ConnectionError("protocol must be an integer")
266 else:
267 if p < 2 or p > 3:
268 raise ConnectionError("protocol must be either 2 or 3")
269 self.protocol = p
270 self.legacy_responses = legacy_responses
271 if parser_class != _AsyncHiredisParser:
272 # The Python parsers are protocol-specific; hiredis supports both.
273 if self.protocol == 3 and parser_class == _AsyncRESP2Parser:
274 parser_class = _AsyncRESP3Parser
275 elif self.protocol == 2 and parser_class == _AsyncRESP3Parser:
276 parser_class = _AsyncRESP2Parser
277 self.set_parser(parser_class)
278
279 def __del__(self, _warnings: Any = warnings):
280 # For some reason, the individual streams don't get properly garbage
281 # collected and therefore produce no resource warnings. We add one
282 # here, in the same style as those from the stdlib.
283 if getattr(self, "_writer", None):
284 _warnings.warn(
285 f"unclosed Connection {self!r}", ResourceWarning, source=self
286 )
287
288 try:
289 asyncio.get_running_loop()
290 self._close()
291 except RuntimeError:
292 # No actions been taken if pool already closed.
293 pass
294
295 def _close(self):
296 """
297 Internal method to silently close the connection without waiting
298 """
299 if self._writer:
300 self._writer.close()
301 self._writer = self._reader = None
302
303 def __repr__(self):
304 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
305 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
306
307 @abstractmethod
308 def repr_pieces(self):
309 pass
310
311 @property
312 def is_connected(self):
313 return self._reader is not None and self._writer is not None
314
315 def register_connect_callback(self, callback):
316 """
317 Register a callback to be called when the connection is established either
318 initially or reconnected. This allows listeners to issue commands that
319 are ephemeral to the connection, for example pub/sub subscription or
320 key tracking. The callback must be a _method_ and will be kept as
321 a weak reference.
322 """
323 wm = weakref.WeakMethod(callback)
324 if wm not in self._connect_callbacks:
325 self._connect_callbacks.append(wm)
326
327 def deregister_connect_callback(self, callback):
328 """
329 De-register a previously registered callback. It will no-longer receive
330 notifications on connection events. Calling this is not required when the
331 listener goes away, since the callbacks are kept as weak methods.
332 """
333 try:
334 self._connect_callbacks.remove(weakref.WeakMethod(callback))
335 except ValueError:
336 pass
337
338 def set_parser(self, parser_class: Type[BaseParser]) -> None:
339 """
340 Creates a new instance of parser_class with socket size:
341 _socket_read_size and assigns it to the parser for the connection
342 :param parser_class: The required parser class
343 """
344 self._parser = parser_class(socket_read_size=self._socket_read_size)
345
346 async def connect(self):
347 """Connects to the Redis server if not already connected"""
348 # try once the socket connect with the handshake, retry the whole
349 # connect/handshake flow based on retry policy
350 await self.retry.call_with_retry(
351 lambda: self.connect_check_health(
352 check_health=True, retry_socket_connect=False
353 ),
354 lambda error, failure_count: self.disconnect(
355 error=error, failure_count=failure_count
356 ),
357 with_failure_count=True,
358 )
359
360 async def connect_check_health(
361 self, check_health: bool = True, retry_socket_connect: bool = True
362 ):
363 if self.is_connected:
364 return
365 # Track actual retry attempts for error reporting
366 actual_retry_attempts = 0
367
368 def failure_callback(error, failure_count):
369 nonlocal actual_retry_attempts
370 actual_retry_attempts = failure_count
371 return self.disconnect(error=error, failure_count=failure_count)
372
373 try:
374 if retry_socket_connect:
375 await self.retry.call_with_retry(
376 lambda: self._connect(),
377 failure_callback,
378 with_failure_count=True,
379 )
380 else:
381 await self._connect()
382 except asyncio.CancelledError:
383 raise # in 3.7 and earlier, this is an Exception, not BaseException
384 except (socket.timeout, asyncio.TimeoutError):
385 e = TimeoutError("Timeout connecting to server")
386 await record_error_count(
387 server_address=getattr(self, "host", None),
388 server_port=getattr(self, "port", None),
389 network_peer_address=getattr(self, "host", None),
390 network_peer_port=getattr(self, "port", None),
391 error_type=e,
392 retry_attempts=actual_retry_attempts,
393 is_internal=False,
394 )
395 raise e
396 except OSError as e:
397 e = ConnectionError(self._error_message(e))
398 await record_error_count(
399 server_address=getattr(self, "host", None),
400 server_port=getattr(self, "port", None),
401 network_peer_address=getattr(self, "host", None),
402 network_peer_port=getattr(self, "port", None),
403 error_type=e,
404 retry_attempts=actual_retry_attempts,
405 is_internal=False,
406 )
407 raise e
408 except Exception as exc:
409 raise ConnectionError(exc) from exc
410
411 try:
412 if not self.redis_connect_func:
413 # Use the default on_connect function
414 await self.on_connect_check_health(check_health=check_health)
415 else:
416 # Use the passed function redis_connect_func
417 (
418 await self.redis_connect_func(self)
419 if asyncio.iscoroutinefunction(self.redis_connect_func)
420 else self.redis_connect_func(self)
421 )
422 except RedisError:
423 # clean up after any error in on_connect
424 await self.disconnect()
425 raise
426
427 # run any user callbacks. right now the only internal callback
428 # is for pubsub channel/pattern resubscription
429 # first, remove any dead weakrefs
430 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
431 for ref in self._connect_callbacks:
432 callback = ref()
433 task = callback(self)
434 if task and inspect.isawaitable(task):
435 await task
436
437 def mark_for_reconnect(self):
438 self._should_reconnect = True
439
440 def should_reconnect(self):
441 return self._should_reconnect
442
443 def reset_should_reconnect(self):
444 self._should_reconnect = False
445
446 @abstractmethod
447 async def _connect(self):
448 pass
449
450 @abstractmethod
451 def _host_error(self) -> str:
452 pass
453
454 def _error_message(self, exception: BaseException) -> str:
455 return format_error_message(self._host_error(), exception)
456
457 def get_protocol(self):
458 return self.protocol
459
460 async def on_connect(self) -> None:
461 """Initialize the connection, authenticate and select a database"""
462 await self.on_connect_check_health(check_health=True)
463
464 async def on_connect_check_health(self, check_health: bool = True) -> None:
465 self._parser.on_connect(self)
466 parser = self._parser
467
468 auth_args = None
469 # if credential provider or username and/or password are set, authenticate
470 if self.credential_provider or (self.username or self.password):
471 cred_provider = (
472 self.credential_provider
473 or UsernamePasswordCredentialProvider(self.username, self.password)
474 )
475 auth_args = await cred_provider.get_credentials_async()
476
477 # if resp version is specified and we have auth args,
478 # we need to send them via HELLO
479 if auth_args and self.protocol not in [2, "2"]:
480 if isinstance(self._parser, _AsyncRESP2Parser):
481 self.set_parser(_AsyncRESP3Parser)
482 # update cluster exception classes
483 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
484 self._parser.on_connect(self)
485 if len(auth_args) == 1:
486 auth_args = ["default", auth_args[0]]
487 # avoid checking health here -- PING will fail if we try
488 # to check the health prior to the AUTH
489 await self.send_command(
490 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
491 )
492 response = await self.read_response()
493 if response.get(b"proto") != int(self.protocol) and response.get(
494 "proto"
495 ) != int(self.protocol):
496 raise ConnectionError("Invalid RESP version")
497 # avoid checking health here -- PING will fail if we try
498 # to check the health prior to the AUTH
499 elif auth_args:
500 await self.send_command("AUTH", *auth_args, check_health=False)
501
502 try:
503 auth_response = await self.read_response()
504 except AuthenticationWrongNumberOfArgsError:
505 # a username and password were specified but the Redis
506 # server seems to be < 6.0.0 which expects a single password
507 # arg. retry auth with just the password.
508 # https://github.com/andymccurdy/redis-py/issues/1274
509 await self.send_command("AUTH", auth_args[-1], check_health=False)
510 auth_response = await self.read_response()
511
512 if str_if_bytes(auth_response) != "OK":
513 raise AuthenticationError("Invalid Username or Password")
514
515 # if resp version is specified, switch to it
516 elif self.protocol not in [2, "2"]:
517 if isinstance(self._parser, _AsyncRESP2Parser):
518 self.set_parser(_AsyncRESP3Parser)
519 # update cluster exception classes
520 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
521 self._parser.on_connect(self)
522 await self.send_command("HELLO", self.protocol, check_health=check_health)
523 response = await self.read_response()
524 # if response.get(b"proto") != self.protocol and response.get(
525 # "proto"
526 # ) != self.protocol:
527 # raise ConnectionError("Invalid RESP version")
528
529 # if a client_name is given, set it
530 if self.client_name:
531 await self.send_command(
532 "CLIENT",
533 "SETNAME",
534 self.client_name,
535 check_health=check_health,
536 )
537 if str_if_bytes(await self.read_response()) != "OK":
538 raise ConnectionError("Error setting client name")
539
540 # Set the library name and version from driver_info, pipeline for lower startup latency
541 lib_name_sent = False
542 lib_version_sent = False
543
544 if self.driver_info and self.driver_info.formatted_name:
545 await self.send_command(
546 "CLIENT",
547 "SETINFO",
548 "LIB-NAME",
549 self.driver_info.formatted_name,
550 check_health=check_health,
551 )
552 lib_name_sent = True
553
554 if self.driver_info and self.driver_info.lib_version:
555 await self.send_command(
556 "CLIENT",
557 "SETINFO",
558 "LIB-VER",
559 self.driver_info.lib_version,
560 check_health=check_health,
561 )
562 lib_version_sent = True
563
564 # if a database is specified, switch to it. Also pipeline this
565 if self.db:
566 await self.send_command("SELECT", self.db, check_health=check_health)
567
568 # read responses from pipeline
569 for _ in range(sum([lib_name_sent, lib_version_sent])):
570 try:
571 await self.read_response()
572 except ResponseError:
573 pass
574
575 if self.db:
576 if str_if_bytes(await self.read_response()) != "OK":
577 raise ConnectionError("Invalid Database")
578
579 async def disconnect(
580 self,
581 nowait: bool = False,
582 error: Optional[Exception] = None,
583 failure_count: Optional[int] = None,
584 health_check_failed: bool = False,
585 ) -> None:
586 """Disconnects from the Redis server"""
587 # On Python 3.13+, asyncio.timeout() raises RuntimeError when called
588 # outside a running Task (e.g. during GC finalization or event-loop
589 # callbacks). In that context we fall back to a synchronous close.
590 # See https://github.com/redis/redis-py/issues/3856
591 if asyncio.current_task() is None:
592 self._parser.on_disconnect()
593 self.reset_should_reconnect()
594 self._close()
595 return
596
597 try:
598 async with async_timeout(self.socket_connect_timeout):
599 self._parser.on_disconnect()
600 # Reset the reconnect flag
601 self.reset_should_reconnect()
602 if not self.is_connected:
603 return
604 try:
605 self._writer.close() # type: ignore[union-attr]
606 # wait for close to finish, except when handling errors and
607 # forcefully disconnecting.
608 if not nowait:
609 await self._writer.wait_closed() # type: ignore[union-attr]
610 except OSError:
611 pass
612 finally:
613 self._reader = None
614 self._writer = None
615 except asyncio.TimeoutError:
616 raise TimeoutError(
617 f"Timed out closing connection after {self.socket_connect_timeout}"
618 ) from None
619
620 if error:
621 if health_check_failed:
622 close_reason = CloseReason.HEALTHCHECK_FAILED
623 else:
624 close_reason = CloseReason.ERROR
625
626 if failure_count is not None and failure_count > self.retry.get_retries():
627 await record_error_count(
628 server_address=getattr(self, "host", None),
629 server_port=getattr(self, "port", None),
630 network_peer_address=getattr(self, "host", None),
631 network_peer_port=getattr(self, "port", None),
632 error_type=error,
633 retry_attempts=failure_count,
634 )
635
636 await record_connection_closed(
637 close_reason=close_reason,
638 error_type=error,
639 )
640 else:
641 await record_connection_closed(
642 close_reason=CloseReason.APPLICATION_CLOSE,
643 )
644
645 async def _send_ping(self):
646 """Send PING, expect PONG in return"""
647 await self.send_command("PING", check_health=False)
648 if str_if_bytes(await self.read_response()) != "PONG":
649 raise ConnectionError("Bad response from PING health check")
650
651 async def _ping_failed(self, error, failure_count):
652 """Function to call when PING fails"""
653 await self.disconnect(
654 error=error, failure_count=failure_count, health_check_failed=True
655 )
656
657 async def check_health(self):
658 """Check the health of the connection with a PING/PONG"""
659 if (
660 self.health_check_interval
661 and asyncio.get_running_loop().time() > self.next_health_check
662 ):
663 await self.retry.call_with_retry(
664 self._send_ping, self._ping_failed, with_failure_count=True
665 )
666
667 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
668 self._writer.writelines(command)
669 await self._writer.drain()
670
671 async def send_packed_command(
672 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
673 ) -> None:
674 if not self.is_connected:
675 await self.connect_check_health(check_health=False)
676 if check_health:
677 await self.check_health()
678
679 try:
680 if isinstance(command, str):
681 command = command.encode()
682 if isinstance(command, bytes):
683 command = [command]
684 if self.socket_timeout:
685 await asyncio.wait_for(
686 self._send_packed_command(command), self.socket_timeout
687 )
688 else:
689 self._writer.writelines(command)
690 await self._writer.drain()
691 except asyncio.TimeoutError:
692 await self.disconnect(nowait=True)
693 raise TimeoutError("Timeout writing to socket") from None
694 except OSError as e:
695 await self.disconnect(nowait=True)
696 if len(e.args) == 1:
697 err_no, errmsg = "UNKNOWN", e.args[0]
698 else:
699 err_no = e.args[0]
700 errmsg = e.args[1]
701 raise ConnectionError(
702 f"Error {err_no} while writing to socket. {errmsg}."
703 ) from e
704 except BaseException:
705 # BaseExceptions can be raised when a socket send operation is not
706 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
707 # to send un-sent data. However, the send_packed_command() API
708 # does not support it so there is no point in keeping the connection open.
709 await self.disconnect(nowait=True)
710 raise
711
712 async def send_command(self, *args: Any, **kwargs: Any) -> None:
713 """Pack and send a command to the Redis server"""
714 await self.send_packed_command(
715 self.pack_command(*args), check_health=kwargs.get("check_health", True)
716 )
717
718 @deprecated_function(
719 version="8.0.0", reason="Use can_read() instead", name="can_read_destructive"
720 )
721 async def can_read_destructive(self) -> bool:
722 """Check the socket to see if there's data loaded in the buffer."""
723 try:
724 return await self._parser.can_read()
725 except OSError as e:
726 await self.disconnect(nowait=True)
727 host_error = self._host_error()
728 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
729
730 async def can_read(self) -> bool:
731 """Check the socket to see if there's data loaded in the buffer."""
732 # TODO: Rename this API; it detects pending data or dirty/closed
733 # connection state, not only whether application data can be read.
734 try:
735 return await self._parser.can_read()
736 except OSError as e:
737 await self.disconnect(nowait=True)
738 host_error = self._host_error()
739 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
740
741 async def read_response(
742 self,
743 disable_decoding: bool = False,
744 timeout: Optional[float] = None,
745 *,
746 disconnect_on_error: bool = True,
747 push_request: Optional[bool] = False,
748 ):
749 """Read the response from a previously sent command"""
750 read_timeout = timeout if timeout is not None else self.socket_timeout
751 host_error = self._host_error()
752 try:
753 if read_timeout is not None and self.protocol in ["3", 3]:
754 async with async_timeout(read_timeout):
755 response = await self._parser.read_response(
756 disable_decoding=disable_decoding, push_request=push_request
757 )
758 elif read_timeout is not None:
759 async with async_timeout(read_timeout):
760 response = await self._parser.read_response(
761 disable_decoding=disable_decoding
762 )
763 elif self.protocol in ["3", 3]:
764 response = await self._parser.read_response(
765 disable_decoding=disable_decoding, push_request=push_request
766 )
767 else:
768 response = await self._parser.read_response(
769 disable_decoding=disable_decoding
770 )
771 except asyncio.TimeoutError:
772 if timeout is not None:
773 # user requested timeout, return None. Operation can be retried
774 return None
775 # it was a self.socket_timeout error.
776 if disconnect_on_error:
777 await self.disconnect(nowait=True)
778 raise TimeoutError(f"Timeout reading from {host_error}")
779 except OSError as e:
780 if disconnect_on_error:
781 await self.disconnect(nowait=True)
782 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
783 except BaseException:
784 # Also by default close in case of BaseException. A lot of code
785 # relies on this behaviour when doing Command/Response pairs.
786 # See #1128.
787 if disconnect_on_error:
788 await self.disconnect(nowait=True)
789 raise
790
791 if self.health_check_interval:
792 next_time = asyncio.get_running_loop().time() + self.health_check_interval
793 self.next_health_check = next_time
794
795 if isinstance(response, ResponseError):
796 raise response from None
797 return response
798
799 def pack_command(self, *args: EncodableT) -> List[bytes]:
800 """Pack a series of arguments into the Redis protocol"""
801 output = []
802 # the client might have included 1 or more literal arguments in
803 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
804 # arguments to be sent separately, so split the first argument
805 # manually. These arguments should be bytestrings so that they are
806 # not encoded.
807 assert not isinstance(args[0], float)
808 if isinstance(args[0], str):
809 args = tuple(args[0].encode().split()) + args[1:]
810 elif b" " in args[0]:
811 args = tuple(args[0].split()) + args[1:]
812
813 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
814
815 buffer_cutoff = self._buffer_cutoff
816 for arg in map(self.encoder.encode, args):
817 # to avoid large string mallocs, chunk the command into the
818 # output list if we're sending large values or memoryviews
819 arg_length = len(arg)
820 if (
821 len(buff) > buffer_cutoff
822 or arg_length > buffer_cutoff
823 or isinstance(arg, memoryview)
824 ):
825 buff = SYM_EMPTY.join(
826 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
827 )
828 output.append(buff)
829 output.append(arg)
830 buff = SYM_CRLF
831 else:
832 buff = SYM_EMPTY.join(
833 (
834 buff,
835 SYM_DOLLAR,
836 str(arg_length).encode(),
837 SYM_CRLF,
838 arg,
839 SYM_CRLF,
840 )
841 )
842 output.append(buff)
843 return output
844
845 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
846 """Pack multiple commands into the Redis protocol"""
847 output: List[bytes] = []
848 pieces: List[bytes] = []
849 buffer_length = 0
850 buffer_cutoff = self._buffer_cutoff
851
852 for cmd in commands:
853 for chunk in self.pack_command(*cmd):
854 chunklen = len(chunk)
855 if (
856 buffer_length > buffer_cutoff
857 or chunklen > buffer_cutoff
858 or isinstance(chunk, memoryview)
859 ):
860 if pieces:
861 output.append(SYM_EMPTY.join(pieces))
862 buffer_length = 0
863 pieces = []
864
865 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
866 output.append(chunk)
867 else:
868 pieces.append(chunk)
869 buffer_length += chunklen
870
871 if pieces:
872 output.append(SYM_EMPTY.join(pieces))
873 return output
874
875 def _socket_is_empty(self):
876 """Check if the socket is empty"""
877 return len(self._reader._buffer) == 0
878
879 async def process_invalidation_messages(self):
880 while not self._socket_is_empty():
881 await self.read_response(push_request=True)
882
883 def set_re_auth_token(self, token: TokenInterface):
884 self._re_auth_token = token
885
886 async def re_auth(self):
887 if self._re_auth_token is not None:
888 await self.send_command(
889 "AUTH",
890 self._re_auth_token.try_get("oid"),
891 self._re_auth_token.get_value(),
892 )
893 await self.read_response()
894 self._re_auth_token = None
895
896
897class Connection(AbstractConnection):
898 "Manages TCP communication to and from a Redis server"
899
900 def __init__(
901 self,
902 *,
903 host: str = "localhost",
904 port: str | int = 6379,
905 socket_keepalive: bool = True,
906 socket_keepalive_options: Mapping[int, int | bytes] | object | None = SENTINEL,
907 socket_type: int = 0,
908 **kwargs,
909 ):
910 """
911 Initialize a TCP connection.
912
913 Parameters
914 ----------
915 socket_keepalive : bool
916 If `True`, TCP keepalive is enabled for TCP socket connections.
917 socket_keepalive_options : Mapping[int, int | bytes] | object | None
918 Mapping of TCP keepalive socket option constants to values, for
919 example `{socket.TCP_KEEPIDLE: 30}`. If left unspecified, redis-py
920 uses TCP keepalive defaults when `socket_keepalive` is enabled:
921 idle 30 seconds, interval 5 seconds, and 3 probes. Platform-specific
922 options that are not available are skipped. Pass `None` or `{}` to
923 avoid setting additional TCP keepalive options.
924 """
925 self.host = host
926 self.port = int(port)
927 self.socket_keepalive = socket_keepalive
928 if socket_keepalive_options is SENTINEL:
929 socket_keepalive_options = get_default_socket_keepalive_options()
930 self.socket_keepalive_options = socket_keepalive_options or {}
931 self.socket_type = socket_type
932 super().__init__(**kwargs)
933
934 def repr_pieces(self):
935 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
936 if self.client_name:
937 pieces.append(("client_name", self.client_name))
938 return pieces
939
940 def _connection_arguments(self) -> Mapping:
941 return {"host": self.host, "port": self.port}
942
943 async def _connect(self):
944 """Create a TCP socket connection"""
945 async with async_timeout(self.socket_connect_timeout):
946 reader, writer = await asyncio.open_connection(
947 **self._connection_arguments()
948 )
949 self._reader = reader
950 self._writer = writer
951 sock = writer.transport.get_extra_info("socket")
952 if sock:
953 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
954 try:
955 # TCP_KEEPALIVE
956 if self.socket_keepalive:
957 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
958 for k, v in self.socket_keepalive_options.items():
959 sock.setsockopt(socket.SOL_TCP, k, v)
960
961 except (OSError, TypeError):
962 # `socket_keepalive_options` might contain invalid options
963 # causing an error. Do not leave the connection open.
964 writer.close()
965 raise
966
967 def _host_error(self) -> str:
968 return f"{self.host}:{self.port}"
969
970
971class SSLConnection(Connection):
972 """Manages SSL connections to and from the Redis server(s).
973 This class extends the Connection class, adding SSL functionality, and making
974 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
975 """
976
977 def __init__(
978 self,
979 ssl_keyfile: Optional[str] = None,
980 ssl_certfile: Optional[str] = None,
981 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
982 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
983 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
984 ssl_ca_certs: Optional[str] = None,
985 ssl_ca_data: Optional[str] = None,
986 ssl_ca_path: Optional[str] = None,
987 ssl_check_hostname: bool = True,
988 ssl_min_version: Optional[TLSVersion] = None,
989 ssl_ciphers: Optional[str] = None,
990 ssl_password: Optional[str] = None,
991 **kwargs,
992 ):
993 if not SSL_AVAILABLE:
994 raise RedisError("Python wasn't built with SSL support")
995
996 self.ssl_context: RedisSSLContext = RedisSSLContext(
997 keyfile=ssl_keyfile,
998 certfile=ssl_certfile,
999 cert_reqs=ssl_cert_reqs,
1000 include_verify_flags=ssl_include_verify_flags,
1001 exclude_verify_flags=ssl_exclude_verify_flags,
1002 ca_certs=ssl_ca_certs,
1003 ca_data=ssl_ca_data,
1004 ca_path=ssl_ca_path,
1005 check_hostname=ssl_check_hostname,
1006 min_version=ssl_min_version,
1007 ciphers=ssl_ciphers,
1008 password=ssl_password,
1009 )
1010 super().__init__(**kwargs)
1011
1012 def _connection_arguments(self) -> Mapping:
1013 kwargs = super()._connection_arguments()
1014 kwargs["ssl"] = self.ssl_context.get()
1015 return kwargs
1016
1017 @property
1018 def keyfile(self):
1019 return self.ssl_context.keyfile
1020
1021 @property
1022 def certfile(self):
1023 return self.ssl_context.certfile
1024
1025 @property
1026 def cert_reqs(self):
1027 return self.ssl_context.cert_reqs
1028
1029 @property
1030 def include_verify_flags(self):
1031 return self.ssl_context.include_verify_flags
1032
1033 @property
1034 def exclude_verify_flags(self):
1035 return self.ssl_context.exclude_verify_flags
1036
1037 @property
1038 def ca_certs(self):
1039 return self.ssl_context.ca_certs
1040
1041 @property
1042 def ca_data(self):
1043 return self.ssl_context.ca_data
1044
1045 @property
1046 def check_hostname(self):
1047 return self.ssl_context.check_hostname
1048
1049 @property
1050 def min_version(self):
1051 return self.ssl_context.min_version
1052
1053
1054class RedisSSLContext:
1055 __slots__ = (
1056 "keyfile",
1057 "certfile",
1058 "cert_reqs",
1059 "include_verify_flags",
1060 "exclude_verify_flags",
1061 "ca_certs",
1062 "ca_data",
1063 "ca_path",
1064 "context",
1065 "check_hostname",
1066 "min_version",
1067 "ciphers",
1068 "password",
1069 )
1070
1071 def __init__(
1072 self,
1073 keyfile: Optional[str] = None,
1074 certfile: Optional[str] = None,
1075 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
1076 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1077 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1078 ca_certs: Optional[str] = None,
1079 ca_data: Optional[str] = None,
1080 ca_path: Optional[str] = None,
1081 check_hostname: bool = False,
1082 min_version: Optional[TLSVersion] = None,
1083 ciphers: Optional[str] = None,
1084 password: Optional[str] = None,
1085 ):
1086 if not SSL_AVAILABLE:
1087 raise RedisError("Python wasn't built with SSL support")
1088
1089 self.keyfile = keyfile
1090 self.certfile = certfile
1091 if cert_reqs is None:
1092 cert_reqs = ssl.CERT_NONE
1093 elif isinstance(cert_reqs, str):
1094 CERT_REQS = { # noqa: N806
1095 "none": ssl.CERT_NONE,
1096 "optional": ssl.CERT_OPTIONAL,
1097 "required": ssl.CERT_REQUIRED,
1098 }
1099 if cert_reqs not in CERT_REQS:
1100 raise RedisError(
1101 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
1102 )
1103 cert_reqs = CERT_REQS[cert_reqs]
1104 self.cert_reqs = cert_reqs
1105 self.include_verify_flags = include_verify_flags
1106 self.exclude_verify_flags = exclude_verify_flags
1107 self.ca_certs = ca_certs
1108 self.ca_data = ca_data
1109 self.ca_path = ca_path
1110 self.check_hostname = (
1111 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1112 )
1113 self.min_version = min_version
1114 self.ciphers = ciphers
1115 self.password = password
1116 self.context: Optional[SSLContext] = None
1117
1118 def get(self) -> SSLContext:
1119 if not self.context:
1120 context = ssl.create_default_context()
1121 context.check_hostname = self.check_hostname
1122 context.verify_mode = self.cert_reqs
1123 if self.include_verify_flags:
1124 for flag in self.include_verify_flags:
1125 context.verify_flags |= flag
1126 if self.exclude_verify_flags:
1127 for flag in self.exclude_verify_flags:
1128 context.verify_flags &= ~flag
1129 if self.certfile or self.keyfile:
1130 context.load_cert_chain(
1131 certfile=self.certfile,
1132 keyfile=self.keyfile,
1133 password=self.password,
1134 )
1135 if self.ca_certs or self.ca_data or self.ca_path:
1136 context.load_verify_locations(
1137 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1138 )
1139 if self.min_version is not None:
1140 context.minimum_version = self.min_version
1141 if self.ciphers is not None:
1142 context.set_ciphers(self.ciphers)
1143 self.context = context
1144 return self.context
1145
1146
1147class UnixDomainSocketConnection(AbstractConnection):
1148 "Manages UDS communication to and from a Redis server"
1149
1150 def __init__(self, *, path: str = "", **kwargs):
1151 self.path = path
1152 super().__init__(**kwargs)
1153
1154 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
1155 pieces = [("path", self.path), ("db", self.db)]
1156 if self.client_name:
1157 pieces.append(("client_name", self.client_name))
1158 return pieces
1159
1160 async def _connect(self):
1161 async with async_timeout(self.socket_connect_timeout):
1162 reader, writer = await asyncio.open_unix_connection(path=self.path)
1163 self._reader = reader
1164 self._writer = writer
1165 await self.on_connect()
1166
1167 def _host_error(self) -> str:
1168 return self.path
1169
1170
1171FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1172
1173
1174def to_bool(value) -> Optional[bool]:
1175 if value is None or value == "":
1176 return None
1177 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1178 return False
1179 return bool(value)
1180
1181
1182def parse_ssl_verify_flags(value):
1183 # flags are passed in as a string representation of a list,
1184 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1185 verify_flags_str = value.replace("[", "").replace("]", "")
1186
1187 verify_flags = []
1188 for flag in verify_flags_str.split(","):
1189 flag = flag.strip()
1190 if not hasattr(VerifyFlags, flag):
1191 raise ValueError(f"Invalid ssl verify flag: {flag}")
1192 verify_flags.append(getattr(VerifyFlags, flag))
1193 return verify_flags
1194
1195
1196URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1197 {
1198 "db": int,
1199 "socket_timeout": float,
1200 "socket_connect_timeout": float,
1201 "socket_read_size": int,
1202 "socket_keepalive": to_bool,
1203 "retry_on_timeout": to_bool,
1204 "max_connections": int,
1205 "health_check_interval": int,
1206 "ssl_check_hostname": to_bool,
1207 "ssl_include_verify_flags": parse_ssl_verify_flags,
1208 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1209 "ssl_min_version": int,
1210 "timeout": float,
1211 "protocol": int,
1212 "legacy_responses": to_bool,
1213 }
1214)
1215
1216
1217class ConnectKwargs(TypedDict, total=False):
1218 username: str
1219 password: str
1220 connection_class: Type[AbstractConnection]
1221 host: str
1222 port: int
1223 db: int
1224 path: str
1225
1226
1227def parse_url(url: str) -> ConnectKwargs:
1228 parsed: ParseResult = urlparse(url)
1229 kwargs: ConnectKwargs = {}
1230
1231 for name, value_list in parse_qs(parsed.query).items():
1232 if value_list and len(value_list) > 0:
1233 value = unquote(value_list[0])
1234 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1235 if parser:
1236 try:
1237 kwargs[name] = parser(value)
1238 except (TypeError, ValueError):
1239 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1240 else:
1241 kwargs[name] = value
1242
1243 if parsed.username:
1244 kwargs["username"] = unquote(parsed.username)
1245 if parsed.password:
1246 kwargs["password"] = unquote(parsed.password)
1247
1248 # We only support redis://, rediss:// and unix:// schemes.
1249 if parsed.scheme == "unix":
1250 if parsed.path:
1251 kwargs["path"] = unquote(parsed.path)
1252 kwargs["connection_class"] = UnixDomainSocketConnection
1253
1254 elif parsed.scheme in ("redis", "rediss"):
1255 if parsed.hostname:
1256 kwargs["host"] = unquote(parsed.hostname)
1257 if parsed.port:
1258 kwargs["port"] = int(parsed.port)
1259
1260 # If there's a path argument, use it as the db argument if a
1261 # querystring value wasn't specified
1262 if parsed.path and "db" not in kwargs:
1263 try:
1264 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1265 except (AttributeError, ValueError):
1266 pass
1267
1268 if parsed.scheme == "rediss":
1269 kwargs["connection_class"] = SSLConnection
1270
1271 else:
1272 valid_schemes = "redis://, rediss://, unix://"
1273 raise ValueError(
1274 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1275 )
1276
1277 return kwargs
1278
1279
1280_CP = TypeVar("_CP", bound="ConnectionPool")
1281
1282
1283class ConnectionPoolInterface(ABC):
1284 @abstractmethod
1285 def get_protocol(self):
1286 pass
1287
1288 @abstractmethod
1289 def reset(self) -> None:
1290 pass
1291
1292 @abstractmethod
1293 @deprecated_args(
1294 args_to_warn=["*"],
1295 reason="Use get_connection() without args instead",
1296 version="5.3.0",
1297 )
1298 async def get_connection(
1299 self, command_name: Optional[str] = None, *keys: Any, **options: Any
1300 ) -> "AbstractConnection":
1301 pass
1302
1303 @abstractmethod
1304 def get_encoder(self) -> "Encoder":
1305 pass
1306
1307 @abstractmethod
1308 async def release(self, connection: "AbstractConnection") -> None:
1309 pass
1310
1311 @abstractmethod
1312 async def disconnect(self, inuse_connections: bool = True) -> None:
1313 pass
1314
1315 @abstractmethod
1316 async def aclose(self) -> None:
1317 pass
1318
1319 @abstractmethod
1320 def set_retry(self, retry: "Retry") -> None:
1321 pass
1322
1323 @abstractmethod
1324 async def re_auth_callback(self, token: TokenInterface) -> None:
1325 pass
1326
1327 @abstractmethod
1328 def get_connection_count(self) -> List[Tuple[int, dict]]:
1329 """
1330 Returns a connection count (both idle and in use).
1331 """
1332 pass
1333
1334
1335class ConnectionPool(ConnectionPoolInterface):
1336 """
1337 Create a connection pool. ``If max_connections`` is set, then this
1338 object raises :py:class:`~redis.ConnectionError` when the pool's
1339 limit is reached.
1340
1341 By default, TCP connections are created unless ``connection_class``
1342 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1343 unix sockets.
1344 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1345
1346 Any additional keyword arguments are passed to the constructor of
1347 ``connection_class``.
1348 """
1349
1350 @classmethod
1351 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1352 """
1353 Return a connection pool configured from the given URL.
1354
1355 For example::
1356
1357 redis://[[username]:[password]]@localhost:6379/0
1358 rediss://[[username]:[password]]@localhost:6379/0
1359 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1360
1361 Three URL schemes are supported:
1362
1363 - `redis://` creates a TCP socket connection. See more at:
1364 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1365 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1366 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1367 - ``unix://``: creates a Unix Domain Socket connection.
1368
1369 The username, password, hostname, path and all querystring values
1370 are passed through urllib.parse.unquote in order to replace any
1371 percent-encoded values with their corresponding characters.
1372
1373 There are several ways to specify a database number. The first value
1374 found will be used:
1375
1376 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1377
1378 2. If using the redis:// or rediss:// schemes, the path argument
1379 of the url, e.g. redis://localhost/0
1380
1381 3. A ``db`` keyword argument to this function.
1382
1383 If none of these options are specified, the default db=0 is used.
1384
1385 All querystring options are cast to their appropriate Python types.
1386 Boolean arguments can be specified with string values "True"/"False"
1387 or "Yes"/"No". Values that cannot be properly cast cause a
1388 ``ValueError`` to be raised. Once parsed, the querystring arguments
1389 and keyword arguments are passed to the ``ConnectionPool``'s
1390 class initializer. In the case of conflicting arguments, querystring
1391 arguments always win.
1392 """
1393 url_options = parse_url(url)
1394 kwargs.update(url_options)
1395 return cls(**kwargs)
1396
1397 def __init__(
1398 self,
1399 connection_class: Type[AbstractConnection] = Connection,
1400 max_connections: Optional[int] = None,
1401 **connection_kwargs,
1402 ):
1403 max_connections = max_connections or 100
1404 if not isinstance(max_connections, int) or max_connections < 0:
1405 raise ValueError('"max_connections" must be a positive integer')
1406
1407 self.connection_class = connection_class
1408 self.connection_kwargs = connection_kwargs
1409 self.max_connections = max_connections
1410
1411 self._available_connections: List[AbstractConnection] = []
1412 self._in_use_connections: Set[AbstractConnection] = set()
1413 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1414 self._lock = asyncio.Lock()
1415 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1416 if self._event_dispatcher is None:
1417 self._event_dispatcher = EventDispatcher()
1418
1419 # Keys that should be redacted in __repr__ to avoid exposing sensitive information
1420 SENSITIVE_REPR_KEYS = frozenset(
1421 {
1422 "password",
1423 "username",
1424 "ssl_password",
1425 "credential_provider",
1426 }
1427 )
1428
1429 def __repr__(self):
1430 conn_kwargs = ",".join(
1431 [
1432 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}"
1433 for k, v in self.connection_kwargs.items()
1434 ]
1435 )
1436 return (
1437 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1438 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1439 f"({conn_kwargs})>)>"
1440 )
1441
1442 def get_protocol(self):
1443 """
1444 Returns:
1445 The RESP protocol version, or ``None`` if the protocol is not specified,
1446 in which case the server default will be used.
1447 """
1448 return self.connection_kwargs.get("protocol", None)
1449
1450 def reset(self):
1451 # Record metrics for connections being removed before clearing
1452 # (only if attributes exist - they won't during __init__)
1453 if hasattr(self, "_available_connections") and hasattr(
1454 self, "_in_use_connections"
1455 ):
1456 idle_count = len(self._available_connections)
1457 in_use_count = len(self._in_use_connections)
1458 if idle_count > 0 or in_use_count > 0:
1459 pool_name = get_pool_name(self)
1460 # Note: Using sync version since reset() is sync
1461 from redis.observability.recorder import (
1462 record_connection_count as sync_record_connection_count,
1463 )
1464
1465 if idle_count > 0:
1466 sync_record_connection_count(
1467 pool_name=pool_name,
1468 connection_state=ConnectionState.IDLE,
1469 counter=-idle_count,
1470 )
1471 if in_use_count > 0:
1472 sync_record_connection_count(
1473 pool_name=pool_name,
1474 connection_state=ConnectionState.USED,
1475 counter=-in_use_count,
1476 )
1477
1478 self._available_connections = []
1479 self._in_use_connections = weakref.WeakSet()
1480
1481 def __del__(self) -> None:
1482 """Clean up connection pool and record metrics when garbage collected."""
1483 try:
1484 if not hasattr(self, "_available_connections") or not hasattr(
1485 self, "_in_use_connections"
1486 ):
1487 return
1488 idle_count = len(self._available_connections)
1489 in_use_count = len(self._in_use_connections)
1490 if idle_count > 0 or in_use_count > 0:
1491 pool_name = get_pool_name(self)
1492 # Note: Using sync version since __del__ is sync
1493 from redis.observability.recorder import (
1494 record_connection_count as sync_record_connection_count,
1495 )
1496
1497 if idle_count > 0:
1498 sync_record_connection_count(
1499 pool_name=pool_name,
1500 connection_state=ConnectionState.IDLE,
1501 counter=-idle_count,
1502 )
1503 if in_use_count > 0:
1504 sync_record_connection_count(
1505 pool_name=pool_name,
1506 connection_state=ConnectionState.USED,
1507 counter=-in_use_count,
1508 )
1509 except Exception:
1510 pass
1511
1512 def can_get_connection(self) -> bool:
1513 """Return True if a connection can be retrieved from the pool."""
1514 return (
1515 self._available_connections
1516 or len(self._in_use_connections) < self.max_connections
1517 )
1518
1519 @deprecated_args(
1520 args_to_warn=["*"],
1521 reason="Use get_connection() without args instead",
1522 version="5.3.0",
1523 )
1524 async def get_connection(self, command_name=None, *keys, **options):
1525 """Get a connected connection from the pool"""
1526 # Track connection count before to detect if a new connection is created
1527 async with self._lock:
1528 connections_before = len(self._available_connections) + len(
1529 self._in_use_connections
1530 )
1531 start_time_created = time.monotonic()
1532 connection = self.get_available_connection()
1533 connections_after = len(self._available_connections) + len(
1534 self._in_use_connections
1535 )
1536 is_created = connections_after > connections_before
1537
1538 # Record state transition for observability
1539 # This ensures counters stay balanced if ensure_connection() fails and release() is called
1540 pool_name = get_pool_name(self)
1541 if is_created:
1542 # New connection created and acquired: just USED +1
1543 await record_connection_count(
1544 pool_name=pool_name,
1545 connection_state=ConnectionState.USED,
1546 counter=1,
1547 )
1548 else:
1549 # Existing connection acquired from pool: IDLE -> USED
1550 await record_connection_count(
1551 pool_name=pool_name,
1552 connection_state=ConnectionState.IDLE,
1553 counter=-1,
1554 )
1555 await record_connection_count(
1556 pool_name=pool_name,
1557 connection_state=ConnectionState.USED,
1558 counter=1,
1559 )
1560
1561 # We now perform the connection check outside of the lock.
1562 try:
1563 await self.ensure_connection(connection)
1564
1565 if is_created:
1566 await record_connection_create_time(
1567 connection_pool=self,
1568 duration_seconds=time.monotonic() - start_time_created,
1569 )
1570
1571 return connection
1572 except BaseException:
1573 await self.release(connection)
1574 raise
1575
1576 def get_available_connection(self):
1577 """Get a connection from the pool, without making sure it is connected"""
1578 try:
1579 connection = self._available_connections.pop()
1580 except IndexError:
1581 if len(self._in_use_connections) >= self.max_connections:
1582 raise MaxConnectionsError("Too many connections") from None
1583 connection = self.make_connection()
1584 self._in_use_connections.add(connection)
1585 return connection
1586
1587 def get_encoder(self):
1588 """Return an encoder based on encoding settings"""
1589 kwargs = self.connection_kwargs
1590 return self.encoder_class(
1591 encoding=kwargs.get("encoding", "utf-8"),
1592 encoding_errors=kwargs.get("encoding_errors", "strict"),
1593 decode_responses=kwargs.get("decode_responses", False),
1594 )
1595
1596 def make_connection(self):
1597 """Create a new connection. Can be overridden by child classes."""
1598 # Note: We don't record IDLE here because async uses a sync make_connection
1599 # but async record_connection_count. The recording is handled in get_connection.
1600 return self.connection_class(**self.connection_kwargs)
1601
1602 async def ensure_connection(self, connection: AbstractConnection):
1603 """Ensure that the connection object is connected and valid"""
1604 await connection.connect()
1605 # connections that the pool provides should be ready to send
1606 # a command. if not, the connection was either returned to the
1607 # pool before all data has been read or the socket has been
1608 # closed. either way, reconnect and verify everything is good.
1609 try:
1610 if await connection.can_read():
1611 raise ConnectionError("Connection has data") from None
1612 except (ConnectionError, TimeoutError, OSError):
1613 await connection.disconnect()
1614 await connection.connect()
1615 if await connection.can_read():
1616 raise ConnectionError("Connection not ready") from None
1617
1618 async def release(self, connection: AbstractConnection):
1619 """Releases the connection back to the pool"""
1620 # Connections should always be returned to the correct pool,
1621 # not doing so is an error that will cause an exception here.
1622 self._in_use_connections.remove(connection)
1623
1624 if connection.should_reconnect():
1625 await connection.disconnect()
1626
1627 self._available_connections.append(connection)
1628 await self._event_dispatcher.dispatch_async(
1629 AsyncAfterConnectionReleasedEvent(connection)
1630 )
1631
1632 # Record state transition: USED -> IDLE
1633 pool_name = get_pool_name(self)
1634 await record_connection_count(
1635 pool_name=pool_name,
1636 connection_state=ConnectionState.USED,
1637 counter=-1,
1638 )
1639 await record_connection_count(
1640 pool_name=pool_name,
1641 connection_state=ConnectionState.IDLE,
1642 counter=1,
1643 )
1644
1645 async def disconnect(self, inuse_connections: bool = True):
1646 """
1647 Disconnects connections in the pool
1648
1649 If ``inuse_connections`` is True, disconnect connections that are
1650 current in use, potentially by other tasks. Otherwise only disconnect
1651 connections that are idle in the pool.
1652 """
1653 if inuse_connections:
1654 connections: Iterable[AbstractConnection] = chain(
1655 self._available_connections, self._in_use_connections
1656 )
1657 else:
1658 connections = self._available_connections
1659 resp = await asyncio.gather(
1660 *(connection.disconnect() for connection in connections),
1661 return_exceptions=True,
1662 )
1663
1664 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1665 if exc:
1666 raise exc
1667
1668 async def update_active_connections_for_reconnect(self):
1669 """
1670 Mark all active connections for reconnect.
1671 """
1672 async with self._lock:
1673 for conn in self._in_use_connections:
1674 conn.mark_for_reconnect()
1675
1676 async def aclose(self) -> None:
1677 """Close the pool, disconnecting all connections"""
1678 await self.disconnect()
1679
1680 def set_retry(self, retry: "Retry") -> None:
1681 for conn in self._available_connections:
1682 conn.retry = retry
1683 for conn in self._in_use_connections:
1684 conn.retry = retry
1685
1686 async def re_auth_callback(self, token: TokenInterface):
1687 async with self._lock:
1688 for conn in self._available_connections:
1689 await conn.retry.call_with_retry(
1690 lambda: conn.send_command(
1691 "AUTH", token.try_get("oid"), token.get_value()
1692 ),
1693 lambda error: self._mock(error),
1694 )
1695 await conn.retry.call_with_retry(
1696 lambda: conn.read_response(), lambda error: self._mock(error)
1697 )
1698 for conn in self._in_use_connections:
1699 conn.set_re_auth_token(token)
1700
1701 async def _mock(self, error: RedisError):
1702 """
1703 Dummy functions, needs to be passed as error callback to retry object.
1704 :param error:
1705 :return:
1706 """
1707 pass
1708
1709 def get_connection_count(self) -> List[tuple[int, dict]]:
1710 """
1711 Returns a connection count (both idle and in use).
1712 """
1713 attributes = AttributeBuilder.build_base_attributes()
1714 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
1715 free_connections_attributes = attributes.copy()
1716 in_use_connections_attributes = attributes.copy()
1717
1718 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1719 ConnectionState.IDLE.value
1720 )
1721 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1722 ConnectionState.USED.value
1723 )
1724
1725 return [
1726 (len(self._available_connections), free_connections_attributes),
1727 (len(self._in_use_connections), in_use_connections_attributes),
1728 ]
1729
1730
1731class BlockingConnectionPool(ConnectionPool):
1732 """
1733 A blocking connection pool::
1734
1735 >>> from redis.asyncio import Redis, BlockingConnectionPool
1736 >>> client = Redis.from_pool(BlockingConnectionPool())
1737
1738 It performs the same function as the default
1739 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1740 it maintains a pool of reusable connections that can be shared by
1741 multiple async redis clients.
1742
1743 The difference is that, in the event that a client tries to get a
1744 connection from the pool when all of connections are in use, rather than
1745 raising a :py:class:`~redis.ConnectionError` (as the default
1746 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1747 blocks the current `Task` for a specified number of seconds until
1748 a connection becomes available.
1749
1750 Use ``max_connections`` to increase / decrease the pool size::
1751
1752 >>> pool = BlockingConnectionPool(max_connections=10)
1753
1754 Use ``timeout`` to tell it either how many seconds to wait for a connection
1755 to become available, or to block forever:
1756
1757 >>> # Block forever.
1758 >>> pool = BlockingConnectionPool(timeout=None)
1759
1760 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1761 >>> # not available.
1762 >>> pool = BlockingConnectionPool(timeout=5)
1763 """
1764
1765 def __init__(
1766 self,
1767 max_connections: int = 50,
1768 timeout: Optional[float] = 20,
1769 connection_class: Type[AbstractConnection] = Connection,
1770 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1771 **connection_kwargs,
1772 ):
1773 super().__init__(
1774 connection_class=connection_class,
1775 max_connections=max_connections,
1776 **connection_kwargs,
1777 )
1778 self._condition = asyncio.Condition()
1779 self.timeout = timeout
1780
1781 @deprecated_args(
1782 args_to_warn=["*"],
1783 reason="Use get_connection() without args instead",
1784 version="5.3.0",
1785 )
1786 async def get_connection(self, command_name=None, *keys, **options):
1787 """Gets a connection from the pool, blocking until one is available"""
1788 # Start timing for wait time observability
1789 start_time_acquired = time.monotonic()
1790
1791 try:
1792 async with self._condition:
1793 async with async_timeout(self.timeout):
1794 await self._condition.wait_for(self.can_get_connection)
1795 # Track connection count before to detect if a new connection is created
1796 connections_before = len(self._available_connections) + len(
1797 self._in_use_connections
1798 )
1799 start_time_created = time.monotonic()
1800 connection = super().get_available_connection()
1801 connections_after = len(self._available_connections) + len(
1802 self._in_use_connections
1803 )
1804 is_created = connections_after > connections_before
1805 except asyncio.TimeoutError as err:
1806 raise ConnectionError("No connection available.") from err
1807
1808 # We now perform the connection check outside of the lock.
1809 try:
1810 await self.ensure_connection(connection)
1811
1812 if is_created:
1813 await record_connection_create_time(
1814 connection_pool=self,
1815 duration_seconds=time.monotonic() - start_time_created,
1816 )
1817
1818 await record_connection_wait_time(
1819 pool_name=get_pool_name(self),
1820 duration_seconds=time.monotonic() - start_time_acquired,
1821 )
1822
1823 return connection
1824 except BaseException:
1825 await self.release(connection)
1826 raise
1827
1828 async def release(self, connection: AbstractConnection):
1829 """Releases the connection back to the pool."""
1830 async with self._condition:
1831 await super().release(connection)
1832 self._condition.notify()