Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

788 statements  

1import asyncio 

2import copy 

3import enum 

4import inspect 

5import socket 

6import sys 

7import time 

8import warnings 

9import weakref 

10from abc import abstractmethod 

11from itertools import chain 

12from types import MappingProxyType 

13from typing import ( 

14 Any, 

15 Callable, 

16 Iterable, 

17 List, 

18 Mapping, 

19 Optional, 

20 Protocol, 

21 Set, 

22 Tuple, 

23 Type, 

24 TypedDict, 

25 TypeVar, 

26 Union, 

27) 

28from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

29 

30from ..observability.attributes import ( 

31 DB_CLIENT_CONNECTION_POOL_NAME, 

32 DB_CLIENT_CONNECTION_STATE, 

33 AttributeBuilder, 

34 ConnectionState, 

35 get_pool_name, 

36) 

37from ..utils import SSL_AVAILABLE 

38 

39if SSL_AVAILABLE: 

40 import ssl 

41 from ssl import SSLContext, TLSVersion, VerifyFlags 

42else: 

43 ssl = None 

44 TLSVersion = None 

45 SSLContext = None 

46 VerifyFlags = None 

47 

48from ..auth.token import TokenInterface 

49from ..driver_info import DriverInfo, resolve_driver_info 

50from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

51from ..utils import deprecated_args, format_error_message 

52 

53# the functionality is available in 3.11.x but has a major issue before 

54# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

55if sys.version_info >= (3, 11, 3): 

56 from asyncio import timeout as async_timeout 

57else: 

58 from async_timeout import timeout as async_timeout 

59 

60from redis.asyncio.observability.recorder import ( 

61 record_connection_closed, 

62 record_connection_count, 

63 record_connection_create_time, 

64 record_connection_wait_time, 

65 record_error_count, 

66) 

67from redis.asyncio.retry import Retry 

68from redis.backoff import NoBackoff 

69from redis.connection import DEFAULT_RESP_VERSION 

70from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

71from redis.exceptions import ( 

72 AuthenticationError, 

73 AuthenticationWrongNumberOfArgsError, 

74 ConnectionError, 

75 DataError, 

76 MaxConnectionsError, 

77 RedisError, 

78 ResponseError, 

79 TimeoutError, 

80) 

81from redis.observability.metrics import CloseReason 

82from redis.typing import EncodableT 

83from redis.utils import HIREDIS_AVAILABLE, str_if_bytes 

84 

85from .._parsers import ( 

86 BaseParser, 

87 Encoder, 

88 _AsyncHiredisParser, 

89 _AsyncRESP2Parser, 

90 _AsyncRESP3Parser, 

91) 

92 

93SYM_STAR = b"*" 

94SYM_DOLLAR = b"$" 

95SYM_CRLF = b"\r\n" 

96SYM_LF = b"\n" 

97SYM_EMPTY = b"" 

98 

99 

100class _Sentinel(enum.Enum): 

101 sentinel = object() 

102 

103 

104SENTINEL = _Sentinel.sentinel 

105 

106 

107DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

108if HIREDIS_AVAILABLE: 

109 DefaultParser = _AsyncHiredisParser 

110else: 

111 DefaultParser = _AsyncRESP2Parser 

112 

113 

114class ConnectCallbackProtocol(Protocol): 

115 def __call__(self, connection: "AbstractConnection"): ... 

116 

117 

118class AsyncConnectCallbackProtocol(Protocol): 

119 async def __call__(self, connection: "AbstractConnection"): ... 

120 

121 

122ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

123 

124 

125class AbstractConnection: 

126 """Manages communication to and from a Redis server""" 

127 

128 __slots__ = ( 

129 "db", 

130 "username", 

131 "client_name", 

132 "lib_name", 

133 "lib_version", 

134 "credential_provider", 

135 "password", 

136 "socket_timeout", 

137 "socket_connect_timeout", 

138 "redis_connect_func", 

139 "retry_on_timeout", 

140 "retry_on_error", 

141 "health_check_interval", 

142 "next_health_check", 

143 "last_active_at", 

144 "encoder", 

145 "ssl_context", 

146 "protocol", 

147 "_reader", 

148 "_writer", 

149 "_parser", 

150 "_connect_callbacks", 

151 "_buffer_cutoff", 

152 "_lock", 

153 "_socket_read_size", 

154 "__dict__", 

155 ) 

156 

157 @deprecated_args( 

158 args_to_warn=["lib_name", "lib_version"], 

159 reason="Use 'driver_info' parameter instead. " 

160 "lib_name and lib_version will be removed in a future version.", 

161 ) 

162 def __init__( 

163 self, 

164 *, 

165 db: Union[str, int] = 0, 

166 password: Optional[str] = None, 

167 socket_timeout: Optional[float] = None, 

168 socket_connect_timeout: Optional[float] = None, 

169 retry_on_timeout: bool = False, 

170 retry_on_error: Union[list, _Sentinel] = SENTINEL, 

171 encoding: str = "utf-8", 

172 encoding_errors: str = "strict", 

173 decode_responses: bool = False, 

174 parser_class: Type[BaseParser] = DefaultParser, 

175 socket_read_size: int = 65536, 

176 health_check_interval: float = 0, 

177 client_name: Optional[str] = None, 

178 lib_name: Optional[str] = None, 

179 lib_version: Optional[str] = None, 

180 driver_info: Optional[DriverInfo] = None, 

181 username: Optional[str] = None, 

182 retry: Optional[Retry] = None, 

183 redis_connect_func: Optional[ConnectCallbackT] = None, 

184 encoder_class: Type[Encoder] = Encoder, 

185 credential_provider: Optional[CredentialProvider] = None, 

186 protocol: Optional[int] = 2, 

187 event_dispatcher: Optional[EventDispatcher] = None, 

188 ): 

189 """ 

190 Initialize a new async Connection. 

191 

192 Parameters 

193 ---------- 

194 driver_info : DriverInfo, optional 

195 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

196 are ignored. If not provided, a DriverInfo will be created from lib_name 

197 and lib_version (or defaults if those are also None). 

198 lib_name : str, optional 

199 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

200 lib_version : str, optional 

201 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

202 """ 

203 if (username or password) and credential_provider is not None: 

204 raise DataError( 

205 "'username' and 'password' cannot be passed along with 'credential_" 

206 "provider'. Please provide only one of the following arguments: \n" 

207 "1. 'password' and (optional) 'username'\n" 

208 "2. 'credential_provider'" 

209 ) 

210 if event_dispatcher is None: 

211 self._event_dispatcher = EventDispatcher() 

212 else: 

213 self._event_dispatcher = event_dispatcher 

214 self.db = db 

215 self.client_name = client_name 

216 

217 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

218 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

219 

220 self.credential_provider = credential_provider 

221 self.password = password 

222 self.username = username 

223 self.socket_timeout = socket_timeout 

224 if socket_connect_timeout is None: 

225 socket_connect_timeout = socket_timeout 

226 self.socket_connect_timeout = socket_connect_timeout 

227 self.retry_on_timeout = retry_on_timeout 

228 if retry_on_error is SENTINEL: 

229 retry_on_error = [] 

230 if retry_on_timeout: 

231 retry_on_error.append(TimeoutError) 

232 retry_on_error.append(socket.timeout) 

233 retry_on_error.append(asyncio.TimeoutError) 

