Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

707 statements  

1import asyncio 

2import copy 

3import enum 

4import inspect 

5import socket 

6import sys 

7import warnings 

8import weakref 

9from abc import abstractmethod 

10from itertools import chain 

11from types import MappingProxyType 

12from typing import ( 

13 Any, 

14 Callable, 

15 Iterable, 

16 List, 

17 Mapping, 

18 Optional, 

19 Protocol, 

20 Set, 

21 Tuple, 

22 Type, 

23 TypedDict, 

24 TypeVar, 

25 Union, 

26) 

27from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

28 

29from ..utils import SSL_AVAILABLE 

30 

31if SSL_AVAILABLE: 

32 import ssl 

33 from ssl import SSLContext, TLSVersion, VerifyFlags 

34else: 

35 ssl = None 

36 TLSVersion = None 

37 SSLContext = None 

38 VerifyFlags = None 

39 

40from ..auth.token import TokenInterface 

41from ..driver_info import DriverInfo, resolve_driver_info 

42from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

43from ..utils import deprecated_args, format_error_message 

44 

45# the functionality is available in 3.11.x but has a major issue before 

46# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

47if sys.version_info >= (3, 11, 3): 

48 from asyncio import timeout as async_timeout 

49else: 

50 from async_timeout import timeout as async_timeout 

51 

52from redis.asyncio.retry import Retry 

53from redis.backoff import NoBackoff 

54from redis.connection import DEFAULT_RESP_VERSION 

55from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

56from redis.exceptions import ( 

57 AuthenticationError, 

58 AuthenticationWrongNumberOfArgsError, 

59 ConnectionError, 

60 DataError, 

61 MaxConnectionsError, 

62 RedisError, 

63 ResponseError, 

64 TimeoutError, 

65) 

66from redis.typing import EncodableT 

67from redis.utils import HIREDIS_AVAILABLE, str_if_bytes 

68 

69from .._parsers import ( 

70 BaseParser, 

71 Encoder, 

72 _AsyncHiredisParser, 

73 _AsyncRESP2Parser, 

74 _AsyncRESP3Parser, 

75) 

76 

77SYM_STAR = b"*" 

78SYM_DOLLAR = b"$" 

79SYM_CRLF = b"\r\n" 

80SYM_LF = b"\n" 

81SYM_EMPTY = b"" 

82 

83 

84class _Sentinel(enum.Enum): 

85 sentinel = object() 

86 

87 

88SENTINEL = _Sentinel.sentinel 

89 

90 

91DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

92if HIREDIS_AVAILABLE: 

93 DefaultParser = _AsyncHiredisParser 

94else: 

95 DefaultParser = _AsyncRESP2Parser 

96 

97 

98class ConnectCallbackProtocol(Protocol): 

99 def __call__(self, connection: "AbstractConnection"): ... 

100 

101 

102class AsyncConnectCallbackProtocol(Protocol): 

103 async def __call__(self, connection: "AbstractConnection"): ... 

104 

105 

106ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

107 

108 

109class AbstractConnection: 

110 """Manages communication to and from a Redis server""" 

111 

112 __slots__ = ( 

113 "db", 

114 "username", 

115 "client_name", 

116 "lib_name", 

117 "lib_version", 

118 "credential_provider", 

119 "password", 

120 "socket_timeout", 

121 "socket_connect_timeout", 

122 "redis_connect_func", 

123 "retry_on_timeout", 

124 "retry_on_error", 

125 "health_check_interval", 

126 "next_health_check", 

127 "last_active_at", 

128 "encoder", 

129 "ssl_context", 

130 "protocol", 

131 "_reader", 

132 "_writer", 

133 "_parser", 

134 "_connect_callbacks", 

135 "_buffer_cutoff", 

136 "_lock", 

137 "_socket_read_size", 

138 "__dict__", 

139 ) 

140 

141 @deprecated_args( 

142 args_to_warn=["lib_name", "lib_version"], 

143 reason="Use 'driver_info' parameter instead. " 

144 "lib_name and lib_version will be removed in a future version.", 

145 ) 

146 def __init__( 

147 self, 

148 *, 

149 db: Union[str, int] = 0, 

150 password: Optional[str] = None, 

151 socket_timeout: Optional[float] = None, 

152 socket_connect_timeout: Optional[float] = None, 

153 retry_on_timeout: bool = False, 

154 retry_on_error: Union[list, _Sentinel] = SENTINEL, 

155 encoding: str = "utf-8", 

156 encoding_errors: str = "strict", 

157 decode_responses: bool = False, 

158 parser_class: Type[BaseParser] = DefaultParser, 

159 socket_read_size: int = 65536, 

160 health_check_interval: float = 0, 

161 client_name: Optional[str] = None, 

162 lib_name: Optional[str] = None, 

163 lib_version: Optional[str] = None, 

164 driver_info: Optional[DriverInfo] = None, 

165 username: Optional[str] = None, 

166 retry: Optional[Retry] = None, 

167 redis_connect_func: Optional[ConnectCallbackT] = None, 

168 encoder_class: Type[Encoder] = Encoder, 

169 credential_provider: Optional[CredentialProvider] = None, 

170 protocol: Optional[int] = 2, 

171 event_dispatcher: Optional[EventDispatcher] = None, 

172 ): 

173 """ 

174 Initialize a new async Connection. 

175 

176 Parameters 

177 ---------- 

178 driver_info : DriverInfo, optional 

179 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

180 are ignored. If not provided, a DriverInfo will be created from lib_name 

181 and lib_version (or defaults if those are also None). 

182 lib_name : str, optional 

183 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

184 lib_version : str, optional 

185 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

186 """ 

187 if (username or password) and credential_provider is not None: 

188 raise DataError( 

189 "'username' and 'password' cannot be passed along with 'credential_" 

190 "provider'. Please provide only one of the following arguments: \n" 

191 "1. 'password' and (optional) 'username'\n" 

192 "2. 'credential_provider'" 

193 ) 

194 if event_dispatcher is None: 

195 self._event_dispatcher = EventDispatcher() 

196 else: 

197 self._event_dispatcher = event_dispatcher 

198 self.db = db 

