1import asyncio
2import copy
3import enum
4import inspect
5import socket
6import sys
7import time
8import warnings
9import weakref
10from abc import abstractmethod
11from itertools import chain
12from types import MappingProxyType
13from typing import (
14 Any,
15 Callable,
16 Iterable,
17 List,
18 Mapping,
19 Optional,
20 Protocol,
21 Set,
22 Tuple,
23 Type,
24 TypedDict,
25 TypeVar,
26 Union,
27)
28from urllib.parse import ParseResult, parse_qs, unquote, urlparse
29
30from ..observability.attributes import (
31 DB_CLIENT_CONNECTION_POOL_NAME,
32 DB_CLIENT_CONNECTION_STATE,
33 AttributeBuilder,
34 ConnectionState,
35 get_pool_name,
36)
37from ..utils import SSL_AVAILABLE
38
39if SSL_AVAILABLE:
40 import ssl
41 from ssl import SSLContext, TLSVersion, VerifyFlags
42else:
43 ssl = None
44 TLSVersion = None
45 SSLContext = None
46 VerifyFlags = None
47
48from ..auth.token import TokenInterface
49from ..driver_info import DriverInfo, resolve_driver_info
50from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
51from ..utils import deprecated_args, format_error_message
52
53# the functionality is available in 3.11.x but has a major issue before
54# 3.11.3. See https://github.com/redis/redis-py/issues/2633
55if sys.version_info >= (3, 11, 3):
56 from asyncio import timeout as async_timeout
57else:
58 from async_timeout import timeout as async_timeout
59
60from redis.asyncio.observability.recorder import (
61 record_connection_closed,
62 record_connection_create_time,
63 record_connection_wait_time,
64 record_error_count,
65)
66from redis.asyncio.retry import Retry
67from redis.backoff import NoBackoff
68from redis.connection import DEFAULT_RESP_VERSION
69from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
70from redis.exceptions import (
71 AuthenticationError,
72 AuthenticationWrongNumberOfArgsError,
73 ConnectionError,
74 DataError,
75 MaxConnectionsError,
76 RedisError,
77 ResponseError,
78 TimeoutError,
79)
80from redis.observability.metrics import CloseReason
81from redis.typing import EncodableT
82from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
83
84from .._parsers import (
85 BaseParser,
86 Encoder,
87 _AsyncHiredisParser,
88 _AsyncRESP2Parser,
89 _AsyncRESP3Parser,
90)
91
92SYM_STAR = b"*"
93SYM_DOLLAR = b"$"
94SYM_CRLF = b"\r\n"
95SYM_LF = b"\n"
96SYM_EMPTY = b""
97
98
99class _Sentinel(enum.Enum):
100 sentinel = object()
101
102
103SENTINEL = _Sentinel.sentinel
104
105
106DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
107if HIREDIS_AVAILABLE:
108 DefaultParser = _AsyncHiredisParser
109else:
110 DefaultParser = _AsyncRESP2Parser
111
112
113class ConnectCallbackProtocol(Protocol):
114 def __call__(self, connection: "AbstractConnection"): ...
115
116
117class AsyncConnectCallbackProtocol(Protocol):
118 async def __call__(self, connection: "AbstractConnection"): ...
119
120
121ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
122
123
124class AbstractConnection:
125 """Manages communication to and from a Redis server"""
126
127 __slots__ = (
128 "db",
129 "username",
130 "client_name",
131 "lib_name",
132 "lib_version",
133 "credential_provider",
134 "password",
135 "socket_timeout",
136 "socket_connect_timeout",
137 "redis_connect_func",
138 "retry_on_timeout",
139 "retry_on_error",
140 "health_check_interval",
141 "next_health_check",
142 "last_active_at",
143 "encoder",
144 "ssl_context",
145 "protocol",
146 "_reader",
147 "_writer",
148 "_parser",
149 "_connect_callbacks",
150 "_buffer_cutoff",
151 "_lock",
152 "_socket_read_size",
153 "__dict__",
154 )
155
156 @deprecated_args(
157 args_to_warn=["lib_name", "lib_version"],
158 reason="Use 'driver_info' parameter instead. "
159 "lib_name and lib_version will be removed in a future version.",
160 )
161 def __init__(
162 self,
163 *,
164 db: Union[str, int] = 0,
165 password: Optional[str] = None,
166 socket_timeout: Optional[float] = None,
167 socket_connect_timeout: Optional[float] = None,
168 retry_on_timeout: bool = False,
169 retry_on_error: Union[list, _Sentinel] = SENTINEL,
170 encoding: str = "utf-8",
171 encoding_errors: str = "strict",
172 decode_responses: bool = False,
173 parser_class: Type[BaseParser] = DefaultParser,
174 socket_read_size: int = 65536,
175 health_check_interval: float = 0,
176 client_name: Optional[str] = None,
177 lib_name: Optional[str] = None,
178 lib_version: Optional[str] = None,
179 driver_info: Optional[DriverInfo] = None,
180 username: Optional[str] = None,
181 retry: Optional[Retry] = None,
182 redis_connect_func: Optional[ConnectCallbackT] = None,
183 encoder_class: Type[Encoder] = Encoder,
184 credential_provider: Optional[CredentialProvider] = None,
185 protocol: Optional[int] = 2,
186 event_dispatcher: Optional[EventDispatcher] = None,
187 ):
188 """
189 Initialize a new async Connection.
190
191 Parameters
192 ----------
193 driver_info : DriverInfo, optional
194 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
195 are ignored. If not provided, a DriverInfo will be created from lib_name
196 and lib_version (or defaults if those are also None).
197 lib_name : str, optional
198 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
199 lib_version : str, optional
200 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
201 """
202 if (username or password) and credential_provider is not None:
203 raise DataError(
204 "'username' and 'password' cannot be passed along with 'credential_"
205 "provider'. Please provide only one of the following arguments: \n"
206 "1. 'password' and (optional) 'username'\n"
207 "2. 'credential_provider'"
208 )
209 if event_dispatcher is None:
210 self._event_dispatcher = EventDispatcher()
211 else:
212 self._event_dispatcher = event_dispatcher
213 self.db = db
214 self.client_name = client_name
215
216 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
217 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
218
219 self.credential_provider = credential_provider
220 self.password = password
221 self.username = username
222 self.socket_timeout = socket_timeout
223 if socket_connect_timeout is None:
224 socket_connect_timeout = socket_timeout
225 self.socket_connect_timeout = socket_connect_timeout
226 self.retry_on_timeout = retry_on_timeout
227 if retry_on_error is SENTINEL:
228 retry_on_error = []
229 if retry_on_timeout:
230 retry_on_error.append(TimeoutError)
231 retry_on_error.append(socket.timeout)
232 retry_on_error.append(asyncio.TimeoutError)
233 self.retry_on_error = retry_on_error
234 if retry or retry_on_error:
235 if not retry:
236 self.retry = Retry(NoBackoff(), 1)
237 else:
238 # deep-copy the Retry object as it is mutable
239 self.retry = copy.deepcopy(retry)
240 # Update the retry's supported errors with the specified errors
241 self.retry.update_supported_errors(retry_on_error)
242 else:
243 self.retry = Retry(NoBackoff(), 0)
244 self.health_check_interval = health_check_interval
245 self.next_health_check: float = -1
246 self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
247 self.redis_connect_func = redis_connect_func
248 self._reader: Optional[asyncio.StreamReader] = None
249 self._writer: Optional[asyncio.StreamWriter] = None
250 self._socket_read_size = socket_read_size
251 self.set_parser(parser_class)
252 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
253 self._buffer_cutoff = 6000
254 self._re_auth_token: Optional[TokenInterface] = None
255 self._should_reconnect = False
256
257 try:
258 p = int(protocol)
259 except TypeError:
260 p = DEFAULT_RESP_VERSION
261 except ValueError:
262 raise ConnectionError("protocol must be an integer")
263 else:
264 if p < 2 or p > 3:
265 raise ConnectionError("protocol must be either 2 or 3")
266 self.protocol = p
267
268 def __del__(self, _warnings: Any = warnings):
269 # For some reason, the individual streams don't get properly garbage
270 # collected and therefore produce no resource warnings. We add one
271 # here, in the same style as those from the stdlib.
272 if getattr(self, "_writer", None):
273 _warnings.warn(
274 f"unclosed Connection {self!r}", ResourceWarning, source=self
275 )
276
277 try:
278 asyncio.get_running_loop()
279 self._close()
280 except RuntimeError:
281 # No actions been taken if pool already closed.
282 pass
283
284 def _close(self):
285 """
286 Internal method to silently close the connection without waiting
287 """
288 if self._writer:
289 self._writer.close()
290 self._writer = self._reader = None
291
292 def __repr__(self):
293 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
294 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
295
296 @abstractmethod
297 def repr_pieces(self):
298 pass
299
300 @property
301 def is_connected(self):
302 return self._reader is not None and self._writer is not None
303
304 def register_connect_callback(self, callback):
305 """
306 Register a callback to be called when the connection is established either
307 initially or reconnected. This allows listeners to issue commands that
308 are ephemeral to the connection, for example pub/sub subscription or
309 key tracking. The callback must be a _method_ and will be kept as
310 a weak reference.
311 """
312 wm = weakref.WeakMethod(callback)
313 if wm not in self._connect_callbacks:
314 self._connect_callbacks.append(wm)
315
316 def deregister_connect_callback(self, callback):
317 """
318 De-register a previously registered callback. It will no-longer receive
319 notifications on connection events. Calling this is not required when the
320 listener goes away, since the callbacks are kept as weak methods.
321 """
322 try:
323 self._connect_callbacks.remove(weakref.WeakMethod(callback))
324 except ValueError:
325 pass
326
327 def set_parser(self, parser_class: Type[BaseParser]) -> None:
328 """
329 Creates a new instance of parser_class with socket size:
330 _socket_read_size and assigns it to the parser for the connection
331 :param parser_class: The required parser class
332 """
333 self._parser = parser_class(socket_read_size=self._socket_read_size)
334
335 async def connect(self):
336 """Connects to the Redis server if not already connected"""
337 # try once the socket connect with the handshake, retry the whole
338 # connect/handshake flow based on retry policy
339 await self.retry.call_with_retry(
340 lambda: self.connect_check_health(
341 check_health=True, retry_socket_connect=False
342 ),
343 lambda error, failure_count: self.disconnect(
344 error=error, failure_count=failure_count
345 ),
346 with_failure_count=True,
347 )
348
349 async def connect_check_health(
350 self, check_health: bool = True, retry_socket_connect: bool = True
351 ):
352 if self.is_connected:
353 return
354 # Track actual retry attempts for error reporting
355 actual_retry_attempts = 0
356
357 def failure_callback(error, failure_count):
358 nonlocal actual_retry_attempts
359 actual_retry_attempts = failure_count
360 return self.disconnect(error=error, failure_count=failure_count)
361
362 try:
363 if retry_socket_connect:
364 await self.retry.call_with_retry(
365 lambda: self._connect(),
366 failure_callback,
367 with_failure_count=True,
368 )
369 else:
370 await self._connect()
371 except asyncio.CancelledError:
372 raise # in 3.7 and earlier, this is an Exception, not BaseException
373 except (socket.timeout, asyncio.TimeoutError):
374 e = TimeoutError("Timeout connecting to server")
375 await record_error_count(
376 server_address=getattr(self, "host", None),
377 server_port=getattr(self, "port", None),
378 network_peer_address=getattr(self, "host", None),
379 network_peer_port=getattr(self, "port", None),
380 error_type=e,
381 retry_attempts=actual_retry_attempts,
382 is_internal=False,
383 )
384 raise e
385 except OSError as e:
386 e = ConnectionError(self._error_message(e))
387 await record_error_count(
388 server_address=getattr(self, "host", None),
389 server_port=getattr(self, "port", None),
390 network_peer_address=getattr(self, "host", None),
391 network_peer_port=getattr(self, "port", None),
392 error_type=e,
393 retry_attempts=actual_retry_attempts,
394 is_internal=False,
395 )
396 raise e
397 except Exception as exc:
398 raise ConnectionError(exc) from exc
399
400 try:
401 if not self.redis_connect_func:
402 # Use the default on_connect function
403 await self.on_connect_check_health(check_health=check_health)
404 else:
405 # Use the passed function redis_connect_func
406 (
407 await self.redis_connect_func(self)
408 if asyncio.iscoroutinefunction(self.redis_connect_func)
409 else self.redis_connect_func(self)
410 )
411 except RedisError:
412 # clean up after any error in on_connect
413 await self.disconnect()
414 raise
415
416 # run any user callbacks. right now the only internal callback
417 # is for pubsub channel/pattern resubscription
418 # first, remove any dead weakrefs
419 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
420 for ref in self._connect_callbacks:
421 callback = ref()
422 task = callback(self)
423 if task and inspect.isawaitable(task):
424 await task
425
426 def mark_for_reconnect(self):
427 self._should_reconnect = True
428
429 def should_reconnect(self):
430 return self._should_reconnect
431
432 def reset_should_reconnect(self):
433 self._should_reconnect = False
434
435 @abstractmethod
436 async def _connect(self):
437 pass
438
439 @abstractmethod
440 def _host_error(self) -> str:
441 pass
442
443 def _error_message(self, exception: BaseException) -> str:
444 return format_error_message(self._host_error(), exception)
445
446 def get_protocol(self):
447 return self.protocol
448
449 async def on_connect(self) -> None:
450 """Initialize the connection, authenticate and select a database"""
451 await self.on_connect_check_health(check_health=True)
452
453 async def on_connect_check_health(self, check_health: bool = True) -> None:
454 self._parser.on_connect(self)
455 parser = self._parser
456
457 auth_args = None
458 # if credential provider or username and/or password are set, authenticate
459 if self.credential_provider or (self.username or self.password):
460 cred_provider = (
461 self.credential_provider
462 or UsernamePasswordCredentialProvider(self.username, self.password)
463 )
464 auth_args = await cred_provider.get_credentials_async()
465
466 # if resp version is specified and we have auth args,
467 # we need to send them via HELLO
468 if auth_args and self.protocol not in [2, "2"]:
469 if isinstance(self._parser, _AsyncRESP2Parser):
470 self.set_parser(_AsyncRESP3Parser)
471 # update cluster exception classes
472 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
473 self._parser.on_connect(self)
474 if len(auth_args) == 1:
475 auth_args = ["default", auth_args[0]]
476 # avoid checking health here -- PING will fail if we try
477 # to check the health prior to the AUTH
478 await self.send_command(
479 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
480 )
481 response = await self.read_response()
482 if response.get(b"proto") != int(self.protocol) and response.get(
483 "proto"
484 ) != int(self.protocol):
485 raise ConnectionError("Invalid RESP version")
486 # avoid checking health here -- PING will fail if we try
487 # to check the health prior to the AUTH
488 elif auth_args:
489 await self.send_command("AUTH", *auth_args, check_health=False)
490
491 try:
492 auth_response = await self.read_response()
493 except AuthenticationWrongNumberOfArgsError:
494 # a username and password were specified but the Redis
495 # server seems to be < 6.0.0 which expects a single password
496 # arg. retry auth with just the password.
497 # https://github.com/andymccurdy/redis-py/issues/1274
498 await self.send_command("AUTH", auth_args[-1], check_health=False)
499 auth_response = await self.read_response()
500
501 if str_if_bytes(auth_response) != "OK":
502 raise AuthenticationError("Invalid Username or Password")
503
504 # if resp version is specified, switch to it
505 elif self.protocol not in [2, "2"]:
506 if isinstance(self._parser, _AsyncRESP2Parser):
507 self.set_parser(_AsyncRESP3Parser)
508 # update cluster exception classes
509 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
510 self._parser.on_connect(self)
511 await self.send_command("HELLO", self.protocol, check_health=check_health)
512 response = await self.read_response()
513 # if response.get(b"proto") != self.protocol and response.get(
514 # "proto"
515 # ) != self.protocol:
516 # raise ConnectionError("Invalid RESP version")
517
518 # if a client_name is given, set it
519 if self.client_name:
520 await self.send_command(
521 "CLIENT",
522 "SETNAME",
523 self.client_name,
524 check_health=check_health,
525 )
526 if str_if_bytes(await self.read_response()) != "OK":
527 raise ConnectionError("Error setting client name")
528
529 # Set the library name and version from driver_info, pipeline for lower startup latency
530 lib_name_sent = False
531 lib_version_sent = False
532
533 if self.driver_info and self.driver_info.formatted_name:
534 await self.send_command(
535 "CLIENT",
536 "SETINFO",
537 "LIB-NAME",
538 self.driver_info.formatted_name,
539 check_health=check_health,
540 )
541 lib_name_sent = True
542
543 if self.driver_info and self.driver_info.lib_version:
544 await self.send_command(
545 "CLIENT",
546 "SETINFO",
547 "LIB-VER",
548 self.driver_info.lib_version,
549 check_health=check_health,
550 )
551 lib_version_sent = True
552
553 # if a database is specified, switch to it. Also pipeline this
554 if self.db:
555 await self.send_command("SELECT", self.db, check_health=check_health)
556
557 # read responses from pipeline
558 for _ in range(sum([lib_name_sent, lib_version_sent])):
559 try:
560 await self.read_response()
561 except ResponseError:
562 pass
563
564 if self.db:
565 if str_if_bytes(await self.read_response()) != "OK":
566 raise ConnectionError("Invalid Database")
567
568 async def disconnect(
569 self,
570 nowait: bool = False,
571 error: Optional[Exception] = None,
572 failure_count: Optional[int] = None,
573 health_check_failed: bool = False,
574 ) -> None:
575 """Disconnects from the Redis server"""
576 try:
577 async with async_timeout(self.socket_connect_timeout):
578 self._parser.on_disconnect()
579 # Reset the reconnect flag
580 self.reset_should_reconnect()
581 if not self.is_connected:
582 return
583 try:
584 self._writer.close() # type: ignore[union-attr]
585 # wait for close to finish, except when handling errors and
586 # forcefully disconnecting.
587 if not nowait:
588 await self._writer.wait_closed() # type: ignore[union-attr]
589 except OSError:
590 pass
591 finally:
592 self._reader = None
593 self._writer = None
594 except asyncio.TimeoutError:
595 raise TimeoutError(
596 f"Timed out closing connection after {self.socket_connect_timeout}"
597 ) from None
598
599 if error:
600 if health_check_failed:
601 close_reason = CloseReason.HEALTHCHECK_FAILED
602 else:
603 close_reason = CloseReason.ERROR
604
605 if failure_count is not None and failure_count > self.retry.get_retries():
606 await record_error_count(
607 server_address=getattr(self, "host", None),
608 server_port=getattr(self, "port", None),
609 network_peer_address=getattr(self, "host", None),
610 network_peer_port=getattr(self, "port", None),
611 error_type=error,
612 retry_attempts=failure_count,
613 )
614
615 await record_connection_closed(
616 close_reason=close_reason,
617 error_type=error,
618 )
619 else:
620 await record_connection_closed(
621 close_reason=CloseReason.APPLICATION_CLOSE,
622 )
623
624 async def _send_ping(self):
625 """Send PING, expect PONG in return"""
626 await self.send_command("PING", check_health=False)
627 if str_if_bytes(await self.read_response()) != "PONG":
628 raise ConnectionError("Bad response from PING health check")
629
630 async def _ping_failed(self, error, failure_count):
631 """Function to call when PING fails"""
632 await self.disconnect(
633 error=error, failure_count=failure_count, health_check_failed=True
634 )
635
636 async def check_health(self):
637 """Check the health of the connection with a PING/PONG"""
638 if (
639 self.health_check_interval
640 and asyncio.get_running_loop().time() > self.next_health_check
641 ):
642 await self.retry.call_with_retry(
643 self._send_ping, self._ping_failed, with_failure_count=True
644 )
645
646 async def _send_packed_command(self, command: Iterable[bytes]) -> None:
647 self._writer.writelines(command)
648 await self._writer.drain()
649
650 async def send_packed_command(
651 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
652 ) -> None:
653 if not self.is_connected:
654 await self.connect_check_health(check_health=False)
655 if check_health:
656 await self.check_health()
657
658 try:
659 if isinstance(command, str):
660 command = command.encode()
661 if isinstance(command, bytes):
662 command = [command]
663 if self.socket_timeout:
664 await asyncio.wait_for(
665 self._send_packed_command(command), self.socket_timeout
666 )
667 else:
668 self._writer.writelines(command)
669 await self._writer.drain()
670 except asyncio.TimeoutError:
671 await self.disconnect(nowait=True)
672 raise TimeoutError("Timeout writing to socket") from None
673 except OSError as e:
674 await self.disconnect(nowait=True)
675 if len(e.args) == 1:
676 err_no, errmsg = "UNKNOWN", e.args[0]
677 else:
678 err_no = e.args[0]
679 errmsg = e.args[1]
680 raise ConnectionError(
681 f"Error {err_no} while writing to socket. {errmsg}."
682 ) from e
683 except BaseException:
684 # BaseExceptions can be raised when a socket send operation is not
685 # finished, e.g. due to a timeout. Ideally, a caller could then re-try
686 # to send un-sent data. However, the send_packed_command() API
687 # does not support it so there is no point in keeping the connection open.
688 await self.disconnect(nowait=True)
689 raise
690
691 async def send_command(self, *args: Any, **kwargs: Any) -> None:
692 """Pack and send a command to the Redis server"""
693 await self.send_packed_command(
694 self.pack_command(*args), check_health=kwargs.get("check_health", True)
695 )
696
697 async def can_read_destructive(self):
698 """Poll the socket to see if there's data that can be read."""
699 try:
700 return await self._parser.can_read_destructive()
701 except OSError as e:
702 await self.disconnect(nowait=True)
703 host_error = self._host_error()
704 raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
705
706 async def read_response(
707 self,
708 disable_decoding: bool = False,
709 timeout: Optional[float] = None,
710 *,
711 disconnect_on_error: bool = True,
712 push_request: Optional[bool] = False,
713 ):
714 """Read the response from a previously sent command"""
715 read_timeout = timeout if timeout is not None else self.socket_timeout
716 host_error = self._host_error()
717 try:
718 if read_timeout is not None and self.protocol in ["3", 3]:
719 async with async_timeout(read_timeout):
720 response = await self._parser.read_response(
721 disable_decoding=disable_decoding, push_request=push_request
722 )
723 elif read_timeout is not None:
724 async with async_timeout(read_timeout):
725 response = await self._parser.read_response(
726 disable_decoding=disable_decoding
727 )
728 elif self.protocol in ["3", 3]:
729 response = await self._parser.read_response(
730 disable_decoding=disable_decoding, push_request=push_request
731 )
732 else:
733 response = await self._parser.read_response(
734 disable_decoding=disable_decoding
735 )
736 except asyncio.TimeoutError:
737 if timeout is not None:
738 # user requested timeout, return None. Operation can be retried
739 return None
740 # it was a self.socket_timeout error.
741 if disconnect_on_error:
742 await self.disconnect(nowait=True)
743 raise TimeoutError(f"Timeout reading from {host_error}")
744 except OSError as e:
745 if disconnect_on_error:
746 await self.disconnect(nowait=True)
747 raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
748 except BaseException:
749 # Also by default close in case of BaseException. A lot of code
750 # relies on this behaviour when doing Command/Response pairs.
751 # See #1128.
752 if disconnect_on_error:
753 await self.disconnect(nowait=True)
754 raise
755
756 if self.health_check_interval:
757 next_time = asyncio.get_running_loop().time() + self.health_check_interval
758 self.next_health_check = next_time
759
760 if isinstance(response, ResponseError):
761 raise response from None
762 return response
763
764 def pack_command(self, *args: EncodableT) -> List[bytes]:
765 """Pack a series of arguments into the Redis protocol"""
766 output = []
767 # the client might have included 1 or more literal arguments in
768 # the command name, e.g., 'CONFIG GET'. The Redis server expects these
769 # arguments to be sent separately, so split the first argument
770 # manually. These arguments should be bytestrings so that they are
771 # not encoded.
772 assert not isinstance(args[0], float)
773 if isinstance(args[0], str):
774 args = tuple(args[0].encode().split()) + args[1:]
775 elif b" " in args[0]:
776 args = tuple(args[0].split()) + args[1:]
777
778 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
779
780 buffer_cutoff = self._buffer_cutoff
781 for arg in map(self.encoder.encode, args):
782 # to avoid large string mallocs, chunk the command into the
783 # output list if we're sending large values or memoryviews
784 arg_length = len(arg)
785 if (
786 len(buff) > buffer_cutoff
787 or arg_length > buffer_cutoff
788 or isinstance(arg, memoryview)
789 ):
790 buff = SYM_EMPTY.join(
791 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
792 )
793 output.append(buff)
794 output.append(arg)
795 buff = SYM_CRLF
796 else:
797 buff = SYM_EMPTY.join(
798 (
799 buff,
800 SYM_DOLLAR,
801 str(arg_length).encode(),
802 SYM_CRLF,
803 arg,
804 SYM_CRLF,
805 )
806 )
807 output.append(buff)
808 return output
809
810 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
811 """Pack multiple commands into the Redis protocol"""
812 output: List[bytes] = []
813 pieces: List[bytes] = []
814 buffer_length = 0
815 buffer_cutoff = self._buffer_cutoff
816
817 for cmd in commands:
818 for chunk in self.pack_command(*cmd):
819 chunklen = len(chunk)
820 if (
821 buffer_length > buffer_cutoff
822 or chunklen > buffer_cutoff
823 or isinstance(chunk, memoryview)
824 ):
825 if pieces:
826 output.append(SYM_EMPTY.join(pieces))
827 buffer_length = 0
828 pieces = []
829
830 if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
831 output.append(chunk)
832 else:
833 pieces.append(chunk)
834 buffer_length += chunklen
835
836 if pieces:
837 output.append(SYM_EMPTY.join(pieces))
838 return output
839
840 def _socket_is_empty(self):
841 """Check if the socket is empty"""
842 return len(self._reader._buffer) == 0
843
844 async def process_invalidation_messages(self):
845 while not self._socket_is_empty():
846 await self.read_response(push_request=True)
847
848 def set_re_auth_token(self, token: TokenInterface):
849 self._re_auth_token = token
850
851 async def re_auth(self):
852 if self._re_auth_token is not None:
853 await self.send_command(
854 "AUTH",
855 self._re_auth_token.try_get("oid"),
856 self._re_auth_token.get_value(),
857 )
858 await self.read_response()
859 self._re_auth_token = None
860
861
862class Connection(AbstractConnection):
863 "Manages TCP communication to and from a Redis server"
864
865 def __init__(
866 self,
867 *,
868 host: str = "localhost",
869 port: Union[str, int] = 6379,
870 socket_keepalive: bool = False,
871 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
872 socket_type: int = 0,
873 **kwargs,
874 ):
875 self.host = host
876 self.port = int(port)
877 self.socket_keepalive = socket_keepalive
878 self.socket_keepalive_options = socket_keepalive_options or {}
879 self.socket_type = socket_type
880 super().__init__(**kwargs)
881
882 def repr_pieces(self):
883 pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
884 if self.client_name:
885 pieces.append(("client_name", self.client_name))
886 return pieces
887
888 def _connection_arguments(self) -> Mapping:
889 return {"host": self.host, "port": self.port}
890
891 async def _connect(self):
892 """Create a TCP socket connection"""
893 async with async_timeout(self.socket_connect_timeout):
894 reader, writer = await asyncio.open_connection(
895 **self._connection_arguments()
896 )
897 self._reader = reader
898 self._writer = writer
899 sock = writer.transport.get_extra_info("socket")
900 if sock:
901 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
902 try:
903 # TCP_KEEPALIVE
904 if self.socket_keepalive:
905 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
906 for k, v in self.socket_keepalive_options.items():
907 sock.setsockopt(socket.SOL_TCP, k, v)
908
909 except (OSError, TypeError):
910 # `socket_keepalive_options` might contain invalid options
911 # causing an error. Do not leave the connection open.
912 writer.close()
913 raise
914
915 def _host_error(self) -> str:
916 return f"{self.host}:{self.port}"
917
918
919class SSLConnection(Connection):
920 """Manages SSL connections to and from the Redis server(s).
921 This class extends the Connection class, adding SSL functionality, and making
922 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
923 """
924
925 def __init__(
926 self,
927 ssl_keyfile: Optional[str] = None,
928 ssl_certfile: Optional[str] = None,
929 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
930 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
931 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
932 ssl_ca_certs: Optional[str] = None,
933 ssl_ca_data: Optional[str] = None,
934 ssl_ca_path: Optional[str] = None,
935 ssl_check_hostname: bool = True,
936 ssl_min_version: Optional[TLSVersion] = None,
937 ssl_ciphers: Optional[str] = None,
938 ssl_password: Optional[str] = None,
939 **kwargs,
940 ):
941 if not SSL_AVAILABLE:
942 raise RedisError("Python wasn't built with SSL support")
943
944 self.ssl_context: RedisSSLContext = RedisSSLContext(
945 keyfile=ssl_keyfile,
946 certfile=ssl_certfile,
947 cert_reqs=ssl_cert_reqs,
948 include_verify_flags=ssl_include_verify_flags,
949 exclude_verify_flags=ssl_exclude_verify_flags,
950 ca_certs=ssl_ca_certs,
951 ca_data=ssl_ca_data,
952 ca_path=ssl_ca_path,
953 check_hostname=ssl_check_hostname,
954 min_version=ssl_min_version,
955 ciphers=ssl_ciphers,
956 password=ssl_password,
957 )
958 super().__init__(**kwargs)
959
960 def _connection_arguments(self) -> Mapping:
961 kwargs = super()._connection_arguments()
962 kwargs["ssl"] = self.ssl_context.get()
963 return kwargs
964
965 @property
966 def keyfile(self):
967 return self.ssl_context.keyfile
968
969 @property
970 def certfile(self):
971 return self.ssl_context.certfile
972
973 @property
974 def cert_reqs(self):
975 return self.ssl_context.cert_reqs
976
977 @property
978 def include_verify_flags(self):
979 return self.ssl_context.include_verify_flags
980
981 @property
982 def exclude_verify_flags(self):
983 return self.ssl_context.exclude_verify_flags
984
985 @property
986 def ca_certs(self):
987 return self.ssl_context.ca_certs
988
989 @property
990 def ca_data(self):
991 return self.ssl_context.ca_data
992
993 @property
994 def check_hostname(self):
995 return self.ssl_context.check_hostname
996
997 @property
998 def min_version(self):
999 return self.ssl_context.min_version
1000
1001
1002class RedisSSLContext:
1003 __slots__ = (
1004 "keyfile",
1005 "certfile",
1006 "cert_reqs",
1007 "include_verify_flags",
1008 "exclude_verify_flags",
1009 "ca_certs",
1010 "ca_data",
1011 "ca_path",
1012 "context",
1013 "check_hostname",
1014 "min_version",
1015 "ciphers",
1016 "password",
1017 )
1018
1019 def __init__(
1020 self,
1021 keyfile: Optional[str] = None,
1022 certfile: Optional[str] = None,
1023 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
1024 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1025 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
1026 ca_certs: Optional[str] = None,
1027 ca_data: Optional[str] = None,
1028 ca_path: Optional[str] = None,
1029 check_hostname: bool = False,
1030 min_version: Optional[TLSVersion] = None,
1031 ciphers: Optional[str] = None,
1032 password: Optional[str] = None,
1033 ):
1034 if not SSL_AVAILABLE:
1035 raise RedisError("Python wasn't built with SSL support")
1036
1037 self.keyfile = keyfile
1038 self.certfile = certfile
1039 if cert_reqs is None:
1040 cert_reqs = ssl.CERT_NONE
1041 elif isinstance(cert_reqs, str):
1042 CERT_REQS = { # noqa: N806
1043 "none": ssl.CERT_NONE,
1044 "optional": ssl.CERT_OPTIONAL,
1045 "required": ssl.CERT_REQUIRED,
1046 }
1047 if cert_reqs not in CERT_REQS:
1048 raise RedisError(
1049 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
1050 )
1051 cert_reqs = CERT_REQS[cert_reqs]
1052 self.cert_reqs = cert_reqs
1053 self.include_verify_flags = include_verify_flags
1054 self.exclude_verify_flags = exclude_verify_flags
1055 self.ca_certs = ca_certs
1056 self.ca_data = ca_data
1057 self.ca_path = ca_path
1058 self.check_hostname = (
1059 check_hostname if self.cert_reqs != ssl.CERT_NONE else False
1060 )
1061 self.min_version = min_version
1062 self.ciphers = ciphers
1063 self.password = password
1064 self.context: Optional[SSLContext] = None
1065
1066 def get(self) -> SSLContext:
1067 if not self.context:
1068 context = ssl.create_default_context()
1069 context.check_hostname = self.check_hostname
1070 context.verify_mode = self.cert_reqs
1071 if self.include_verify_flags:
1072 for flag in self.include_verify_flags:
1073 context.verify_flags |= flag
1074 if self.exclude_verify_flags:
1075 for flag in self.exclude_verify_flags:
1076 context.verify_flags &= ~flag
1077 if self.certfile or self.keyfile:
1078 context.load_cert_chain(
1079 certfile=self.certfile,
1080 keyfile=self.keyfile,
1081 password=self.password,
1082 )
1083 if self.ca_certs or self.ca_data or self.ca_path:
1084 context.load_verify_locations(
1085 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1086 )
1087 if self.min_version is not None:
1088 context.minimum_version = self.min_version
1089 if self.ciphers is not None:
1090 context.set_ciphers(self.ciphers)
1091 self.context = context
1092 return self.context
1093
1094
1095class UnixDomainSocketConnection(AbstractConnection):
1096 "Manages UDS communication to and from a Redis server"
1097
1098 def __init__(self, *, path: str = "", **kwargs):
1099 self.path = path
1100 super().__init__(**kwargs)
1101
1102 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
1103 pieces = [("path", self.path), ("db", self.db)]
1104 if self.client_name:
1105 pieces.append(("client_name", self.client_name))
1106 return pieces
1107
1108 async def _connect(self):
1109 async with async_timeout(self.socket_connect_timeout):
1110 reader, writer = await asyncio.open_unix_connection(path=self.path)
1111 self._reader = reader
1112 self._writer = writer
1113 await self.on_connect()
1114
1115 def _host_error(self) -> str:
1116 return self.path
1117
1118
1119FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1120
1121
1122def to_bool(value) -> Optional[bool]:
1123 if value is None or value == "":
1124 return None
1125 if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1126 return False
1127 return bool(value)
1128
1129
1130def parse_ssl_verify_flags(value):
1131 # flags are passed in as a string representation of a list,
1132 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1133 verify_flags_str = value.replace("[", "").replace("]", "")
1134
1135 verify_flags = []
1136 for flag in verify_flags_str.split(","):
1137 flag = flag.strip()
1138 if not hasattr(VerifyFlags, flag):
1139 raise ValueError(f"Invalid ssl verify flag: {flag}")
1140 verify_flags.append(getattr(VerifyFlags, flag))
1141 return verify_flags
1142
1143
1144URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
1145 {
1146 "db": int,
1147 "socket_timeout": float,
1148 "socket_connect_timeout": float,
1149 "socket_keepalive": to_bool,
1150 "retry_on_timeout": to_bool,
1151 "max_connections": int,
1152 "health_check_interval": int,
1153 "ssl_check_hostname": to_bool,
1154 "ssl_include_verify_flags": parse_ssl_verify_flags,
1155 "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1156 "timeout": float,
1157 }
1158)
1159
1160
1161class ConnectKwargs(TypedDict, total=False):
1162 username: str
1163 password: str
1164 connection_class: Type[AbstractConnection]
1165 host: str
1166 port: int
1167 db: int
1168 path: str
1169
1170
1171def parse_url(url: str) -> ConnectKwargs:
1172 parsed: ParseResult = urlparse(url)
1173 kwargs: ConnectKwargs = {}
1174
1175 for name, value_list in parse_qs(parsed.query).items():
1176 if value_list and len(value_list) > 0:
1177 value = unquote(value_list[0])
1178 parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1179 if parser:
1180 try:
1181 kwargs[name] = parser(value)
1182 except (TypeError, ValueError):
1183 raise ValueError(f"Invalid value for '{name}' in connection URL.")
1184 else:
1185 kwargs[name] = value
1186
1187 if parsed.username:
1188 kwargs["username"] = unquote(parsed.username)
1189 if parsed.password:
1190 kwargs["password"] = unquote(parsed.password)
1191
1192 # We only support redis://, rediss:// and unix:// schemes.
1193 if parsed.scheme == "unix":
1194 if parsed.path:
1195 kwargs["path"] = unquote(parsed.path)
1196 kwargs["connection_class"] = UnixDomainSocketConnection
1197
1198 elif parsed.scheme in ("redis", "rediss"):
1199 if parsed.hostname:
1200 kwargs["host"] = unquote(parsed.hostname)
1201 if parsed.port:
1202 kwargs["port"] = int(parsed.port)
1203
1204 # If there's a path argument, use it as the db argument if a
1205 # querystring value wasn't specified
1206 if parsed.path and "db" not in kwargs:
1207 try:
1208 kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
1209 except (AttributeError, ValueError):
1210 pass
1211
1212 if parsed.scheme == "rediss":
1213 kwargs["connection_class"] = SSLConnection
1214
1215 else:
1216 valid_schemes = "redis://, rediss://, unix://"
1217 raise ValueError(
1218 f"Redis URL must specify one of the following schemes ({valid_schemes})"
1219 )
1220
1221 return kwargs
1222
1223
1224_CP = TypeVar("_CP", bound="ConnectionPool")
1225
1226
1227class ConnectionPool:
1228 """
1229 Create a connection pool. ``If max_connections`` is set, then this
1230 object raises :py:class:`~redis.ConnectionError` when the pool's
1231 limit is reached.
1232
1233 By default, TCP connections are created unless ``connection_class``
1234 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
1235 unix sockets.
1236 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1237
1238 Any additional keyword arguments are passed to the constructor of
1239 ``connection_class``.
1240 """
1241
1242 @classmethod
1243 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1244 """
1245 Return a connection pool configured from the given URL.
1246
1247 For example::
1248
1249 redis://[[username]:[password]]@localhost:6379/0
1250 rediss://[[username]:[password]]@localhost:6379/0
1251 unix://[username@]/path/to/socket.sock?db=0[&password=password]
1252
1253 Three URL schemes are supported:
1254
1255 - `redis://` creates a TCP socket connection. See more at:
1256 <https://www.iana.org/assignments/uri-schemes/prov/redis>
1257 - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
1258 <https://www.iana.org/assignments/uri-schemes/prov/rediss>
1259 - ``unix://``: creates a Unix Domain Socket connection.
1260
1261 The username, password, hostname, path and all querystring values
1262 are passed through urllib.parse.unquote in order to replace any
1263 percent-encoded values with their corresponding characters.
1264
1265 There are several ways to specify a database number. The first value
1266 found will be used:
1267
1268 1. A ``db`` querystring option, e.g. redis://localhost?db=0
1269
1270 2. If using the redis:// or rediss:// schemes, the path argument
1271 of the url, e.g. redis://localhost/0
1272
1273 3. A ``db`` keyword argument to this function.
1274
1275 If none of these options are specified, the default db=0 is used.
1276
1277 All querystring options are cast to their appropriate Python types.
1278 Boolean arguments can be specified with string values "True"/"False"
1279 or "Yes"/"No". Values that cannot be properly cast cause a
1280 ``ValueError`` to be raised. Once parsed, the querystring arguments
1281 and keyword arguments are passed to the ``ConnectionPool``'s
1282 class initializer. In the case of conflicting arguments, querystring
1283 arguments always win.
1284 """
1285 url_options = parse_url(url)
1286 kwargs.update(url_options)
1287 return cls(**kwargs)
1288
1289 def __init__(
1290 self,
1291 connection_class: Type[AbstractConnection] = Connection,
1292 max_connections: Optional[int] = None,
1293 **connection_kwargs,
1294 ):
1295 max_connections = max_connections or 2**31
1296 if not isinstance(max_connections, int) or max_connections < 0:
1297 raise ValueError('"max_connections" must be a positive integer')
1298
1299 self.connection_class = connection_class
1300 self.connection_kwargs = connection_kwargs
1301 self.max_connections = max_connections
1302
1303 self._available_connections: List[AbstractConnection] = []
1304 self._in_use_connections: Set[AbstractConnection] = set()
1305 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1306 self._lock = asyncio.Lock()
1307 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1308 if self._event_dispatcher is None:
1309 self._event_dispatcher = EventDispatcher()
1310
1311 def __repr__(self):
1312 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
1313 return (
1314 f"<{self.__class__.__module__}.{self.__class__.__name__}"
1315 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
1316 f"({conn_kwargs})>)>"
1317 )
1318
1319 def reset(self):
1320 self._available_connections = []
1321 self._in_use_connections = weakref.WeakSet()
1322
1323 def can_get_connection(self) -> bool:
1324 """Return True if a connection can be retrieved from the pool."""
1325 return (
1326 self._available_connections
1327 or len(self._in_use_connections) < self.max_connections
1328 )
1329
1330 @deprecated_args(
1331 args_to_warn=["*"],
1332 reason="Use get_connection() without args instead",
1333 version="5.3.0",
1334 )
1335 async def get_connection(self, command_name=None, *keys, **options):
1336 """Get a connected connection from the pool"""
1337 # Track connection count before to detect if a new connection is created
1338 async with self._lock:
1339 connections_before = len(self._available_connections) + len(
1340 self._in_use_connections
1341 )
1342 start_time_created = time.monotonic()
1343 connection = self.get_available_connection()
1344 connections_after = len(self._available_connections) + len(
1345 self._in_use_connections
1346 )
1347 is_created = connections_after > connections_before
1348
1349 # We now perform the connection check outside of the lock.
1350 try:
1351 await self.ensure_connection(connection)
1352
1353 if is_created:
1354 await record_connection_create_time(
1355 connection_pool=self,
1356 duration_seconds=time.monotonic() - start_time_created,
1357 )
1358
1359 return connection
1360 except BaseException:
1361 await self.release(connection)
1362 raise
1363
1364 def get_available_connection(self):
1365 """Get a connection from the pool, without making sure it is connected"""
1366 try:
1367 connection = self._available_connections.pop()
1368 except IndexError:
1369 if len(self._in_use_connections) >= self.max_connections:
1370 raise MaxConnectionsError("Too many connections") from None
1371 connection = self.make_connection()
1372 self._in_use_connections.add(connection)
1373 return connection
1374
1375 def get_encoder(self):
1376 """Return an encoder based on encoding settings"""
1377 kwargs = self.connection_kwargs
1378 return self.encoder_class(
1379 encoding=kwargs.get("encoding", "utf-8"),
1380 encoding_errors=kwargs.get("encoding_errors", "strict"),
1381 decode_responses=kwargs.get("decode_responses", False),
1382 )
1383
1384 def make_connection(self):
1385 """Create a new connection. Can be overridden by child classes."""
1386 return self.connection_class(**self.connection_kwargs)
1387
1388 async def ensure_connection(self, connection: AbstractConnection):
1389 """Ensure that the connection object is connected and valid"""
1390 await connection.connect()
1391 # connections that the pool provides should be ready to send
1392 # a command. if not, the connection was either returned to the
1393 # pool before all data has been read or the socket has been
1394 # closed. either way, reconnect and verify everything is good.
1395 try:
1396 if await connection.can_read_destructive():
1397 raise ConnectionError("Connection has data") from None
1398 except (ConnectionError, TimeoutError, OSError):
1399 await connection.disconnect()
1400 await connection.connect()
1401 if await connection.can_read_destructive():
1402 raise ConnectionError("Connection not ready") from None
1403
1404 async def release(self, connection: AbstractConnection):
1405 """Releases the connection back to the pool"""
1406 # Connections should always be returned to the correct pool,
1407 # not doing so is an error that will cause an exception here.
1408 self._in_use_connections.remove(connection)
1409 if connection.should_reconnect():
1410 await connection.disconnect()
1411
1412 self._available_connections.append(connection)
1413 await self._event_dispatcher.dispatch_async(
1414 AsyncAfterConnectionReleasedEvent(connection)
1415 )
1416
1417 async def disconnect(self, inuse_connections: bool = True):
1418 """
1419 Disconnects connections in the pool
1420
1421 If ``inuse_connections`` is True, disconnect connections that are
1422 current in use, potentially by other tasks. Otherwise only disconnect
1423 connections that are idle in the pool.
1424 """
1425 if inuse_connections:
1426 connections: Iterable[AbstractConnection] = chain(
1427 self._available_connections, self._in_use_connections
1428 )
1429 else:
1430 connections = self._available_connections
1431 resp = await asyncio.gather(
1432 *(connection.disconnect() for connection in connections),
1433 return_exceptions=True,
1434 )
1435 exc = next((r for r in resp if isinstance(r, BaseException)), None)
1436 if exc:
1437 raise exc
1438
1439 async def update_active_connections_for_reconnect(self):
1440 """
1441 Mark all active connections for reconnect.
1442 """
1443 async with self._lock:
1444 for conn in self._in_use_connections:
1445 conn.mark_for_reconnect()
1446
1447 async def aclose(self) -> None:
1448 """Close the pool, disconnecting all connections"""
1449 await self.disconnect()
1450
1451 def set_retry(self, retry: "Retry") -> None:
1452 for conn in self._available_connections:
1453 conn.retry = retry
1454 for conn in self._in_use_connections:
1455 conn.retry = retry
1456
1457 async def re_auth_callback(self, token: TokenInterface):
1458 async with self._lock:
1459 for conn in self._available_connections:
1460 await conn.retry.call_with_retry(
1461 lambda: conn.send_command(
1462 "AUTH", token.try_get("oid"), token.get_value()
1463 ),
1464 lambda error: self._mock(error),
1465 )
1466 await conn.retry.call_with_retry(
1467 lambda: conn.read_response(), lambda error: self._mock(error)
1468 )
1469 for conn in self._in_use_connections:
1470 conn.set_re_auth_token(token)
1471
1472 async def _mock(self, error: RedisError):
1473 """
1474 Dummy functions, needs to be passed as error callback to retry object.
1475 :param error:
1476 :return:
1477 """
1478 pass
1479
1480 def get_connection_count(self) -> List[tuple[int, dict]]:
1481 """
1482 Returns a connection count (both idle and in use).
1483 """
1484 attributes = AttributeBuilder.build_base_attributes()
1485 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
1486 free_connections_attributes = attributes.copy()
1487 in_use_connections_attributes = attributes.copy()
1488
1489 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1490 ConnectionState.IDLE.value
1491 )
1492 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
1493 ConnectionState.USED.value
1494 )
1495
1496 return [
1497 (len(self._available_connections), free_connections_attributes),
1498 (len(self._in_use_connections), in_use_connections_attributes),
1499 ]
1500
1501
1502class BlockingConnectionPool(ConnectionPool):
1503 """
1504 A blocking connection pool::
1505
1506 >>> from redis.asyncio import Redis, BlockingConnectionPool
1507 >>> client = Redis.from_pool(BlockingConnectionPool())
1508
1509 It performs the same function as the default
1510 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
1511 it maintains a pool of reusable connections that can be shared by
1512 multiple async redis clients.
1513
1514 The difference is that, in the event that a client tries to get a
1515 connection from the pool when all of connections are in use, rather than
1516 raising a :py:class:`~redis.ConnectionError` (as the default
1517 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
1518 blocks the current `Task` for a specified number of seconds until
1519 a connection becomes available.
1520
1521 Use ``max_connections`` to increase / decrease the pool size::
1522
1523 >>> pool = BlockingConnectionPool(max_connections=10)
1524
1525 Use ``timeout`` to tell it either how many seconds to wait for a connection
1526 to become available, or to block forever:
1527
1528 >>> # Block forever.
1529 >>> pool = BlockingConnectionPool(timeout=None)
1530
1531 >>> # Raise a ``ConnectionError`` after five seconds if a connection is
1532 >>> # not available.
1533 >>> pool = BlockingConnectionPool(timeout=5)
1534 """
1535
1536 def __init__(
1537 self,
1538 max_connections: int = 50,
1539 timeout: Optional[float] = 20,
1540 connection_class: Type[AbstractConnection] = Connection,
1541 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
1542 **connection_kwargs,
1543 ):
1544 super().__init__(
1545 connection_class=connection_class,
1546 max_connections=max_connections,
1547 **connection_kwargs,
1548 )
1549 self._condition = asyncio.Condition()
1550 self.timeout = timeout
1551
1552 @deprecated_args(
1553 args_to_warn=["*"],
1554 reason="Use get_connection() without args instead",
1555 version="5.3.0",
1556 )
1557 async def get_connection(self, command_name=None, *keys, **options):
1558 """Gets a connection from the pool, blocking until one is available"""
1559 # Start timing for wait time observability
1560 start_time_acquired = time.monotonic()
1561
1562 try:
1563 async with self._condition:
1564 async with async_timeout(self.timeout):
1565 await self._condition.wait_for(self.can_get_connection)
1566 # Track connection count before to detect if a new connection is created
1567 connections_before = len(self._available_connections) + len(
1568 self._in_use_connections
1569 )
1570 start_time_created = time.monotonic()
1571 connection = super().get_available_connection()
1572 connections_after = len(self._available_connections) + len(
1573 self._in_use_connections
1574 )
1575 is_created = connections_after > connections_before
1576 except asyncio.TimeoutError as err:
1577 raise ConnectionError("No connection available.") from err
1578
1579 # We now perform the connection check outside of the lock.
1580 try:
1581 await self.ensure_connection(connection)
1582
1583 if is_created:
1584 await record_connection_create_time(
1585 connection_pool=self,
1586 duration_seconds=time.monotonic() - start_time_created,
1587 )
1588
1589 await record_connection_wait_time(
1590 pool_name=get_pool_name(self),
1591 duration_seconds=time.monotonic() - start_time_acquired,
1592 )
1593
1594 return connection
1595 except BaseException:
1596 await self.release(connection)
1597 raise
1598
1599 async def release(self, connection: AbstractConnection):
1600 """Releases the connection back to the pool."""
1601 async with self._condition:
1602 await super().release(connection)
1603 self._condition.notify()