234 self.retry_on_error = retry_on_error 

235 if retry or retry_on_error: 

236 if not retry: 

237 self.retry = Retry(NoBackoff(), 1) 

238 else: 

239 # deep-copy the Retry object as it is mutable 

240 self.retry = copy.deepcopy(retry) 

241 # Update the retry's supported errors with the specified errors 

242 self.retry.update_supported_errors(retry_on_error) 

243 else: 

244 self.retry = Retry(NoBackoff(), 0) 

245 self.health_check_interval = health_check_interval 

246 self.next_health_check: float = -1 

247 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

248 self.redis_connect_func = redis_connect_func 

249 self._reader: Optional[asyncio.StreamReader] = None 

250 self._writer: Optional[asyncio.StreamWriter] = None 

251 self._socket_read_size = socket_read_size 

252 self.set_parser(parser_class) 

253 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

254 self._buffer_cutoff = 6000 

255 self._re_auth_token: Optional[TokenInterface] = None 

256 self._should_reconnect = False 

257 

258 try: 

259 p = int(protocol) 

260 except TypeError: 

261 p = DEFAULT_RESP_VERSION 

262 except ValueError: 

263 raise ConnectionError("protocol must be an integer") 

264 else: 

265 if p < 2 or p > 3: 

266 raise ConnectionError("protocol must be either 2 or 3") 

267 self.protocol = p 

268 

269 def __del__(self, _warnings: Any = warnings): 

270 # For some reason, the individual streams don't get properly garbage 

271 # collected and therefore produce no resource warnings. We add one 

272 # here, in the same style as those from the stdlib. 

273 if getattr(self, "_writer", None): 

274 _warnings.warn( 

275 f"unclosed Connection {self!r}", ResourceWarning, source=self 

276 ) 

277 

278 try: 

279 asyncio.get_running_loop() 

280 self._close() 

281 except RuntimeError: 

282 # No actions been taken if pool already closed. 

283 pass 

284 

285 def _close(self): 

286 """ 

287 Internal method to silently close the connection without waiting 

288 """ 

289 if self._writer: 

290 self._writer.close() 

291 self._writer = self._reader = None 

292 

293 def __repr__(self): 

294 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

295 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

296 

297 @abstractmethod 

298 def repr_pieces(self): 

299 pass 

300 

301 @property 

302 def is_connected(self): 

303 return self._reader is not None and self._writer is not None 

304 

305 def register_connect_callback(self, callback): 

306 """ 

307 Register a callback to be called when the connection is established either 

308 initially or reconnected. This allows listeners to issue commands that 

309 are ephemeral to the connection, for example pub/sub subscription or 

310 key tracking. The callback must be a _method_ and will be kept as 

311 a weak reference. 

312 """ 

313 wm = weakref.WeakMethod(callback) 

314 if wm not in self._connect_callbacks: 

315 self._connect_callbacks.append(wm) 

316 

317 def deregister_connect_callback(self, callback): 

318 """ 

319 De-register a previously registered callback. It will no-longer receive 

320 notifications on connection events. Calling this is not required when the 

321 listener goes away, since the callbacks are kept as weak methods. 

322 """ 

323 try: 

324 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

325 except ValueError: 

326 pass 

327 

328 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

329 """ 

330 Creates a new instance of parser_class with socket size: 

331 _socket_read_size and assigns it to the parser for the connection 

332 :param parser_class: The required parser class 

333 """ 

334 self._parser = parser_class(socket_read_size=self._socket_read_size) 

335 

336 async def connect(self): 

337 """Connects to the Redis server if not already connected""" 

338 # try once the socket connect with the handshake, retry the whole 

339 # connect/handshake flow based on retry policy 

340 await self.retry.call_with_retry( 

341 lambda: self.connect_check_health( 

342 check_health=True, retry_socket_connect=False 

343 ), 

344 lambda error, failure_count: self.disconnect( 

345 error=error, failure_count=failure_count 

346 ), 

347 with_failure_count=True, 

348 ) 

349 

350 async def connect_check_health( 

351 self, check_health: bool = True, retry_socket_connect: bool = True 

352 ): 

353 if self.is_connected: 

354 return 

355 # Track actual retry attempts for error reporting 

356 actual_retry_attempts = 0 

357 

358 def failure_callback(error, failure_count): 

359 nonlocal actual_retry_attempts 

360 actual_retry_attempts = failure_count 

361 return self.disconnect(error=error, failure_count=failure_count) 

362 

363 try: 

364 if retry_socket_connect: 

365 await self.retry.call_with_retry( 

366 lambda: self._connect(), 

367 failure_callback, 

368 with_failure_count=True, 

369 ) 

370 else: 

371 await self._connect() 

372 except asyncio.CancelledError: 

373 raise # in 3.7 and earlier, this is an Exception, not BaseException 

374 except (socket.timeout, asyncio.TimeoutError): 

375 e = TimeoutError("Timeout connecting to server") 

376 await record_error_count( 

377 server_address=getattr(self, "host", None), 

378 server_port=getattr(self, "port", None), 

379 network_peer_address=getattr(self, "host", None), 

380 network_peer_port=getattr(self, "port", None), 

381 error_type=e, 

382 retry_attempts=actual_retry_attempts, 

383 is_internal=False, 

384 ) 

385 raise e 

386 except OSError as e: 

387 e = ConnectionError(self._error_message(e)) 

388 await record_error_count( 

389 server_address=getattr(self, "host", None), 

390 server_port=getattr(self, "port", None), 

391 network_peer_address=getattr(self, "host", None), 

392 network_peer_port=getattr(self, "port", None), 

393 error_type=e, 

394 retry_attempts=actual_retry_attempts, 

395 is_internal=False, 

396 ) 

397 raise e 

398 except Exception as exc: 

399 raise ConnectionError(exc) from exc 

400 

401 try: 

402 if not self.redis_connect_func: 

403 # Use the default on_connect function 

404 await self.on_connect_check_health(check_health=check_health) 

405 else: 

406 # Use the passed function redis_connect_func 

407 ( 

408 await self.redis_connect_func(self) 

409 if asyncio.iscoroutinefunction(self.redis_connect_func) 

410 else self.redis_connect_func(self) 

411 ) 

412 except RedisError: 

413 # clean up after any error in on_connect 

414 await self.disconnect() 

415 raise 

416 

417 # run any user callbacks. right now the only internal callback 

418 # is for pubsub channel/pattern resubscription 

419 # first, remove any dead weakrefs 

420 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

421 for ref in self._connect_callbacks: 

422 callback = ref() 

423 task = callback(self) 

424 if task and inspect.isawaitable(task): 

425 await task 

426 

427 def mark_for_reconnect(self): 

428 self._should_reconnect = True 

429 

430 def should_reconnect(self): 

431 return self._should_reconnect 

432 

433 def reset_should_reconnect(self): 

434 self._should_reconnect = False 

435 

436 @abstractmethod 

437 async def _connect(self): 

438 pass 

439 

440 @abstractmethod 

441 def _host_error(self) -> str: 

442 pass 

443 

444 def _error_message(self, exception: BaseException) -> str: 

445 return format_error_message(self._host_error(), exception) 

446 

447 def get_protocol(self): 

448 return self.protocol 

449 

450 async def on_connect(self) -> None: 