199 self.client_name = client_name 

200 

201 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

202 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

203 

204 self.credential_provider = credential_provider 

205 self.password = password 

206 self.username = username 

207 self.socket_timeout = socket_timeout 

208 if socket_connect_timeout is None: 

209 socket_connect_timeout = socket_timeout 

210 self.socket_connect_timeout = socket_connect_timeout 

211 self.retry_on_timeout = retry_on_timeout 

212 if retry_on_error is SENTINEL: 

213 retry_on_error = [] 

214 if retry_on_timeout: 

215 retry_on_error.append(TimeoutError) 

216 retry_on_error.append(socket.timeout) 

217 retry_on_error.append(asyncio.TimeoutError) 

218 self.retry_on_error = retry_on_error 

219 if retry or retry_on_error: 

220 if not retry: 

221 self.retry = Retry(NoBackoff(), 1) 

222 else: 

223 # deep-copy the Retry object as it is mutable 

224 self.retry = copy.deepcopy(retry) 

225 # Update the retry's supported errors with the specified errors 

226 self.retry.update_supported_errors(retry_on_error) 

227 else: 

228 self.retry = Retry(NoBackoff(), 0) 

229 self.health_check_interval = health_check_interval 

230 self.next_health_check: float = -1 

231 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

232 self.redis_connect_func = redis_connect_func 

233 self._reader: Optional[asyncio.StreamReader] = None 

234 self._writer: Optional[asyncio.StreamWriter] = None 

235 self._socket_read_size = socket_read_size 

236 self.set_parser(parser_class) 

237 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

238 self._buffer_cutoff = 6000 

239 self._re_auth_token: Optional[TokenInterface] = None 

240 self._should_reconnect = False 

241 

242 try: 

243 p = int(protocol) 

244 except TypeError: 

245 p = DEFAULT_RESP_VERSION 

246 except ValueError: 

247 raise ConnectionError("protocol must be an integer") 

248 finally: 

249 if p < 2 or p > 3: 

250 raise ConnectionError("protocol must be either 2 or 3") 

251 self.protocol = protocol 

252 

253 def __del__(self, _warnings: Any = warnings): 

254 # For some reason, the individual streams don't get properly garbage 

255 # collected and therefore produce no resource warnings. We add one 

256 # here, in the same style as those from the stdlib. 

257 if getattr(self, "_writer", None): 

258 _warnings.warn( 

259 f"unclosed Connection {self!r}", ResourceWarning, source=self 

260 ) 

261 

262 try: 

263 asyncio.get_running_loop() 

264 self._close() 

265 except RuntimeError: 

266 # No actions been taken if pool already closed. 

267 pass 

268 

269 def _close(self): 

270 """ 

271 Internal method to silently close the connection without waiting 

272 """ 

273 if self._writer: 

274 self._writer.close() 

275 self._writer = self._reader = None 

276 

277 def __repr__(self): 

278 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

279 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

280 

281 @abstractmethod 

282 def repr_pieces(self): 

283 pass 

284 

285 @property 

286 def is_connected(self): 

287 return self._reader is not None and self._writer is not None 

288 

289 def register_connect_callback(self, callback): 

290 """ 

291 Register a callback to be called when the connection is established either 

292 initially or reconnected. This allows listeners to issue commands that 

293 are ephemeral to the connection, for example pub/sub subscription or 

294 key tracking. The callback must be a _method_ and will be kept as 

295 a weak reference. 

296 """ 

297 wm = weakref.WeakMethod(callback) 

298 if wm not in self._connect_callbacks: 

299 self._connect_callbacks.append(wm) 

300 

301 def deregister_connect_callback(self, callback): 

302 """ 

303 De-register a previously registered callback. It will no-longer receive 

304 notifications on connection events. Calling this is not required when the 

305 listener goes away, since the callbacks are kept as weak methods. 

306 """ 

307 try: 

308 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

309 except ValueError: 

310 pass 

311 

312 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

313 """ 

314 Creates a new instance of parser_class with socket size: 

315 _socket_read_size and assigns it to the parser for the connection 

316 :param parser_class: The required parser class 

317 """ 

318 self._parser = parser_class(socket_read_size=self._socket_read_size) 

319 

320 async def connect(self): 

321 """Connects to the Redis server if not already connected""" 

322 # try once the socket connect with the handshake, retry the whole 

323 # connect/handshake flow based on retry policy 

324 await self.retry.call_with_retry( 

325 lambda: self.connect_check_health( 

326 check_health=True, retry_socket_connect=False 

327 ), 

328 lambda error: self.disconnect(), 

329 ) 

330 

331 async def connect_check_health( 

332 self, check_health: bool = True, retry_socket_connect: bool = True 

333 ): 

334 if self.is_connected: 

335 return 

336 try: 

337 if retry_socket_connect: 

338 await self.retry.call_with_retry( 

339 lambda: self._connect(), lambda error: self.disconnect() 

340 ) 

341 else: 

342 await self._connect() 

343 except asyncio.CancelledError: 

344 raise # in 3.7 and earlier, this is an Exception, not BaseException 

345 except (socket.timeout, asyncio.TimeoutError): 

346 raise TimeoutError("Timeout connecting to server") 

347 except OSError as e: 

348 raise ConnectionError(self._error_message(e)) 

349 except Exception as exc: 

350 raise ConnectionError(exc) from exc 

351 

352 try: 

353 if not self.redis_connect_func: 

354 # Use the default on_connect function 

355 await self.on_connect_check_health(check_health=check_health) 

356 else: 

357 # Use the passed function redis_connect_func 

358 ( 

359 await self.redis_connect_func(self) 

360 if asyncio.iscoroutinefunction(self.redis_connect_func) 

361 else self.redis_connect_func(self) 

362 ) 

363 except RedisError: 

364 # clean up after any error in on_connect 

365 await self.disconnect() 

366 raise 

367 

368 # run any user callbacks. right now the only internal callback 

369 # is for pubsub channel/pattern resubscription 

370 # first, remove any dead weakrefs 

371 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

372 for ref in self._connect_callbacks: 

373 callback = ref() 

374 task = callback(self) 

