Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

704 statements  

1import asyncio 

2import copy 

3import enum 

4import inspect 

5import socket 

6import sys 

7import warnings 

8import weakref 

9from abc import abstractmethod 

10from itertools import chain 

11from types import MappingProxyType 

12from typing import ( 

13 Any, 

14 Callable, 

15 Iterable, 

16 List, 

17 Mapping, 

18 Optional, 

19 Protocol, 

20 Set, 

21 Tuple, 

22 Type, 

23 TypedDict, 

24 TypeVar, 

25 Union, 

26) 

27from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

28 

29from ..utils import SSL_AVAILABLE 

30 

31if SSL_AVAILABLE: 

32 import ssl 

33 from ssl import SSLContext, TLSVersion, VerifyFlags 

34else: 

35 ssl = None 

36 TLSVersion = None 

37 SSLContext = None 

38 VerifyFlags = None 

39 

40from ..auth.token import TokenInterface 

41from ..driver_info import DriverInfo, resolve_driver_info 

42from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

43from ..utils import deprecated_args, format_error_message 

44 

45# the functionality is available in 3.11.x but has a major issue before 

46# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

47if sys.version_info >= (3, 11, 3): 

48 from asyncio import timeout as async_timeout 

49else: 

50 from async_timeout import timeout as async_timeout 

51 

52from redis.asyncio.retry import Retry 

53from redis.backoff import NoBackoff 

54from redis.connection import DEFAULT_RESP_VERSION 

55from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

56from redis.exceptions import ( 

57 AuthenticationError, 

58 AuthenticationWrongNumberOfArgsError, 

59 ConnectionError, 

60 DataError, 

61 MaxConnectionsError, 

62 RedisError, 

63 ResponseError, 

64 TimeoutError, 

65) 

66from redis.typing import EncodableT 

67from redis.utils import HIREDIS_AVAILABLE, str_if_bytes 

68 

69from .._parsers import ( 

70 BaseParser, 

71 Encoder, 

72 _AsyncHiredisParser, 

73 _AsyncRESP2Parser, 

74 _AsyncRESP3Parser, 

75) 

76 

77SYM_STAR = b"*" 

78SYM_DOLLAR = b"$" 

79SYM_CRLF = b"\r\n" 

80SYM_LF = b"\n" 

81SYM_EMPTY = b"" 

82 

83 

84class _Sentinel(enum.Enum): 

85 sentinel = object() 

86 

87 

88SENTINEL = _Sentinel.sentinel 

89 

90 

91DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

92if HIREDIS_AVAILABLE: 

93 DefaultParser = _AsyncHiredisParser 

94else: 

95 DefaultParser = _AsyncRESP2Parser 

96 

97 

98class ConnectCallbackProtocol(Protocol): 

99 def __call__(self, connection: "AbstractConnection"): ... 

100 

101 

102class AsyncConnectCallbackProtocol(Protocol): 

103 async def __call__(self, connection: "AbstractConnection"): ... 

104 

105 

106ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

107 

108 

109class AbstractConnection: 

110 """Manages communication to and from a Redis server""" 

111 

112 __slots__ = ( 

113 "db", 

114 "username", 

115 "client_name", 

116 "lib_name", 

117 "lib_version", 

118 "credential_provider", 

119 "password", 

120 "socket_timeout", 

121 "socket_connect_timeout", 

122 "redis_connect_func", 

123 "retry_on_timeout", 

124 "retry_on_error", 

125 "health_check_interval", 

126 "next_health_check", 

127 "last_active_at", 

128 "encoder", 

129 "ssl_context", 

130 "protocol", 

131 "_reader", 

132 "_writer", 

133 "_parser", 

134 "_connect_callbacks", 

135 "_buffer_cutoff", 

136 "_lock", 

137 "_socket_read_size", 

138 "__dict__", 

139 ) 

140 

141 @deprecated_args( 

142 args_to_warn=["lib_name", "lib_version"], 

143 reason="Use 'driver_info' parameter instead. " 

144 "lib_name and lib_version will be removed in a future version.", 

145 ) 

146 def __init__( 

147 self, 

148 *, 

149 db: Union[str, int] = 0, 

150 password: Optional[str] = None, 

151 socket_timeout: Optional[float] = None, 

152 socket_connect_timeout: Optional[float] = None, 

153 retry_on_timeout: bool = False, 

154 retry_on_error: Union[list, _Sentinel] = SENTINEL, 

155 encoding: str = "utf-8", 

156 encoding_errors: str = "strict", 

157 decode_responses: bool = False, 

158 parser_class: Type[BaseParser] = DefaultParser, 

159 socket_read_size: int = 65536, 

160 health_check_interval: float = 0, 

161 client_name: Optional[str] = None, 

162 lib_name: Optional[str] = None, 

163 lib_version: Optional[str] = None, 

164 driver_info: Optional[DriverInfo] = None, 

165 username: Optional[str] = None, 

166 retry: Optional[Retry] = None, 

167 redis_connect_func: Optional[ConnectCallbackT] = None, 

168 encoder_class: Type[Encoder] = Encoder, 

169 credential_provider: Optional[CredentialProvider] = None, 

170 protocol: Optional[int] = 2, 

171 event_dispatcher: Optional[EventDispatcher] = None, 

172 ): 

173 """ 

174 Initialize a new async Connection. 

175 

176 Parameters 

177 ---------- 

178 driver_info : DriverInfo, optional 

179 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

180 are ignored. If not provided, a DriverInfo will be created from lib_name 

181 and lib_version (or defaults if those are also None). 

182 lib_name : str, optional 

183 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

184 lib_version : str, optional 

185 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

186 """ 

187 if (username or password) and credential_provider is not None: 

188 raise DataError( 

189 "'username' and 'password' cannot be passed along with 'credential_" 

190 "provider'. Please provide only one of the following arguments: \n" 

191 "1. 'password' and (optional) 'username'\n" 

192 "2. 'credential_provider'" 

193 ) 

194 if event_dispatcher is None: 

195 self._event_dispatcher = EventDispatcher() 

196 else: 

197 self._event_dispatcher = event_dispatcher 