451 """Initialize the connection, authenticate and select a database""" 

452 await self.on_connect_check_health(check_health=True) 

453 

454 async def on_connect_check_health(self, check_health: bool = True) -> None: 

455 self._parser.on_connect(self) 

456 parser = self._parser 

457 

458 auth_args = None 

459 # if credential provider or username and/or password are set, authenticate 

460 if self.credential_provider or (self.username or self.password): 

461 cred_provider = ( 

462 self.credential_provider 

463 or UsernamePasswordCredentialProvider(self.username, self.password) 

464 ) 

465 auth_args = await cred_provider.get_credentials_async() 

466 

467 # if resp version is specified and we have auth args, 

468 # we need to send them via HELLO 

469 if auth_args and self.protocol not in [2, "2"]: 

470 if isinstance(self._parser, _AsyncRESP2Parser): 

471 self.set_parser(_AsyncRESP3Parser) 

472 # update cluster exception classes 

473 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

474 self._parser.on_connect(self) 

475 if len(auth_args) == 1: 

476 auth_args = ["default", auth_args[0]] 

477 # avoid checking health here -- PING will fail if we try 

478 # to check the health prior to the AUTH 

479 await self.send_command( 

480 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

481 ) 

482 response = await self.read_response() 

483 if response.get(b"proto") != int(self.protocol) and response.get( 

484 "proto" 

485 ) != int(self.protocol): 

486 raise ConnectionError("Invalid RESP version") 

487 # avoid checking health here -- PING will fail if we try 

488 # to check the health prior to the AUTH 

489 elif auth_args: 

490 await self.send_command("AUTH", *auth_args, check_health=False) 

491 

492 try: 

493 auth_response = await self.read_response() 

494 except AuthenticationWrongNumberOfArgsError: 

495 # a username and password were specified but the Redis 

496 # server seems to be < 6.0.0 which expects a single password 

497 # arg. retry auth with just the password. 

498 # https://github.com/andymccurdy/redis-py/issues/1274 

499 await self.send_command("AUTH", auth_args[-1], check_health=False) 

500 auth_response = await self.read_response() 

501 

502 if str_if_bytes(auth_response) != "OK": 

503 raise AuthenticationError("Invalid Username or Password") 

504 

505 # if resp version is specified, switch to it 

506 elif self.protocol not in [2, "2"]: 

507 if isinstance(self._parser, _AsyncRESP2Parser): 

508 self.set_parser(_AsyncRESP3Parser) 

509 # update cluster exception classes 

510 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

511 self._parser.on_connect(self) 

512 await self.send_command("HELLO", self.protocol, check_health=check_health) 

513 response = await self.read_response() 

514 # if response.get(b"proto") != self.protocol and response.get( 

515 # "proto" 

516 # ) != self.protocol: 

517 # raise ConnectionError("Invalid RESP version") 

518 

519 # if a client_name is given, set it 

520 if self.client_name: 

521 await self.send_command( 

522 "CLIENT", 

523 "SETNAME", 

524 self.client_name, 

525 check_health=check_health, 

526 ) 

527 if str_if_bytes(await self.read_response()) != "OK": 

528 raise ConnectionError("Error setting client name") 

529 

530 # Set the library name and version from driver_info, pipeline for lower startup latency 

531 lib_name_sent = False 

532 lib_version_sent = False 

533 

534 if self.driver_info and self.driver_info.formatted_name: 

535 await self.send_command( 

536 "CLIENT", 

537 "SETINFO", 

538 "LIB-NAME", 

539 self.driver_info.formatted_name, 

540 check_health=check_health, 

541 ) 

542 lib_name_sent = True 

543 

544 if self.driver_info and self.driver_info.lib_version: 

545 await self.send_command( 

546 "CLIENT", 

547 "SETINFO", 

548 "LIB-VER", 

549 self.driver_info.lib_version, 

550 check_health=check_health, 

551 ) 

552 lib_version_sent = True 

553 

554 # if a database is specified, switch to it. Also pipeline this 

555 if self.db: 

556 await self.send_command("SELECT", self.db, check_health=check_health) 

557 

558 # read responses from pipeline 

559 for _ in range(sum([lib_name_sent, lib_version_sent])): 

560 try: 

561 await self.read_response() 

562 except ResponseError: 

563 pass 

564 

565 if self.db: 

566 if str_if_bytes(await self.read_response()) != "OK": 

567 raise ConnectionError("Invalid Database") 

568 

569 async def disconnect( 

570 self, 

571 nowait: bool = False, 

572 error: Optional[Exception] = None, 

573 failure_count: Optional[int] = None, 

574 health_check_failed: bool = False, 

575 ) -> None: 

576 """Disconnects from the Redis server""" 

577 # On Python 3.13+, asyncio.timeout() raises RuntimeError when called 

578 # outside a running Task (e.g. during GC finalization or event-loop 

579 # callbacks). In that context we fall back to a synchronous close. 

580 # See https://github.com/redis/redis-py/issues/3856 

581 if asyncio.current_task() is None: 

582 self._parser.on_disconnect() 

583 self.reset_should_reconnect() 

584 self._close() 

585 return 

586 

587 try: 

588 async with async_timeout(self.socket_connect_timeout): 

589 self._parser.on_disconnect() 

590 # Reset the reconnect flag 

591 self.reset_should_reconnect() 

592 if not self.is_connected: 

593 return 

594 try: 

595 self._writer.close() # type: ignore[union-attr] 

596 # wait for close to finish, except when handling errors and 

597 # forcefully disconnecting. 

598 if not nowait: 

599 await self._writer.wait_closed() # type: ignore[union-attr] 

600 except OSError: 

601 pass 

602 finally: 

603 self._reader = None 

604 self._writer = None 

605 except asyncio.TimeoutError: 

606 raise TimeoutError( 

607 f"Timed out closing connection after {self.socket_connect_timeout}" 

608 ) from None 

609 

610 if error: 

611 if health_check_failed: 

612 close_reason = CloseReason.HEALTHCHECK_FAILED 

613 else: 

614 close_reason = CloseReason.ERROR 

615 

616 if failure_count is not None and failure_count > self.retry.get_retries(): 

617 await record_error_count( 

618 server_address=getattr(self, "host", None), 

619 server_port=getattr(self, "port", None), 

620 network_peer_address=getattr(self, "host", None), 

621 network_peer_port=getattr(self, "port", None), 

622 error_type=error, 

623 retry_attempts=failure_count, 

624 ) 

625 

626 await record_connection_closed( 

627 close_reason=close_reason, 

628 error_type=error, 

629 ) 

630 else: 

631 await record_connection_closed( 

632 close_reason=CloseReason.APPLICATION_CLOSE, 

633 ) 

634 

635 async def _send_ping(self): 

636 """Send PING, expect PONG in return""" 

637 await self.send_command("PING", check_health=False) 

638 if str_if_bytes(await self.read_response()) != "PONG": 

639 raise ConnectionError("Bad response from PING health check") 

640 

641 async def _ping_failed(self, error, failure_count): 

642 """Function to call when PING fails""" 

643 await self.disconnect( 

644 error=error, failure_count=failure_count, health_check_failed=True 

645 ) 

646 

647 async def check_health(self): 

648 """Check the health of the connection with a PING/PONG""" 