375 if task and inspect.isawaitable(task): 

376 await task 

377 

378 def mark_for_reconnect(self): 

379 self._should_reconnect = True 

380 

381 def should_reconnect(self): 

382 return self._should_reconnect 

383 

384 def reset_should_reconnect(self): 

385 self._should_reconnect = False 

386 

387 @abstractmethod 

388 async def _connect(self): 

389 pass 

390 

391 @abstractmethod 

392 def _host_error(self) -> str: 

393 pass 

394 

395 def _error_message(self, exception: BaseException) -> str: 

396 return format_error_message(self._host_error(), exception) 

397 

398 def get_protocol(self): 

399 return self.protocol 

400 

401 async def on_connect(self) -> None: 

402 """Initialize the connection, authenticate and select a database""" 

403 await self.on_connect_check_health(check_health=True) 

404 

405 async def on_connect_check_health(self, check_health: bool = True) -> None: 

406 self._parser.on_connect(self) 

407 parser = self._parser 

408 

409 auth_args = None 

410 # if credential provider or username and/or password are set, authenticate 

411 if self.credential_provider or (self.username or self.password): 

412 cred_provider = ( 

413 self.credential_provider 

414 or UsernamePasswordCredentialProvider(self.username, self.password) 

415 ) 

416 auth_args = await cred_provider.get_credentials_async() 

417 

418 # if resp version is specified and we have auth args, 

419 # we need to send them via HELLO 

420 if auth_args and self.protocol not in [2, "2"]: 

421 if isinstance(self._parser, _AsyncRESP2Parser): 

422 self.set_parser(_AsyncRESP3Parser) 

423 # update cluster exception classes 

424 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

425 self._parser.on_connect(self) 

426 if len(auth_args) == 1: 

427 auth_args = ["default", auth_args[0]] 

428 # avoid checking health here -- PING will fail if we try 

429 # to check the health prior to the AUTH 

430 await self.send_command( 

431 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

432 ) 

433 response = await self.read_response() 

434 if response.get(b"proto") != int(self.protocol) and response.get( 

435 "proto" 

436 ) != int(self.protocol): 

437 raise ConnectionError("Invalid RESP version") 

438 # avoid checking health here -- PING will fail if we try 

439 # to check the health prior to the AUTH 

440 elif auth_args: 

441 await self.send_command("AUTH", *auth_args, check_health=False) 

442 

443 try: 

444 auth_response = await self.read_response() 

445 except AuthenticationWrongNumberOfArgsError: 

446 # a username and password were specified but the Redis 

447 # server seems to be < 6.0.0 which expects a single password 

448 # arg. retry auth with just the password. 

449 # https://github.com/andymccurdy/redis-py/issues/1274 

450 await self.send_command("AUTH", auth_args[-1], check_health=False) 

451 auth_response = await self.read_response() 

452 

453 if str_if_bytes(auth_response) != "OK": 

454 raise AuthenticationError("Invalid Username or Password") 

455 

456 # if resp version is specified, switch to it 

457 elif self.protocol not in [2, "2"]: 

458 if isinstance(self._parser, _AsyncRESP2Parser): 

459 self.set_parser(_AsyncRESP3Parser) 

460 # update cluster exception classes 

461 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

462 self._parser.on_connect(self) 

463 await self.send_command("HELLO", self.protocol, check_health=check_health) 

464 response = await self.read_response() 

465 # if response.get(b"proto") != self.protocol and response.get( 

466 # "proto" 

467 # ) != self.protocol: 

468 # raise ConnectionError("Invalid RESP version") 

469 

470 # if a client_name is given, set it 

471 if self.client_name: 

472 await self.send_command( 

473 "CLIENT", 

474 "SETNAME", 

475 self.client_name, 

476 check_health=check_health, 

477 ) 

478 if str_if_bytes(await self.read_response()) != "OK": 

479 raise ConnectionError("Error setting client name") 

480 

481 # Set the library name and version from driver_info, pipeline for lower startup latency 

482 lib_name_sent = False 

483 lib_version_sent = False 

484 

485 if self.driver_info and self.driver_info.formatted_name: 

486 await self.send_command( 

487 "CLIENT", 

488 "SETINFO", 

489 "LIB-NAME", 

490 self.driver_info.formatted_name, 

491 check_health=check_health, 

492 ) 

493 lib_name_sent = True 

494 

495 if self.driver_info and self.driver_info.lib_version: 

496 await self.send_command( 

497 "CLIENT", 

498 "SETINFO", 

499 "LIB-VER", 

500 self.driver_info.lib_version, 

501 check_health=check_health, 

502 ) 

503 lib_version_sent = True 

504 

505 # if a database is specified, switch to it. Also pipeline this 

506 if self.db: 

507 await self.send_command("SELECT", self.db, check_health=check_health) 

508 

509 # read responses from pipeline 

510 for _ in range(sum([lib_name_sent, lib_version_sent])): 

511 try: 

512 await self.read_response() 

513 except ResponseError: 

514 pass 

515 

516 if self.db: 

517 if str_if_bytes(await self.read_response()) != "OK": 

518 raise ConnectionError("Invalid Database") 

519 

520 async def disconnect(self, nowait: bool = False) -> None: 

521 """Disconnects from the Redis server""" 

522 try: 

523 async with async_timeout(self.socket_connect_timeout): 

524 self._parser.on_disconnect() 

525 # Reset the reconnect flag 

526 self.reset_should_reconnect() 

527 if not self.is_connected: 

528 return 

529 try: 

530 self._writer.close() # type: ignore[union-attr] 

531 # wait for close to finish, except when handling errors and 

532 # forcefully disconnecting. 

533 if not nowait: 

534 await self._writer.wait_closed() # type: ignore[union-attr] 

535 except OSError: 

536 pass 

537 finally: 

538 self._reader = None 

539 self._writer = None 

540 except asyncio.TimeoutError: 

541 raise TimeoutError( 

542 f"Timed out closing connection after {self.socket_connect_timeout}" 

543 ) from None 

544 

545 async def _send_ping(self): 

546 """Send PING, expect PONG in return""" 

547 await self.send_command("PING", check_health=False) 