198 self.db = db 

199 self.client_name = client_name 

200 

201 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

202 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

203 

204 self.credential_provider = credential_provider 

205 self.password = password 

206 self.username = username 

207 self.socket_timeout = socket_timeout 

208 if socket_connect_timeout is None: 

209 socket_connect_timeout = socket_timeout 

210 self.socket_connect_timeout = socket_connect_timeout 

211 self.retry_on_timeout = retry_on_timeout 

212 if retry_on_error is SENTINEL: 

213 retry_on_error = [] 

214 if retry_on_timeout: 

215 retry_on_error.append(TimeoutError) 

216 retry_on_error.append(socket.timeout) 

217 retry_on_error.append(asyncio.TimeoutError) 

218 self.retry_on_error = retry_on_error 

219 if retry or retry_on_error: 

220 if not retry: 

221 self.retry = Retry(NoBackoff(), 1) 

222 else: 

223 # deep-copy the Retry object as it is mutable 

224 self.retry = copy.deepcopy(retry) 

225 # Update the retry's supported errors with the specified errors 

226 self.retry.update_supported_errors(retry_on_error) 

227 else: 

228 self.retry = Retry(NoBackoff(), 0) 

229 self.health_check_interval = health_check_interval 

230 self.next_health_check: float = -1 

231 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

232 self.redis_connect_func = redis_connect_func 

233 self._reader: Optional[asyncio.StreamReader] = None 

234 self._writer: Optional[asyncio.StreamWriter] = None 

235 self._socket_read_size = socket_read_size 

236 self.set_parser(parser_class) 

237 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

238 self._buffer_cutoff = 6000 

239 self._re_auth_token: Optional[TokenInterface] = None 

240 self._should_reconnect = False 

241 

242 try: 

243 p = int(protocol) 

244 except TypeError: 

245 p = DEFAULT_RESP_VERSION 

246 except ValueError: 

247 raise ConnectionError("protocol must be an integer") 

248 finally: 

249 if p < 2 or p > 3: 

250 raise ConnectionError("protocol must be either 2 or 3") 

251 self.protocol = protocol 

252 

253 def __del__(self, _warnings: Any = warnings): 

254 # For some reason, the individual streams don't get properly garbage 

255 # collected and therefore produce no resource warnings. We add one 

256 # here, in the same style as those from the stdlib. 

257 if getattr(self, "_writer", None): 

258 _warnings.warn( 

259 f"unclosed Connection {self!r}", ResourceWarning, source=self 

260 ) 

261 

262 try: 

263 asyncio.get_running_loop() 

264 self._close() 

265 except RuntimeError: 

266 # No actions been taken if pool already closed. 

267 pass 

268 

269 def _close(self): 

270 """ 

271 Internal method to silently close the connection without waiting 

272 """ 

273 if self._writer: 

274 self._writer.close() 

275 self._writer = self._reader = None 

276 

277 def __repr__(self): 

278 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

279 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

280 

281 @abstractmethod 

282 def repr_pieces(self): 

283 pass 

284 

285 @property 

286 def is_connected(self): 

287 return self._reader is not None and self._writer is not None 

288 

289 def register_connect_callback(self, callback): 

290 """ 

291 Register a callback to be called when the connection is established either 

292 initially or reconnected. This allows listeners to issue commands that 

293 are ephemeral to the connection, for example pub/sub subscription or 

294 key tracking. The callback must be a _method_ and will be kept as 

295 a weak reference. 

296 """ 

297 wm = weakref.WeakMethod(callback) 

298 if wm not in self._connect_callbacks: 

299 self._connect_callbacks.append(wm) 

300 

301 def deregister_connect_callback(self, callback): 

302 """ 

303 De-register a previously registered callback. It will no-longer receive 

304 notifications on connection events. Calling this is not required when the 

305 listener goes away, since the callbacks are kept as weak methods. 

306 """ 

307 try: 

308 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

309 except ValueError: 

310 pass 

311 

312 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

313 """ 

314 Creates a new instance of parser_class with socket size: 

315 _socket_read_size and assigns it to the parser for the connection 

316 :param parser_class: The required parser class 

317 """ 

318 self._parser = parser_class(socket_read_size=self._socket_read_size) 

319 

320 async def connect(self): 

321 """Connects to the Redis server if not already connected""" 

322 # try once the socket connect with the handshake, retry the whole 

323 # connect/handshake flow based on retry policy 

324 await self.retry.call_with_retry( 

325 lambda: self.connect_check_health( 

326 check_health=True, retry_socket_connect=False 

327 ), 

328 lambda error: self.disconnect(), 

329 ) 

330 

331 async def connect_check_health( 

332 self, check_health: bool = True, retry_socket_connect: bool = True 

333 ): 

334 if self.is_connected: 

335 return 

336 try: 

337 if retry_socket_connect: 

338 await self.retry.call_with_retry( 

339 lambda: self._connect(), lambda error: self.disconnect() 

340 ) 

341 else: 

342 await self._connect() 

343 except asyncio.CancelledError: 

344 raise # in 3.7 and earlier, this is an Exception, not BaseException 

345 except (socket.timeout, asyncio.TimeoutError): 

346 raise TimeoutError("Timeout connecting to server") 

347 except OSError as e: 

348 raise ConnectionError(self._error_message(e)) 

349 except Exception as exc: 

350 raise ConnectionError(exc) from exc 

351 

352 try: 

353 if not self.redis_connect_func: 

354 # Use the default on_connect function 

355 await self.on_connect_check_health(check_health=check_health) 

356 else: 

357 # Use the passed function redis_connect_func 

358 ( 

359 await self.redis_connect_func(self) 

360 if asyncio.iscoroutinefunction(self.redis_connect_func) 

361 else self.redis_connect_func(self) 

362 ) 

363 except RedisError: 

364 # clean up after any error in on_connect 

365 await self.disconnect() 

366 raise 

367 

368 # run any user callbacks. right now the only internal callback 

369 # is for pubsub channel/pattern resubscription 

370 # first, remove any dead weakrefs 

371 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

372 for ref in self._connect_callbacks: 