649 if ( 

650 self.health_check_interval 

651 and asyncio.get_running_loop().time() > self.next_health_check 

652 ): 

653 await self.retry.call_with_retry( 

654 self._send_ping, self._ping_failed, with_failure_count=True 

655 ) 

656 

657 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

658 self._writer.writelines(command) 

659 await self._writer.drain() 

660 

661 async def send_packed_command( 

662 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

663 ) -> None: 

664 if not self.is_connected: 

665 await self.connect_check_health(check_health=False) 

666 if check_health: 

667 await self.check_health() 

668 

669 try: 

670 if isinstance(command, str): 

671 command = command.encode() 

672 if isinstance(command, bytes): 

673 command = [command] 

674 if self.socket_timeout: 

675 await asyncio.wait_for( 

676 self._send_packed_command(command), self.socket_timeout 

677 ) 

678 else: 

679 self._writer.writelines(command) 

680 await self._writer.drain() 

681 except asyncio.TimeoutError: 

682 await self.disconnect(nowait=True) 

683 raise TimeoutError("Timeout writing to socket") from None 

684 except OSError as e: 

685 await self.disconnect(nowait=True) 

686 if len(e.args) == 1: 

687 err_no, errmsg = "UNKNOWN", e.args[0] 

688 else: 

689 err_no = e.args[0] 

690 errmsg = e.args[1] 

691 raise ConnectionError( 

692 f"Error {err_no} while writing to socket. {errmsg}." 

693 ) from e 

694 except BaseException: 

695 # BaseExceptions can be raised when a socket send operation is not 

696 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

697 # to send un-sent data. However, the send_packed_command() API 

698 # does not support it so there is no point in keeping the connection open. 

699 await self.disconnect(nowait=True) 

700 raise 

701 

702 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

703 """Pack and send a command to the Redis server""" 

704 await self.send_packed_command( 

705 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

706 ) 

707 

708 async def can_read_destructive(self): 

709 """Poll the socket to see if there's data that can be read.""" 

710 try: 

711 return await self._parser.can_read_destructive() 

712 except OSError as e: 

713 await self.disconnect(nowait=True) 

714 host_error = self._host_error() 

715 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

716 

717 async def read_response( 

718 self, 

719 disable_decoding: bool = False, 

720 timeout: Optional[float] = None, 

721 *, 

722 disconnect_on_error: bool = True, 

723 push_request: Optional[bool] = False, 

724 ): 

725 """Read the response from a previously sent command""" 

726 read_timeout = timeout if timeout is not None else self.socket_timeout 

727 host_error = self._host_error() 

728 try: 

729 if read_timeout is not None and self.protocol in ["3", 3]: 

730 async with async_timeout(read_timeout): 

731 response = await self._parser.read_response( 

732 disable_decoding=disable_decoding, push_request=push_request 

733 ) 

734 elif read_timeout is not None: 

735 async with async_timeout(read_timeout): 

736 response = await self._parser.read_response( 

737 disable_decoding=disable_decoding 

738 ) 

739 elif self.protocol in ["3", 3]: 

740 response = await self._parser.read_response( 

741 disable_decoding=disable_decoding, push_request=push_request 

742 ) 

743 else: 

744 response = await self._parser.read_response( 

745 disable_decoding=disable_decoding 

746 ) 

747 except asyncio.TimeoutError: 

748 if timeout is not None: 

749 # user requested timeout, return None. Operation can be retried 

750 return None 

751 # it was a self.socket_timeout error. 

752 if disconnect_on_error: 

753 await self.disconnect(nowait=True) 

754 raise TimeoutError(f"Timeout reading from {host_error}") 

755 except OSError as e: 

756 if disconnect_on_error: 

757 await self.disconnect(nowait=True) 

758 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

759 except BaseException: 

760 # Also by default close in case of BaseException. A lot of code 

761 # relies on this behaviour when doing Command/Response pairs. 

762 # See #1128. 

763 if disconnect_on_error: 

764 await self.disconnect(nowait=True) 

765 raise 

766 

767 if self.health_check_interval: 

768 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

769 self.next_health_check = next_time 

770 

771 if isinstance(response, ResponseError): 

772 raise response from None 

773 return response 

774 

775 def pack_command(self, *args: EncodableT) -> List[bytes]: 

776 """Pack a series of arguments into the Redis protocol""" 

777 output = [] 

778 # the client might have included 1 or more literal arguments in 

779 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

780 # arguments to be sent separately, so split the first argument 

781 # manually. These arguments should be bytestrings so that they are 

782 # not encoded. 

783 assert not isinstance(args[0], float) 

784 if isinstance(args[0], str): 

785 args = tuple(args[0].encode().split()) + args[1:] 

786 elif b" " in args[0]: 

787 args = tuple(args[0].split()) + args[1:] 

788 

789 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

790 

791 buffer_cutoff = self._buffer_cutoff 

792 for arg in map(self.encoder.encode, args): 

793 # to avoid large string mallocs, chunk the command into the 

794 # output list if we're sending large values or memoryviews 

795 arg_length = len(arg) 

796 if ( 

797 len(buff) > buffer_cutoff 

798 or arg_length > buffer_cutoff 

799 or isinstance(arg, memoryview) 

800 ): 

801 buff = SYM_EMPTY.join( 

802 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

803 ) 

804 output.append(buff) 

805 output.append(arg) 

806 buff = SYM_CRLF 

807 else: 

808 buff = SYM_EMPTY.join( 

809 ( 

810 buff, 

811 SYM_DOLLAR, 

812 str(arg_length).encode(), 

813 SYM_CRLF, 

814 arg, 

815 SYM_CRLF, 

816 ) 

817 ) 

818 output.append(buff) 

819 return output 

820 

821 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

822 """Pack multiple commands into the Redis protocol""" 

823 output: List[bytes] = [] 

824 pieces: List[bytes] = [] 

825 buffer_length = 0 

826 buffer_cutoff = self._buffer_cutoff 

827 

828 for cmd in commands: 

829 for chunk in self.pack_command(*cmd): 

830 chunklen = len(chunk) 

831 if ( 

832 buffer_length > buffer_cutoff 

833 or chunklen > buffer_cutoff 

834 or isinstance(chunk, memoryview) 

835 ): 

836 if pieces: 

837 output.append(SYM_EMPTY.join(pieces)) 

838 buffer_length = 0 

839 pieces = [] 

840 

841 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

842 output.append(chunk) 

843 else: 

844 pieces.append(chunk) 

845 buffer_length += chunklen 

846 

847 if pieces: 

848 output.append(SYM_EMPTY.join(pieces)) 

849 return output 

850 

851 def _socket_is_empty(self): 

852 """Check if the socket is empty""" 

853 return len(self._reader._buffer) == 0 

854 

855 async def process_invalidation_messages(self): 

856 while not self._socket_is_empty(): 

857 await self.read_response(push_request=True) 

858 

859 def set_re_auth_token(self, token: TokenInterface): 

860 self._re_auth_token = token 

861 

862 async def re_auth(self): 

863 if self._re_auth_token is not None: 

864 await self.send_command( 

865 "AUTH", 

866 self._re_auth_token.try_get("oid"), 

867 self._re_auth_token.get_value(), 

868 ) 

869 await self.read_response() 

870 self._re_auth_token = None 

871 

872 

873class Connection(AbstractConnection): 