548 if str_if_bytes(await self.read_response()) != "PONG": 

549 raise ConnectionError("Bad response from PING health check") 

550 

551 async def _ping_failed(self, error): 

552 """Function to call when PING fails""" 

553 await self.disconnect() 

554 

555 async def check_health(self): 

556 """Check the health of the connection with a PING/PONG""" 

557 if ( 

558 self.health_check_interval 

559 and asyncio.get_running_loop().time() > self.next_health_check 

560 ): 

561 await self.retry.call_with_retry(self._send_ping, self._ping_failed) 

562 

563 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

564 self._writer.writelines(command) 

565 await self._writer.drain() 

566 

567 async def send_packed_command( 

568 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

569 ) -> None: 

570 if not self.is_connected: 

571 await self.connect_check_health(check_health=False) 

572 if check_health: 

573 await self.check_health() 

574 

575 try: 

576 if isinstance(command, str): 

577 command = command.encode() 

578 if isinstance(command, bytes): 

579 command = [command] 

580 if self.socket_timeout: 

581 await asyncio.wait_for( 

582 self._send_packed_command(command), self.socket_timeout 

583 ) 

584 else: 

585 self._writer.writelines(command) 

586 await self._writer.drain() 

587 except asyncio.TimeoutError: 

588 await self.disconnect(nowait=True) 

589 raise TimeoutError("Timeout writing to socket") from None 

590 except OSError as e: 

591 await self.disconnect(nowait=True) 

592 if len(e.args) == 1: 

593 err_no, errmsg = "UNKNOWN", e.args[0] 

594 else: 

595 err_no = e.args[0] 

596 errmsg = e.args[1] 

597 raise ConnectionError( 

598 f"Error {err_no} while writing to socket. {errmsg}." 

599 ) from e 

600 except BaseException: 

601 # BaseExceptions can be raised when a socket send operation is not 

602 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

603 # to send un-sent data. However, the send_packed_command() API 

604 # does not support it so there is no point in keeping the connection open. 

605 await self.disconnect(nowait=True) 

606 raise 

607 

608 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

609 """Pack and send a command to the Redis server""" 

610 await self.send_packed_command( 

611 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

612 ) 

613 

614 async def can_read_destructive(self): 

615 """Poll the socket to see if there's data that can be read.""" 

616 try: 

617 return await self._parser.can_read_destructive() 

618 except OSError as e: 

619 await self.disconnect(nowait=True) 

620 host_error = self._host_error() 

621 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

622 

623 async def read_response( 

624 self, 

625 disable_decoding: bool = False, 

626 timeout: Optional[float] = None, 

627 *, 

628 disconnect_on_error: bool = True, 

629 push_request: Optional[bool] = False, 

630 ): 

631 """Read the response from a previously sent command""" 

632 read_timeout = timeout if timeout is not None else self.socket_timeout 

633 host_error = self._host_error() 

634 try: 

635 if read_timeout is not None and self.protocol in ["3", 3]: 

636 async with async_timeout(read_timeout): 

637 response = await self._parser.read_response( 

638 disable_decoding=disable_decoding, push_request=push_request 

639 ) 

640 elif read_timeout is not None: 

641 async with async_timeout(read_timeout): 

642 response = await self._parser.read_response( 

643 disable_decoding=disable_decoding 

644 ) 

645 elif self.protocol in ["3", 3]: 

646 response = await self._parser.read_response( 

647 disable_decoding=disable_decoding, push_request=push_request 

648 ) 

649 else: 

650 response = await self._parser.read_response( 

651 disable_decoding=disable_decoding 

652 ) 

653 except asyncio.TimeoutError: 

654 if timeout is not None: 

655 # user requested timeout, return None. Operation can be retried 

656 return None 

657 # it was a self.socket_timeout error. 

658 if disconnect_on_error: 

659 await self.disconnect(nowait=True) 

660 raise TimeoutError(f"Timeout reading from {host_error}") 

661 except OSError as e: 

662 if disconnect_on_error: 

663 await self.disconnect(nowait=True) 

664 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

665 except BaseException: 

666 # Also by default close in case of BaseException. A lot of code 

667 # relies on this behaviour when doing Command/Response pairs. 

668 # See #1128. 

669 if disconnect_on_error: 

670 await self.disconnect(nowait=True) 

671 raise 

672 

673 if self.health_check_interval: 

674 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

675 self.next_health_check = next_time 

676 

677 if isinstance(response, ResponseError): 

678 raise response from None 

679 return response 

680 

681 def pack_command(self, *args: EncodableT) -> List[bytes]: 

682 """Pack a series of arguments into the Redis protocol""" 

683 output = [] 

684 # the client might have included 1 or more literal arguments in 

685 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

686 # arguments to be sent separately, so split the first argument 

687 # manually. These arguments should be bytestrings so that they are 

688 # not encoded. 

689 assert not isinstance(args[0], float) 

690 if isinstance(args[0], str): 

691 args = tuple(args[0].encode().split()) + args[1:] 

692 elif b" " in args[0]: 

693 args = tuple(args[0].split()) + args[1:] 

694 

695 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

696 

697 buffer_cutoff = self._buffer_cutoff 

698 for arg in map(self.encoder.encode, args): 

699 # to avoid large string mallocs, chunk the command into the 

700 # output list if we're sending large values or memoryviews 

701 arg_length = len(arg) 

702 if ( 

703 len(buff) > buffer_cutoff 

704 or arg_length > buffer_cutoff 

705 or isinstance(arg, memoryview) 

706 ): 

707 buff = SYM_EMPTY.join( 

708 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

709 ) 

710 output.append(buff) 

711 output.append(arg) 

712 buff = SYM_CRLF 

713 else: 

714 buff = SYM_EMPTY.join( 

715 ( 

716 buff, 

717 SYM_DOLLAR, 

718 str(arg_length).encode(), 

719 SYM_CRLF, 

720 arg, 

721 SYM_CRLF, 

722 ) 

723 ) 

724 output.append(buff) 

725 return output 

726 

727 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

728 """Pack multiple commands into the Redis protocol""" 

729 output: List[bytes] = [] 