373 callback = ref() 

374 task = callback(self) 

375 if task and inspect.isawaitable(task): 

376 await task 

377 

378 def mark_for_reconnect(self): 

379 self._should_reconnect = True 

380 

381 def should_reconnect(self): 

382 return self._should_reconnect 

383 

384 @abstractmethod 

385 async def _connect(self): 

386 pass 

387 

388 @abstractmethod 

389 def _host_error(self) -> str: 

390 pass 

391 

392 def _error_message(self, exception: BaseException) -> str: 

393 return format_error_message(self._host_error(), exception) 

394 

395 def get_protocol(self): 

396 return self.protocol 

397 

398 async def on_connect(self) -> None: 

399 """Initialize the connection, authenticate and select a database""" 

400 await self.on_connect_check_health(check_health=True) 

401 

402 async def on_connect_check_health(self, check_health: bool = True) -> None: 

403 self._parser.on_connect(self) 

404 parser = self._parser 

405 

406 auth_args = None 

407 # if credential provider or username and/or password are set, authenticate 

408 if self.credential_provider or (self.username or self.password): 

409 cred_provider = ( 

410 self.credential_provider 

411 or UsernamePasswordCredentialProvider(self.username, self.password) 

412 ) 

413 auth_args = await cred_provider.get_credentials_async() 

414 

415 # if resp version is specified and we have auth args, 

416 # we need to send them via HELLO 

417 if auth_args and self.protocol not in [2, "2"]: 

418 if isinstance(self._parser, _AsyncRESP2Parser): 

419 self.set_parser(_AsyncRESP3Parser) 

420 # update cluster exception classes 

421 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

422 self._parser.on_connect(self) 

423 if len(auth_args) == 1: 

424 auth_args = ["default", auth_args[0]] 

425 # avoid checking health here -- PING will fail if we try 

426 # to check the health prior to the AUTH 

427 await self.send_command( 

428 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

429 ) 

430 response = await self.read_response() 

431 if response.get(b"proto") != int(self.protocol) and response.get( 

432 "proto" 

433 ) != int(self.protocol): 

434 raise ConnectionError("Invalid RESP version") 

435 # avoid checking health here -- PING will fail if we try 

436 # to check the health prior to the AUTH 

437 elif auth_args: 

438 await self.send_command("AUTH", *auth_args, check_health=False) 

439 

440 try: 

441 auth_response = await self.read_response() 

442 except AuthenticationWrongNumberOfArgsError: 

443 # a username and password were specified but the Redis 

444 # server seems to be < 6.0.0 which expects a single password 

445 # arg. retry auth with just the password. 

446 # https://github.com/andymccurdy/redis-py/issues/1274 

447 await self.send_command("AUTH", auth_args[-1], check_health=False) 

448 auth_response = await self.read_response() 

449 

450 if str_if_bytes(auth_response) != "OK": 

451 raise AuthenticationError("Invalid Username or Password") 

452 

453 # if resp version is specified, switch to it 

454 elif self.protocol not in [2, "2"]: 

455 if isinstance(self._parser, _AsyncRESP2Parser): 

456 self.set_parser(_AsyncRESP3Parser) 

457 # update cluster exception classes 

458 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

459 self._parser.on_connect(self) 

460 await self.send_command("HELLO", self.protocol, check_health=check_health) 

461 response = await self.read_response() 

462 # if response.get(b"proto") != self.protocol and response.get( 

463 # "proto" 

464 # ) != self.protocol: 

465 # raise ConnectionError("Invalid RESP version") 

466 

467 # if a client_name is given, set it 

468 if self.client_name: 

469 await self.send_command( 

470 "CLIENT", 

471 "SETNAME", 

472 self.client_name, 

473 check_health=check_health, 

474 ) 

475 if str_if_bytes(await self.read_response()) != "OK": 

476 raise ConnectionError("Error setting client name") 

477 

478 # Set the library name and version from driver_info, pipeline for lower startup latency 

479 lib_name_sent = False 

480 lib_version_sent = False 

481 

482 if self.driver_info and self.driver_info.formatted_name: 

483 await self.send_command( 

484 "CLIENT", 

485 "SETINFO", 

486 "LIB-NAME", 

487 self.driver_info.formatted_name, 

488 check_health=check_health, 

489 ) 

490 lib_name_sent = True 

491 

492 if self.driver_info and self.driver_info.lib_version: 

493 await self.send_command( 

494 "CLIENT", 

495 "SETINFO", 

496 "LIB-VER", 

497 self.driver_info.lib_version, 

498 check_health=check_health, 

499 ) 

500 lib_version_sent = True 

501 

502 # if a database is specified, switch to it. Also pipeline this 

503 if self.db: 

504 await self.send_command("SELECT", self.db, check_health=check_health) 

505 

506 # read responses from pipeline 

507 for _ in range(sum([lib_name_sent, lib_version_sent])): 

508 try: 

509 await self.read_response() 

510 except ResponseError: 

511 pass 

512 

513 if self.db: 

514 if str_if_bytes(await self.read_response()) != "OK": 

515 raise ConnectionError("Invalid Database") 

516 

517 async def disconnect(self, nowait: bool = False) -> None: 

518 """Disconnects from the Redis server""" 

519 try: 

520 async with async_timeout(self.socket_connect_timeout): 

521 self._parser.on_disconnect() 

522 if not self.is_connected: 

523 return 

524 try: 

525 self._writer.close() # type: ignore[union-attr] 

526 # wait for close to finish, except when handling errors and 

527 # forcefully disconnecting. 

528 if not nowait: 

529 await self._writer.wait_closed() # type: ignore[union-attr] 

530 except OSError: 

531 pass 

532 finally: 

533 self._reader = None 

534 self._writer = None 

535 except asyncio.TimeoutError: 

536 raise TimeoutError( 

537 f"Timed out closing connection after {self.socket_connect_timeout}" 

538 ) from None 

539 

540 async def _send_ping(self): 

541 """Send PING, expect PONG in return""" 

542 await self.send_command("PING", check_health=False) 

543 if str_if_bytes(await self.read_response()) != "PONG": 

