Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

749 statements  

1import asyncio 

2import copy 

3import enum 

4import inspect 

5import socket 

6import sys 

7import time 

8import warnings 

9import weakref 

10from abc import abstractmethod 

11from itertools import chain 

12from types import MappingProxyType 

13from typing import ( 

14 Any, 

15 Callable, 

16 Iterable, 

17 List, 

18 Mapping, 

19 Optional, 

20 Protocol, 

21 Set, 

22 Tuple, 

23 Type, 

24 TypedDict, 

25 TypeVar, 

26 Union, 

27) 

28from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

29 

30from ..observability.attributes import ( 

31 DB_CLIENT_CONNECTION_POOL_NAME, 

32 DB_CLIENT_CONNECTION_STATE, 

33 AttributeBuilder, 

34 ConnectionState, 

35 get_pool_name, 

36) 

37from ..utils import SSL_AVAILABLE 

38 

39if SSL_AVAILABLE: 

40 import ssl 

41 from ssl import SSLContext, TLSVersion, VerifyFlags 

42else: 

43 ssl = None 

44 TLSVersion = None 

45 SSLContext = None 

46 VerifyFlags = None 

47 

48from ..auth.token import TokenInterface 

49from ..driver_info import DriverInfo, resolve_driver_info 

50from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

51from ..utils import deprecated_args, format_error_message 

52 

53# the functionality is available in 3.11.x but has a major issue before 

54# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

55if sys.version_info >= (3, 11, 3): 

56 from asyncio import timeout as async_timeout 

57else: 

58 from async_timeout import timeout as async_timeout 

59 

60from redis.asyncio.observability.recorder import ( 

61 record_connection_closed, 

62 record_connection_create_time, 

63 record_connection_wait_time, 

64 record_error_count, 

65) 

66from redis.asyncio.retry import Retry 

67from redis.backoff import NoBackoff 

68from redis.connection import DEFAULT_RESP_VERSION 

69from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

70from redis.exceptions import ( 

71 AuthenticationError, 

72 AuthenticationWrongNumberOfArgsError, 

73 ConnectionError, 

74 DataError, 

75 MaxConnectionsError, 

76 RedisError, 

77 ResponseError, 

78 TimeoutError, 

79) 

80from redis.observability.metrics import CloseReason 

81from redis.typing import EncodableT 

82from redis.utils import HIREDIS_AVAILABLE, str_if_bytes 

83 

84from .._parsers import ( 

85 BaseParser, 

86 Encoder, 

87 _AsyncHiredisParser, 

88 _AsyncRESP2Parser, 

89 _AsyncRESP3Parser, 

90) 

91 

92SYM_STAR = b"*" 

93SYM_DOLLAR = b"$" 

94SYM_CRLF = b"\r\n" 

95SYM_LF = b"\n" 

96SYM_EMPTY = b"" 

97 

98 

99class _Sentinel(enum.Enum): 

100 sentinel = object() 

101 

102 

103SENTINEL = _Sentinel.sentinel 

104 

105 

106DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

107if HIREDIS_AVAILABLE: 

108 DefaultParser = _AsyncHiredisParser 

109else: 

110 DefaultParser = _AsyncRESP2Parser 

111 

112 

113class ConnectCallbackProtocol(Protocol): 

114 def __call__(self, connection: "AbstractConnection"): ... 

115 

116 

117class AsyncConnectCallbackProtocol(Protocol): 

118 async def __call__(self, connection: "AbstractConnection"): ... 

119 

120 

121ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

122 

123 

124class AbstractConnection: 

125 """Manages communication to and from a Redis server""" 

126 

127 __slots__ = ( 

128 "db", 

129 "username", 

130 "client_name", 

131 "lib_name", 

132 "lib_version", 

133 "credential_provider", 

134 "password", 

135 "socket_timeout", 

136 "socket_connect_timeout", 

137 "redis_connect_func", 

138 "retry_on_timeout", 

139 "retry_on_error", 

140 "health_check_interval", 

141 "next_health_check", 

142 "last_active_at", 

143 "encoder", 

144 "ssl_context", 

145 "protocol", 

146 "_reader", 

147 "_writer", 

148 "_parser", 

149 "_connect_callbacks", 

150 "_buffer_cutoff", 

151 "_lock", 

152 "_socket_read_size", 

153 "__dict__", 

154 ) 

155 

156 @deprecated_args( 

157 args_to_warn=["lib_name", "lib_version"], 

158 reason="Use 'driver_info' parameter instead. " 

159 "lib_name and lib_version will be removed in a future version.", 

160 ) 

161 def __init__( 

162 self, 

163 *, 

164 db: Union[str, int] = 0, 

165 password: Optional[str] = None, 

166 socket_timeout: Optional[float] = None, 

167 socket_connect_timeout: Optional[float] = None, 

168 retry_on_timeout: bool = False, 

169 retry_on_error: Union[list, _Sentinel] = SENTINEL, 

170 encoding: str = "utf-8", 

171 encoding_errors: str = "strict", 

172 decode_responses: bool = False, 

173 parser_class: Type[BaseParser] = DefaultParser, 

174 socket_read_size: int = 65536, 

175 health_check_interval: float = 0, 

176 client_name: Optional[str] = None, 

177 lib_name: Optional[str] = None, 

178 lib_version: Optional[str] = None, 

179 driver_info: Optional[DriverInfo] = None, 

180 username: Optional[str] = None, 

181 retry: Optional[Retry] = None, 

182 redis_connect_func: Optional[ConnectCallbackT] = None, 

183 encoder_class: Type[Encoder] = Encoder, 

184 credential_provider: Optional[CredentialProvider] = None, 

185 protocol: Optional[int] = 2, 

186 event_dispatcher: Optional[EventDispatcher] = None, 

187 ): 

188 """ 

189 Initialize a new async Connection. 

190 

191 Parameters 

192 ---------- 

193 driver_info : DriverInfo, optional 

194 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

195 are ignored. If not provided, a DriverInfo will be created from lib_name 

196 and lib_version (or defaults if those are also None). 

197 lib_name : str, optional 

198 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

199 lib_version : str, optional 

200 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

201 """ 

202 if (username or password) and credential_provider is not None: 

203 raise DataError( 

204 "'username' and 'password' cannot be passed along with 'credential_" 

205 "provider'. Please provide only one of the following arguments: \n" 

206 "1. 'password' and (optional) 'username'\n" 

207 "2. 'credential_provider'" 

208 ) 

209 if event_dispatcher is None: 

210 self._event_dispatcher = EventDispatcher() 

211 else: 

212 self._event_dispatcher = event_dispatcher 

213 self.db = db 

214 self.client_name = client_name 

215 

216 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

217 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

218 

219 self.credential_provider = credential_provider 