730 pieces: List[bytes] = [] 

731 buffer_length = 0 

732 buffer_cutoff = self._buffer_cutoff 

733 

734 for cmd in commands: 

735 for chunk in self.pack_command(*cmd): 

736 chunklen = len(chunk) 

737 if ( 

738 buffer_length > buffer_cutoff 

739 or chunklen > buffer_cutoff 

740 or isinstance(chunk, memoryview) 

741 ): 

742 if pieces: 

743 output.append(SYM_EMPTY.join(pieces)) 

744 buffer_length = 0 

745 pieces = [] 

746 

747 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

748 output.append(chunk) 

749 else: 

750 pieces.append(chunk) 

751 buffer_length += chunklen 

752 

753 if pieces: 

754 output.append(SYM_EMPTY.join(pieces)) 

755 return output 

756 

757 def _socket_is_empty(self): 

758 """Check if the socket is empty""" 

759 return len(self._reader._buffer) == 0 

760 

761 async def process_invalidation_messages(self): 

762 while not self._socket_is_empty(): 

763 await self.read_response(push_request=True) 

764 

765 def set_re_auth_token(self, token: TokenInterface): 

766 self._re_auth_token = token 

767 

768 async def re_auth(self): 

769 if self._re_auth_token is not None: 

770 await self.send_command( 

771 "AUTH", 

772 self._re_auth_token.try_get("oid"), 

773 self._re_auth_token.get_value(), 

774 ) 

775 await self.read_response() 

776 self._re_auth_token = None 

777 

778 

779class Connection(AbstractConnection): 

780 "Manages TCP communication to and from a Redis server" 

781 

782 def __init__( 

783 self, 

784 *, 

785 host: str = "localhost", 

786 port: Union[str, int] = 6379, 

787 socket_keepalive: bool = False, 

788 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

789 socket_type: int = 0, 

790 **kwargs, 

791 ): 

792 self.host = host 

793 self.port = int(port) 

794 self.socket_keepalive = socket_keepalive 

795 self.socket_keepalive_options = socket_keepalive_options or {} 

796 self.socket_type = socket_type 

797 super().__init__(**kwargs) 

798 

799 def repr_pieces(self): 

800 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

801 if self.client_name: 

802 pieces.append(("client_name", self.client_name)) 

803 return pieces 

804 

805 def _connection_arguments(self) -> Mapping: 

806 return {"host": self.host, "port": self.port} 

807 

808 async def _connect(self): 

809 """Create a TCP socket connection""" 

810 async with async_timeout(self.socket_connect_timeout): 

811 reader, writer = await asyncio.open_connection( 

812 **self._connection_arguments() 

813 ) 

814 self._reader = reader 

815 self._writer = writer 

816 sock = writer.transport.get_extra_info("socket") 

817 if sock: 

818 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

819 try: 

820 # TCP_KEEPALIVE 

821 if self.socket_keepalive: 

822 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

823 for k, v in self.socket_keepalive_options.items(): 

824 sock.setsockopt(socket.SOL_TCP, k, v) 

825 

826 except (OSError, TypeError): 

827 # `socket_keepalive_options` might contain invalid options 

828 # causing an error. Do not leave the connection open. 

829 writer.close() 

830 raise 

831 

832 def _host_error(self) -> str: 

833 return f"{self.host}:{self.port}" 

834 

835 

836class SSLConnection(Connection): 

837 """Manages SSL connections to and from the Redis server(s). 

838 This class extends the Connection class, adding SSL functionality, and making 

839 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

840 """ 

841 

842 def __init__( 

843 self, 

844 ssl_keyfile: Optional[str] = None, 

845 ssl_certfile: Optional[str] = None, 

846 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

847 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

848 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

849 ssl_ca_certs: Optional[str] = None, 

850 ssl_ca_data: Optional[str] = None, 

851 ssl_ca_path: Optional[str] = None, 

852 ssl_check_hostname: bool = True, 

853 ssl_min_version: Optional[TLSVersion] = None, 

854 ssl_ciphers: Optional[str] = None, 

855 ssl_password: Optional[str] = None, 

856 **kwargs, 

857 ): 

858 if not SSL_AVAILABLE: 

859 raise RedisError("Python wasn't built with SSL support") 

860 

861 self.ssl_context: RedisSSLContext = RedisSSLContext( 

862 keyfile=ssl_keyfile, 

863 certfile=ssl_certfile, 

864 cert_reqs=ssl_cert_reqs, 

865 include_verify_flags=ssl_include_verify_flags, 

866 exclude_verify_flags=ssl_exclude_verify_flags, 

867 ca_certs=ssl_ca_certs, 

868 ca_data=ssl_ca_data, 

869 ca_path=ssl_ca_path, 

870 check_hostname=ssl_check_hostname, 

871 min_version=ssl_min_version, 

872 ciphers=ssl_ciphers, 

873 password=ssl_password, 

874 ) 

875 super().__init__(**kwargs) 

876 

877 def _connection_arguments(self) -> Mapping: 

878 kwargs = super()._connection_arguments() 

879 kwargs["ssl"] = self.ssl_context.get() 

880 return kwargs 

881 

882 @property 

883 def keyfile(self): 

884 return self.ssl_context.keyfile 

885 

886 @property 

887 def certfile(self): 

888 return self.ssl_context.certfile 

889 

890 @property 

891 def cert_reqs(self): 

892 return self.ssl_context.cert_reqs 

893 

894 @property 

895 def include_verify_flags(self): 

896 return self.ssl_context.include_verify_flags 

897 

898 @property 

899 def exclude_verify_flags(self): 

900 return self.ssl_context.exclude_verify_flags 

901 

902 @property 

903 def ca_certs(self): 

904 return self.ssl_context.ca_certs 

905 

906 @property 

907 def ca_data(self): 

908 return self.ssl_context.ca_data 

909 

910 @property 

911 def check_hostname(self): 

912 return self.ssl_context.check_hostname 

913 

914 @property 

915 def min_version(self): 

916 return self.ssl_context.min_version 

917 

918 

919class RedisSSLContext: 