544 raise ConnectionError("Bad response from PING health check") 

545 

546 async def _ping_failed(self, error): 

547 """Function to call when PING fails""" 

548 await self.disconnect() 

549 

550 async def check_health(self): 

551 """Check the health of the connection with a PING/PONG""" 

552 if ( 

553 self.health_check_interval 

554 and asyncio.get_running_loop().time() > self.next_health_check 

555 ): 

556 await self.retry.call_with_retry(self._send_ping, self._ping_failed) 

557 

558 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

559 self._writer.writelines(command) 

560 await self._writer.drain() 

561 

562 async def send_packed_command( 

563 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

564 ) -> None: 

565 if not self.is_connected: 

566 await self.connect_check_health(check_health=False) 

567 if check_health: 

568 await self.check_health() 

569 

570 try: 

571 if isinstance(command, str): 

572 command = command.encode() 

573 if isinstance(command, bytes): 

574 command = [command] 

575 if self.socket_timeout: 

576 await asyncio.wait_for( 

577 self._send_packed_command(command), self.socket_timeout 

578 ) 

579 else: 

580 self._writer.writelines(command) 

581 await self._writer.drain() 

582 except asyncio.TimeoutError: 

583 await self.disconnect(nowait=True) 

584 raise TimeoutError("Timeout writing to socket") from None 

585 except OSError as e: 

586 await self.disconnect(nowait=True) 

587 if len(e.args) == 1: 

588 err_no, errmsg = "UNKNOWN", e.args[0] 

589 else: 

590 err_no = e.args[0] 

591 errmsg = e.args[1] 

592 raise ConnectionError( 

593 f"Error {err_no} while writing to socket. {errmsg}." 

594 ) from e 

595 except BaseException: 

596 # BaseExceptions can be raised when a socket send operation is not 

597 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

598 # to send un-sent data. However, the send_packed_command() API 

599 # does not support it so there is no point in keeping the connection open. 

600 await self.disconnect(nowait=True) 

601 raise 

602 

603 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

604 """Pack and send a command to the Redis server""" 

605 await self.send_packed_command( 

606 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

607 ) 

608 

609 async def can_read_destructive(self): 

610 """Poll the socket to see if there's data that can be read.""" 

611 try: 

612 return await self._parser.can_read_destructive() 

613 except OSError as e: 

614 await self.disconnect(nowait=True) 

615 host_error = self._host_error() 

616 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

617 

618 async def read_response( 

619 self, 

620 disable_decoding: bool = False, 

621 timeout: Optional[float] = None, 

622 *, 

623 disconnect_on_error: bool = True, 

624 push_request: Optional[bool] = False, 

625 ): 

626 """Read the response from a previously sent command""" 

627 read_timeout = timeout if timeout is not None else self.socket_timeout 

628 host_error = self._host_error() 

629 try: 

630 if read_timeout is not None and self.protocol in ["3", 3]: 

631 async with async_timeout(read_timeout): 

632 response = await self._parser.read_response( 

633 disable_decoding=disable_decoding, push_request=push_request 

634 ) 

635 elif read_timeout is not None: 

636 async with async_timeout(read_timeout): 

637 response = await self._parser.read_response( 

638 disable_decoding=disable_decoding 

639 ) 

640 elif self.protocol in ["3", 3]: 

641 response = await self._parser.read_response( 

642 disable_decoding=disable_decoding, push_request=push_request 

643 ) 

644 else: 

645 response = await self._parser.read_response( 

646 disable_decoding=disable_decoding 

647 ) 

648 except asyncio.TimeoutError: 

649 if timeout is not None: 

650 # user requested timeout, return None. Operation can be retried 

651 return None 

652 # it was a self.socket_timeout error. 

653 if disconnect_on_error: 

654 await self.disconnect(nowait=True) 

655 raise TimeoutError(f"Timeout reading from {host_error}") 

656 except OSError as e: 

657 if disconnect_on_error: 

658 await self.disconnect(nowait=True) 

659 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

660 except BaseException: 

661 # Also by default close in case of BaseException. A lot of code 

662 # relies on this behaviour when doing Command/Response pairs. 

663 # See #1128. 

664 if disconnect_on_error: 

665 await self.disconnect(nowait=True) 

666 raise 

667 

668 if self.health_check_interval: 

669 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

670 self.next_health_check = next_time 

671 

672 if isinstance(response, ResponseError): 

673 raise response from None 

674 return response 

675 

676 def pack_command(self, *args: EncodableT) -> List[bytes]: 

677 """Pack a series of arguments into the Redis protocol""" 

678 output = [] 

679 # the client might have included 1 or more literal arguments in 

680 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

681 # arguments to be sent separately, so split the first argument 

682 # manually. These arguments should be bytestrings so that they are 

683 # not encoded. 

684 assert not isinstance(args[0], float) 

685 if isinstance(args[0], str): 

686 args = tuple(args[0].encode().split()) + args[1:] 

687 elif b" " in args[0]: 

688 args = tuple(args[0].split()) + args[1:] 

689 

690 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

691 

692 buffer_cutoff = self._buffer_cutoff 

693 for arg in map(self.encoder.encode, args): 

694 # to avoid large string mallocs, chunk the command into the 

695 # output list if we're sending large values or memoryviews 

696 arg_length = len(arg) 

697 if ( 

698 len(buff) > buffer_cutoff 

699 or arg_length > buffer_cutoff 

700 or isinstance(arg, memoryview) 

701 ): 

702 buff = SYM_EMPTY.join( 

703 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

704 ) 

705 output.append(buff) 

706 output.append(arg) 

707 buff = SYM_CRLF 

708 else: 

709 buff = SYM_EMPTY.join( 

710 ( 

711 buff, 

712 SYM_DOLLAR, 

713 str(arg_length).encode(), 

714 SYM_CRLF, 

715 arg, 

716 SYM_CRLF, 

717 ) 

718 ) 

719 output.append(buff) 

720 return output 

721 

722 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

723 """Pack multiple commands into the Redis protocol""" 

724 output: List[bytes] = [] 

725 pieces: List[bytes] = [] 

726 buffer_length = 0 

