Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

825 statements  

1import asyncio 

2import copy 

3import enum 

4import inspect 

5import socket 

6import sys 

7import time 

8import warnings 

9import weakref 

10from abc import ABC, abstractmethod 

11from itertools import chain 

12from types import MappingProxyType 

13from typing import ( 

14 Any, 

15 Callable, 

16 Iterable, 

17 List, 

18 Mapping, 

19 Optional, 

20 Protocol, 

21 Set, 

22 Tuple, 

23 Type, 

24 TypedDict, 

25 TypeVar, 

26 Union, 

27) 

28from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

29 

30from ..observability.attributes import ( 

31 DB_CLIENT_CONNECTION_POOL_NAME, 

32 DB_CLIENT_CONNECTION_STATE, 

33 AttributeBuilder, 

34 ConnectionState, 

35 get_pool_name, 

36) 

37from ..utils import SSL_AVAILABLE 

38 

39if SSL_AVAILABLE: 

40 import ssl 

41 from ssl import SSLContext, TLSVersion, VerifyFlags 

42else: 

43 ssl = None 

44 TLSVersion = None 

45 SSLContext = None 

46 VerifyFlags = None 

47 

48from ..auth.token import TokenInterface 

49from ..driver_info import DriverInfo, resolve_driver_info 

50from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

51from ..utils import deprecated_args, format_error_message 

52 

53# the functionality is available in 3.11.x but has a major issue before 

54# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

55if sys.version_info >= (3, 11, 3): 

56 from asyncio import timeout as async_timeout 

57else: 

58 from async_timeout import timeout as async_timeout 

59 

60from redis.asyncio.observability.recorder import ( 

61 record_connection_closed, 

62 record_connection_count, 

63 record_connection_create_time, 

64 record_connection_wait_time, 

65 record_error_count, 

66) 

67from redis.asyncio.retry import Retry 

68from redis.backoff import NoBackoff 

69from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

70from redis.exceptions import ( 

71 AuthenticationError, 

72 AuthenticationWrongNumberOfArgsError, 

73 ConnectionError, 

74 DataError, 

75 MaxConnectionsError, 

76 RedisError, 

77 ResponseError, 

78 TimeoutError, 

79) 

80from redis.observability.metrics import CloseReason 

81from redis.typing import EncodableT 

82from redis.utils import DEFAULT_RESP_VERSION, HIREDIS_AVAILABLE, str_if_bytes 

83 

84from .._parsers import ( 

85 BaseParser, 

86 Encoder, 

87 _AsyncHiredisParser, 

88 _AsyncRESP2Parser, 

89 _AsyncRESP3Parser, 

90) 

91 

92SYM_STAR = b"*" 

93SYM_DOLLAR = b"$" 

94SYM_CRLF = b"\r\n" 

95SYM_LF = b"\n" 

96SYM_EMPTY = b"" 

97 

98 

99class _Sentinel(enum.Enum): 

100 sentinel = object() 

101 

102 

103SENTINEL = _Sentinel.sentinel 

104 

105 

106DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

107if HIREDIS_AVAILABLE: 

108 DefaultParser = _AsyncHiredisParser 

109else: 

110 DefaultParser = _AsyncRESP3Parser 

111 

112 

113class ConnectCallbackProtocol(Protocol): 

114 def __call__(self, connection: "AbstractConnection"): ... 

115 

116 

117class AsyncConnectCallbackProtocol(Protocol): 

118 async def __call__(self, connection: "AbstractConnection"): ... 

119 

120 

121ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

122 

123 

124class AbstractConnection: 

125 """Manages communication to and from a Redis server""" 

126 

127 __slots__ = ( 

128 "db", 

129 "username", 

130 "client_name", 

131 "lib_name", 

132 "lib_version", 

133 "credential_provider", 

134 "password", 

135 "socket_timeout", 

136 "socket_connect_timeout", 

137 "redis_connect_func", 

138 "retry_on_timeout", 

139 "retry_on_error", 

140 "health_check_interval", 

141 "next_health_check", 

142 "last_active_at", 

143 "encoder", 

144 "ssl_context", 

145 "protocol", 

146 "_reader", 

147 "_writer", 

148 "_parser", 

149 "_connect_callbacks", 

150 "_buffer_cutoff", 

151 "_lock", 

152 "_socket_read_size", 

153 "__dict__", 

154 ) 

155 

156 @deprecated_args( 

157 args_to_warn=["lib_name", "lib_version"], 

158 reason="Use 'driver_info' parameter instead. " 

159 "lib_name and lib_version will be removed in a future version.", 

160 ) 

161 def __init__( 

162 self, 

163 *, 

164 db: Union[str, int] = 0, 

165 password: Optional[str] = None, 

166 socket_timeout: Optional[float] = None, 

167 socket_connect_timeout: Optional[float] = None, 

168 retry_on_timeout: bool = False, 

169 retry_on_error: Union[list, _Sentinel] = SENTINEL, 

170 encoding: str = "utf-8", 

171 encoding_errors: str = "strict", 

172 decode_responses: bool = False, 

173 parser_class: Type[BaseParser] = DefaultParser, 

174 socket_read_size: int = 65536, 

175 health_check_interval: float = 0, 

176 client_name: Optional[str] = None, 

177 lib_name: Optional[str] = None, 

178 lib_version: Optional[str] = None, 

179 driver_info: Optional[DriverInfo] = None, 

180 username: Optional[str] = None, 

181 retry: Optional[Retry] = None, 

182 redis_connect_func: Optional[ConnectCallbackT] = None, 

183 encoder_class: Type[Encoder] = Encoder, 

184 credential_provider: Optional[CredentialProvider] = None, 

185 protocol: Optional[int] = 3, 

186 event_dispatcher: Optional[EventDispatcher] = None, 

187 ): 

188 """ 

189 Initialize a new async Connection. 

190 

191 Parameters 

192 ---------- 

193 driver_info : DriverInfo, optional 

194 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

195 are ignored. If not provided, a DriverInfo will be created from lib_name 

196 and lib_version (or defaults if those are also None). 

197 lib_name : str, optional 

198 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

199 lib_version : str, optional 

200 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

201 """ 

202 if (username or password) and credential_provider is not None: 

203 raise DataError( 

204 "'username' and 'password' cannot be passed along with 'credential_" 

205 "provider'. Please provide only one of the following arguments: \n" 

206 "1. 'password' and (optional) 'username'\n" 

207 "2. 'credential_provider'" 

208 ) 

209 if event_dispatcher is None: 

210 self._event_dispatcher = EventDispatcher() 

211 else: 

212 self._event_dispatcher = event_dispatcher 

213 self.db = db 

214 self.client_name = client_name 

215 

216 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

217 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

218 

219 self.credential_provider = credential_provider 

220 self.password = password 

221 self.username = username 

222 self.socket_timeout = socket_timeout 

223 if socket_connect_timeout is None: 

224 socket_connect_timeout = socket_timeout 

225 self.socket_connect_timeout = socket_connect_timeout 

226 self.retry_on_timeout = retry_on_timeout 

227 if retry_on_error is SENTINEL: 

228 retry_on_error = [] 

229 if retry_on_timeout: 

230 retry_on_error.append(TimeoutError) 

231 retry_on_error.append(socket.timeout) 

232 retry_on_error.append(asyncio.TimeoutError) 

233 self.retry_on_error = retry_on_error 

234 if retry or retry_on_error: 

235 if not retry: 

236 self.retry = Retry(NoBackoff(), 1) 

237 else: 

238 # deep-copy the Retry object as it is mutable 

239 self.retry = copy.deepcopy(retry) 