920 __slots__ = ( 

921 "keyfile", 

922 "certfile", 

923 "cert_reqs", 

924 "include_verify_flags", 

925 "exclude_verify_flags", 

926 "ca_certs", 

927 "ca_data", 

928 "ca_path", 

929 "context", 

930 "check_hostname", 

931 "min_version", 

932 "ciphers", 

933 "password", 

934 ) 

935 

936 def __init__( 

937 self, 

938 keyfile: Optional[str] = None, 

939 certfile: Optional[str] = None, 

940 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

941 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

942 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

943 ca_certs: Optional[str] = None, 

944 ca_data: Optional[str] = None, 

945 ca_path: Optional[str] = None, 

946 check_hostname: bool = False, 

947 min_version: Optional[TLSVersion] = None, 

948 ciphers: Optional[str] = None, 

949 password: Optional[str] = None, 

950 ): 

951 if not SSL_AVAILABLE: 

952 raise RedisError("Python wasn't built with SSL support") 

953 

954 self.keyfile = keyfile 

955 self.certfile = certfile 

956 if cert_reqs is None: 

957 cert_reqs = ssl.CERT_NONE 

958 elif isinstance(cert_reqs, str): 

959 CERT_REQS = { # noqa: N806 

960 "none": ssl.CERT_NONE, 

961 "optional": ssl.CERT_OPTIONAL, 

962 "required": ssl.CERT_REQUIRED, 

963 } 

964 if cert_reqs not in CERT_REQS: 

965 raise RedisError( 

966 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

967 ) 

968 cert_reqs = CERT_REQS[cert_reqs] 

969 self.cert_reqs = cert_reqs 

970 self.include_verify_flags = include_verify_flags 

971 self.exclude_verify_flags = exclude_verify_flags 

972 self.ca_certs = ca_certs 

973 self.ca_data = ca_data 

974 self.ca_path = ca_path 

975 self.check_hostname = ( 

976 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

977 ) 

978 self.min_version = min_version 

979 self.ciphers = ciphers 

980 self.password = password 

981 self.context: Optional[SSLContext] = None 

982 

983 def get(self) -> SSLContext: 

984 if not self.context: 

985 context = ssl.create_default_context() 

986 context.check_hostname = self.check_hostname 

987 context.verify_mode = self.cert_reqs 

988 if self.include_verify_flags: 

989 for flag in self.include_verify_flags: 

990 context.verify_flags |= flag 

991 if self.exclude_verify_flags: 

992 for flag in self.exclude_verify_flags: 

993 context.verify_flags &= ~flag 

994 if self.certfile or self.keyfile: 

995 context.load_cert_chain( 

996 certfile=self.certfile, 

997 keyfile=self.keyfile, 

998 password=self.password, 

999 ) 

1000 if self.ca_certs or self.ca_data or self.ca_path: 

1001 context.load_verify_locations( 

1002 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1003 ) 

1004 if self.min_version is not None: 

1005 context.minimum_version = self.min_version 

1006 if self.ciphers is not None: 

1007 context.set_ciphers(self.ciphers) 

1008 self.context = context 

1009 return self.context 

1010 

1011 

1012class UnixDomainSocketConnection(AbstractConnection): 

1013 "Manages UDS communication to and from a Redis server" 

1014 

1015 def __init__(self, *, path: str = "", **kwargs): 

1016 self.path = path 

1017 super().__init__(**kwargs) 

1018 

1019 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

1020 pieces = [("path", self.path), ("db", self.db)] 

1021 if self.client_name: 

1022 pieces.append(("client_name", self.client_name)) 

1023 return pieces 

1024 

1025 async def _connect(self): 

1026 async with async_timeout(self.socket_connect_timeout): 

1027 reader, writer = await asyncio.open_unix_connection(path=self.path) 

1028 self._reader = reader 

1029 self._writer = writer 

1030 await self.on_connect() 

1031 

1032 def _host_error(self) -> str: 

1033 return self.path 

1034 

1035 

1036FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1037 

1038 

1039def to_bool(value) -> Optional[bool]: 

1040 if value is None or value == "": 

1041 return None 

1042 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1043 return False 

1044 return bool(value) 

1045 

1046 

1047def parse_ssl_verify_flags(value): 

1048 # flags are passed in as a string representation of a list, 

1049 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1050 verify_flags_str = value.replace("[", "").replace("]", "") 

1051 

1052 verify_flags = [] 

1053 for flag in verify_flags_str.split(","): 

1054 flag = flag.strip() 

1055 if not hasattr(VerifyFlags, flag): 

1056 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1057 verify_flags.append(getattr(VerifyFlags, flag)) 

1058 return verify_flags 

1059 

1060 

1061URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

1062 { 

1063 "db": int, 

1064 "socket_timeout": float, 

1065 "socket_connect_timeout": float, 

1066 "socket_keepalive": to_bool, 

1067 "retry_on_timeout": to_bool, 

1068 "max_connections": int, 

1069 "health_check_interval": int, 

1070 "ssl_check_hostname": to_bool, 

1071 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1072 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1073 "timeout": float, 

1074 } 

1075) 

1076 

1077 

1078class ConnectKwargs(TypedDict, total=False): 

1079 username: str 

1080 password: str 

1081 connection_class: Type[AbstractConnection] 

1082 host: str 

1083 port: int 

1084 db: int 

1085 path: str 

1086 

1087 

1088def parse_url(url: str) -> ConnectKwargs: 

1089 parsed: ParseResult = urlparse(url) 

1090 kwargs: ConnectKwargs = {} 

1091 

1092 for name, value_list in parse_qs(parsed.query).items(): 

1093 if value_list and len(value_list) > 0: 

1094 value = unquote(value_list[0]) 

1095 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1096 if parser: 

1097 try: 

1098 kwargs[name] = parser(value) 

1099 except (TypeError, ValueError): 

1100 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1101 else: 

1102 kwargs[name] = value 

1103 

1104 if parsed.username: 

1105 kwargs["username"] = unquote(parsed.username) 

1106 if parsed.password: 

1107 kwargs["password"] = unquote(parsed.password) 

1108 

1109 # We only support redis://, rediss:// and unix:// schemes. 