220 self.password = password 

221 self.username = username 

222 self.socket_timeout = socket_timeout 

223 if socket_connect_timeout is None: 

224 socket_connect_timeout = socket_timeout 

225 self.socket_connect_timeout = socket_connect_timeout 

226 self.retry_on_timeout = retry_on_timeout 

227 if retry_on_error is SENTINEL: 

228 retry_on_error = [] 

229 if retry_on_timeout: 

230 retry_on_error.append(TimeoutError) 

231 retry_on_error.append(socket.timeout) 

232 retry_on_error.append(asyncio.TimeoutError) 

233 self.retry_on_error = retry_on_error 

234 if retry or retry_on_error: 

235 if not retry: 

236 self.retry = Retry(NoBackoff(), 1) 

237 else: 

238 # deep-copy the Retry object as it is mutable 

239 self.retry = copy.deepcopy(retry) 

240 # Update the retry's supported errors with the specified errors 

241 self.retry.update_supported_errors(retry_on_error) 

242 else: 

243 self.retry = Retry(NoBackoff(), 0) 

244 self.health_check_interval = health_check_interval 

245 self.next_health_check: float = -1 

246 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

247 self.redis_connect_func = redis_connect_func 

248 self._reader: Optional[asyncio.StreamReader] = None 

249 self._writer: Optional[asyncio.StreamWriter] = None 

250 self._socket_read_size = socket_read_size 

251 self.set_parser(parser_class) 

252 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

253 self._buffer_cutoff = 6000 

254 self._re_auth_token: Optional[TokenInterface] = None 

255 self._should_reconnect = False 

256 

257 try: 

258 p = int(protocol) 

259 except TypeError: 

260 p = DEFAULT_RESP_VERSION 

261 except ValueError: 

262 raise ConnectionError("protocol must be an integer") 

263 else: 

264 if p < 2 or p > 3: 

265 raise ConnectionError("protocol must be either 2 or 3") 

266 self.protocol = p 

267 

268 def __del__(self, _warnings: Any = warnings): 

269 # For some reason, the individual streams don't get properly garbage 

270 # collected and therefore produce no resource warnings. We add one 

271 # here, in the same style as those from the stdlib. 

272 if getattr(self, "_writer", None): 

273 _warnings.warn( 

274 f"unclosed Connection {self!r}", ResourceWarning, source=self 

275 ) 

276 

277 try: 

278 asyncio.get_running_loop() 

279 self._close() 

280 except RuntimeError: 

281 # No actions been taken if pool already closed. 

282 pass 

283 

284 def _close(self): 

285 """ 

286 Internal method to silently close the connection without waiting 

287 """ 

288 if self._writer: 

289 self._writer.close() 

290 self._writer = self._reader = None 

291 

292 def __repr__(self): 

293 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

294 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

295 

296 @abstractmethod 

297 def repr_pieces(self): 

298 pass 

299 

300 @property 

301 def is_connected(self): 

302 return self._reader is not None and self._writer is not None 

303 

304 def register_connect_callback(self, callback): 

305 """ 

306 Register a callback to be called when the connection is established either 

307 initially or reconnected. This allows listeners to issue commands that 

308 are ephemeral to the connection, for example pub/sub subscription or 

309 key tracking. The callback must be a _method_ and will be kept as 

310 a weak reference. 

311 """ 

312 wm = weakref.WeakMethod(callback) 

313 if wm not in self._connect_callbacks: 

314 self._connect_callbacks.append(wm) 

315 

316 def deregister_connect_callback(self, callback): 

317 """ 

318 De-register a previously registered callback. It will no-longer receive 

319 notifications on connection events. Calling this is not required when the 

320 listener goes away, since the callbacks are kept as weak methods. 

321 """ 

322 try: 

323 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

324 except ValueError: 

325 pass 

326 

327 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

328 """ 

329 Creates a new instance of parser_class with socket size: 

330 _socket_read_size and assigns it to the parser for the connection 

331 :param parser_class: The required parser class 

332 """ 

333 self._parser = parser_class(socket_read_size=self._socket_read_size) 

334 

335 async def connect(self): 

336 """Connects to the Redis server if not already connected""" 

337 # try once the socket connect with the handshake, retry the whole 

338 # connect/handshake flow based on retry policy 

339 await self.retry.call_with_retry( 

340 lambda: self.connect_check_health( 

341 check_health=True, retry_socket_connect=False 

342 ), 

343 lambda error, failure_count: self.disconnect( 

344 error=error, failure_count=failure_count 

345 ), 

346 with_failure_count=True, 

347 ) 

348 

349 async def connect_check_health( 

350 self, check_health: bool = True, retry_socket_connect: bool = True 

351 ): 

352 if self.is_connected: 

353 return 

354 # Track actual retry attempts for error reporting 

355 actual_retry_attempts = 0 

356 

357 def failure_callback(error, failure_count): 

358 nonlocal actual_retry_attempts 

359 actual_retry_attempts = failure_count 

360 return self.disconnect(error=error, failure_count=failure_count) 

361 

362 try: 

363 if retry_socket_connect: 

364 await self.retry.call_with_retry( 

365 lambda: self._connect(), 

366 failure_callback, 

367 with_failure_count=True, 

368 ) 

369 else: 

370 await self._connect() 

371 except asyncio.CancelledError: 

372 raise # in 3.7 and earlier, this is an Exception, not BaseException 

373 except (socket.timeout, asyncio.TimeoutError): 

374 e = TimeoutError("Timeout connecting to server") 

375 await record_error_count( 

376 server_address=getattr(self, "host", None), 

377 server_port=getattr(self, "port", None), 

378 network_peer_address=getattr(self, "host", None), 

379 network_peer_port=getattr(self, "port", None), 

380 error_type=e, 

381 retry_attempts=actual_retry_attempts, 

382 is_internal=False, 

383 ) 

384 raise e 

385 except OSError as e: 

386 e = ConnectionError(self._error_message(e)) 

387 await record_error_count( 

388 server_address=getattr(self, "host", None), 

389 server_port=getattr(self, "port", None), 

390 network_peer_address=getattr(self, "host", None), 

391 network_peer_port=getattr(self, "port", None), 

392 error_type=e, 

393 retry_attempts=actual_retry_attempts, 

394 is_internal=False, 

395 ) 

396 raise e 

397 except Exception as exc: 

398 raise ConnectionError(exc) from exc 

399 

400 try: 

401 if not self.redis_connect_func: 

402 # Use the default on_connect function 

403 await self.on_connect_check_health(check_health=check_health) 

404 else: 

405 # Use the passed function redis_connect_func 

406 ( 

407 await self.redis_connect_func(self) 

408 if asyncio.iscoroutinefunction(self.redis_connect_func) 

409 else self.redis_connect_func(self) 

410 ) 