240 # Update the retry's supported errors with the specified errors 

241 self.retry.update_supported_errors(retry_on_error) 

242 else: 

243 self.retry = Retry(NoBackoff(), 0) 

244 self.health_check_interval = health_check_interval 

245 self.next_health_check: float = -1 

246 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

247 self.redis_connect_func = redis_connect_func 

248 self._reader: Optional[asyncio.StreamReader] = None 

249 self._writer: Optional[asyncio.StreamWriter] = None 

250 self._socket_read_size = socket_read_size 

251 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

252 self._buffer_cutoff = 6000 

253 self._re_auth_token: Optional[TokenInterface] = None 

254 self._should_reconnect = False 

255 

256 try: 

257 p = int(protocol) 

258 except TypeError: 

259 p = DEFAULT_RESP_VERSION 

260 except ValueError: 

261 raise ConnectionError("protocol must be an integer") 

262 else: 

263 if p < 2 or p > 3: 

264 raise ConnectionError("protocol must be either 2 or 3") 

265 self.protocol = p 

266 # Reconcile parser ↔ protocol mismatches. 

267 # Hiredis handles both RESP2 and RESP3 natively, so only 

268 # pure-Python parsers need to be swapped. 

269 if self.protocol == 3 and parser_class == _AsyncRESP2Parser: 

270 parser_class = _AsyncRESP3Parser 

271 elif self.protocol == 2 and parser_class == _AsyncRESP3Parser: 

272 parser_class = _AsyncRESP2Parser 

273 self.set_parser(parser_class) 

274 

275 def __del__(self, _warnings: Any = warnings): 

276 # For some reason, the individual streams don't get properly garbage 

277 # collected and therefore produce no resource warnings. We add one 

278 # here, in the same style as those from the stdlib. 

279 if getattr(self, "_writer", None): 

280 _warnings.warn( 

281 f"unclosed Connection {self!r}", ResourceWarning, source=self 

282 ) 

283 

284 try: 

285 asyncio.get_running_loop() 

286 self._close() 

287 except RuntimeError: 

288 # No actions been taken if pool already closed. 

289 pass 

290 

291 def _close(self): 

292 """ 

293 Internal method to silently close the connection without waiting 

294 """ 

295 if self._writer: 

296 self._writer.close() 

297 self._writer = self._reader = None 

298 

299 def __repr__(self): 

300 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

301 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

302 

303 @abstractmethod 

304 def repr_pieces(self): 

305 pass 

306 

307 @property 

308 def is_connected(self): 

309 return self._reader is not None and self._writer is not None 

310 

311 def register_connect_callback(self, callback): 

312 """ 

313 Register a callback to be called when the connection is established either 

314 initially or reconnected. This allows listeners to issue commands that 

315 are ephemeral to the connection, for example pub/sub subscription or 

316 key tracking. The callback must be a _method_ and will be kept as 

317 a weak reference. 

318 """ 

319 wm = weakref.WeakMethod(callback) 

320 if wm not in self._connect_callbacks: 

321 self._connect_callbacks.append(wm) 

322 

323 def deregister_connect_callback(self, callback): 

324 """ 

325 De-register a previously registered callback. It will no-longer receive 

326 notifications on connection events. Calling this is not required when the 

327 listener goes away, since the callbacks are kept as weak methods. 

328 """ 

329 try: 

330 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

331 except ValueError: 

332 pass 

333 

334 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

335 """ 

336 Creates a new instance of parser_class with socket size: 

337 _socket_read_size and assigns it to the parser for the connection 

338 :param parser_class: The required parser class 

339 """ 

340 self._parser = parser_class(socket_read_size=self._socket_read_size) 

341 

342 async def connect(self): 

343 """Connects to the Redis server if not already connected""" 

344 # try once the socket connect with the handshake, retry the whole 

345 # connect/handshake flow based on retry policy 

346 await self.retry.call_with_retry( 

347 lambda: self.connect_check_health( 

348 check_health=True, retry_socket_connect=False 

349 ), 

350 lambda error, failure_count: self.disconnect( 

351 error=error, failure_count=failure_count 

352 ), 

353 with_failure_count=True, 

354 ) 

355 

356 async def connect_check_health( 

357 self, check_health: bool = True, retry_socket_connect: bool = True 

358 ): 

359 if self.is_connected: 

360 return 

361 # Track actual retry attempts for error reporting 

362 actual_retry_attempts = 0 

363 

364 def failure_callback(error, failure_count): 

365 nonlocal actual_retry_attempts 

366 actual_retry_attempts = failure_count 

367 return self.disconnect(error=error, failure_count=failure_count) 

368 

369 try: 

370 if retry_socket_connect: 

371 await self.retry.call_with_retry( 

372 lambda: self._connect(), 

373 failure_callback, 

374 with_failure_count=True, 

375 ) 

376 else: 

377 await self._connect() 

378 except asyncio.CancelledError: 

379 raise # in 3.7 and earlier, this is an Exception, not BaseException 

380 except (socket.timeout, asyncio.TimeoutError): 

381 e = TimeoutError("Timeout connecting to server") 

382 await record_error_count( 

383 server_address=getattr(self, "host", None), 

384 server_port=getattr(self, "port", None), 

385 network_peer_address=getattr(self, "host", None), 

386 network_peer_port=getattr(self, "port", None), 

387 error_type=e, 

388 retry_attempts=actual_retry_attempts, 

389 is_internal=False, 

390 ) 

391 raise e 

392 except OSError as e: 

393 e = ConnectionError(self._error_message(e)) 

394 await record_error_count( 

395 server_address=getattr(self, "host", None), 

396 server_port=getattr(self, "port", None), 

397 network_peer_address=getattr(self, "host", None), 

398 network_peer_port=getattr(self, "port", None), 

399 error_type=e, 

400 retry_attempts=actual_retry_attempts, 

401 is_internal=False, 

402 ) 

403 raise e 

404 except Exception as exc: 

405 raise ConnectionError(exc) from exc 

406 

407 try: 

408 if not self.redis_connect_func: 

409 # Use the default on_connect function 

410 await self.on_connect_check_health(check_health=check_health) 

411 else: 

412 # Use the passed function redis_connect_func 

413 ( 

414 await self.redis_connect_func(self) 

415 if asyncio.iscoroutinefunction(self.redis_connect_func) 

416 else self.redis_connect_func(self) 

417 ) 

418 except RedisError: 

419 # clean up after any error in on_connect 

420 await self.disconnect() 

421 raise 

422 

423 # run any user callbacks. right now the only internal callback 

424 # is for pubsub channel/pattern resubscription 

425 # first, remove any dead weakrefs 

426 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

427 for ref in self._connect_callbacks: 

428 callback = ref() 

429 task = callback(self) 

430 if task and inspect.isawaitable(task): 

431 await task 

432 

433 def mark_for_reconnect(self): 

434 self._should_reconnect = True 

435 

436 def should_reconnect(self): 

437 return self._should_reconnect 

438 

439 def reset_should_reconnect(self): 

440 self._should_reconnect = False 

441 

442 @abstractmethod 

443 async def _connect(self): 

444 pass 

445 

446 @abstractmethod 

447 def _host_error(self) -> str: 

448 pass 

449 

450 def _error_message(self, exception: BaseException) -> str: 

451 return format_error_message(self._host_error(), exception) 

452 

453 def get_protocol(self): 

454 return self.protocol 

455 

456 async def on_connect(self) -> None: 

457 """Initialize the connection, authenticate and select a database""" 

458 await self.on_connect_check_health(check_health=True) 

459 