1110 if parsed.scheme == "unix": 

1111 if parsed.path: 

1112 kwargs["path"] = unquote(parsed.path) 

1113 kwargs["connection_class"] = UnixDomainSocketConnection 

1114 

1115 elif parsed.scheme in ("redis", "rediss"): 

1116 if parsed.hostname: 

1117 kwargs["host"] = unquote(parsed.hostname) 

1118 if parsed.port: 

1119 kwargs["port"] = int(parsed.port) 

1120 

1121 # If there's a path argument, use it as the db argument if a 

1122 # querystring value wasn't specified 

1123 if parsed.path and "db" not in kwargs: 

1124 try: 

1125 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1126 except (AttributeError, ValueError): 

1127 pass 

1128 

1129 if parsed.scheme == "rediss": 

1130 kwargs["connection_class"] = SSLConnection 

1131 

1132 else: 

1133 valid_schemes = "redis://, rediss://, unix://" 

1134 raise ValueError( 

1135 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1136 ) 

1137 

1138 return kwargs 

1139 

1140 

1141_CP = TypeVar("_CP", bound="ConnectionPool") 

1142 

1143 

1144class ConnectionPool: 

1145 """ 

1146 Create a connection pool. ``If max_connections`` is set, then this 

1147 object raises :py:class:`~redis.ConnectionError` when the pool's 

1148 limit is reached. 

1149 

1150 By default, TCP connections are created unless ``connection_class`` 

1151 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1152 unix sockets. 

1153 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1154 

1155 Any additional keyword arguments are passed to the constructor of 

1156 ``connection_class``. 

1157 """ 

1158 

1159 @classmethod 

1160 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1161 """ 

1162 Return a connection pool configured from the given URL. 

1163 

1164 For example:: 

1165 

1166 redis://[[username]:[password]]@localhost:6379/0 

1167 rediss://[[username]:[password]]@localhost:6379/0 

1168 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1169 

1170 Three URL schemes are supported: 

1171 

1172 - `redis://` creates a TCP socket connection. See more at: 

1173 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1174 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1175 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1176 - ``unix://``: creates a Unix Domain Socket connection. 

1177 

1178 The username, password, hostname, path and all querystring values 

1179 are passed through urllib.parse.unquote in order to replace any 

1180 percent-encoded values with their corresponding characters. 

1181 

1182 There are several ways to specify a database number. The first value 

1183 found will be used: 

1184 

1185 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1186 

1187 2. If using the redis:// or rediss:// schemes, the path argument 

1188 of the url, e.g. redis://localhost/0 

1189 

1190 3. A ``db`` keyword argument to this function. 

1191 

1192 If none of these options are specified, the default db=0 is used. 

1193 

1194 All querystring options are cast to their appropriate Python types. 

1195 Boolean arguments can be specified with string values "True"/"False" 

1196 or "Yes"/"No". Values that cannot be properly cast cause a 

1197 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1198 and keyword arguments are passed to the ``ConnectionPool``'s 

1199 class initializer. In the case of conflicting arguments, querystring 

1200 arguments always win. 

1201 """ 

1202 url_options = parse_url(url) 

1203 kwargs.update(url_options) 

1204 return cls(**kwargs) 

1205 

1206 def __init__( 

1207 self, 

1208 connection_class: Type[AbstractConnection] = Connection, 

1209 max_connections: Optional[int] = None, 

1210 **connection_kwargs, 

1211 ): 

1212 max_connections = max_connections or 2**31 

1213 if not isinstance(max_connections, int) or max_connections < 0: 

1214 raise ValueError('"max_connections" must be a positive integer') 

1215 

1216 self.connection_class = connection_class 

1217 self.connection_kwargs = connection_kwargs 

1218 self.max_connections = max_connections 

1219 

1220 self._available_connections: List[AbstractConnection] = [] 

1221 self._in_use_connections: Set[AbstractConnection] = set() 

1222 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1223 self._lock = asyncio.Lock() 

1224 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1225 if self._event_dispatcher is None: 

1226 self._event_dispatcher = EventDispatcher() 

1227 

1228 def __repr__(self): 

1229 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1230 return ( 

1231 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1232 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1233 f"({conn_kwargs})>)>" 

1234 ) 

1235 

1236 def reset(self): 

1237 self._available_connections = [] 

1238 self._in_use_connections = weakref.WeakSet() 

1239 

1240 def can_get_connection(self) -> bool: 

1241 """Return True if a connection can be retrieved from the pool.""" 

1242 return ( 

1243 self._available_connections 

1244 or len(self._in_use_connections) < self.max_connections 

1245 ) 

1246 

1247 @deprecated_args( 

1248 args_to_warn=["*"], 

1249 reason="Use get_connection() without args instead", 

1250 version="5.3.0", 

1251 ) 

1252 async def get_connection(self, command_name=None, *keys, **options): 

1253 """Get a connected connection from the pool""" 

1254 async with self._lock: 

1255 connection = self.get_available_connection() 

1256 

1257 # We now perform the connection check outside of the lock. 

1258 try: 

1259 await self.ensure_connection(connection) 

1260 return connection 

1261 except BaseException: 

1262 await self.release(connection) 

1263 raise 

1264 

1265 def get_available_connection(self): 

1266 """Get a connection from the pool, without making sure it is connected""" 

1267 try: 

1268 connection = self._available_connections.pop() 

1269 except IndexError: 

1270 if len(self._in_use_connections) >= self.max_connections: 

1271 raise MaxConnectionsError("Too many connections") from None 

1272 connection = self.make_connection() 

1273 self._in_use_connections.add(connection) 

1274 return connection 

1275 

1276 def get_encoder(self): 

1277 """Return an encoder based on encoding settings""" 

1278 kwargs = self.connection_kwargs 

1279 return self.encoder_class( 

1280 encoding=kwargs.get("encoding", "utf-8"), 

1281 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1282 decode_responses=kwargs.get("decode_responses", False), 

1283 ) 

1284 

1285 def make_connection(self): 

1286 """Create a new connection. Can be overridden by child classes.""" 

1287 return self.connection_class(**self.connection_kwargs) 