411 except RedisError: 

412 # clean up after any error in on_connect 

413 await self.disconnect() 

414 raise 

415 

416 # run any user callbacks. right now the only internal callback 

417 # is for pubsub channel/pattern resubscription 

418 # first, remove any dead weakrefs 

419 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

420 for ref in self._connect_callbacks: 

421 callback = ref() 

422 task = callback(self) 

423 if task and inspect.isawaitable(task): 

424 await task 

425 

426 def mark_for_reconnect(self): 

427 self._should_reconnect = True 

428 

429 def should_reconnect(self): 

430 return self._should_reconnect 

431 

432 def reset_should_reconnect(self): 

433 self._should_reconnect = False 

434 

435 @abstractmethod 

436 async def _connect(self): 

437 pass 

438 

439 @abstractmethod 

440 def _host_error(self) -> str: 

441 pass 

442 

443 def _error_message(self, exception: BaseException) -> str: 

444 return format_error_message(self._host_error(), exception) 

445 

446 def get_protocol(self): 

447 return self.protocol 

448 

449 async def on_connect(self) -> None: 

450 """Initialize the connection, authenticate and select a database""" 

451 await self.on_connect_check_health(check_health=True) 

452 

453 async def on_connect_check_health(self, check_health: bool = True) -> None: 

454 self._parser.on_connect(self) 

455 parser = self._parser 

456 

457 auth_args = None 

458 # if credential provider or username and/or password are set, authenticate 

459 if self.credential_provider or (self.username or self.password): 

460 cred_provider = ( 

461 self.credential_provider 

462 or UsernamePasswordCredentialProvider(self.username, self.password) 

463 ) 

464 auth_args = await cred_provider.get_credentials_async() 

465 

466 # if resp version is specified and we have auth args, 

467 # we need to send them via HELLO 

468 if auth_args and self.protocol not in [2, "2"]: 

469 if isinstance(self._parser, _AsyncRESP2Parser): 

470 self.set_parser(_AsyncRESP3Parser) 

471 # update cluster exception classes 

472 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

473 self._parser.on_connect(self) 

474 if len(auth_args) == 1: 

475 auth_args = ["default", auth_args[0]] 

476 # avoid checking health here -- PING will fail if we try 

477 # to check the health prior to the AUTH 

478 await self.send_command( 

479 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

480 ) 

481 response = await self.read_response() 

482 if response.get(b"proto") != int(self.protocol) and response.get( 

483 "proto" 

484 ) != int(self.protocol): 

485 raise ConnectionError("Invalid RESP version") 

486 # avoid checking health here -- PING will fail if we try 

487 # to check the health prior to the AUTH 

488 elif auth_args: 

489 await self.send_command("AUTH", *auth_args, check_health=False) 

490 

491 try: 

492 auth_response = await self.read_response() 

493 except AuthenticationWrongNumberOfArgsError: 

494 # a username and password were specified but the Redis 

495 # server seems to be < 6.0.0 which expects a single password 

496 # arg. retry auth with just the password. 

497 # https://github.com/andymccurdy/redis-py/issues/1274 

498 await self.send_command("AUTH", auth_args[-1], check_health=False) 

499 auth_response = await self.read_response() 

500 

501 if str_if_bytes(auth_response) != "OK": 

502 raise AuthenticationError("Invalid Username or Password") 

503 

504 # if resp version is specified, switch to it 

505 elif self.protocol not in [2, "2"]: 

506 if isinstance(self._parser, _AsyncRESP2Parser): 

507 self.set_parser(_AsyncRESP3Parser) 

508 # update cluster exception classes 

509 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

510 self._parser.on_connect(self) 

511 await self.send_command("HELLO", self.protocol, check_health=check_health) 

512 response = await self.read_response() 

513 # if response.get(b"proto") != self.protocol and response.get( 

514 # "proto" 

515 # ) != self.protocol: 

516 # raise ConnectionError("Invalid RESP version") 

517 

518 # if a client_name is given, set it 

519 if self.client_name: 

520 await self.send_command( 

521 "CLIENT", 

522 "SETNAME", 

523 self.client_name, 

524 check_health=check_health, 

525 ) 

526 if str_if_bytes(await self.read_response()) != "OK": 

527 raise ConnectionError("Error setting client name") 

528 

529 # Set the library name and version from driver_info, pipeline for lower startup latency 

530 lib_name_sent = False 

531 lib_version_sent = False 

532 

533 if self.driver_info and self.driver_info.formatted_name: 

534 await self.send_command( 

535 "CLIENT", 

536 "SETINFO", 

537 "LIB-NAME", 

538 self.driver_info.formatted_name, 

539 check_health=check_health, 

540 ) 

541 lib_name_sent = True 

542 

543 if self.driver_info and self.driver_info.lib_version: 

544 await self.send_command( 

545 "CLIENT", 

546 "SETINFO", 

547 "LIB-VER", 

548 self.driver_info.lib_version, 

549 check_health=check_health, 

550 ) 

551 lib_version_sent = True 

552 

553 # if a database is specified, switch to it. Also pipeline this 

554 if self.db: 

555 await self.send_command("SELECT", self.db, check_health=check_health) 

556 

557 # read responses from pipeline 

558 for _ in range(sum([lib_name_sent, lib_version_sent])): 

559 try: 

560 await self.read_response() 

561 except ResponseError: 

562 pass 

563 

564 if self.db: 

565 if str_if_bytes(await self.read_response()) != "OK": 

566 raise ConnectionError("Invalid Database") 

567 

568 async def disconnect( 

569 self, 

570 nowait: bool = False, 

571 error: Optional[Exception] = None, 

572 failure_count: Optional[int] = None, 

573 health_check_failed: bool = False, 

574 ) -> None: 

575 """Disconnects from the Redis server""" 

576 try: 

577 async with async_timeout(self.socket_connect_timeout): 

578 self._parser.on_disconnect() 

579 # Reset the reconnect flag 

580 self.reset_should_reconnect() 

581 if not self.is_connected: 

582 return 

583 try: 

584 self._writer.close() # type: ignore[union-attr] 

585 # wait for close to finish, except when handling errors and 

586 # forcefully disconnecting. 

587 if not nowait: 

588 await self._writer.wait_closed() # type: ignore[union-attr] 

589 except OSError: 

590 pass 

591 finally: 

592 self._reader = None 

593 self._writer = None 

594 except asyncio.TimeoutError: 

595 raise TimeoutError( 

596 f"Timed out closing connection after {self.socket_connect_timeout}" 

597 ) from None 

598 

599 if error: 

600 if health_check_failed: 

601 close_reason = CloseReason.HEALTHCHECK_FAILED 

602 else: 

603 close_reason = CloseReason.ERROR 