727 buffer_cutoff = self._buffer_cutoff 

728 

729 for cmd in commands: 

730 for chunk in self.pack_command(*cmd): 

731 chunklen = len(chunk) 

732 if ( 

733 buffer_length > buffer_cutoff 

734 or chunklen > buffer_cutoff 

735 or isinstance(chunk, memoryview) 

736 ): 

737 if pieces: 

738 output.append(SYM_EMPTY.join(pieces)) 

739 buffer_length = 0 

740 pieces = [] 

741 

742 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

743 output.append(chunk) 

744 else: 

745 pieces.append(chunk) 

746 buffer_length += chunklen 

747 

748 if pieces: 

749 output.append(SYM_EMPTY.join(pieces)) 

750 return output 

751 

752 def _socket_is_empty(self): 

753 """Check if the socket is empty""" 

754 return len(self._reader._buffer) == 0 

755 

756 async def process_invalidation_messages(self): 

757 while not self._socket_is_empty(): 

758 await self.read_response(push_request=True) 

759 

760 def set_re_auth_token(self, token: TokenInterface): 

761 self._re_auth_token = token 

762 

763 async def re_auth(self): 

764 if self._re_auth_token is not None: 

765 await self.send_command( 

766 "AUTH", 

767 self._re_auth_token.try_get("oid"), 

768 self._re_auth_token.get_value(), 

769 ) 

770 await self.read_response() 

771 self._re_auth_token = None 

772 

773 

774class Connection(AbstractConnection): 

775 "Manages TCP communication to and from a Redis server" 

776 

777 def __init__( 

778 self, 

779 *, 

780 host: str = "localhost", 

781 port: Union[str, int] = 6379, 

782 socket_keepalive: bool = False, 

783 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

784 socket_type: int = 0, 

785 **kwargs, 

786 ): 

787 self.host = host 

788 self.port = int(port) 

789 self.socket_keepalive = socket_keepalive 

790 self.socket_keepalive_options = socket_keepalive_options or {} 

791 self.socket_type = socket_type 

792 super().__init__(**kwargs) 

793 

794 def repr_pieces(self): 

795 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

796 if self.client_name: 

797 pieces.append(("client_name", self.client_name)) 

798 return pieces 

799 

800 def _connection_arguments(self) -> Mapping: 

801 return {"host": self.host, "port": self.port} 

802 

803 async def _connect(self): 

804 """Create a TCP socket connection""" 

805 async with async_timeout(self.socket_connect_timeout): 

806 reader, writer = await asyncio.open_connection( 

807 **self._connection_arguments() 

808 ) 

809 self._reader = reader 

810 self._writer = writer 

811 sock = writer.transport.get_extra_info("socket") 

812 if sock: 

813 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

814 try: 

815 # TCP_KEEPALIVE 

816 if self.socket_keepalive: 

817 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

818 for k, v in self.socket_keepalive_options.items(): 

819 sock.setsockopt(socket.SOL_TCP, k, v) 

820 

821 except (OSError, TypeError): 

822 # `socket_keepalive_options` might contain invalid options 

823 # causing an error. Do not leave the connection open. 

824 writer.close() 

825 raise 

826 

827 def _host_error(self) -> str: 

828 return f"{self.host}:{self.port}" 

829 

830 

831class SSLConnection(Connection): 

832 """Manages SSL connections to and from the Redis server(s). 

833 This class extends the Connection class, adding SSL functionality, and making 

834 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

835 """ 

836 

837 def __init__( 

838 self, 

839 ssl_keyfile: Optional[str] = None, 

840 ssl_certfile: Optional[str] = None, 

841 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

842 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

843 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

844 ssl_ca_certs: Optional[str] = None, 

845 ssl_ca_data: Optional[str] = None, 

846 ssl_ca_path: Optional[str] = None, 

847 ssl_check_hostname: bool = True, 

848 ssl_min_version: Optional[TLSVersion] = None, 

849 ssl_ciphers: Optional[str] = None, 

850 ssl_password: Optional[str] = None, 

851 **kwargs, 

852 ): 

853 if not SSL_AVAILABLE: 

854 raise RedisError("Python wasn't built with SSL support") 

855 

856 self.ssl_context: RedisSSLContext = RedisSSLContext( 

857 keyfile=ssl_keyfile, 

858 certfile=ssl_certfile, 

859 cert_reqs=ssl_cert_reqs, 

860 include_verify_flags=ssl_include_verify_flags, 

861 exclude_verify_flags=ssl_exclude_verify_flags, 

862 ca_certs=ssl_ca_certs, 

863 ca_data=ssl_ca_data, 

864 ca_path=ssl_ca_path, 

865 check_hostname=ssl_check_hostname, 

866 min_version=ssl_min_version, 

867 ciphers=ssl_ciphers, 

868 password=ssl_password, 

869 ) 

870 super().__init__(**kwargs) 

871 

872 def _connection_arguments(self) -> Mapping: 

873 kwargs = super()._connection_arguments() 

874 kwargs["ssl"] = self.ssl_context.get() 

875 return kwargs 

876 

877 @property 

878 def keyfile(self): 

879 return self.ssl_context.keyfile 

880 

881 @property 

882 def certfile(self): 

883 return self.ssl_context.certfile 

884 

885 @property 

886 def cert_reqs(self): 

887 return self.ssl_context.cert_reqs 

888 

889 @property 

890 def include_verify_flags(self): 

891 return self.ssl_context.include_verify_flags 

892 

893 @property 

894 def exclude_verify_flags(self): 

895 return self.ssl_context.exclude_verify_flags 

896 

897 @property 

898 def ca_certs(self): 

899 return self.ssl_context.ca_certs 

900 

901 @property 

902 def ca_data(self): 

903 return self.ssl_context.ca_data 

904 

905 @property 

906 def check_hostname(self): 

907 return self.ssl_context.check_hostname 

908 

909 @property 

910 def min_version(self): 

911 return self.ssl_context.min_version 

912 

913 

914class RedisSSLContext: 