874 "Manages TCP communication to and from a Redis server" 

875 

876 def __init__( 

877 self, 

878 *, 

879 host: str = "localhost", 

880 port: Union[str, int] = 6379, 

881 socket_keepalive: bool = False, 

882 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

883 socket_type: int = 0, 

884 **kwargs, 

885 ): 

886 self.host = host 

887 self.port = int(port) 

888 self.socket_keepalive = socket_keepalive 

889 self.socket_keepalive_options = socket_keepalive_options or {} 

890 self.socket_type = socket_type 

891 super().__init__(**kwargs) 

892 

893 def repr_pieces(self): 

894 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

895 if self.client_name: 

896 pieces.append(("client_name", self.client_name)) 

897 return pieces 

898 

899 def _connection_arguments(self) -> Mapping: 

900 return {"host": self.host, "port": self.port} 

901 

902 async def _connect(self): 

903 """Create a TCP socket connection""" 

904 async with async_timeout(self.socket_connect_timeout): 

905 reader, writer = await asyncio.open_connection( 

906 **self._connection_arguments() 

907 ) 

908 self._reader = reader 

909 self._writer = writer 

910 sock = writer.transport.get_extra_info("socket") 

911 if sock: 

912 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

913 try: 

914 # TCP_KEEPALIVE 

915 if self.socket_keepalive: 

916 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

917 for k, v in self.socket_keepalive_options.items(): 

918 sock.setsockopt(socket.SOL_TCP, k, v) 

919 

920 except (OSError, TypeError): 

921 # `socket_keepalive_options` might contain invalid options 

922 # causing an error. Do not leave the connection open. 

923 writer.close() 

924 raise 

925 

926 def _host_error(self) -> str: 

927 return f"{self.host}:{self.port}" 

928 

929 

930class SSLConnection(Connection): 

931 """Manages SSL connections to and from the Redis server(s). 

932 This class extends the Connection class, adding SSL functionality, and making 

933 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

934 """ 

935 

936 def __init__( 

937 self, 

938 ssl_keyfile: Optional[str] = None, 

939 ssl_certfile: Optional[str] = None, 

940 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

941 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

942 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

943 ssl_ca_certs: Optional[str] = None, 

944 ssl_ca_data: Optional[str] = None, 

945 ssl_ca_path: Optional[str] = None, 

946 ssl_check_hostname: bool = True, 

947 ssl_min_version: Optional[TLSVersion] = None, 

948 ssl_ciphers: Optional[str] = None, 

949 ssl_password: Optional[str] = None, 

950 **kwargs, 

951 ): 

952 if not SSL_AVAILABLE: 

953 raise RedisError("Python wasn't built with SSL support") 

954 

955 self.ssl_context: RedisSSLContext = RedisSSLContext( 

956 keyfile=ssl_keyfile, 

957 certfile=ssl_certfile, 

958 cert_reqs=ssl_cert_reqs, 

959 include_verify_flags=ssl_include_verify_flags, 

960 exclude_verify_flags=ssl_exclude_verify_flags, 

961 ca_certs=ssl_ca_certs, 

962 ca_data=ssl_ca_data, 

963 ca_path=ssl_ca_path, 

964 check_hostname=ssl_check_hostname, 

965 min_version=ssl_min_version, 

966 ciphers=ssl_ciphers, 

967 password=ssl_password, 

968 ) 

969 super().__init__(**kwargs) 

970 

971 def _connection_arguments(self) -> Mapping: 

972 kwargs = super()._connection_arguments() 

973 kwargs["ssl"] = self.ssl_context.get() 

974 return kwargs 

975 

976 @property 

977 def keyfile(self): 

978 return self.ssl_context.keyfile 

979 

980 @property 

981 def certfile(self): 

982 return self.ssl_context.certfile 

983 

984 @property 

985 def cert_reqs(self): 

986 return self.ssl_context.cert_reqs 

987 

988 @property 

989 def include_verify_flags(self): 

990 return self.ssl_context.include_verify_flags 

991 

992 @property 

993 def exclude_verify_flags(self): 

994 return self.ssl_context.exclude_verify_flags 

995 

996 @property 

997 def ca_certs(self): 

998 return self.ssl_context.ca_certs 

999 

1000 @property 

1001 def ca_data(self): 

1002 return self.ssl_context.ca_data 

1003 

1004 @property 

1005 def check_hostname(self): 

1006 return self.ssl_context.check_hostname 

1007 

1008 @property 

1009 def min_version(self): 

1010 return self.ssl_context.min_version 

1011 

1012 

1013class RedisSSLContext: 

1014 __slots__ = ( 

1015 "keyfile", 

1016 "certfile", 

1017 "cert_reqs", 

1018 "include_verify_flags", 

1019 "exclude_verify_flags", 

1020 "ca_certs", 

1021 "ca_data", 

1022 "ca_path", 

1023 "context", 

1024 "check_hostname", 

1025 "min_version", 

1026 "ciphers", 

1027 "password", 

1028 ) 

1029 

1030 def __init__( 

1031 self, 

1032 keyfile: Optional[str] = None, 

1033 certfile: Optional[str] = None, 

1034 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

1035 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1036 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1037 ca_certs: Optional[str] = None, 

1038 ca_data: Optional[str] = None, 

1039 ca_path: Optional[str] = None, 

1040 check_hostname: bool = False, 

1041 min_version: Optional[TLSVersion] = None, 

1042 ciphers: Optional[str] = None, 

1043 password: Optional[str] = None, 

1044 ): 

1045 if not SSL_AVAILABLE: 

1046 raise RedisError("Python wasn't built with SSL support") 

1047 

1048 self.keyfile = keyfile 

1049 self.certfile = certfile 

1050 if cert_reqs is None: 

1051 cert_reqs = ssl.CERT_NONE 

1052 elif isinstance(cert_reqs, str): 

1053 CERT_REQS = { # noqa: N806 

1054 "none": ssl.CERT_NONE, 

1055 "optional": ssl.CERT_OPTIONAL, 

1056 "required": ssl.CERT_REQUIRED, 

1057 } 

1058 if cert_reqs not in CERT_REQS: 

1059 raise RedisError( 

1060 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

1061 ) 

1062 cert_reqs = CERT_REQS[cert_reqs] 

1063 self.cert_reqs = cert_reqs 

1064 self.include_verify_flags = include_verify_flags 

1065 self.exclude_verify_flags = exclude_verify_flags 

1066 self.ca_certs = ca_certs 

1067 self.ca_data = ca_data 

1068 self.ca_path = ca_path 

1069 self.check_hostname = ( 

1070 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1071 ) 

1072 self.min_version = min_version 

1073 self.ciphers = ciphers 

1074 self.password = password 

1075 self.context: Optional[SSLContext] = None 

1076 

1077 def get(self) -> SSLContext: 

1078 if not self.context: 

1079 context = ssl.create_default_context() 

1080 context.check_hostname = self.check_hostname 

1081 context.verify_mode = self.cert_reqs 

1082 if self.include_verify_flags: 

1083 for flag in self.include_verify_flags: 

1084 context.verify_flags |= flag 

1085 if self.exclude_verify_flags: 

1086 for flag in self.exclude_verify_flags: 

1087 context.verify_flags &= ~flag 

1088 if self.certfile or self.keyfile: 