460 async def on_connect_check_health(self, check_health: bool = True) -> None: 

461 self._parser.on_connect(self) 

462 parser = self._parser 

463 

464 auth_args = None 

465 # if credential provider or username and/or password are set, authenticate 

466 if self.credential_provider or (self.username or self.password): 

467 cred_provider = ( 

468 self.credential_provider 

469 or UsernamePasswordCredentialProvider(self.username, self.password) 

470 ) 

471 auth_args = await cred_provider.get_credentials_async() 

472 

473 # if resp version is specified and we have auth args, 

474 # we need to send them via HELLO 

475 if auth_args and self.protocol not in [2, "2"]: 

476 if isinstance(self._parser, _AsyncRESP2Parser): 

477 self.set_parser(_AsyncRESP3Parser) 

478 # update cluster exception classes 

479 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

480 self._parser.on_connect(self) 

481 if len(auth_args) == 1: 

482 auth_args = ["default", auth_args[0]] 

483 # avoid checking health here -- PING will fail if we try 

484 # to check the health prior to the AUTH 

485 await self.send_command( 

486 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

487 ) 

488 response = await self.read_response() 

489 if response.get(b"proto") != int(self.protocol) and response.get( 

490 "proto" 

491 ) != int(self.protocol): 

492 raise ConnectionError("Invalid RESP version") 

493 # avoid checking health here -- PING will fail if we try 

494 # to check the health prior to the AUTH 

495 elif auth_args: 

496 await self.send_command("AUTH", *auth_args, check_health=False) 

497 

498 try: 

499 auth_response = await self.read_response() 

500 except AuthenticationWrongNumberOfArgsError: 

501 # a username and password were specified but the Redis 

502 # server seems to be < 6.0.0 which expects a single password 

503 # arg. retry auth with just the password. 

504 # https://github.com/andymccurdy/redis-py/issues/1274 

505 await self.send_command("AUTH", auth_args[-1], check_health=False) 

506 auth_response = await self.read_response() 

507 

508 if str_if_bytes(auth_response) != "OK": 

509 raise AuthenticationError("Invalid Username or Password") 

510 

511 # if resp version is specified, switch to it 

512 elif self.protocol not in [2, "2"]: 

513 if isinstance(self._parser, _AsyncRESP2Parser): 

514 self.set_parser(_AsyncRESP3Parser) 

515 # update cluster exception classes 

516 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

517 self._parser.on_connect(self) 

518 await self.send_command("HELLO", self.protocol, check_health=check_health) 

519 response = await self.read_response() 

520 # if response.get(b"proto") != self.protocol and response.get( 

521 # "proto" 

522 # ) != self.protocol: 

523 # raise ConnectionError("Invalid RESP version") 

524 

525 # if a client_name is given, set it 

526 if self.client_name: 

527 await self.send_command( 

528 "CLIENT", 

529 "SETNAME", 

530 self.client_name, 

531 check_health=check_health, 

532 ) 

533 if str_if_bytes(await self.read_response()) != "OK": 

534 raise ConnectionError("Error setting client name") 

535 

536 # Set the library name and version from driver_info, pipeline for lower startup latency 

537 lib_name_sent = False 

538 lib_version_sent = False 

539 

540 if self.driver_info and self.driver_info.formatted_name: 

541 await self.send_command( 

542 "CLIENT", 

543 "SETINFO", 

544 "LIB-NAME", 

545 self.driver_info.formatted_name, 

546 check_health=check_health, 

547 ) 

548 lib_name_sent = True 

549 

550 if self.driver_info and self.driver_info.lib_version: 

551 await self.send_command( 

552 "CLIENT", 

553 "SETINFO", 

554 "LIB-VER", 

555 self.driver_info.lib_version, 

556 check_health=check_health, 

557 ) 

558 lib_version_sent = True 

559 

560 # if a database is specified, switch to it. Also pipeline this 

561 if self.db: 

562 await self.send_command("SELECT", self.db, check_health=check_health) 

563 

564 # read responses from pipeline 

565 for _ in range(sum([lib_name_sent, lib_version_sent])): 

566 try: 

567 await self.read_response() 

568 except ResponseError: 

569 pass 

570 

571 if self.db: 

572 if str_if_bytes(await self.read_response()) != "OK": 

573 raise ConnectionError("Invalid Database") 

574 

575 async def disconnect( 

576 self, 

577 nowait: bool = False, 

578 error: Optional[Exception] = None, 

579 failure_count: Optional[int] = None, 

580 health_check_failed: bool = False, 

581 ) -> None: 

582 """Disconnects from the Redis server""" 

583 # On Python 3.13+, asyncio.timeout() raises RuntimeError when called 

584 # outside a running Task (e.g. during GC finalization or event-loop 

585 # callbacks). In that context we fall back to a synchronous close. 

586 # See https://github.com/redis/redis-py/issues/3856 

587 if asyncio.current_task() is None: 

588 self._parser.on_disconnect() 

589 self.reset_should_reconnect() 

590 self._close() 

591 return 

592 

593 try: 

594 async with async_timeout(self.socket_connect_timeout): 

595 self._parser.on_disconnect() 

596 # Reset the reconnect flag 

597 self.reset_should_reconnect() 

598 if not self.is_connected: 

599 return 

600 try: 

601 self._writer.close() # type: ignore[union-attr] 

602 # wait for close to finish, except when handling errors and 

603 # forcefully disconnecting. 

604 if not nowait: 

605 await self._writer.wait_closed() # type: ignore[union-attr] 

606 except OSError: 

607 pass 

608 finally: 

609 self._reader = None 

610 self._writer = None 

611 except asyncio.TimeoutError: 

612 raise TimeoutError( 

613 f"Timed out closing connection after {self.socket_connect_timeout}" 

614 ) from None 

615 

616 if error: 

617 if health_check_failed: 

618 close_reason = CloseReason.HEALTHCHECK_FAILED 

619 else: 

620 close_reason = CloseReason.ERROR 

621 

622 if failure_count is not None and failure_count > self.retry.get_retries(): 

623 await record_error_count( 

624 server_address=getattr(self, "host", None), 

625 server_port=getattr(self, "port", None), 

626 network_peer_address=getattr(self, "host", None), 

627 network_peer_port=getattr(self, "port", None), 

628 error_type=error, 

629 retry_attempts=failure_count, 

630 ) 

631 

632 await record_connection_closed( 

633 close_reason=close_reason, 

634 error_type=error, 

635 ) 

636 else: 

637 await record_connection_closed( 

638 close_reason=CloseReason.APPLICATION_CLOSE, 

639 ) 

640 

641 async def _send_ping(self): 

642 """Send PING, expect PONG in return""" 

643 await self.send_command("PING", check_health=False) 

644 if str_if_bytes(await self.read_response()) != "PONG": 

645 raise ConnectionError("Bad response from PING health check") 

646 

647 async def _ping_failed(self, error, failure_count): 

648 """Function to call when PING fails""" 

649 await self.disconnect( 

650 error=error, failure_count=failure_count, health_check_failed=True 

651 ) 

652 

653 async def check_health(self): 

654 """Check the health of the connection with a PING/PONG""" 

655 if ( 

656 self.health_check_interval 

657 and asyncio.get_running_loop().time() > self.next_health_check 

658 ): 

659 await self.retry.call_with_retry( 

660 self._send_ping, self._ping_failed, with_failure_count=True 

661 ) 

662 

663 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

664 self._writer.writelines(command) 

665 await self._writer.drain() 

666 

667 async def send_packed_command( 

668 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

669 ) -> None: 

670 if not self.is_connected: 

671 await self.connect_check_health(check_health=False) 