915 __slots__ = ( 

916 "keyfile", 

917 "certfile", 

918 "cert_reqs", 

919 "include_verify_flags", 

920 "exclude_verify_flags", 

921 "ca_certs", 

922 "ca_data", 

923 "ca_path", 

924 "context", 

925 "check_hostname", 

926 "min_version", 

927 "ciphers", 

928 "password", 

929 ) 

930 

931 def __init__( 

932 self, 

933 keyfile: Optional[str] = None, 

934 certfile: Optional[str] = None, 

935 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

936 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

937 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

938 ca_certs: Optional[str] = None, 

939 ca_data: Optional[str] = None, 

940 ca_path: Optional[str] = None, 

941 check_hostname: bool = False, 

942 min_version: Optional[TLSVersion] = None, 

943 ciphers: Optional[str] = None, 

944 password: Optional[str] = None, 

945 ): 

946 if not SSL_AVAILABLE: 

947 raise RedisError("Python wasn't built with SSL support") 

948 

949 self.keyfile = keyfile 

950 self.certfile = certfile 

951 if cert_reqs is None: 

952 cert_reqs = ssl.CERT_NONE 

953 elif isinstance(cert_reqs, str): 

954 CERT_REQS = { # noqa: N806 

955 "none": ssl.CERT_NONE, 

956 "optional": ssl.CERT_OPTIONAL, 

957 "required": ssl.CERT_REQUIRED, 

958 } 

959 if cert_reqs not in CERT_REQS: 

960 raise RedisError( 

961 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

962 ) 

963 cert_reqs = CERT_REQS[cert_reqs] 

964 self.cert_reqs = cert_reqs 

965 self.include_verify_flags = include_verify_flags 

966 self.exclude_verify_flags = exclude_verify_flags 

967 self.ca_certs = ca_certs 

968 self.ca_data = ca_data 

969 self.ca_path = ca_path 

970 self.check_hostname = ( 

971 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

972 ) 

973 self.min_version = min_version 

974 self.ciphers = ciphers 

975 self.password = password 

976 self.context: Optional[SSLContext] = None 

977 

978 def get(self) -> SSLContext: 

979 if not self.context: 

980 context = ssl.create_default_context() 

981 context.check_hostname = self.check_hostname 

982 context.verify_mode = self.cert_reqs 

983 if self.include_verify_flags: 

984 for flag in self.include_verify_flags: 

985 context.verify_flags |= flag 

986 if self.exclude_verify_flags: 

987 for flag in self.exclude_verify_flags: 

988 context.verify_flags &= ~flag 

989 if self.certfile or self.keyfile: 

990 context.load_cert_chain( 

991 certfile=self.certfile, 

992 keyfile=self.keyfile, 

993 password=self.password, 

994 ) 

995 if self.ca_certs or self.ca_data or self.ca_path: 

996 context.load_verify_locations( 

997 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

998 ) 

999 if self.min_version is not None: 

1000 context.minimum_version = self.min_version 

1001 if self.ciphers is not None: 

1002 context.set_ciphers(self.ciphers) 

1003 self.context = context 

1004 return self.context 

1005 

1006 

1007class UnixDomainSocketConnection(AbstractConnection): 

1008 "Manages UDS communication to and from a Redis server" 

1009 

1010 def __init__(self, *, path: str = "", **kwargs): 

1011 self.path = path 

1012 super().__init__(**kwargs) 

1013 

1014 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

1015 pieces = [("path", self.path), ("db", self.db)] 

1016 if self.client_name: 

1017 pieces.append(("client_name", self.client_name)) 

1018 return pieces 

1019 

1020 async def _connect(self): 

1021 async with async_timeout(self.socket_connect_timeout): 

1022 reader, writer = await asyncio.open_unix_connection(path=self.path) 

1023 self._reader = reader 

1024 self._writer = writer 

1025 await self.on_connect() 

1026 

1027 def _host_error(self) -> str: 

1028 return self.path 

1029 

1030 

1031FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1032 

1033 

1034def to_bool(value) -> Optional[bool]: 

1035 if value is None or value == "": 

1036 return None 

1037 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1038 return False 

1039 return bool(value) 

1040 

1041 

1042def parse_ssl_verify_flags(value): 

1043 # flags are passed in as a string representation of a list, 

1044 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1045 verify_flags_str = value.replace("[", "").replace("]", "") 

1046 

1047 verify_flags = [] 

1048 for flag in verify_flags_str.split(","): 

1049 flag = flag.strip() 

1050 if not hasattr(VerifyFlags, flag): 

1051 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1052 verify_flags.append(getattr(VerifyFlags, flag)) 

1053 return verify_flags 

1054 

1055 

1056URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

1057 { 

1058 "db": int, 

1059 "socket_timeout": float, 

1060 "socket_connect_timeout": float, 

1061 "socket_keepalive": to_bool, 

1062 "retry_on_timeout": to_bool, 

1063 "max_connections": int, 

1064 "health_check_interval": int, 

1065 "ssl_check_hostname": to_bool, 

1066 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1067 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1068 "timeout": float, 

1069 } 

1070) 

1071 

1072 

1073class ConnectKwargs(TypedDict, total=False): 

1074 username: str 

1075 password: str 

1076 connection_class: Type[AbstractConnection] 

1077 host: str 

1078 port: int 

1079 db: int 

1080 path: str 

1081 

1082 

1083def parse_url(url: str) -> ConnectKwargs: 

1084 parsed: ParseResult = urlparse(url) 

1085 kwargs: ConnectKwargs = {} 

1086 

1087 for name, value_list in parse_qs(parsed.query).items(): 

1088 if value_list and len(value_list) > 0: 

1089 value = unquote(value_list[0]) 

1090 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1091 if parser: 

1092 try: 

1093 kwargs[name] = parser(value) 

1094 except (TypeError, ValueError): 

1095 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1096 else: 

1097 kwargs[name] = value 

1098 

1099 if parsed.username: 

1100 kwargs["username"] = unquote(parsed.username) 

1101 if parsed.password: 

1102 kwargs["password"] = unquote(parsed.password) 

1103 

1104 # We only support redis://, rediss:// and unix:// schemes. 

1105 if parsed.scheme == "unix": 