1089 context.load_cert_chain( 

1090 certfile=self.certfile, 

1091 keyfile=self.keyfile, 

1092 password=self.password, 

1093 ) 

1094 if self.ca_certs or self.ca_data or self.ca_path: 

1095 context.load_verify_locations( 

1096 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1097 ) 

1098 if self.min_version is not None: 

1099 context.minimum_version = self.min_version 

1100 if self.ciphers is not None: 

1101 context.set_ciphers(self.ciphers) 

1102 self.context = context 

1103 return self.context 

1104 

1105 

1106class UnixDomainSocketConnection(AbstractConnection): 

1107 "Manages UDS communication to and from a Redis server" 

1108 

1109 def __init__(self, *, path: str = "", **kwargs): 

1110 self.path = path 

1111 super().__init__(**kwargs) 

1112 

1113 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

1114 pieces = [("path", self.path), ("db", self.db)] 

1115 if self.client_name: 

1116 pieces.append(("client_name", self.client_name)) 

1117 return pieces 

1118 

1119 async def _connect(self): 

1120 async with async_timeout(self.socket_connect_timeout): 

1121 reader, writer = await asyncio.open_unix_connection(path=self.path) 

1122 self._reader = reader 

1123 self._writer = writer 

1124 await self.on_connect() 

1125 

1126 def _host_error(self) -> str: 

1127 return self.path 

1128 

1129 

1130FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1131 

1132 

1133def to_bool(value) -> Optional[bool]: 

1134 if value is None or value == "": 

1135 return None 

1136 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1137 return False 

1138 return bool(value) 

1139 

1140 

1141def parse_ssl_verify_flags(value): 

1142 # flags are passed in as a string representation of a list, 

1143 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1144 verify_flags_str = value.replace("[", "").replace("]", "") 

1145 

1146 verify_flags = [] 

1147 for flag in verify_flags_str.split(","): 

1148 flag = flag.strip() 

1149 if not hasattr(VerifyFlags, flag): 

1150 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1151 verify_flags.append(getattr(VerifyFlags, flag)) 

1152 return verify_flags 

1153 

1154 

1155URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

1156 { 

1157 "db": int, 

1158 "socket_timeout": float, 

1159 "socket_connect_timeout": float, 

1160 "socket_keepalive": to_bool, 

1161 "retry_on_timeout": to_bool, 

1162 "max_connections": int, 

1163 "health_check_interval": int, 

1164 "ssl_check_hostname": to_bool, 

1165 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1166 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1167 "timeout": float, 

1168 } 

1169) 

1170 

1171 

1172class ConnectKwargs(TypedDict, total=False): 

1173 username: str 

1174 password: str 

1175 connection_class: Type[AbstractConnection] 

1176 host: str 

1177 port: int 

1178 db: int 

1179 path: str 

1180 

1181 

1182def parse_url(url: str) -> ConnectKwargs: 

1183 parsed: ParseResult = urlparse(url) 

1184 kwargs: ConnectKwargs = {} 

1185 

1186 for name, value_list in parse_qs(parsed.query).items(): 

1187 if value_list and len(value_list) > 0: 

1188 value = unquote(value_list[0]) 

1189 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1190 if parser: 

1191 try: 

1192 kwargs[name] = parser(value) 

1193 except (TypeError, ValueError): 

1194 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1195 else: 

1196 kwargs[name] = value 

1197 

1198 if parsed.username: 

1199 kwargs["username"] = unquote(parsed.username) 

1200 if parsed.password: 

1201 kwargs["password"] = unquote(parsed.password) 

1202 

1203 # We only support redis://, rediss:// and unix:// schemes. 

1204 if parsed.scheme == "unix": 

1205 if parsed.path: 

1206 kwargs["path"] = unquote(parsed.path) 

1207 kwargs["connection_class"] = UnixDomainSocketConnection 

1208 

1209 elif parsed.scheme in ("redis", "rediss"): 

1210 if parsed.hostname: 

1211 kwargs["host"] = unquote(parsed.hostname) 

1212 if parsed.port: 

1213 kwargs["port"] = int(parsed.port) 

1214 

1215 # If there's a path argument, use it as the db argument if a 

1216 # querystring value wasn't specified 

1217 if parsed.path and "db" not in kwargs: 

1218 try: 

1219 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1220 except (AttributeError, ValueError): 

1221 pass 

1222 

1223 if parsed.scheme == "rediss": 

1224 kwargs["connection_class"] = SSLConnection 

1225 

1226 else: 

1227 valid_schemes = "redis://, rediss://, unix://" 

1228 raise ValueError( 

1229 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1230 ) 

1231 

1232 return kwargs 

1233 

1234 

1235_CP = TypeVar("_CP", bound="ConnectionPool") 

1236 

1237 

1238class ConnectionPool: 

1239 """ 

1240 Create a connection pool. ``If max_connections`` is set, then this 

1241 object raises :py:class:`~redis.ConnectionError` when the pool's 

1242 limit is reached. 

1243 

1244 By default, TCP connections are created unless ``connection_class`` 

1245 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1246 unix sockets. 

1247 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1248 

1249 Any additional keyword arguments are passed to the constructor of 

1250 ``connection_class``. 

1251 """ 

1252 

1253 @classmethod 

1254 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1255 """ 

1256 Return a connection pool configured from the given URL. 

1257 

1258 For example:: 

1259 

1260 redis://[[username]:[password]]@localhost:6379/0 

1261 rediss://[[username]:[password]]@localhost:6379/0 

1262 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1263 

1264 Three URL schemes are supported: 

1265 

1266 - `redis://` creates a TCP socket connection. See more at: 

1267 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1268 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1269 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1270 - ``unix://``: creates a Unix Domain Socket connection. 

1271 

1272 The username, password, hostname, path and all querystring values 

1273 are passed through urllib.parse.unquote in order to replace any 

1274 percent-encoded values with their corresponding characters. 

1275 

1276 There are several ways to specify a database number. The first value 

1277 found will be used: 

1278 

1279 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1280 

1281 2. If using the redis:// or rediss:// schemes, the path argument 

1282 of the url, e.g. redis://localhost/0 

1283 

1284 3. A ``db`` keyword argument to this function. 

1285 

1286 If none of these options are specified, the default db=0 is used. 

1287 

1288 All querystring options are cast to their appropriate Python types. 

1289 Boolean arguments can be specified with string values "True"/"False" 

1290 or "Yes"/"No". Values that cannot be properly cast cause a 

1291 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1292 and keyword arguments are passed to the ``ConnectionPool``'s 

1293 class initializer. In the case of conflicting arguments, querystring 

1294 arguments always win. 

1295 """ 

1296 url_options = parse_url(url) 

1297 kwargs.update(url_options) 

1298 return cls(**kwargs) 

1299 

1300 def __init__( 

1301 self, 

1302 connection_class: Type[AbstractConnection] = Connection, 

1303 max_connections: Optional[int] = None, 

1304 **connection_kwargs, 

1305 ): 

1306 max_connections = max_connections or 2**31 

1307 if not isinstance(max_connections, int) or max_connections < 0: 

1308 raise ValueError('"max_connections" must be a positive integer') 

1309 

1310 self.connection_class = connection_class 

1311 self.connection_kwargs = connection_kwargs 

1312 self.max_connections = max_connections 

1313 

1314 self._available_connections: List[AbstractConnection] = [] 