672 if check_health: 

673 await self.check_health() 

674 

675 try: 

676 if isinstance(command, str): 

677 command = command.encode() 

678 if isinstance(command, bytes): 

679 command = [command] 

680 if self.socket_timeout: 

681 await asyncio.wait_for( 

682 self._send_packed_command(command), self.socket_timeout 

683 ) 

684 else: 

685 self._writer.writelines(command) 

686 await self._writer.drain() 

687 except asyncio.TimeoutError: 

688 await self.disconnect(nowait=True) 

689 raise TimeoutError("Timeout writing to socket") from None 

690 except OSError as e: 

691 await self.disconnect(nowait=True) 

692 if len(e.args) == 1: 

693 err_no, errmsg = "UNKNOWN", e.args[0] 

694 else: 

695 err_no = e.args[0] 

696 errmsg = e.args[1] 

697 raise ConnectionError( 

698 f"Error {err_no} while writing to socket. {errmsg}." 

699 ) from e 

700 except BaseException: 

701 # BaseExceptions can be raised when a socket send operation is not 

702 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

703 # to send un-sent data. However, the send_packed_command() API 

704 # does not support it so there is no point in keeping the connection open. 

705 await self.disconnect(nowait=True) 

706 raise 

707 

708 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

709 """Pack and send a command to the Redis server""" 

710 await self.send_packed_command( 

711 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

712 ) 

713 

714 async def can_read_destructive(self): 

715 """Poll the socket to see if there's data that can be read.""" 

716 try: 

717 return await self._parser.can_read_destructive() 

718 except OSError as e: 

719 await self.disconnect(nowait=True) 

720 host_error = self._host_error() 

721 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

722 

723 async def read_response( 

724 self, 

725 disable_decoding: bool = False, 

726 timeout: Optional[float] = None, 

727 *, 

728 disconnect_on_error: bool = True, 

729 push_request: Optional[bool] = False, 

730 ): 

731 """Read the response from a previously sent command""" 

732 read_timeout = timeout if timeout is not None else self.socket_timeout 

733 host_error = self._host_error() 

734 try: 

735 if read_timeout is not None and self.protocol in ["3", 3]: 

736 async with async_timeout(read_timeout): 

737 response = await self._parser.read_response( 

738 disable_decoding=disable_decoding, push_request=push_request 

739 ) 

740 elif read_timeout is not None: 

741 async with async_timeout(read_timeout): 

742 response = await self._parser.read_response( 

743 disable_decoding=disable_decoding 

744 ) 

745 elif self.protocol in ["3", 3]: 

746 response = await self._parser.read_response( 

747 disable_decoding=disable_decoding, push_request=push_request 

748 ) 

749 else: 

750 response = await self._parser.read_response( 

751 disable_decoding=disable_decoding 

752 ) 

753 except asyncio.TimeoutError: 

754 if timeout is not None: 

755 # user requested timeout, return None. Operation can be retried 

756 return None 

757 # it was a self.socket_timeout error. 

758 if disconnect_on_error: 

759 await self.disconnect(nowait=True) 

760 raise TimeoutError(f"Timeout reading from {host_error}") 

761 except OSError as e: 

762 if disconnect_on_error: 

763 await self.disconnect(nowait=True) 

764 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

765 except BaseException: 

766 # Also by default close in case of BaseException. A lot of code 

767 # relies on this behaviour when doing Command/Response pairs. 

768 # See #1128. 

769 if disconnect_on_error: 

770 await self.disconnect(nowait=True) 

771 raise 

772 

773 if self.health_check_interval: 

774 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

775 self.next_health_check = next_time 

776 

777 if isinstance(response, ResponseError): 

778 raise response from None 

779 return response 

780 

781 def pack_command(self, *args: EncodableT) -> List[bytes]: 

782 """Pack a series of arguments into the Redis protocol""" 

783 output = [] 

784 # the client might have included 1 or more literal arguments in 

785 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

786 # arguments to be sent separately, so split the first argument 

787 # manually. These arguments should be bytestrings so that they are 

788 # not encoded. 

789 assert not isinstance(args[0], float) 

790 if isinstance(args[0], str): 

791 args = tuple(args[0].encode().split()) + args[1:] 

792 elif b" " in args[0]: 

793 args = tuple(args[0].split()) + args[1:] 

794 

795 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

796 

797 buffer_cutoff = self._buffer_cutoff 

798 for arg in map(self.encoder.encode, args): 

799 # to avoid large string mallocs, chunk the command into the 

800 # output list if we're sending large values or memoryviews 

801 arg_length = len(arg) 

802 if ( 

803 len(buff) > buffer_cutoff 

804 or arg_length > buffer_cutoff 

805 or isinstance(arg, memoryview) 

806 ): 

807 buff = SYM_EMPTY.join( 

808 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

809 ) 

810 output.append(buff) 

811 output.append(arg) 

812 buff = SYM_CRLF 

813 else: 

814 buff = SYM_EMPTY.join( 

815 ( 

816 buff, 

817 SYM_DOLLAR, 

818 str(arg_length).encode(), 

819 SYM_CRLF, 

820 arg, 

821 SYM_CRLF, 

822 ) 

823 ) 

824 output.append(buff) 

825 return output 

826 

827 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

828 """Pack multiple commands into the Redis protocol""" 

829 output: List[bytes] = [] 

830 pieces: List[bytes] = [] 

831 buffer_length = 0 

832 buffer_cutoff = self._buffer_cutoff 

833 

834 for cmd in commands: 

835 for chunk in self.pack_command(*cmd): 

836 chunklen = len(chunk) 

837 if ( 

838 buffer_length > buffer_cutoff 

839 or chunklen > buffer_cutoff 

840 or isinstance(chunk, memoryview) 

841 ): 

842 if pieces: 

843 output.append(SYM_EMPTY.join(pieces)) 

844 buffer_length = 0 

845 pieces = [] 

846 

847 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

848 output.append(chunk) 

849 else: 

850 pieces.append(chunk) 

851 buffer_length += chunklen 

852 

853 if pieces: 

854 output.append(SYM_EMPTY.join(pieces)) 

855 return output 

856 

857 def _socket_is_empty(self): 

858 """Check if the socket is empty""" 

859 return len(self._reader._buffer) == 0 

860 

861 async def process_invalidation_messages(self): 

862 while not self._socket_is_empty(): 

863 await self.read_response(push_request=True) 

864 

865 def set_re_auth_token(self, token: TokenInterface): 

866 self._re_auth_token = token 

867 

868 async def re_auth(self): 

869 if self._re_auth_token is not None: 

870 await self.send_command( 

871 "AUTH", 

872 self._re_auth_token.try_get("oid"), 

873 self._re_auth_token.get_value(), 

874 ) 

875 await self.read_response() 

876 self._re_auth_token = None 

877 

878 

879class Connection(AbstractConnection): 

880 "Manages TCP communication to and from a Redis server" 

881 

882 def __init__( 

883 self, 

884 *, 

885 host: str = "localhost", 

886 port: Union[str, int] = 6379, 

887 socket_keepalive: bool = False, 

888 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

889 socket_type: int = 0, 

890 **kwargs, 

891 ): 

892 self.host = host 

893 self.port = int(port) 

894 self.socket_keepalive = socket_keepalive 

895 self.socket_keepalive_options = socket_keepalive_options or {} 

896 self.socket_type = socket_type 

897 super().__init__(**kwargs) 

898 

899 def repr_pieces(self): 

900 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

901 if self.client_name: 

902 pieces.append(("client_name", self.client_name)) 

903 return pieces 

904 