1106 if parsed.path: 

1107 kwargs["path"] = unquote(parsed.path) 

1108 kwargs["connection_class"] = UnixDomainSocketConnection 

1109 

1110 elif parsed.scheme in ("redis", "rediss"): 

1111 if parsed.hostname: 

1112 kwargs["host"] = unquote(parsed.hostname) 

1113 if parsed.port: 

1114 kwargs["port"] = int(parsed.port) 

1115 

1116 # If there's a path argument, use it as the db argument if a 

1117 # querystring value wasn't specified 

1118 if parsed.path and "db" not in kwargs: 

1119 try: 

1120 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1121 except (AttributeError, ValueError): 

1122 pass 

1123 

1124 if parsed.scheme == "rediss": 

1125 kwargs["connection_class"] = SSLConnection 

1126 

1127 else: 

1128 valid_schemes = "redis://, rediss://, unix://" 

1129 raise ValueError( 

1130 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1131 ) 

1132 

1133 return kwargs 

1134 

1135 

1136_CP = TypeVar("_CP", bound="ConnectionPool") 

1137 

1138 

1139class ConnectionPool: 

1140 """ 

1141 Create a connection pool. ``If max_connections`` is set, then this 

1142 object raises :py:class:`~redis.ConnectionError` when the pool's 

1143 limit is reached. 

1144 

1145 By default, TCP connections are created unless ``connection_class`` 

1146 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1147 unix sockets. 

1148 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1149 

1150 Any additional keyword arguments are passed to the constructor of 

1151 ``connection_class``. 

1152 """ 

1153 

1154 @classmethod 

1155 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1156 """ 

1157 Return a connection pool configured from the given URL. 

1158 

1159 For example:: 

1160 

1161 redis://[[username]:[password]]@localhost:6379/0 

1162 rediss://[[username]:[password]]@localhost:6379/0 

1163 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1164 

1165 Three URL schemes are supported: 

1166 

1167 - `redis://` creates a TCP socket connection. See more at: 

1168 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1169 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1170 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1171 - ``unix://``: creates a Unix Domain Socket connection. 

1172 

1173 The username, password, hostname, path and all querystring values 

1174 are passed through urllib.parse.unquote in order to replace any 

1175 percent-encoded values with their corresponding characters. 

1176 

1177 There are several ways to specify a database number. The first value 

1178 found will be used: 

1179 

1180 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1181 

1182 2. If using the redis:// or rediss:// schemes, the path argument 

1183 of the url, e.g. redis://localhost/0 

1184 

1185 3. A ``db`` keyword argument to this function. 

1186 

1187 If none of these options are specified, the default db=0 is used. 

1188 

1189 All querystring options are cast to their appropriate Python types. 

1190 Boolean arguments can be specified with string values "True"/"False" 

1191 or "Yes"/"No". Values that cannot be properly cast cause a 

1192 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1193 and keyword arguments are passed to the ``ConnectionPool``'s 

1194 class initializer. In the case of conflicting arguments, querystring 

1195 arguments always win. 

1196 """ 

1197 url_options = parse_url(url) 

1198 kwargs.update(url_options) 

1199 return cls(**kwargs) 

1200 

1201 def __init__( 

1202 self, 

1203 connection_class: Type[AbstractConnection] = Connection, 

1204 max_connections: Optional[int] = None, 

1205 **connection_kwargs, 

1206 ): 

1207 max_connections = max_connections or 2**31 

1208 if not isinstance(max_connections, int) or max_connections < 0: 

1209 raise ValueError('"max_connections" must be a positive integer') 

1210 

1211 self.connection_class = connection_class 

1212 self.connection_kwargs = connection_kwargs 

1213 self.max_connections = max_connections 

1214 

1215 self._available_connections: List[AbstractConnection] = [] 

1216 self._in_use_connections: Set[AbstractConnection] = set() 

1217 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1218 self._lock = asyncio.Lock() 

1219 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1220 if self._event_dispatcher is None: 

1221 self._event_dispatcher = EventDispatcher() 

1222 

1223 def __repr__(self): 

1224 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1225 return ( 

1226 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1227 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1228 f"({conn_kwargs})>)>" 

1229 ) 

1230 

1231 def reset(self): 

1232 self._available_connections = [] 

1233 self._in_use_connections = weakref.WeakSet() 

1234 

1235 def can_get_connection(self) -> bool: 

1236 """Return True if a connection can be retrieved from the pool.""" 

1237 return ( 

1238 self._available_connections 

1239 or len(self._in_use_connections) < self.max_connections 

1240 ) 

1241 

1242 @deprecated_args( 

1243 args_to_warn=["*"], 

1244 reason="Use get_connection() without args instead", 

1245 version="5.3.0", 

1246 ) 

1247 async def get_connection(self, command_name=None, *keys, **options): 

1248 """Get a connected connection from the pool""" 

1249 async with self._lock: 

1250 connection = self.get_available_connection() 

1251 

1252 # We now perform the connection check outside of the lock. 

1253 try: 

1254 await self.ensure_connection(connection) 

1255 return connection 

1256 except BaseException: 

1257 await self.release(connection) 

1258 raise 

1259 

1260 def get_available_connection(self): 

1261 """Get a connection from the pool, without making sure it is connected""" 

1262 try: 

1263 connection = self._available_connections.pop() 

1264 except IndexError: 

1265 if len(self._in_use_connections) >= self.max_connections: 

1266 raise MaxConnectionsError("Too many connections") from None 

1267 connection = self.make_connection() 

1268 self._in_use_connections.add(connection) 

1269 return connection 

1270 

1271 def get_encoder(self): 

1272 """Return an encoder based on encoding settings""" 

1273 kwargs = self.connection_kwargs 

1274 return self.encoder_class( 

1275 encoding=kwargs.get("encoding", "utf-8"), 

1276 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1277 decode_responses=kwargs.get("decode_responses", False), 

1278 ) 

1279 

1280 def make_connection(self): 

1281 """Create a new connection. Can be overridden by child classes.""" 

1282 return self.connection_class(**self.connection_kwargs) 

1283 