604 

605 if failure_count is not None and failure_count > self.retry.get_retries(): 

606 await record_error_count( 

607 server_address=getattr(self, "host", None), 

608 server_port=getattr(self, "port", None), 

609 network_peer_address=getattr(self, "host", None), 

610 network_peer_port=getattr(self, "port", None), 

611 error_type=error, 

612 retry_attempts=failure_count, 

613 ) 

614 

615 await record_connection_closed( 

616 close_reason=close_reason, 

617 error_type=error, 

618 ) 

619 else: 

620 await record_connection_closed( 

621 close_reason=CloseReason.APPLICATION_CLOSE, 

622 ) 

623 

624 async def _send_ping(self): 

625 """Send PING, expect PONG in return""" 

626 await self.send_command("PING", check_health=False) 

627 if str_if_bytes(await self.read_response()) != "PONG": 

628 raise ConnectionError("Bad response from PING health check") 

629 

630 async def _ping_failed(self, error, failure_count): 

631 """Function to call when PING fails""" 

632 await self.disconnect( 

633 error=error, failure_count=failure_count, health_check_failed=True 

634 ) 

635 

636 async def check_health(self): 

637 """Check the health of the connection with a PING/PONG""" 

638 if ( 

639 self.health_check_interval 

640 and asyncio.get_running_loop().time() > self.next_health_check 

641 ): 

642 await self.retry.call_with_retry( 

643 self._send_ping, self._ping_failed, with_failure_count=True 

644 ) 

645 

646 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

647 self._writer.writelines(command) 

648 await self._writer.drain() 

649 

650 async def send_packed_command( 

651 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

652 ) -> None: 

653 if not self.is_connected: 

654 await self.connect_check_health(check_health=False) 

655 if check_health: 

656 await self.check_health() 

657 

658 try: 

659 if isinstance(command, str): 

660 command = command.encode() 

661 if isinstance(command, bytes): 

662 command = [command] 

663 if self.socket_timeout: 

664 await asyncio.wait_for( 

665 self._send_packed_command(command), self.socket_timeout 

666 ) 

667 else: 

668 self._writer.writelines(command) 

669 await self._writer.drain() 

670 except asyncio.TimeoutError: 

671 await self.disconnect(nowait=True) 

672 raise TimeoutError("Timeout writing to socket") from None 

673 except OSError as e: 

674 await self.disconnect(nowait=True) 

675 if len(e.args) == 1: 

676 err_no, errmsg = "UNKNOWN", e.args[0] 

677 else: 

678 err_no = e.args[0] 

679 errmsg = e.args[1] 

680 raise ConnectionError( 

681 f"Error {err_no} while writing to socket. {errmsg}." 

682 ) from e 

683 except BaseException: 

684 # BaseExceptions can be raised when a socket send operation is not 

685 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

686 # to send un-sent data. However, the send_packed_command() API 

687 # does not support it so there is no point in keeping the connection open. 

688 await self.disconnect(nowait=True) 

689 raise 

690 

691 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

692 """Pack and send a command to the Redis server""" 

693 await self.send_packed_command( 

694 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

695 ) 

696 

697 async def can_read_destructive(self): 

698 """Poll the socket to see if there's data that can be read.""" 

699 try: 

700 return await self._parser.can_read_destructive() 

701 except OSError as e: 

702 await self.disconnect(nowait=True) 

703 host_error = self._host_error() 

704 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

705 

706 async def read_response( 

707 self, 

708 disable_decoding: bool = False, 

709 timeout: Optional[float] = None, 

710 *, 

711 disconnect_on_error: bool = True, 

712 push_request: Optional[bool] = False, 

713 ): 

714 """Read the response from a previously sent command""" 

715 read_timeout = timeout if timeout is not None else self.socket_timeout 

716 host_error = self._host_error() 

717 try: 

718 if read_timeout is not None and self.protocol in ["3", 3]: 

719 async with async_timeout(read_timeout): 

720 response = await self._parser.read_response( 

721 disable_decoding=disable_decoding, push_request=push_request 

722 ) 

723 elif read_timeout is not None: 

724 async with async_timeout(read_timeout): 

725 response = await self._parser.read_response( 

726 disable_decoding=disable_decoding 

727 ) 

728 elif self.protocol in ["3", 3]: 

729 response = await self._parser.read_response( 

730 disable_decoding=disable_decoding, push_request=push_request 

731 ) 

732 else: 

733 response = await self._parser.read_response( 

734 disable_decoding=disable_decoding 

735 ) 

736 except asyncio.TimeoutError: 

737 if timeout is not None: 

738 # user requested timeout, return None. Operation can be retried 

739 return None 

740 # it was a self.socket_timeout error. 

741 if disconnect_on_error: 

742 await self.disconnect(nowait=True) 

743 raise TimeoutError(f"Timeout reading from {host_error}") 

744 except OSError as e: 

745 if disconnect_on_error: 

746 await self.disconnect(nowait=True) 

747 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

748 except BaseException: 

749 # Also by default close in case of BaseException. A lot of code 

750 # relies on this behaviour when doing Command/Response pairs. 

751 # See #1128. 

752 if disconnect_on_error: 

753 await self.disconnect(nowait=True) 

754 raise 

755 

756 if self.health_check_interval: 

757 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

758 self.next_health_check = next_time 

759 

760 if isinstance(response, ResponseError): 

761 raise response from None 

762 return response 

763 

764 def pack_command(self, *args: EncodableT) -> List[bytes]: 

765 """Pack a series of arguments into the Redis protocol""" 

766 output = [] 

767 # the client might have included 1 or more literal arguments in 

768 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

769 # arguments to be sent separately, so split the first argument 

770 # manually. These arguments should be bytestrings so that they are 

771 # not encoded. 

772 assert not isinstance(args[0], float) 

773 if isinstance(args[0], str): 

774 args = tuple(args[0].encode().split()) + args[1:] 

775 elif b" " in args[0]: 

776 args = tuple(args[0].split()) + args[1:] 

777 

778 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

779 

780 buffer_cutoff = self._buffer_cutoff 

781 for arg in map(self.encoder.encode, args): 

782 # to avoid large string mallocs, chunk the command into the 

783 # output list if we're sending large values or memoryviews 

784 arg_length = len(arg) 

785 if ( 

786 len(buff) > buffer_cutoff 

787 or arg_length > buffer_cutoff 

788 or isinstance(arg, memoryview) 

789 ): 

790 buff = SYM_EMPTY.join( 

791 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

792 ) 

793 output.append(buff) 

794 output.append(arg) 

795 buff = SYM_CRLF 

796 else: 

797 buff = SYM_EMPTY.join( 

798 ( 

799 buff, 

800 SYM_DOLLAR, 

801 str(arg_length).encode(), 

802 SYM_CRLF, 

803 arg, 

804 SYM_CRLF, 

805 ) 

806 ) 