905 def _connection_arguments(self) -> Mapping: 

906 return {"host": self.host, "port": self.port} 

907 

908 async def _connect(self): 

909 """Create a TCP socket connection""" 

910 async with async_timeout(self.socket_connect_timeout): 

911 reader, writer = await asyncio.open_connection( 

912 **self._connection_arguments() 

913 ) 

914 self._reader = reader 

915 self._writer = writer 

916 sock = writer.transport.get_extra_info("socket") 

917 if sock: 

918 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

919 try: 

920 # TCP_KEEPALIVE 

921 if self.socket_keepalive: 

922 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

923 for k, v in self.socket_keepalive_options.items(): 

924 sock.setsockopt(socket.SOL_TCP, k, v) 

925 

926 except (OSError, TypeError): 

927 # `socket_keepalive_options` might contain invalid options 

928 # causing an error. Do not leave the connection open. 

929 writer.close() 

930 raise 

931 

932 def _host_error(self) -> str: 

933 return f"{self.host}:{self.port}" 

934 

935 

936class SSLConnection(Connection): 

937 """Manages SSL connections to and from the Redis server(s). 

938 This class extends the Connection class, adding SSL functionality, and making 

939 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

940 """ 

941 

942 def __init__( 

943 self, 

944 ssl_keyfile: Optional[str] = None, 

945 ssl_certfile: Optional[str] = None, 

946 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

947 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

948 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

949 ssl_ca_certs: Optional[str] = None, 

950 ssl_ca_data: Optional[str] = None, 

951 ssl_ca_path: Optional[str] = None, 

952 ssl_check_hostname: bool = True, 

953 ssl_min_version: Optional[TLSVersion] = None, 

954 ssl_ciphers: Optional[str] = None, 

955 ssl_password: Optional[str] = None, 

956 **kwargs, 

957 ): 

958 if not SSL_AVAILABLE: 

959 raise RedisError("Python wasn't built with SSL support") 

960 

961 self.ssl_context: RedisSSLContext = RedisSSLContext( 

962 keyfile=ssl_keyfile, 

963 certfile=ssl_certfile, 

964 cert_reqs=ssl_cert_reqs, 

965 include_verify_flags=ssl_include_verify_flags, 

966 exclude_verify_flags=ssl_exclude_verify_flags, 

967 ca_certs=ssl_ca_certs, 

968 ca_data=ssl_ca_data, 

969 ca_path=ssl_ca_path, 

970 check_hostname=ssl_check_hostname, 

971 min_version=ssl_min_version, 

972 ciphers=ssl_ciphers, 

973 password=ssl_password, 

974 ) 

975 super().__init__(**kwargs) 

976 

977 def _connection_arguments(self) -> Mapping: 

978 kwargs = super()._connection_arguments() 

979 kwargs["ssl"] = self.ssl_context.get() 

980 return kwargs 

981 

982 @property 

983 def keyfile(self): 

984 return self.ssl_context.keyfile 

985 

986 @property 

987 def certfile(self): 

988 return self.ssl_context.certfile 

989 

990 @property 

991 def cert_reqs(self): 

992 return self.ssl_context.cert_reqs 

993 

994 @property 

995 def include_verify_flags(self): 

996 return self.ssl_context.include_verify_flags 

997 

998 @property 

999 def exclude_verify_flags(self): 

1000 return self.ssl_context.exclude_verify_flags 

1001 

1002 @property 

1003 def ca_certs(self): 

1004 return self.ssl_context.ca_certs 

1005 

1006 @property 

1007 def ca_data(self): 

1008 return self.ssl_context.ca_data 

1009 

1010 @property 

1011 def check_hostname(self): 

1012 return self.ssl_context.check_hostname 

1013 

1014 @property 

1015 def min_version(self): 

1016 return self.ssl_context.min_version 

1017 

1018 

1019class RedisSSLContext: 

1020 __slots__ = ( 

1021 "keyfile", 

1022 "certfile", 

1023 "cert_reqs", 

1024 "include_verify_flags", 

1025 "exclude_verify_flags", 

1026 "ca_certs", 

1027 "ca_data", 

1028 "ca_path", 

1029 "context", 

1030 "check_hostname", 

1031 "min_version", 

1032 "ciphers", 

1033 "password", 

1034 ) 

1035 

1036 def __init__( 

1037 self, 

1038 keyfile: Optional[str] = None, 

1039 certfile: Optional[str] = None, 

1040 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

1041 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1042 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1043 ca_certs: Optional[str] = None, 

1044 ca_data: Optional[str] = None, 

1045 ca_path: Optional[str] = None, 

1046 check_hostname: bool = False, 

1047 min_version: Optional[TLSVersion] = None, 

1048 ciphers: Optional[str] = None, 

1049 password: Optional[str] = None, 

1050 ): 

1051 if not SSL_AVAILABLE: 

1052 raise RedisError("Python wasn't built with SSL support") 

1053 

1054 self.keyfile = keyfile 

1055 self.certfile = certfile 

1056 if cert_reqs is None: 

1057 cert_reqs = ssl.CERT_NONE 

1058 elif isinstance(cert_reqs, str): 

1059 CERT_REQS = { # noqa: N806 

1060 "none": ssl.CERT_NONE, 

1061 "optional": ssl.CERT_OPTIONAL, 

1062 "required": ssl.CERT_REQUIRED, 

1063 } 

1064 if cert_reqs not in CERT_REQS: 

1065 raise RedisError( 

1066 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

1067 ) 

1068 cert_reqs = CERT_REQS[cert_reqs] 

1069 self.cert_reqs = cert_reqs 

1070 self.include_verify_flags = include_verify_flags 

1071 self.exclude_verify_flags = exclude_verify_flags 

1072 self.ca_certs = ca_certs 

1073 self.ca_data = ca_data 

1074 self.ca_path = ca_path 

1075 self.check_hostname = ( 

1076 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1077 ) 

1078 self.min_version = min_version 

1079 self.ciphers = ciphers 

1080 self.password = password 

1081 self.context: Optional[SSLContext] = None 

1082 

1083 def get(self) -> SSLContext: 

1084 if not self.context: 

1085 context = ssl.create_default_context() 

1086 context.check_hostname = self.check_hostname 

1087 context.verify_mode = self.cert_reqs 

1088 if self.include_verify_flags: 

1089 for flag in self.include_verify_flags: 

1090 context.verify_flags |= flag 

1091 if self.exclude_verify_flags: 

1092 for flag in self.exclude_verify_flags: 

1093 context.verify_flags &= ~flag 

1094 if self.certfile or self.keyfile: 

1095 context.load_cert_chain( 

1096 certfile=self.certfile, 

1097 keyfile=self.keyfile, 

1098 password=self.password, 

1099 ) 

1100 if self.ca_certs or self.ca_data or self.ca_path: 

1101 context.load_verify_locations( 

1102 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1103 ) 

1104 if self.min_version is not None: 

1105 context.minimum_version = self.min_version 

1106 if self.ciphers is not None: 

1107 context.set_ciphers(self.ciphers) 

1108 self.context = context 

1109 return self.context 

1110 

1111 

1112class UnixDomainSocketConnection(AbstractConnection): 

1113 "Manages UDS communication to and from a Redis server" 

1114 

1115 def __init__(self, *, path: str = "", **kwargs): 

1116 self.path = path 

1117 super().__init__(**kwargs) 

1118 

1119 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

1120 pieces = [("path", self.path), ("db", self.db)] 

1121 if self.client_name: 

1122 pieces.append(("client_name", self.client_name)) 

1123 return pieces 

1124 

1125 async def _connect(self): 