1315 self._in_use_connections: Set[AbstractConnection] = set() 

1316 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1317 self._lock = asyncio.Lock() 

1318 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1319 if self._event_dispatcher is None: 

1320 self._event_dispatcher = EventDispatcher() 

1321 

1322 # Keys that should be redacted in __repr__ to avoid exposing sensitive information 

1323 SENSITIVE_REPR_KEYS = frozenset( 

1324 { 

1325 "password", 

1326 "username", 

1327 "ssl_password", 

1328 "credential_provider", 

1329 } 

1330 ) 

1331 

1332 def __repr__(self): 

1333 conn_kwargs = ",".join( 

1334 [ 

1335 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}" 

1336 for k, v in self.connection_kwargs.items() 

1337 ] 

1338 ) 

1339 return ( 

1340 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1341 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1342 f"({conn_kwargs})>)>" 

1343 ) 

1344 

1345 def reset(self): 

1346 # Record metrics for connections being removed before clearing 

1347 # (only if attributes exist - they won't during __init__) 

1348 if hasattr(self, "_available_connections") and hasattr( 

1349 self, "_in_use_connections" 

1350 ): 

1351 idle_count = len(self._available_connections) 

1352 in_use_count = len(self._in_use_connections) 

1353 if idle_count > 0 or in_use_count > 0: 

1354 pool_name = get_pool_name(self) 

1355 # Note: Using sync version since reset() is sync 

1356 from redis.observability.recorder import ( 

1357 record_connection_count as sync_record_connection_count, 

1358 ) 

1359 

1360 if idle_count > 0: 

1361 sync_record_connection_count( 

1362 pool_name=pool_name, 

1363 connection_state=ConnectionState.IDLE, 

1364 counter=-idle_count, 

1365 ) 

1366 if in_use_count > 0: 

1367 sync_record_connection_count( 

1368 pool_name=pool_name, 

1369 connection_state=ConnectionState.USED, 

1370 counter=-in_use_count, 

1371 ) 

1372 

1373 self._available_connections = [] 

1374 self._in_use_connections = weakref.WeakSet() 

1375 

1376 def __del__(self) -> None: 

1377 """Clean up connection pool and record metrics when garbage collected.""" 

1378 try: 

1379 if not hasattr(self, "_available_connections") or not hasattr( 

1380 self, "_in_use_connections" 

1381 ): 

1382 return 

1383 idle_count = len(self._available_connections) 

1384 in_use_count = len(self._in_use_connections) 

1385 if idle_count > 0 or in_use_count > 0: 

1386 pool_name = get_pool_name(self) 

1387 # Note: Using sync version since __del__ is sync 

1388 from redis.observability.recorder import ( 

1389 record_connection_count as sync_record_connection_count, 

1390 ) 

1391 

1392 if idle_count > 0: 

1393 sync_record_connection_count( 

1394 pool_name=pool_name, 

1395 connection_state=ConnectionState.IDLE, 

1396 counter=-idle_count, 

1397 ) 

1398 if in_use_count > 0: 

1399 sync_record_connection_count( 

1400 pool_name=pool_name, 

1401 connection_state=ConnectionState.USED, 

1402 counter=-in_use_count, 

1403 ) 

1404 except Exception: 

1405 pass 

1406 

1407 def can_get_connection(self) -> bool: 

1408 """Return True if a connection can be retrieved from the pool.""" 

1409 return ( 

1410 self._available_connections 

1411 or len(self._in_use_connections) < self.max_connections 

1412 ) 

1413 

1414 @deprecated_args( 

1415 args_to_warn=["*"], 

1416 reason="Use get_connection() without args instead", 

1417 version="5.3.0", 

1418 ) 

1419 async def get_connection(self, command_name=None, *keys, **options): 

1420 """Get a connected connection from the pool""" 

1421 # Track connection count before to detect if a new connection is created 

1422 async with self._lock: 

1423 connections_before = len(self._available_connections) + len( 

1424 self._in_use_connections 

1425 ) 

1426 start_time_created = time.monotonic() 

1427 connection = self.get_available_connection() 

1428 connections_after = len(self._available_connections) + len( 

1429 self._in_use_connections 

1430 ) 

1431 is_created = connections_after > connections_before 

1432 

1433 # Record state transition for observability 

1434 # This ensures counters stay balanced if ensure_connection() fails and release() is called 

1435 pool_name = get_pool_name(self) 

1436 if is_created: 

1437 # New connection created and acquired: just USED +1 

1438 await record_connection_count( 

1439 pool_name=pool_name, 

1440 connection_state=ConnectionState.USED, 

1441 counter=1, 

1442 ) 

1443 else: 

1444 # Existing connection acquired from pool: IDLE -> USED 

1445 await record_connection_count( 

1446 pool_name=pool_name, 

1447 connection_state=ConnectionState.IDLE, 

1448 counter=-1, 

1449 ) 

1450 await record_connection_count( 

1451 pool_name=pool_name, 

1452 connection_state=ConnectionState.USED, 

1453 counter=1, 

1454 ) 

1455 

1456 # We now perform the connection check outside of the lock. 

1457 try: 

1458 await self.ensure_connection(connection) 

1459 

1460 if is_created: 

1461 await record_connection_create_time( 

1462 connection_pool=self, 

1463 duration_seconds=time.monotonic() - start_time_created, 

1464 ) 

1465 

1466 return connection 

1467 except BaseException: 

1468 await self.release(connection) 

1469 raise 

1470 

1471 def get_available_connection(self): 

1472 """Get a connection from the pool, without making sure it is connected""" 

1473 try: 

1474 connection = self._available_connections.pop() 

1475 except IndexError: 

1476 if len(self._in_use_connections) >= self.max_connections: 

1477 raise MaxConnectionsError("Too many connections") from None 

1478 connection = self.make_connection() 

1479 self._in_use_connections.add(connection) 

1480 return connection 

1481 

1482 def get_encoder(self): 

1483 """Return an encoder based on encoding settings""" 

1484 kwargs = self.connection_kwargs 

1485 return self.encoder_class( 

1486 encoding=kwargs.get("encoding", "utf-8"), 

1487 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1488 decode_responses=kwargs.get("decode_responses", False), 

1489 ) 

1490 

1491 def make_connection(self): 

1492 """Create a new connection. Can be overridden by child classes.""" 

1493 # Note: We don't record IDLE here because async uses a sync make_connection 

1494 # but async record_connection_count. The recording is handled in get_connection. 

1495 return self.connection_class(**self.connection_kwargs) 

1496 

1497 async def ensure_connection(self, connection: AbstractConnection): 

1498 """Ensure that the connection object is connected and valid""" 

1499 await connection.connect() 

1500 # connections that the pool provides should be ready to send 

1501 # a command. if not, the connection was either returned to the 

1502 # pool before all data has been read or the socket has been 

1503 # closed. either way, reconnect and verify everything is good. 

1504 try: 

1505 if await connection.can_read_destructive(): 

1506 raise ConnectionError("Connection has data") from None 

1507 except (ConnectionError, TimeoutError, OSError): 

1508 await connection.disconnect() 

1509 await connection.connect() 

1510 if await connection.can_read_destructive(): 

1511 raise ConnectionError("Connection not ready") from None 

1512 

1513 async def release(self, connection: AbstractConnection): 