807 output.append(buff) 

808 return output 

809 

810 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

811 """Pack multiple commands into the Redis protocol""" 

812 output: List[bytes] = [] 

813 pieces: List[bytes] = [] 

814 buffer_length = 0 

815 buffer_cutoff = self._buffer_cutoff 

816 

817 for cmd in commands: 

818 for chunk in self.pack_command(*cmd): 

819 chunklen = len(chunk) 

820 if ( 

821 buffer_length > buffer_cutoff 

822 or chunklen > buffer_cutoff 

823 or isinstance(chunk, memoryview) 

824 ): 

825 if pieces: 

826 output.append(SYM_EMPTY.join(pieces)) 

827 buffer_length = 0 

828 pieces = [] 

829 

830 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

831 output.append(chunk) 

832 else: 

833 pieces.append(chunk) 

834 buffer_length += chunklen 

835 

836 if pieces: 

837 output.append(SYM_EMPTY.join(pieces)) 

838 return output 

839 

840 def _socket_is_empty(self): 

841 """Check if the socket is empty""" 

842 return len(self._reader._buffer) == 0 

843 

844 async def process_invalidation_messages(self): 

845 while not self._socket_is_empty(): 

846 await self.read_response(push_request=True) 

847 

848 def set_re_auth_token(self, token: TokenInterface): 

849 self._re_auth_token = token 

850 

851 async def re_auth(self): 

852 if self._re_auth_token is not None: 

853 await self.send_command( 

854 "AUTH", 

855 self._re_auth_token.try_get("oid"), 

856 self._re_auth_token.get_value(), 

857 ) 

858 await self.read_response() 

859 self._re_auth_token = None 

860 

861 

862class Connection(AbstractConnection): 

863 "Manages TCP communication to and from a Redis server" 

864 

865 def __init__( 

866 self, 

867 *, 

868 host: str = "localhost", 

869 port: Union[str, int] = 6379, 

870 socket_keepalive: bool = False, 

871 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

872 socket_type: int = 0, 

873 **kwargs, 

874 ): 

875 self.host = host 

876 self.port = int(port) 

877 self.socket_keepalive = socket_keepalive 

878 self.socket_keepalive_options = socket_keepalive_options or {} 

879 self.socket_type = socket_type 

880 super().__init__(**kwargs) 

881 

882 def repr_pieces(self): 

883 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

884 if self.client_name: 

885 pieces.append(("client_name", self.client_name)) 

886 return pieces 

887 

888 def _connection_arguments(self) -> Mapping: 

889 return {"host": self.host, "port": self.port} 

890 

891 async def _connect(self): 

892 """Create a TCP socket connection""" 

893 async with async_timeout(self.socket_connect_timeout): 

894 reader, writer = await asyncio.open_connection( 

895 **self._connection_arguments() 

896 ) 

897 self._reader = reader 

898 self._writer = writer 

899 sock = writer.transport.get_extra_info("socket") 

900 if sock: 

901 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

902 try: 

903 # TCP_KEEPALIVE 

904 if self.socket_keepalive: 

905 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

906 for k, v in self.socket_keepalive_options.items(): 

907 sock.setsockopt(socket.SOL_TCP, k, v) 

908 

909 except (OSError, TypeError): 

910 # `socket_keepalive_options` might contain invalid options 

911 # causing an error. Do not leave the connection open. 

912 writer.close() 

913 raise 

914 

915 def _host_error(self) -> str: 

916 return f"{self.host}:{self.port}" 

917 

918 

919class SSLConnection(Connection): 

920 """Manages SSL connections to and from the Redis server(s). 

921 This class extends the Connection class, adding SSL functionality, and making 

922 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

923 """ 

924 

925 def __init__( 

926 self, 

927 ssl_keyfile: Optional[str] = None, 

928 ssl_certfile: Optional[str] = None, 

929 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

930 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

931 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

932 ssl_ca_certs: Optional[str] = None, 

933 ssl_ca_data: Optional[str] = None, 

934 ssl_ca_path: Optional[str] = None, 

935 ssl_check_hostname: bool = True, 

936 ssl_min_version: Optional[TLSVersion] = None, 

937 ssl_ciphers: Optional[str] = None, 

938 ssl_password: Optional[str] = None, 

939 **kwargs, 

940 ): 

941 if not SSL_AVAILABLE: 

942 raise RedisError("Python wasn't built with SSL support") 

943 

944 self.ssl_context: RedisSSLContext = RedisSSLContext( 

945 keyfile=ssl_keyfile, 

946 certfile=ssl_certfile, 

947 cert_reqs=ssl_cert_reqs, 

948 include_verify_flags=ssl_include_verify_flags, 

949 exclude_verify_flags=ssl_exclude_verify_flags, 

950 ca_certs=ssl_ca_certs, 

951 ca_data=ssl_ca_data, 

952 ca_path=ssl_ca_path, 

953 check_hostname=ssl_check_hostname, 

954 min_version=ssl_min_version, 

955 ciphers=ssl_ciphers, 

956 password=ssl_password, 

957 ) 

958 super().__init__(**kwargs) 

959 

960 def _connection_arguments(self) -> Mapping: 

961 kwargs = super()._connection_arguments() 

962 kwargs["ssl"] = self.ssl_context.get() 

963 return kwargs 

964 

965 @property 

966 def keyfile(self): 

967 return self.ssl_context.keyfile 

968 

969 @property 

970 def certfile(self): 

971 return self.ssl_context.certfile 

972 

973 @property 

974 def cert_reqs(self): 

975 return self.ssl_context.cert_reqs 

976 

977 @property 

978 def include_verify_flags(self): 

979 return self.ssl_context.include_verify_flags 

980 

981 @property 

982 def exclude_verify_flags(self): 

983 return self.ssl_context.exclude_verify_flags 

984 

985 @property 

986 def ca_certs(self): 

987 return self.ssl_context.ca_certs 

988 

989 @property 

990 def ca_data(self): 

991 return self.ssl_context.ca_data 

992 

993 @property 

994 def check_hostname(self): 

995 return self.ssl_context.check_hostname 

996 

997 @property 

998 def min_version(self): 

999 return self.ssl_context.min_version 

1000 

1001 

1002class RedisSSLContext: 

1003 __slots__ = ( 

1004 "keyfile", 

1005 "certfile", 

1006 "cert_reqs", 

1007 "include_verify_flags", 

1008 "exclude_verify_flags", 

1009 "ca_certs", 

1010 "ca_data", 

1011 "ca_path", 

1012 "context", 

1013 "check_hostname", 

1014 "min_version", 

1015 "ciphers", 

1016 "password", 

1017 ) 