1288 

1289 async def ensure_connection(self, connection: AbstractConnection): 

1290 """Ensure that the connection object is connected and valid""" 

1291 await connection.connect() 

1292 # connections that the pool provides should be ready to send 

1293 # a command. if not, the connection was either returned to the 

1294 # pool before all data has been read or the socket has been 

1295 # closed. either way, reconnect and verify everything is good. 

1296 try: 

1297 if await connection.can_read_destructive(): 

1298 raise ConnectionError("Connection has data") from None 

1299 except (ConnectionError, TimeoutError, OSError): 

1300 await connection.disconnect() 

1301 await connection.connect() 

1302 if await connection.can_read_destructive(): 

1303 raise ConnectionError("Connection not ready") from None 

1304 

1305 async def release(self, connection: AbstractConnection): 

1306 """Releases the connection back to the pool""" 

1307 # Connections should always be returned to the correct pool, 

1308 # not doing so is an error that will cause an exception here. 

1309 self._in_use_connections.remove(connection) 

1310 if connection.should_reconnect(): 

1311 await connection.disconnect() 

1312 

1313 self._available_connections.append(connection) 

1314 await self._event_dispatcher.dispatch_async( 

1315 AsyncAfterConnectionReleasedEvent(connection) 

1316 ) 

1317 

1318 async def disconnect(self, inuse_connections: bool = True): 

1319 """ 

1320 Disconnects connections in the pool 

1321 

1322 If ``inuse_connections`` is True, disconnect connections that are 

1323 current in use, potentially by other tasks. Otherwise only disconnect 

1324 connections that are idle in the pool. 

1325 """ 

1326 if inuse_connections: 

1327 connections: Iterable[AbstractConnection] = chain( 

1328 self._available_connections, self._in_use_connections 

1329 ) 

1330 else: 

1331 connections = self._available_connections 

1332 resp = await asyncio.gather( 

1333 *(connection.disconnect() for connection in connections), 

1334 return_exceptions=True, 

1335 ) 

1336 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1337 if exc: 

1338 raise exc 

1339 

1340 async def update_active_connections_for_reconnect(self): 

1341 """ 

1342 Mark all active connections for reconnect. 

1343 """ 

1344 async with self._lock: 

1345 for conn in self._in_use_connections: 

1346 conn.mark_for_reconnect() 

1347 

1348 async def aclose(self) -> None: 

1349 """Close the pool, disconnecting all connections""" 

1350 await self.disconnect() 

1351 

1352 def set_retry(self, retry: "Retry") -> None: 

1353 for conn in self._available_connections: 

1354 conn.retry = retry 

1355 for conn in self._in_use_connections: 

1356 conn.retry = retry 

1357 

1358 async def re_auth_callback(self, token: TokenInterface): 

1359 async with self._lock: 

1360 for conn in self._available_connections: 

1361 await conn.retry.call_with_retry( 

1362 lambda: conn.send_command( 

1363 "AUTH", token.try_get("oid"), token.get_value() 

1364 ), 

1365 lambda error: self._mock(error), 

1366 ) 

1367 await conn.retry.call_with_retry( 

1368 lambda: conn.read_response(), lambda error: self._mock(error) 

1369 ) 

1370 for conn in self._in_use_connections: 

1371 conn.set_re_auth_token(token) 

1372 

1373 async def _mock(self, error: RedisError): 

1374 """ 

1375 Dummy functions, needs to be passed as error callback to retry object. 

1376 :param error: 

1377 :return: 

1378 """ 

1379 pass 

1380 

1381 

1382class BlockingConnectionPool(ConnectionPool): 

1383 """ 

1384 A blocking connection pool:: 

1385 

1386 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1387 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1388 

1389 It performs the same function as the default 

1390 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1391 it maintains a pool of reusable connections that can be shared by 

1392 multiple async redis clients. 

1393 

1394 The difference is that, in the event that a client tries to get a 

1395 connection from the pool when all of connections are in use, rather than 

1396 raising a :py:class:`~redis.ConnectionError` (as the default 

1397 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1398 blocks the current `Task` for a specified number of seconds until 

1399 a connection becomes available. 

1400 

1401 Use ``max_connections`` to increase / decrease the pool size:: 

1402 

1403 >>> pool = BlockingConnectionPool(max_connections=10) 

1404 

1405 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1406 to become available, or to block forever: 

1407 

1408 >>> # Block forever. 

1409 >>> pool = BlockingConnectionPool(timeout=None) 

1410 

1411 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1412 >>> # not available. 

1413 >>> pool = BlockingConnectionPool(timeout=5) 

1414 """ 

1415 

1416 def __init__( 

1417 self, 

1418 max_connections: int = 50, 

1419 timeout: Optional[float] = 20, 

1420 connection_class: Type[AbstractConnection] = Connection, 

1421 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1422 **connection_kwargs, 

1423 ): 

1424 super().__init__( 

1425 connection_class=connection_class, 

1426 max_connections=max_connections, 

1427 **connection_kwargs, 

1428 ) 

1429 self._condition = asyncio.Condition() 

1430 self.timeout = timeout 

1431 

1432 @deprecated_args( 

1433 args_to_warn=["*"], 

1434 reason="Use get_connection() without args instead", 

1435 version="5.3.0", 

1436 ) 

1437 async def get_connection(self, command_name=None, *keys, **options): 

1438 """Gets a connection from the pool, blocking until one is available""" 

1439 try: 

1440 async with self._condition: 

1441 async with async_timeout(self.timeout): 

1442 await self._condition.wait_for(self.can_get_connection) 

1443 connection = super().get_available_connection() 

1444 except asyncio.TimeoutError as err: 

1445 raise ConnectionError("No connection available.") from err 

1446 

1447 # We now perform the connection check outside of the lock. 

1448 try: 

1449 await self.ensure_connection(connection) 

1450 return connection 

1451 except BaseException: 

1452 await self.release(connection) 

1453 raise 

1454 

1455 async def release(self, connection: AbstractConnection): 

1456 """Releases the connection back to the pool.""" 

1457 async with self._condition: 

1458 await super().release(connection) 

1459 self._condition.notify()