1126 async with async_timeout(self.socket_connect_timeout): 

1127 reader, writer = await asyncio.open_unix_connection(path=self.path) 

1128 self._reader = reader 

1129 self._writer = writer 

1130 await self.on_connect() 

1131 

1132 def _host_error(self) -> str: 

1133 return self.path 

1134 

1135 

1136FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1137 

1138 

1139def to_bool(value) -> Optional[bool]: 

1140 if value is None or value == "": 

1141 return None 

1142 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1143 return False 

1144 return bool(value) 

1145 

1146 

1147def parse_ssl_verify_flags(value): 

1148 # flags are passed in as a string representation of a list, 

1149 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1150 verify_flags_str = value.replace("[", "").replace("]", "") 

1151 

1152 verify_flags = [] 

1153 for flag in verify_flags_str.split(","): 

1154 flag = flag.strip() 

1155 if not hasattr(VerifyFlags, flag): 

1156 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1157 verify_flags.append(getattr(VerifyFlags, flag)) 

1158 return verify_flags 

1159 

1160 

1161URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

1162 { 

1163 "db": int, 

1164 "socket_timeout": float, 

1165 "socket_connect_timeout": float, 

1166 "socket_keepalive": to_bool, 

1167 "retry_on_timeout": to_bool, 

1168 "max_connections": int, 

1169 "health_check_interval": int, 

1170 "ssl_check_hostname": to_bool, 

1171 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1172 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1173 "timeout": float, 

1174 } 

1175) 

1176 

1177 

1178class ConnectKwargs(TypedDict, total=False): 

1179 username: str 

1180 password: str 

1181 connection_class: Type[AbstractConnection] 

1182 host: str 

1183 port: int 

1184 db: int 

1185 path: str 

1186 

1187 

1188def parse_url(url: str) -> ConnectKwargs: 

1189 parsed: ParseResult = urlparse(url) 

1190 kwargs: ConnectKwargs = {} 

1191 

1192 for name, value_list in parse_qs(parsed.query).items(): 

1193 if value_list and len(value_list) > 0: 

1194 value = unquote(value_list[0]) 

1195 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1196 if parser: 

1197 try: 

1198 kwargs[name] = parser(value) 

1199 except (TypeError, ValueError): 

1200 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1201 else: 

1202 kwargs[name] = value 

1203 

1204 if parsed.username: 

1205 kwargs["username"] = unquote(parsed.username) 

1206 if parsed.password: 

1207 kwargs["password"] = unquote(parsed.password) 

1208 

1209 # We only support redis://, rediss:// and unix:// schemes. 

1210 if parsed.scheme == "unix": 

1211 if parsed.path: 

1212 kwargs["path"] = unquote(parsed.path) 

1213 kwargs["connection_class"] = UnixDomainSocketConnection 

1214 

1215 elif parsed.scheme in ("redis", "rediss"): 

1216 if parsed.hostname: 

1217 kwargs["host"] = unquote(parsed.hostname) 

1218 if parsed.port: 

1219 kwargs["port"] = int(parsed.port) 

1220 

1221 # If there's a path argument, use it as the db argument if a 

1222 # querystring value wasn't specified 

1223 if parsed.path and "db" not in kwargs: 

1224 try: 

1225 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1226 except (AttributeError, ValueError): 

1227 pass 

1228 

1229 if parsed.scheme == "rediss": 

1230 kwargs["connection_class"] = SSLConnection 

1231 

1232 else: 

1233 valid_schemes = "redis://, rediss://, unix://" 

1234 raise ValueError( 

1235 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1236 ) 

1237 

1238 return kwargs 

1239 

1240 

1241_CP = TypeVar("_CP", bound="ConnectionPool") 

1242 

1243 

1244class ConnectionPoolInterface(ABC): 

1245 @abstractmethod 

1246 def get_protocol(self): 

1247 pass 

1248 

1249 @abstractmethod 

1250 def reset(self) -> None: 

1251 pass 

1252 

1253 @abstractmethod 

1254 @deprecated_args( 

1255 args_to_warn=["*"], 

1256 reason="Use get_connection() without args instead", 

1257 version="5.3.0", 

1258 ) 

1259 async def get_connection( 

1260 self, command_name: Optional[str] = None, *keys: Any, **options: Any 

1261 ) -> "AbstractConnection": 

1262 pass 

1263 

1264 @abstractmethod 

1265 def get_encoder(self) -> "Encoder": 

1266 pass 

1267 

1268 @abstractmethod 

1269 async def release(self, connection: "AbstractConnection") -> None: 

1270 pass 

1271 

1272 @abstractmethod 

1273 async def disconnect(self, inuse_connections: bool = True) -> None: 

1274 pass 

1275 

1276 @abstractmethod 

1277 async def aclose(self) -> None: 

1278 pass 

1279 

1280 @abstractmethod 

1281 def set_retry(self, retry: "Retry") -> None: 

1282 pass 

1283 

1284 @abstractmethod 

1285 async def re_auth_callback(self, token: TokenInterface) -> None: 

1286 pass 

1287 

1288 @abstractmethod 

1289 def get_connection_count(self) -> List[Tuple[int, dict]]: 

1290 """ 

1291 Returns a connection count (both idle and in use). 

1292 """ 

1293 pass 

1294 

1295 

1296class ConnectionPool(ConnectionPoolInterface): 

1297 """ 

1298 Create a connection pool. ``If max_connections`` is set, then this 

1299 object raises :py:class:`~redis.ConnectionError` when the pool's 

1300 limit is reached. 

1301 

1302 By default, TCP connections are created unless ``connection_class`` 

1303 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1304 unix sockets. 

1305 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1306 

1307 Any additional keyword arguments are passed to the constructor of 

1308 ``connection_class``. 

1309 """ 

1310 

1311 @classmethod 

1312 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1313 """ 

1314 Return a connection pool configured from the given URL. 

1315 

1316 For example:: 

1317 

1318 redis://[[username]:[password]]@localhost:6379/0 

1319 rediss://[[username]:[password]]@localhost:6379/0 

1320 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1321 

1322 Three URL schemes are supported: 

1323 

1324 - `redis://` creates a TCP socket connection. See more at: 

1325 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1326 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1327 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1328 - ``unix://``: creates a Unix Domain Socket connection. 

1329 

1330 The username, password, hostname, path and all querystring values 

1331 are passed through urllib.parse.unquote in order to replace any 

1332 percent-encoded values with their corresponding characters. 

1333 

1334 There are several ways to specify a database number. The first value 

1335 found will be used: 

1336 

1337 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1338 

1339 2. If using the redis:// or rediss:// schemes, the path argument 

1340 of the url, e.g. redis://localhost/0 

1341 

1342 3. A ``db`` keyword argument to this function. 

1343 

1344 If none of these options are specified, the default db=0 is used. 

1345 

1346 All querystring options are cast to their appropriate Python types. 

1347 Boolean arguments can be specified with string values "True"/"False" 

1348 or "Yes"/"No". Values that cannot be properly cast cause a 

1349 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1350 and keyword arguments are passed to the ``ConnectionPool``'s 

1351 class initializer. In the case of conflicting arguments, querystring 

1352 arguments always win. 

1353 """ 

1354 url_options = parse_url(url) 

1355 kwargs.update(url_options) 

1356 return cls(**kwargs) 

1357 

1358 def __init__( 

1359 self, 

1360 connection_class: Type[AbstractConnection] = Connection, 

1361 max_connections: Optional[int] = None, 

1362 **connection_kwargs, 

1363 ): 

1364 max_connections = max_connections or 2**31 