1018 

1019 def __init__( 

1020 self, 

1021 keyfile: Optional[str] = None, 

1022 certfile: Optional[str] = None, 

1023 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

1024 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1025 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1026 ca_certs: Optional[str] = None, 

1027 ca_data: Optional[str] = None, 

1028 ca_path: Optional[str] = None, 

1029 check_hostname: bool = False, 

1030 min_version: Optional[TLSVersion] = None, 

1031 ciphers: Optional[str] = None, 

1032 password: Optional[str] = None, 

1033 ): 

1034 if not SSL_AVAILABLE: 

1035 raise RedisError("Python wasn't built with SSL support") 

1036 

1037 self.keyfile = keyfile 

1038 self.certfile = certfile 

1039 if cert_reqs is None: 

1040 cert_reqs = ssl.CERT_NONE 

1041 elif isinstance(cert_reqs, str): 

1042 CERT_REQS = { # noqa: N806 

1043 "none": ssl.CERT_NONE, 

1044 "optional": ssl.CERT_OPTIONAL, 

1045 "required": ssl.CERT_REQUIRED, 

1046 } 

1047 if cert_reqs not in CERT_REQS: 

1048 raise RedisError( 

1049 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

1050 ) 

1051 cert_reqs = CERT_REQS[cert_reqs] 

1052 self.cert_reqs = cert_reqs 

1053 self.include_verify_flags = include_verify_flags 

1054 self.exclude_verify_flags = exclude_verify_flags 

1055 self.ca_certs = ca_certs 

1056 self.ca_data = ca_data 

1057 self.ca_path = ca_path 

1058 self.check_hostname = ( 

1059 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1060 ) 

1061 self.min_version = min_version 

1062 self.ciphers = ciphers 

1063 self.password = password 

1064 self.context: Optional[SSLContext] = None 

1065 

1066 def get(self) -> SSLContext: 

1067 if not self.context: 

1068 context = ssl.create_default_context() 

1069 context.check_hostname = self.check_hostname 

1070 context.verify_mode = self.cert_reqs 

1071 if self.include_verify_flags: 

1072 for flag in self.include_verify_flags: 

1073 context.verify_flags |= flag 

1074 if self.exclude_verify_flags: 

1075 for flag in self.exclude_verify_flags: 

1076 context.verify_flags &= ~flag 

1077 if self.certfile or self.keyfile: 

1078 context.load_cert_chain( 

1079 certfile=self.certfile, 

1080 keyfile=self.keyfile, 

1081 password=self.password, 

1082 ) 

1083 if self.ca_certs or self.ca_data or self.ca_path: 

1084 context.load_verify_locations( 

1085 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1086 ) 

1087 if self.min_version is not None: 

1088 context.minimum_version = self.min_version 

1089 if self.ciphers is not None: 

1090 context.set_ciphers(self.ciphers) 

1091 self.context = context 

1092 return self.context 

1093 

1094 

1095class UnixDomainSocketConnection(AbstractConnection): 

1096 "Manages UDS communication to and from a Redis server" 

1097 

1098 def __init__(self, *, path: str = "", **kwargs): 

1099 self.path = path 

1100 super().__init__(**kwargs) 

1101 

1102 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

1103 pieces = [("path", self.path), ("db", self.db)] 

1104 if self.client_name: 

1105 pieces.append(("client_name", self.client_name)) 

1106 return pieces 

1107 

1108 async def _connect(self): 

1109 async with async_timeout(self.socket_connect_timeout): 

1110 reader, writer = await asyncio.open_unix_connection(path=self.path) 

1111 self._reader = reader 

1112 self._writer = writer 

1113 await self.on_connect() 

1114 

1115 def _host_error(self) -> str: 

1116 return self.path 

1117 

1118 

1119FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1120 

1121 

1122def to_bool(value) -> Optional[bool]: 

1123 if value is None or value == "": 

1124 return None 

1125 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1126 return False 

1127 return bool(value) 

1128 

1129 

1130def parse_ssl_verify_flags(value): 

1131 # flags are passed in as a string representation of a list, 

1132 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1133 verify_flags_str = value.replace("[", "").replace("]", "") 

1134 

1135 verify_flags = [] 

1136 for flag in verify_flags_str.split(","): 

1137 flag = flag.strip() 

1138 if not hasattr(VerifyFlags, flag): 

1139 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1140 verify_flags.append(getattr(VerifyFlags, flag)) 

1141 return verify_flags 

1142 

1143 

1144URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

1145 { 

1146 "db": int, 

1147 "socket_timeout": float, 

1148 "socket_connect_timeout": float, 

1149 "socket_keepalive": to_bool, 

1150 "retry_on_timeout": to_bool, 

1151 "max_connections": int, 

1152 "health_check_interval": int, 

1153 "ssl_check_hostname": to_bool, 

1154 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1155 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1156 "timeout": float, 

1157 } 

1158) 

1159 

1160 

1161class ConnectKwargs(TypedDict, total=False): 

1162 username: str 

1163 password: str 

1164 connection_class: Type[AbstractConnection] 

1165 host: str 

1166 port: int 

1167 db: int 

1168 path: str 

1169 

1170 

1171def parse_url(url: str) -> ConnectKwargs: 

1172 parsed: ParseResult = urlparse(url) 

1173 kwargs: ConnectKwargs = {} 

1174 

1175 for name, value_list in parse_qs(parsed.query).items(): 

1176 if value_list and len(value_list) > 0: 

1177 value = unquote(value_list[0]) 

1178 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1179 if parser: 

1180 try: 

1181 kwargs[name] = parser(value) 

1182 except (TypeError, ValueError): 

1183 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1184 else: 

1185 kwargs[name] = value 

1186 

1187 if parsed.username: 

1188 kwargs["username"] = unquote(parsed.username) 

1189 if parsed.password: 

1190 kwargs["password"] = unquote(parsed.password) 

1191 

1192 # We only support redis://, rediss:// and unix:// schemes. 

1193 if parsed.scheme == "unix": 

1194 if parsed.path: 

1195 kwargs["path"] = unquote(parsed.path) 

1196 kwargs["connection_class"] = UnixDomainSocketConnection 

1197 

1198 elif parsed.scheme in ("redis", "rediss"): 

1199 if parsed.hostname: 

1200 kwargs["host"] = unquote(parsed.hostname) 

1201 if parsed.port: 

1202 kwargs["port"] = int(parsed.port) 

1203 

1204 # If there's a path argument, use it as the db argument if a 

1205 # querystring value wasn't specified 

1206 if parsed.path and "db" not in kwargs: 

1207 try: 

1208 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1209 except (AttributeError, ValueError): 

1210 pass 

1211 

1212 if parsed.scheme == "rediss": 