1514 """Releases the connection back to the pool""" 

1515 # Connections should always be returned to the correct pool, 

1516 # not doing so is an error that will cause an exception here. 

1517 self._in_use_connections.remove(connection) 

1518 

1519 if connection.should_reconnect(): 

1520 await connection.disconnect() 

1521 

1522 self._available_connections.append(connection) 

1523 await self._event_dispatcher.dispatch_async( 

1524 AsyncAfterConnectionReleasedEvent(connection) 

1525 ) 

1526 

1527 # Record state transition: USED -> IDLE 

1528 pool_name = get_pool_name(self) 

1529 await record_connection_count( 

1530 pool_name=pool_name, 

1531 connection_state=ConnectionState.USED, 

1532 counter=-1, 

1533 ) 

1534 await record_connection_count( 

1535 pool_name=pool_name, 

1536 connection_state=ConnectionState.IDLE, 

1537 counter=1, 

1538 ) 

1539 

1540 async def disconnect(self, inuse_connections: bool = True): 

1541 """ 

1542 Disconnects connections in the pool 

1543 

1544 If ``inuse_connections`` is True, disconnect connections that are 

1545 current in use, potentially by other tasks. Otherwise only disconnect 

1546 connections that are idle in the pool. 

1547 """ 

1548 if inuse_connections: 

1549 connections: Iterable[AbstractConnection] = chain( 

1550 self._available_connections, self._in_use_connections 

1551 ) 

1552 else: 

1553 connections = self._available_connections 

1554 resp = await asyncio.gather( 

1555 *(connection.disconnect() for connection in connections), 

1556 return_exceptions=True, 

1557 ) 

1558 

1559 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1560 if exc: 

1561 raise exc 

1562 

1563 async def update_active_connections_for_reconnect(self): 

1564 """ 

1565 Mark all active connections for reconnect. 

1566 """ 

1567 async with self._lock: 

1568 for conn in self._in_use_connections: 

1569 conn.mark_for_reconnect() 

1570 

1571 async def aclose(self) -> None: 

1572 """Close the pool, disconnecting all connections""" 

1573 await self.disconnect() 

1574 

1575 def set_retry(self, retry: "Retry") -> None: 

1576 for conn in self._available_connections: 

1577 conn.retry = retry 

1578 for conn in self._in_use_connections: 

1579 conn.retry = retry 

1580 

1581 async def re_auth_callback(self, token: TokenInterface): 

1582 async with self._lock: 

1583 for conn in self._available_connections: 

1584 await conn.retry.call_with_retry( 

1585 lambda: conn.send_command( 

1586 "AUTH", token.try_get("oid"), token.get_value() 

1587 ), 

1588 lambda error: self._mock(error), 

1589 ) 

1590 await conn.retry.call_with_retry( 

1591 lambda: conn.read_response(), lambda error: self._mock(error) 

1592 ) 

1593 for conn in self._in_use_connections: 

1594 conn.set_re_auth_token(token) 

1595 

1596 async def _mock(self, error: RedisError): 

1597 """ 

1598 Dummy functions, needs to be passed as error callback to retry object. 

1599 :param error: 

1600 :return: 

1601 """ 

1602 pass 

1603 

1604 def get_connection_count(self) -> List[tuple[int, dict]]: 

1605 """ 

1606 Returns a connection count (both idle and in use). 

1607 """ 

1608 attributes = AttributeBuilder.build_base_attributes() 

1609 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self) 

1610 free_connections_attributes = attributes.copy() 

1611 in_use_connections_attributes = attributes.copy() 

1612 

1613 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1614 ConnectionState.IDLE.value 

1615 ) 

1616 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1617 ConnectionState.USED.value 

1618 ) 

1619 

1620 return [ 

1621 (len(self._available_connections), free_connections_attributes), 

1622 (len(self._in_use_connections), in_use_connections_attributes), 

1623 ] 

1624 

1625 

1626class BlockingConnectionPool(ConnectionPool): 

1627 """ 

1628 A blocking connection pool:: 

1629 

1630 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1631 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1632 

1633 It performs the same function as the default 

1634 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1635 it maintains a pool of reusable connections that can be shared by 

1636 multiple async redis clients. 

1637 

1638 The difference is that, in the event that a client tries to get a 

1639 connection from the pool when all of connections are in use, rather than 

1640 raising a :py:class:`~redis.ConnectionError` (as the default 

1641 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1642 blocks the current `Task` for a specified number of seconds until 

1643 a connection becomes available. 

1644 

1645 Use ``max_connections`` to increase / decrease the pool size:: 

1646 

1647 >>> pool = BlockingConnectionPool(max_connections=10) 

1648 

1649 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1650 to become available, or to block forever: 

1651 

1652 >>> # Block forever. 

1653 >>> pool = BlockingConnectionPool(timeout=None) 

1654 

1655 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1656 >>> # not available. 

1657 >>> pool = BlockingConnectionPool(timeout=5) 

1658 """ 

1659 

1660 def __init__( 

1661 self, 

1662 max_connections: int = 50, 

1663 timeout: Optional[float] = 20, 

1664 connection_class: Type[AbstractConnection] = Connection, 

1665 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1666 **connection_kwargs, 

1667 ): 

1668 super().__init__( 

1669 connection_class=connection_class, 

1670 max_connections=max_connections, 

1671 **connection_kwargs, 

1672 ) 

1673 self._condition = asyncio.Condition() 

1674 self.timeout = timeout 

1675 

1676 @deprecated_args( 

1677 args_to_warn=["*"], 

1678 reason="Use get_connection() without args instead", 

1679 version="5.3.0", 

1680 ) 

1681 async def get_connection(self, command_name=None, *keys, **options): 

1682 """Gets a connection from the pool, blocking until one is available""" 

1683 # Start timing for wait time observability 

1684 start_time_acquired = time.monotonic() 

1685 

1686 try: 

1687 async with self._condition: 

1688 async with async_timeout(self.timeout): 

1689 await self._condition.wait_for(self.can_get_connection) 

1690 # Track connection count before to detect if a new connection is created 

1691 connections_before = len(self._available_connections) + len( 

1692 self._in_use_connections 

1693 ) 

1694 start_time_created = time.monotonic() 

1695 connection = super().get_available_connection() 

1696 connections_after = len(self._available_connections) + len( 

1697 self._in_use_connections 

1698 ) 

1699 is_created = connections_after > connections_before 

1700 except asyncio.TimeoutError as err: 

1701 raise ConnectionError("No connection available.") from err 

1702 

1703 # We now perform the connection check outside of the lock. 

1704 try: 

1705 await self.ensure_connection(connection) 

1706 

1707 if is_created: 

1708 await record_connection_create_time( 

1709 connection_pool=self, 

1710 duration_seconds=time.monotonic() - start_time_created, 

1711 ) 

1712 

1713 await record_connection_wait_time( 

1714 pool_name=get_pool_name(self), 

1715 duration_seconds=time.monotonic() - start_time_acquired, 

1716 ) 

1717 

1718 return connection 

1719 except BaseException: 

1720 await self.release(connection) 

1721 raise 

1722 

1723 async def release(self, connection: AbstractConnection): 

1724 """Releases the connection back to the pool.""" 

1725 async with self._condition: 

1726 await super().release(connection) 

1727 self._condition.notify()