1365 if not isinstance(max_connections, int) or max_connections < 0: 

1366 raise ValueError('"max_connections" must be a positive integer') 

1367 

1368 self.connection_class = connection_class 

1369 self.connection_kwargs = connection_kwargs 

1370 self.max_connections = max_connections 

1371 

1372 self._available_connections: List[AbstractConnection] = [] 

1373 self._in_use_connections: Set[AbstractConnection] = set() 

1374 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1375 self._lock = asyncio.Lock() 

1376 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1377 if self._event_dispatcher is None: 

1378 self._event_dispatcher = EventDispatcher() 

1379 

1380 # Keys that should be redacted in __repr__ to avoid exposing sensitive information 

1381 SENSITIVE_REPR_KEYS = frozenset( 

1382 { 

1383 "password", 

1384 "username", 

1385 "ssl_password", 

1386 "credential_provider", 

1387 } 

1388 ) 

1389 

1390 def __repr__(self): 

1391 conn_kwargs = ",".join( 

1392 [ 

1393 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}" 

1394 for k, v in self.connection_kwargs.items() 

1395 ] 

1396 ) 

1397 return ( 

1398 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1399 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1400 f"({conn_kwargs})>)>" 

1401 ) 

1402 

1403 def get_protocol(self): 

1404 """ 

1405 Returns: 

1406 The RESP protocol version, or ``None`` if the protocol is not specified, 

1407 in which case the server default will be used. 

1408 """ 

1409 return self.connection_kwargs.get("protocol", None) 

1410 

1411 def reset(self): 

1412 # Record metrics for connections being removed before clearing 

1413 # (only if attributes exist - they won't during __init__) 

1414 if hasattr(self, "_available_connections") and hasattr( 

1415 self, "_in_use_connections" 

1416 ): 

1417 idle_count = len(self._available_connections) 

1418 in_use_count = len(self._in_use_connections) 

1419 if idle_count > 0 or in_use_count > 0: 

1420 pool_name = get_pool_name(self) 

1421 # Note: Using sync version since reset() is sync 

1422 from redis.observability.recorder import ( 

1423 record_connection_count as sync_record_connection_count, 

1424 ) 

1425 

1426 if idle_count > 0: 

1427 sync_record_connection_count( 

1428 pool_name=pool_name, 

1429 connection_state=ConnectionState.IDLE, 

1430 counter=-idle_count, 

1431 ) 

1432 if in_use_count > 0: 

1433 sync_record_connection_count( 

1434 pool_name=pool_name, 

1435 connection_state=ConnectionState.USED, 

1436 counter=-in_use_count, 

1437 ) 

1438 

1439 self._available_connections = [] 

1440 self._in_use_connections = weakref.WeakSet() 

1441 

1442 def __del__(self) -> None: 

1443 """Clean up connection pool and record metrics when garbage collected.""" 

1444 try: 

1445 if not hasattr(self, "_available_connections") or not hasattr( 

1446 self, "_in_use_connections" 

1447 ): 

1448 return 

1449 idle_count = len(self._available_connections) 

1450 in_use_count = len(self._in_use_connections) 

1451 if idle_count > 0 or in_use_count > 0: 

1452 pool_name = get_pool_name(self) 

1453 # Note: Using sync version since __del__ is sync 

1454 from redis.observability.recorder import ( 

1455 record_connection_count as sync_record_connection_count, 

1456 ) 

1457 

1458 if idle_count > 0: 

1459 sync_record_connection_count( 

1460 pool_name=pool_name, 

1461 connection_state=ConnectionState.IDLE, 

1462 counter=-idle_count, 

1463 ) 

1464 if in_use_count > 0: 

1465 sync_record_connection_count( 

1466 pool_name=pool_name, 

1467 connection_state=ConnectionState.USED, 

1468 counter=-in_use_count, 

1469 ) 

1470 except Exception: 

1471 pass 

1472 

1473 def can_get_connection(self) -> bool: 

1474 """Return True if a connection can be retrieved from the pool.""" 

1475 return ( 

1476 self._available_connections 

1477 or len(self._in_use_connections) < self.max_connections 

1478 ) 

1479 

1480 @deprecated_args( 

1481 args_to_warn=["*"], 

1482 reason="Use get_connection() without args instead", 

1483 version="5.3.0", 

1484 ) 

1485 async def get_connection(self, command_name=None, *keys, **options): 

1486 """Get a connected connection from the pool""" 

1487 # Track connection count before to detect if a new connection is created 

1488 async with self._lock: 

1489 connections_before = len(self._available_connections) + len( 

1490 self._in_use_connections 

1491 ) 

1492 start_time_created = time.monotonic() 

1493 connection = self.get_available_connection() 

1494 connections_after = len(self._available_connections) + len( 

1495 self._in_use_connections 

1496 ) 

1497 is_created = connections_after > connections_before 

1498 

1499 # Record state transition for observability 

1500 # This ensures counters stay balanced if ensure_connection() fails and release() is called 

1501 pool_name = get_pool_name(self) 

1502 if is_created: 

1503 # New connection created and acquired: just USED +1 

1504 await record_connection_count( 

1505 pool_name=pool_name, 

1506 connection_state=ConnectionState.USED, 

1507 counter=1, 

1508 ) 

1509 else: 

1510 # Existing connection acquired from pool: IDLE -> USED 

1511 await record_connection_count( 

1512 pool_name=pool_name, 

1513 connection_state=ConnectionState.IDLE, 

1514 counter=-1, 

1515 ) 

1516 await record_connection_count( 

1517 pool_name=pool_name, 

1518 connection_state=ConnectionState.USED, 

1519 counter=1, 

1520 ) 

1521 

1522 # We now perform the connection check outside of the lock. 

1523 try: 

1524 await self.ensure_connection(connection) 

1525 

1526 if is_created: 

1527 await record_connection_create_time( 

1528 connection_pool=self, 

1529 duration_seconds=time.monotonic() - start_time_created, 

1530 ) 

1531 

1532 return connection 

1533 except BaseException: 

1534 await self.release(connection) 

1535 raise 

1536 

1537 def get_available_connection(self): 

1538 """Get a connection from the pool, without making sure it is connected""" 

1539 try: 

1540 connection = self._available_connections.pop() 

1541 except IndexError: 

1542 if len(self._in_use_connections) >= self.max_connections: 

1543 raise MaxConnectionsError("Too many connections") from None 

1544 connection = self.make_connection() 

1545 self._in_use_connections.add(connection) 

1546 return connection 

1547 

1548 def get_encoder(self): 

1549 """Return an encoder based on encoding settings""" 

1550 kwargs = self.connection_kwargs 

1551 return self.encoder_class( 

1552 encoding=kwargs.get("encoding", "utf-8"), 

1553 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1554 decode_responses=kwargs.get("decode_responses", False), 

1555 ) 

1556 

1557 def make_connection(self): 

1558 """Create a new connection. Can be overridden by child classes.""" 

1559 # Note: We don't record IDLE here because async uses a sync make_connection 

1560 # but async record_connection_count. The recording is handled in get_connection. 

1561 return self.connection_class(**self.connection_kwargs) 

1562 

1563 async def ensure_connection(self, connection: AbstractConnection): 

1564 """Ensure that the connection object is connected and valid""" 

1565 await connection.connect() 

1566 # connections that the pool provides should be ready to send 

1567 # a command. if not, the connection was either returned to the 

1568 # pool before all data has been read or the socket has been 

1569 # closed. either way, reconnect and verify everything is good. 

1570 try: 

1571 if await connection.can_read_destructive(): 

1572 raise ConnectionError("Connection has data") from None 