1213 kwargs["connection_class"] = SSLConnection 

1214 

1215 else: 

1216 valid_schemes = "redis://, rediss://, unix://" 

1217 raise ValueError( 

1218 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1219 ) 

1220 

1221 return kwargs 

1222 

1223 

1224_CP = TypeVar("_CP", bound="ConnectionPool") 

1225 

1226 

1227class ConnectionPool: 

1228 """ 

1229 Create a connection pool. ``If max_connections`` is set, then this 

1230 object raises :py:class:`~redis.ConnectionError` when the pool's 

1231 limit is reached. 

1232 

1233 By default, TCP connections are created unless ``connection_class`` 

1234 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1235 unix sockets. 

1236 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1237 

1238 Any additional keyword arguments are passed to the constructor of 

1239 ``connection_class``. 

1240 """ 

1241 

1242 @classmethod 

1243 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1244 """ 

1245 Return a connection pool configured from the given URL. 

1246 

1247 For example:: 

1248 

1249 redis://[[username]:[password]]@localhost:6379/0 

1250 rediss://[[username]:[password]]@localhost:6379/0 

1251 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1252 

1253 Three URL schemes are supported: 

1254 

1255 - `redis://` creates a TCP socket connection. See more at: 

1256 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1257 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1258 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1259 - ``unix://``: creates a Unix Domain Socket connection. 

1260 

1261 The username, password, hostname, path and all querystring values 

1262 are passed through urllib.parse.unquote in order to replace any 

1263 percent-encoded values with their corresponding characters. 

1264 

1265 There are several ways to specify a database number. The first value 

1266 found will be used: 

1267 

1268 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1269 

1270 2. If using the redis:// or rediss:// schemes, the path argument 

1271 of the url, e.g. redis://localhost/0 

1272 

1273 3. A ``db`` keyword argument to this function. 

1274 

1275 If none of these options are specified, the default db=0 is used. 

1276 

1277 All querystring options are cast to their appropriate Python types. 

1278 Boolean arguments can be specified with string values "True"/"False" 

1279 or "Yes"/"No". Values that cannot be properly cast cause a 

1280 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1281 and keyword arguments are passed to the ``ConnectionPool``'s 

1282 class initializer. In the case of conflicting arguments, querystring 

1283 arguments always win. 

1284 """ 

1285 url_options = parse_url(url) 

1286 kwargs.update(url_options) 

1287 return cls(**kwargs) 

1288 

1289 def __init__( 

1290 self, 

1291 connection_class: Type[AbstractConnection] = Connection, 

1292 max_connections: Optional[int] = None, 

1293 **connection_kwargs, 

1294 ): 

1295 max_connections = max_connections or 2**31 

1296 if not isinstance(max_connections, int) or max_connections < 0: 

1297 raise ValueError('"max_connections" must be a positive integer') 

1298 

1299 self.connection_class = connection_class 

1300 self.connection_kwargs = connection_kwargs 

1301 self.max_connections = max_connections 

1302 

1303 self._available_connections: List[AbstractConnection] = [] 

1304 self._in_use_connections: Set[AbstractConnection] = set() 

1305 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1306 self._lock = asyncio.Lock() 

1307 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1308 if self._event_dispatcher is None: 

1309 self._event_dispatcher = EventDispatcher() 

1310 

1311 def __repr__(self): 

1312 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1313 return ( 

1314 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1315 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1316 f"({conn_kwargs})>)>" 

1317 ) 

1318 

1319 def reset(self): 

1320 self._available_connections = [] 

1321 self._in_use_connections = weakref.WeakSet() 

1322 

1323 def can_get_connection(self) -> bool: 

1324 """Return True if a connection can be retrieved from the pool.""" 

1325 return ( 

1326 self._available_connections 

1327 or len(self._in_use_connections) < self.max_connections 

1328 ) 

1329 

1330 @deprecated_args( 

1331 args_to_warn=["*"], 

1332 reason="Use get_connection() without args instead", 

1333 version="5.3.0", 

1334 ) 

1335 async def get_connection(self, command_name=None, *keys, **options): 

1336 """Get a connected connection from the pool""" 

1337 # Track connection count before to detect if a new connection is created 

1338 async with self._lock: 

1339 connections_before = len(self._available_connections) + len( 

1340 self._in_use_connections 

1341 ) 

1342 start_time_created = time.monotonic() 

1343 connection = self.get_available_connection() 

1344 connections_after = len(self._available_connections) + len( 

1345 self._in_use_connections 

1346 ) 

1347 is_created = connections_after > connections_before 

1348 

1349 # We now perform the connection check outside of the lock. 

1350 try: 

1351 await self.ensure_connection(connection) 

1352 

1353 if is_created: 

1354 await record_connection_create_time( 

1355 connection_pool=self, 

1356 duration_seconds=time.monotonic() - start_time_created, 

1357 ) 

1358 

1359 return connection 

1360 except BaseException: 

1361 await self.release(connection) 

1362 raise 

1363 

1364 def get_available_connection(self): 

1365 """Get a connection from the pool, without making sure it is connected""" 

1366 try: 

1367 connection = self._available_connections.pop() 

1368 except IndexError: 

1369 if len(self._in_use_connections) >= self.max_connections: 

1370 raise MaxConnectionsError("Too many connections") from None 

1371 connection = self.make_connection() 

1372 self._in_use_connections.add(connection) 

1373 return connection 

1374 

1375 def get_encoder(self): 

1376 """Return an encoder based on encoding settings""" 

1377 kwargs = self.connection_kwargs 

1378 return self.encoder_class( 

1379 encoding=kwargs.get("encoding", "utf-8"), 

1380 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1381 decode_responses=kwargs.get("decode_responses", False), 

1382 ) 

1383 

1384 def make_connection(self): 

1385 """Create a new connection. Can be overridden by child classes.""" 

1386 return self.connection_class(**self.connection_kwargs) 

1387 

1388 async def ensure_connection(self, connection: AbstractConnection): 

1389 """Ensure that the connection object is connected and valid""" 

1390 await connection.connect() 

1391 # connections that the pool provides should be ready to send 

1392 # a command. if not, the connection was either returned to the 

1393 # pool before all data has been read or the socket has been 

1394 # closed. either way, reconnect and verify everything is good. 

1395 try: 

1396 if await connection.can_read_destructive(): 

1397 raise ConnectionError("Connection has data") from None 

1398 except (ConnectionError, TimeoutError, OSError): 

1399 await connection.disconnect() 

1400 await connection.connect() 

1401 if await connection.can_read_destructive(): 

1402 raise ConnectionError("Connection not ready") from None 

1403 

1404 async def release(self, connection: AbstractConnection): 

1405 """Releases the connection back to the pool""" 

1406 # Connections should always be returned to the correct pool, 