1284 async def ensure_connection(self, connection: AbstractConnection): 

1285 """Ensure that the connection object is connected and valid""" 

1286 await connection.connect() 

1287 # connections that the pool provides should be ready to send 

1288 # a command. if not, the connection was either returned to the 

1289 # pool before all data has been read or the socket has been 

1290 # closed. either way, reconnect and verify everything is good. 

1291 try: 

1292 if await connection.can_read_destructive(): 

1293 raise ConnectionError("Connection has data") from None 

1294 except (ConnectionError, TimeoutError, OSError): 

1295 await connection.disconnect() 

1296 await connection.connect() 

1297 if await connection.can_read_destructive(): 

1298 raise ConnectionError("Connection not ready") from None 

1299 

1300 async def release(self, connection: AbstractConnection): 

1301 """Releases the connection back to the pool""" 

1302 # Connections should always be returned to the correct pool, 

1303 # not doing so is an error that will cause an exception here. 

1304 self._in_use_connections.remove(connection) 

1305 if connection.should_reconnect(): 

1306 await connection.disconnect() 

1307 

1308 self._available_connections.append(connection) 

1309 await self._event_dispatcher.dispatch_async( 

1310 AsyncAfterConnectionReleasedEvent(connection) 

1311 ) 

1312 

1313 async def disconnect(self, inuse_connections: bool = True): 

1314 """ 

1315 Disconnects connections in the pool 

1316 

1317 If ``inuse_connections`` is True, disconnect connections that are 

1318 current in use, potentially by other tasks. Otherwise only disconnect 

1319 connections that are idle in the pool. 

1320 """ 

1321 if inuse_connections: 

1322 connections: Iterable[AbstractConnection] = chain( 

1323 self._available_connections, self._in_use_connections 

1324 ) 

1325 else: 

1326 connections = self._available_connections 

1327 resp = await asyncio.gather( 

1328 *(connection.disconnect() for connection in connections), 

1329 return_exceptions=True, 

1330 ) 

1331 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1332 if exc: 

1333 raise exc 

1334 

1335 async def update_active_connections_for_reconnect(self): 

1336 """ 

1337 Mark all active connections for reconnect. 

1338 """ 

1339 async with self._lock: 

1340 for conn in self._in_use_connections: 

1341 conn.mark_for_reconnect() 

1342 

1343 async def aclose(self) -> None: 

1344 """Close the pool, disconnecting all connections""" 

1345 await self.disconnect() 

1346 

1347 def set_retry(self, retry: "Retry") -> None: 

1348 for conn in self._available_connections: 

1349 conn.retry = retry 

1350 for conn in self._in_use_connections: 

1351 conn.retry = retry 

1352 

1353 async def re_auth_callback(self, token: TokenInterface): 

1354 async with self._lock: 

1355 for conn in self._available_connections: 

1356 await conn.retry.call_with_retry( 

1357 lambda: conn.send_command( 

1358 "AUTH", token.try_get("oid"), token.get_value() 

1359 ), 

1360 lambda error: self._mock(error), 

1361 ) 

1362 await conn.retry.call_with_retry( 

1363 lambda: conn.read_response(), lambda error: self._mock(error) 

1364 ) 

1365 for conn in self._in_use_connections: 

1366 conn.set_re_auth_token(token) 

1367 

1368 async def _mock(self, error: RedisError): 

1369 """ 

1370 Dummy functions, needs to be passed as error callback to retry object. 

1371 :param error: 

1372 :return: 

1373 """ 

1374 pass 

1375 

1376 

1377class BlockingConnectionPool(ConnectionPool): 

1378 """ 

1379 A blocking connection pool:: 

1380 

1381 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1382 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1383 

1384 It performs the same function as the default 

1385 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1386 it maintains a pool of reusable connections that can be shared by 

1387 multiple async redis clients. 

1388 

1389 The difference is that, in the event that a client tries to get a 

1390 connection from the pool when all of connections are in use, rather than 

1391 raising a :py:class:`~redis.ConnectionError` (as the default 

1392 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1393 blocks the current `Task` for a specified number of seconds until 

1394 a connection becomes available. 

1395 

1396 Use ``max_connections`` to increase / decrease the pool size:: 

1397 

1398 >>> pool = BlockingConnectionPool(max_connections=10) 

1399 

1400 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1401 to become available, or to block forever: 

1402 

1403 >>> # Block forever. 

1404 >>> pool = BlockingConnectionPool(timeout=None) 

1405 

1406 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1407 >>> # not available. 

1408 >>> pool = BlockingConnectionPool(timeout=5) 

1409 """ 

1410 

1411 def __init__( 

1412 self, 

1413 max_connections: int = 50, 

1414 timeout: Optional[float] = 20, 

1415 connection_class: Type[AbstractConnection] = Connection, 

1416 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1417 **connection_kwargs, 

1418 ): 

1419 super().__init__( 

1420 connection_class=connection_class, 

1421 max_connections=max_connections, 

1422 **connection_kwargs, 

1423 ) 

1424 self._condition = asyncio.Condition() 

1425 self.timeout = timeout 

1426 

1427 @deprecated_args( 

1428 args_to_warn=["*"], 

1429 reason="Use get_connection() without args instead", 

1430 version="5.3.0", 

1431 ) 

1432 async def get_connection(self, command_name=None, *keys, **options): 

1433 """Gets a connection from the pool, blocking until one is available""" 

1434 try: 

1435 async with self._condition: 

1436 async with async_timeout(self.timeout): 

1437 await self._condition.wait_for(self.can_get_connection) 

1438 connection = super().get_available_connection() 

1439 except asyncio.TimeoutError as err: 

1440 raise ConnectionError("No connection available.") from err 

1441 

1442 # We now perform the connection check outside of the lock. 

1443 try: 

1444 await self.ensure_connection(connection) 

1445 return connection 

1446 except BaseException: 

1447 await self.release(connection) 

1448 raise 

1449 

1450 async def release(self, connection: AbstractConnection): 

1451 """Releases the connection back to the pool.""" 

1452 async with self._condition: 

1453 await super().release(connection) 

1454 self._condition.notify()