1573 except (ConnectionError, TimeoutError, OSError): 

1574 await connection.disconnect() 

1575 await connection.connect() 

1576 if await connection.can_read_destructive(): 

1577 raise ConnectionError("Connection not ready") from None 

1578 

1579 async def release(self, connection: AbstractConnection): 

1580 """Releases the connection back to the pool""" 

1581 # Connections should always be returned to the correct pool, 

1582 # not doing so is an error that will cause an exception here. 

1583 self._in_use_connections.remove(connection) 

1584 

1585 if connection.should_reconnect(): 

1586 await connection.disconnect() 

1587 

1588 self._available_connections.append(connection) 

1589 await self._event_dispatcher.dispatch_async( 

1590 AsyncAfterConnectionReleasedEvent(connection) 

1591 ) 

1592 

1593 # Record state transition: USED -> IDLE 

1594 pool_name = get_pool_name(self) 

1595 await record_connection_count( 

1596 pool_name=pool_name, 

1597 connection_state=ConnectionState.USED, 

1598 counter=-1, 

1599 ) 

1600 await record_connection_count( 

1601 pool_name=pool_name, 

1602 connection_state=ConnectionState.IDLE, 

1603 counter=1, 

1604 ) 

1605 

1606 async def disconnect(self, inuse_connections: bool = True): 

1607 """ 

1608 Disconnects connections in the pool 

1609 

1610 If ``inuse_connections`` is True, disconnect connections that are 

1611 current in use, potentially by other tasks. Otherwise only disconnect 

1612 connections that are idle in the pool. 

1613 """ 

1614 if inuse_connections: 

1615 connections: Iterable[AbstractConnection] = chain( 

1616 self._available_connections, self._in_use_connections 

1617 ) 

1618 else: 

1619 connections = self._available_connections 

1620 resp = await asyncio.gather( 

1621 *(connection.disconnect() for connection in connections), 

1622 return_exceptions=True, 

1623 ) 

1624 

1625 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1626 if exc: 

1627 raise exc 

1628 

1629 async def update_active_connections_for_reconnect(self): 

1630 """ 

1631 Mark all active connections for reconnect. 

1632 """ 

1633 async with self._lock: 

1634 for conn in self._in_use_connections: 

1635 conn.mark_for_reconnect() 

1636 

1637 async def aclose(self) -> None: 

1638 """Close the pool, disconnecting all connections""" 

1639 await self.disconnect() 

1640 

1641 def set_retry(self, retry: "Retry") -> None: 

1642 for conn in self._available_connections: 

1643 conn.retry = retry 

1644 for conn in self._in_use_connections: 

1645 conn.retry = retry 

1646 

1647 async def re_auth_callback(self, token: TokenInterface): 

1648 async with self._lock: 

1649 for conn in self._available_connections: 

1650 await conn.retry.call_with_retry( 

1651 lambda: conn.send_command( 

1652 "AUTH", token.try_get("oid"), token.get_value() 

1653 ), 

1654 lambda error: self._mock(error), 

1655 ) 

1656 await conn.retry.call_with_retry( 

1657 lambda: conn.read_response(), lambda error: self._mock(error) 

1658 ) 

1659 for conn in self._in_use_connections: 

1660 conn.set_re_auth_token(token) 

1661 

1662 async def _mock(self, error: RedisError): 

1663 """ 

1664 Dummy functions, needs to be passed as error callback to retry object. 

1665 :param error: 

1666 :return: 

1667 """ 

1668 pass 

1669 

1670 def get_connection_count(self) -> List[tuple[int, dict]]: 

1671 """ 

1672 Returns a connection count (both idle and in use). 

1673 """ 

1674 attributes = AttributeBuilder.build_base_attributes() 

1675 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self) 

1676 free_connections_attributes = attributes.copy() 

1677 in_use_connections_attributes = attributes.copy() 

1678 

1679 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1680 ConnectionState.IDLE.value 

1681 ) 

1682 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1683 ConnectionState.USED.value 

1684 ) 

1685 

1686 return [ 

1687 (len(self._available_connections), free_connections_attributes), 

1688 (len(self._in_use_connections), in_use_connections_attributes), 

1689 ] 

1690 

1691 

1692class BlockingConnectionPool(ConnectionPool): 

1693 """ 

1694 A blocking connection pool:: 

1695 

1696 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1697 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1698 

1699 It performs the same function as the default 

1700 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1701 it maintains a pool of reusable connections that can be shared by 

1702 multiple async redis clients. 

1703 

1704 The difference is that, in the event that a client tries to get a 

1705 connection from the pool when all of connections are in use, rather than 

1706 raising a :py:class:`~redis.ConnectionError` (as the default 

1707 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1708 blocks the current `Task` for a specified number of seconds until 

1709 a connection becomes available. 

1710 

1711 Use ``max_connections`` to increase / decrease the pool size:: 

1712 

1713 >>> pool = BlockingConnectionPool(max_connections=10) 

1714 

1715 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1716 to become available, or to block forever: 

1717 

1718 >>> # Block forever. 

1719 >>> pool = BlockingConnectionPool(timeout=None) 

1720 

1721 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1722 >>> # not available. 

1723 >>> pool = BlockingConnectionPool(timeout=5) 

1724 """ 

1725 

1726 def __init__( 

1727 self, 

1728 max_connections: int = 50, 

1729 timeout: Optional[float] = 20, 

1730 connection_class: Type[AbstractConnection] = Connection, 

1731 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1732 **connection_kwargs, 

1733 ): 

1734 super().__init__( 

1735 connection_class=connection_class, 

1736 max_connections=max_connections, 

1737 **connection_kwargs, 

1738 ) 

1739 self._condition = asyncio.Condition() 

1740 self.timeout = timeout 

1741 

1742 @deprecated_args( 

1743 args_to_warn=["*"], 

1744 reason="Use get_connection() without args instead", 

1745 version="5.3.0", 

1746 ) 

1747 async def get_connection(self, command_name=None, *keys, **options): 

1748 """Gets a connection from the pool, blocking until one is available""" 

1749 # Start timing for wait time observability 

1750 start_time_acquired = time.monotonic() 

1751 

1752 try: 

1753 async with self._condition: 

1754 async with async_timeout(self.timeout): 

1755 await self._condition.wait_for(self.can_get_connection) 

1756 # Track connection count before to detect if a new connection is created 

1757 connections_before = len(self._available_connections) + len( 

1758 self._in_use_connections 

1759 ) 

1760 start_time_created = time.monotonic() 

1761 connection = super().get_available_connection() 

1762 connections_after = len(self._available_connections) + len( 

1763 self._in_use_connections 

1764 ) 

1765 is_created = connections_after > connections_before 

1766 except asyncio.TimeoutError as err: 

1767 raise ConnectionError("No connection available.") from err 

1768 

1769 # We now perform the connection check outside of the lock. 

1770 try: 

1771 await self.ensure_connection(connection) 

1772 

1773 if is_created: 

1774 await record_connection_create_time( 

1775 connection_pool=self, 

1776 duration_seconds=time.monotonic() - start_time_created, 

1777 ) 

1778 

1779 await record_connection_wait_time( 

1780 pool_name=get_pool_name(self), 

1781 duration_seconds=time.monotonic() - start_time_acquired, 

1782 ) 

1783 

1784 return connection 

1785 except BaseException: 

1786 await self.release(connection) 

1787 raise 

1788 

1789 async def release(self, connection: AbstractConnection): 

1790 """Releases the connection back to the pool.""" 

1791 async with self._condition: 

1792 await super().release(connection) 

1793 self._condition.notify()