1407 # not doing so is an error that will cause an exception here. 

1408 self._in_use_connections.remove(connection) 

1409 if connection.should_reconnect(): 

1410 await connection.disconnect() 

1411 

1412 self._available_connections.append(connection) 

1413 await self._event_dispatcher.dispatch_async( 

1414 AsyncAfterConnectionReleasedEvent(connection) 

1415 ) 

1416 

1417 async def disconnect(self, inuse_connections: bool = True): 

1418 """ 

1419 Disconnects connections in the pool 

1420 

1421 If ``inuse_connections`` is True, disconnect connections that are 

1422 current in use, potentially by other tasks. Otherwise only disconnect 

1423 connections that are idle in the pool. 

1424 """ 

1425 if inuse_connections: 

1426 connections: Iterable[AbstractConnection] = chain( 

1427 self._available_connections, self._in_use_connections 

1428 ) 

1429 else: 

1430 connections = self._available_connections 

1431 resp = await asyncio.gather( 

1432 *(connection.disconnect() for connection in connections), 

1433 return_exceptions=True, 

1434 ) 

1435 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1436 if exc: 

1437 raise exc 

1438 

1439 async def update_active_connections_for_reconnect(self): 

1440 """ 

1441 Mark all active connections for reconnect. 

1442 """ 

1443 async with self._lock: 

1444 for conn in self._in_use_connections: 

1445 conn.mark_for_reconnect() 

1446 

1447 async def aclose(self) -> None: 

1448 """Close the pool, disconnecting all connections""" 

1449 await self.disconnect() 

1450 

1451 def set_retry(self, retry: "Retry") -> None: 

1452 for conn in self._available_connections: 

1453 conn.retry = retry 

1454 for conn in self._in_use_connections: 

1455 conn.retry = retry 

1456 

1457 async def re_auth_callback(self, token: TokenInterface): 

1458 async with self._lock: 

1459 for conn in self._available_connections: 

1460 await conn.retry.call_with_retry( 

1461 lambda: conn.send_command( 

1462 "AUTH", token.try_get("oid"), token.get_value() 

1463 ), 

1464 lambda error: self._mock(error), 

1465 ) 

1466 await conn.retry.call_with_retry( 

1467 lambda: conn.read_response(), lambda error: self._mock(error) 

1468 ) 

1469 for conn in self._in_use_connections: 

1470 conn.set_re_auth_token(token) 

1471 

1472 async def _mock(self, error: RedisError): 

1473 """ 

1474 Dummy functions, needs to be passed as error callback to retry object. 

1475 :param error: 

1476 :return: 

1477 """ 

1478 pass 

1479 

1480 def get_connection_count(self) -> List[tuple[int, dict]]: 

1481 """ 

1482 Returns a connection count (both idle and in use). 

1483 """ 

1484 attributes = AttributeBuilder.build_base_attributes() 

1485 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self) 

1486 free_connections_attributes = attributes.copy() 

1487 in_use_connections_attributes = attributes.copy() 

1488 

1489 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1490 ConnectionState.IDLE.value 

1491 ) 

1492 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1493 ConnectionState.USED.value 

1494 ) 

1495 

1496 return [ 

1497 (len(self._available_connections), free_connections_attributes), 

1498 (len(self._in_use_connections), in_use_connections_attributes), 

1499 ] 

1500 

1501 

1502class BlockingConnectionPool(ConnectionPool): 

1503 """ 

1504 A blocking connection pool:: 

1505 

1506 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1507 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1508 

1509 It performs the same function as the default 

1510 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1511 it maintains a pool of reusable connections that can be shared by 

1512 multiple async redis clients. 

1513 

1514 The difference is that, in the event that a client tries to get a 

1515 connection from the pool when all of connections are in use, rather than 

1516 raising a :py:class:`~redis.ConnectionError` (as the default 

1517 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1518 blocks the current `Task` for a specified number of seconds until 

1519 a connection becomes available. 

1520 

1521 Use ``max_connections`` to increase / decrease the pool size:: 

1522 

1523 >>> pool = BlockingConnectionPool(max_connections=10) 

1524 

1525 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1526 to become available, or to block forever: 

1527 

1528 >>> # Block forever. 

1529 >>> pool = BlockingConnectionPool(timeout=None) 

1530 

1531 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1532 >>> # not available. 

1533 >>> pool = BlockingConnectionPool(timeout=5) 

1534 """ 

1535 

1536 def __init__( 

1537 self, 

1538 max_connections: int = 50, 

1539 timeout: Optional[float] = 20, 

1540 connection_class: Type[AbstractConnection] = Connection, 

1541 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1542 **connection_kwargs, 

1543 ): 

1544 super().__init__( 

1545 connection_class=connection_class, 

1546 max_connections=max_connections, 

1547 **connection_kwargs, 

1548 ) 

1549 self._condition = asyncio.Condition() 

1550 self.timeout = timeout 

1551 

1552 @deprecated_args( 

1553 args_to_warn=["*"], 

1554 reason="Use get_connection() without args instead", 

1555 version="5.3.0", 

1556 ) 

1557 async def get_connection(self, command_name=None, *keys, **options): 

1558 """Gets a connection from the pool, blocking until one is available""" 

1559 # Start timing for wait time observability 

1560 start_time_acquired = time.monotonic() 

1561 

1562 try: 

1563 async with self._condition: 

1564 async with async_timeout(self.timeout): 

1565 await self._condition.wait_for(self.can_get_connection) 

1566 # Track connection count before to detect if a new connection is created 

1567 connections_before = len(self._available_connections) + len( 

1568 self._in_use_connections 

1569 ) 

1570 start_time_created = time.monotonic() 

1571 connection = super().get_available_connection() 

1572 connections_after = len(self._available_connections) + len( 

1573 self._in_use_connections 

1574 ) 

1575 is_created = connections_after > connections_before 

1576 except asyncio.TimeoutError as err: 

1577 raise ConnectionError("No connection available.") from err 

1578 

1579 # We now perform the connection check outside of the lock. 

1580 try: 

1581 await self.ensure_connection(connection) 

1582 

1583 if is_created: 

1584 await record_connection_create_time( 

1585 connection_pool=self, 

1586 duration_seconds=time.monotonic() - start_time_created, 

1587 ) 

1588 

1589 await record_connection_wait_time( 

1590 pool_name=get_pool_name(self), 

1591 duration_seconds=time.monotonic() - start_time_acquired, 

1592 ) 

1593 

1594 return connection 

1595 except BaseException: 

1596 await self.release(connection) 

1597 raise 

1598 

1599 async def release(self, connection: AbstractConnection): 

1600 """Releases the connection back to the pool.""" 

1601 async with self._condition: 

1602 await super().release(connection) 

1603 self._condition.notify()